Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
compiler.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_COMPILER_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_COMPILER_H_
17
18#include <stddef.h>
19#include <stdint.h>
20
21#include <algorithm>
22#include <memory>
23#include <vector>
24
25#include "algebra/hash.h"
26#include "circuits/compiler/circuit_id.h"
27#include "circuits/compiler/node.h"
28#include "circuits/compiler/pdqhash.h"
29#include "circuits/compiler/schedule.h"
30#include "sumcheck/circuit.h"
31#include "sumcheck/quad.h"
32#include "util/panic.h"
33
34namespace proofs {
35/*
36QuadCircuit contains methods that facilitate defining circuits used to
37express predicates that are to be proven or verified. This class allows one
38to use basic arithmetic circuit operations (add, mul, input, assert0, ...)
39to define the circuit on a set of abstract wire labels.
40
41The "mkcircuit" compiler method than optimizes the circuit by applying all of
42the basic tricks of constant propagation, common sub-expression elimination,
43squashing layers into as few as possible, and grouping terms into quads.
44
45Quads are a new form of gate (in contrast to the add and mul gates in most
46sumcheck proof systems). Quads represent a "sum of quadratic terms" where
47each term is w_l * w_r * v for two wire labels and a constant v.
48*/
49template <class Field>
50class QuadCircuit {
51 public:
52 using Elt = typename Field::Elt;
53 using nodeinfo = NodeInfoF<Field>;
54 using node = NodeF<Field>;
55 using size_t_for_storage = term::size_t_for_storage;
56 using quad_corner_t = typename Quad<Field>::quad_corner_t;
57
58 const Field& f_;
59
60 public:
61 // Variables for informational purposes:
62 size_t ninput_;
63 size_t npub_input_; // number of public inputs, index of 1st private
64 size_t subfield_boundary_; // least wire not known to be in the subfield
65 size_t noutput_;
66
67 // set by the algebraic simplifiers in this file
68 size_t depth_;
69 size_t nwires_cse_eliminated_;
70 size_t nwires_not_needed_;
71
72 // set by the scheduler
73 size_t nwires_;
74 size_t nquad_terms_;
75 size_t nwires_overhead_;
76
77 explicit QuadCircuit(const Field& f)
78 : f_(f),
79 ninput_(0),
80 npub_input_(0),
81 subfield_boundary_(0),
82 noutput_(0),
83 depth_(0),
84 nwires_cse_eliminated_(0),
85 nwires_not_needed_(0),
86 nwires_(-1), // undefined until set in mkcircuit()
87 nquad_terms_(-1),
88 nwires_overhead_(-1) {
89 // make sure that Elt(0) is represented as index 0 in the constant
90 // table.
91 size_t ki0 = kstore(f.zero());
92 proofs::check(ki0 == 0, "ki0 == 0");
93 size_t ki1 = kstore(f.one());
94 proofs::check(ki1 == 1, "ki1 == 1");
95
96 // make sure node 0 exists, carrying input[0] = F.one()
97 input();
98 }
99
100 // Produce a linear term 1 * op0 that the compiler will not
101 // attempt to optimize to op0. The reason for this function
102 // is to implement linear terms such as a*x in the quadratic form
103 // a*x+b*x*y. Left to its own devices, the compiler peeks into x,
104 // and if x=k*z*w, it produces a term (a*k)*z*w in the previous
105 // layer, possibly destroying common subexpressions. linear(op)
106 // introduces an explicit multiplication by wire 0, which the
107 // compiler does not attempt to optimize away.
108 size_t linear(size_t op0) { return mul(0, op0); }
109 size_t linear(const Elt& k, size_t op0) { return mul(k, 0, op0); }
110
111 size_t mul(const Elt& k, size_t op) {
112 if (k == f_.zero()) {
113 return konst(k);
114 } else if (k == f_.one() || nodes_[op].zero()) {
115 return op;
116 } else {
117 return push_node(scale(k, op));
118 }
119 }
120
121 size_t mul(size_t op0, size_t op1) { return mul(f_.one(), op0, op1); }
122
123 size_t mul(const Elt& k, size_t op0, size_t op1) {
124 const auto& n0 = nodes_[op0];
125 const auto& n1 = nodes_[op1];
126
127 if (n0.zero()) {
128 return op0;
129 } else if (n0.constant()) {
130 // k * (k1 * op1) -> (k * k1) * op1
131 return mul(f_.mulf(k, kload(n0.terms[0].ki)), op1);
132 } else if (n0.linearp()) {
133 // k * ((k1 * op0) * op1) -> (k * k1) * op0 * op1
134 return mul(f_.mulf(k, kload(n0.terms[0].ki)), n0.terms[0].op1, op1);
135 } else if (n1.zero() || n1.constant() || n1.linearp()) {
136 return mul(k, op1, op0);
137 } else {
138 // general term k * op0 * op1
139 return push_node(node(kstore(k), op0, op1));
140 }
141 }
142
143 size_t add(size_t op0, size_t op1) {
144 const auto& n0 = nodes_[op0];
145 const auto& n1 = nodes_[op1];
146
147 if (n0.zero()) {
148 return op1;
149 } else if (n1.zero()) {
150 return op0;
151 } else {
152 // If the two addends are of different depth, do not merge
153 // them, which is accomplished by multiplying the shallower
154 // node by 1 and treating it as a single term of the final
155 // sum.
156 //
157 // Like many other "optimizations", this is a heuristic
158 // that may or may not work, but it seems to be uniformly
159 // beneficial or at least not harmful for all our circuits
160 // as of 2023-11-15.
161 if (n0.info.depth < n1.info.depth) {
162 op0 = linear(op0);
163 } else if (n1.info.depth < n0.info.depth) {
164 op1 = linear(op1);
165 }
166 return push_node(merge(op0, op1));
167 }
168 }
169 size_t sub(size_t op0, size_t op1) { return add(op0, mul(f_.mone(), op1)); }
170
171 size_t konst(const Elt& k) { return push_node(node(kstore(k), 0, 0)); }
172
173 // Generate a special node that asserts that op == 0.
174 // The node has the form 0*(1*op), which does not normally
175 // appear in circuits.
176 size_t assert0(size_t op) {
177 const node* n = &nodes_[op];
178 if (n->zero()) {
179 // Identically zero, so nothing to generate.
180 // More importantly, we cannot multiply OP by 1,
181 // since OP doesn't really exist.
182 return op;
183 } else if (n->linearp()) {
184 // n = k * (1 * op1).
185 //
186 // Reduce to assert0(op1), but handle the screw case k==0,
187 // which shouldn't happen but just in case...
188 if (n->terms[0].ki == 0) {
189 return op;
190 } else {
191 return assert0(n->terms[0].op1);
192 }
193 } else {
194 typename term::assert0_type_hack hack;
195 std::vector<term> terms;
196 terms.push_back(term(op, hack));
197 size_t n1 = push_node(node(terms));
198 nodes_[n1].info.is_assert0 = true;
199 return n1;
200 }
201 }
202
203 // Wrappers to avoid creating unnecessary wires. The
204 // compiler will discard them anyway, but they still take
205 // time and space.
206 size_t axpy(size_t y, const Elt& a, size_t x) {
207 if (a == f_.zero()) {
208 return y;
209 }
210 return add(y, linear(a, x));
211 }
212 size_t apy(size_t y, const Elt& a) {
213 if (a == f_.zero()) {
214 return y;
215 }
216 return add(y, konst(a));
217 }
218
219 size_t input() { return push_node(node(quad_corner_t(ninput_++))); }
220
221 // This function demarcates the end of the public inputs and beginning of
222 // private inputs. It can only be called once.
223 void private_input() {
224 proofs::check(
225 npub_input_ == 0,
226 "private_input can only be called once after setting public inputs");
227 npub_input_ = ninput_;
228 }
229
230 // This function demarcates the end of the private inputs in the
231 // subfield and beginning of the full-field private inputs. It can
232 // only be called once.
233 void begin_full_field() {
234 proofs::check(subfield_boundary_ == 0,
235 "begin_full_field() can only be called once");
236 subfield_boundary_ = ninput_;
237 }
238
239 size_t ninput() const { return ninput_; }
240
241 void output(size_t n, size_t wire_id) {
242 output_internal(n, quad_corner_t(wire_id));
243 }
244
245 std::unique_ptr<Circuit<Field>> mkcircuit(size_t nc) {
246 size_t depth_ub = compute_depth_ub();
247 fixup_last_layer_assertions(depth_ub);
248 compute_needed(depth_ub);
249
250 Scheduler<Field> sched(nodes_, f_);
251 std::unique_ptr<Circuit<Field>> c =
252 sched.mkcircuit(constants_, depth_ub, nc);
253
254 // re-export the scheduler telemetry
255 nwires_ = sched.nwires_;
256 nquad_terms_ = sched.nquad_terms_;
257 nwires_overhead_ = sched.nwires_overhead_;
258
259 c->ninputs = ninput();
260 c->npub_in = npub_input_;
261 c->subfield_boundary = subfield_boundary_;
262
263 circuit_id(c->id, *c, f_);
264 return c;
265 }
266
267 private:
268 void output_internal(size_t n, quad_corner_t wire_id) {
269 nodes_[n].info.is_output = true;
270 nodes_[n].info.desired_wire_id_for_output = wire_id;
271 noutput_++;
272 }
273
274 size_t push_node(node n) {
275 // common-subexpression elimination: if we have already seen a
276 // node equal to n, return that node.
277 uint64_t d = n.hash();
278
279 auto pred = [&](PdqHash::value_t op) { return n == nodes_[op]; };
280 if (size_t op = cse_.find(d, pred); op != PdqHash::kNil) {
281 // do not linear terms as eliminated by the CSE, since they are
282 // likely placeholder nodes absorbed by the next layer.
283 if (!n.linearp()) {
284 ++nwires_cse_eliminated_;
285 }
286 return op;
287 }
288
289 // compute the node depth, which has been so far uninitialized
290 n.info.depth = 0;
291 for (const auto& t : n.terms) {
292 n.info.depth = std::max<size_t>(
293 n.info.depth, 1 + std::max<size_t>(nodes_[t.op0].info.depth,
294 nodes_[t.op1].info.depth));
295 }
296
297 size_t nid = nodes_.size();
298 nodes_.push_back(n);
299
300 // record NID into the common-subexpression elimination table
301 cse_.insert(d, nid);
302
303 return nid;
304 }
305
306 node materialize_input(size_t op) {
307 if (nodes_[op].info.is_input) {
308 return node(/*kstore(f.one())=*/1, 0, op);
309 } else {
310 return /*a copy of*/ nodes_[op];
311 }
312 }
313
314 node scale(const Elt& k, size_t op) {
315 node n = materialize_input(op);
316 for (auto& t : n.terms) {
317 t.ki = kstore(f_.mulf(kload(t.ki), k));
318 }
319 return n;
320 }
321
322 void push_back_unless_zero(std::vector<term>& terms, const term& t) const {
323 if (t.ki != 0) {
324 terms.push_back(t);
325 }
326 }
327
328 node merge(size_t op0, size_t op1) {
329 const node n0 = materialize_input(op0);
330 const node n1 = materialize_input(op1);
331 const std::vector<term>& t0 = n0.terms;
332 const std::vector<term>& t1 = n1.terms;
333 std::vector<term> terms;
334 size_t i0 = 0, i1 = 0;
335 while (i0 < t0.size() && i1 < t1.size()) {
336 term t;
337 if (t0[i0].eqndx(t1[i1])) {
338 t = t0[i0];
339 t.ki = kstore(f_.addf(kload(t.ki), kload(t1[i1].ki)));
340 i0++;
341 i1++;
342 } else if (t0[i0].ltndx(t1[i1])) {
343 t = t0[i0++];
344 } else {
345 t = t1[i1++];
346 }
347 push_back_unless_zero(terms, t);
348 }
349
350 while (i0 < t0.size()) {
351 push_back_unless_zero(terms, t0[i0++]);
352 }
353
354 while (i1 < t1.size()) {
355 push_back_unless_zero(terms, t1[i1++]);
356 }
357
358 return node(terms);
359 }
360
361 // constants_[n] stores the n-th constant, once.
362 // Modulo collisions, constants_[constttab_[hash(k)]] == k
363 // for k \in Elt.
364 std::vector<Elt> constants_;
365 PdqHash consttab_;
366
367 std::vector<node> nodes_;
368 PdqHash cse_;
369
370 size_t kstore(const Elt& k) {
371 uint64_t d = elt_hash(k, f_);
372 auto pred = [&](PdqHash::value_t ki) { return k == constants_[ki]; };
373 size_t ki = consttab_.find(d, pred);
374
375 if (ki == PdqHash::kNil) {
376 ki = constants_.size();
377 constants_.push_back(k);
378 consttab_.insert(d, ki);
379 }
380 return ki;
381 }
382 Elt& kload(size_t ki) { return constants_[ki]; }
383
384 void mark_needed(size_t op, size_t depth_at_which_needed) {
385 nodeinfo* nfo = &nodes_[op].info;
386 nfo->is_needed = true;
387 nfo->max_needed_depth =
388 std::max<size_t>(depth_at_which_needed, nfo->max_needed_depth);
389
390 // If DEPTH_AT_WHICH_NEEDED > DEPTH + 1, we need a constant 1 at
391 // depth DEPTH_AT_WHICH_NEEDED-1 (and implicily any lower depths) in
392 // order to copy the node across levels.
393 if (depth_at_which_needed > nfo->depth + 1) {
394 nodeinfo* nfo0 = &nodes_[0].info;
395 nfo0->is_needed = true;
396 nfo0->max_needed_depth =
397 std::max<size_t>(depth_at_which_needed - 1, nfo0->max_needed_depth);
398 }
399 }
400
401 size_t compute_depth_ub() {
402 size_t r = 0;
403 for (auto& n : nodes_) {
404 if (n.info.is_output) {
405 r = std::max<size_t>(r, 1 + n.info.depth);
406 } else if (n.info.is_assert0) {
407 // Assertions of the form 0*(1*OP) contibute n.info.depth and
408 // not 1 + n.info.depth. If the assertion is in the last
409 // layer, it will be transformed in an output of OP at
410 // n.info.depth. If the assertion is not in the last layer,
411 // then it doesn't matter whether we use DEPTH or 1 + DEPTH.
412 if (n.linearp()) {
413 r = std::max<size_t>(r, n.info.depth);
414 } else {
415 r = std::max<size_t>(r, 1 + n.info.depth);
416 }
417 }
418 }
419 depth_ = r;
420 return r;
421 }
422
423 void fixup_last_layer_assertions(size_t depth_ub) {
424 // convert assertions in the last layer into outputs
425 for (auto& n : nodes_) {
426 if (!n.info.is_output && n.info.is_assert0 && n.info.depth == depth_ub &&
427 n.linearp()) {
428 n.info.is_assert0 = false;
429 output_internal(n.terms[0].op1, nodeinfo::kWireIdUndefined);
430 }
431 }
432 }
433
434 void compute_needed(size_t depth_ub) {
435 nwires_not_needed_ = 0;
436 for (size_t i = nodes_.size(); i-- > 0;) {
437 nodeinfo* nfo = &nodes_[i].info;
438
439 // mark all inputs as needed, to prevent ambiguity
440 // in the layout of the W[] vector.
441 if (nfo->is_input) {
442 mark_needed(i, 1);
443 }
444 // outputs are needed at depth_ub_
445 if (nfo->is_output) {
446 mark_needed(i, depth_ub);
447 }
448 // assertions are needed in the next layer
449 if (nfo->is_assert0) {
450 mark_needed(i, nfo->depth + 1);
451 }
452
453 if (nfo->is_needed) {
454 for (const auto& t : nodes_[i].terms) {
455 mark_needed(t.op0, nfo->depth);
456 mark_needed(t.op1, nfo->depth);
457 }
458 } else {
459 ++nwires_not_needed_;
460 }
461 }
462 }
463};
464
465} // namespace proofs
466
467#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_COMPILER_H_
Definition pdqhash.h:37
Definition quad.h:37
Definition schedule.h:35
Definition gf2_128.h:63
Definition node.h:115
Definition node.h:76
Definition node.h:30