Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
schedule.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_SCHEDULE_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_SCHEDULE_H_
17
18#include <stddef.h>
19#include <stdint.h>
20
21#include <algorithm>
22#include <memory>
23#include <vector>
24
25#include "algebra/compare.h"
26#include "arrays/affine.h"
27#include "circuits/compiler/node.h"
28#include "sumcheck/circuit.h"
29#include "sumcheck/quad.h"
30#include "util/ceildiv.h"
31#include "util/panic.h"
32
33namespace proofs {
34template <class Field>
35class Scheduler {
36 using Elt = typename Field::Elt;
37 using nodeinfo = NodeInfoF<Field>;
38 using node = NodeF<Field>;
39 using size_t_for_storage = term::size_t_for_storage;
40 using quad_corner_t = typename Quad<Field>::quad_corner_t;
41
42 const Field& f_;
43 const std::vector<node>& nodes_;
44
45 public:
46 size_t nwires_;
47 size_t nquad_terms_;
48 size_t nwires_overhead_;
49
50 Scheduler(const std::vector<node>& nodes, const Field& f)
51 : f_(f),
52 nodes_(nodes),
53 nwires_(0),
54 nquad_terms_(0),
55 nwires_overhead_(0) {}
56
57 std::unique_ptr<Circuit<Field>> mkcircuit(const std::vector<Elt>& constants,
58 size_t depth_ub, size_t nc) {
59 std::unique_ptr<Circuit<Field>> c = std::make_unique<Circuit<Field>>();
60
61 // number of layers and copies
62 c->nl = depth_ub - 1; // depth 0 = input nodes, not a "layer"
63 c->nc = nc;
64 c->logc = lg(nc);
65
66 auto lnodes = order_by_layer(constants, depth_ub);
67
68 // TODO [matteof 2025-03-12] ASSIGN_WIRE_IDS() renames LNODES in
69 // order to sort it and assign LNODES[].DESIRED_WIRE ID. Then it
70 // throws away the renamed LNODES. Then FILL_LAYERS() renames
71 // LNODES again in order to produce the final quad. It would be
72 // better to produce the quad directly in ASSIGN_WIRE_IDS(). Punt
73 // for now, this is just a performance optimization of the
74 // compiler anyway.
75 //
76 assign_wire_ids(lnodes);
77 fill_layers(c.get(), depth_ub, lnodes);
78
79 return c;
80 }
81
82 private:
83 // per-layer representation of nodes and terms
84 struct lterm {
85 Elt k;
86 quad_corner_t lop0, lop1;
87 };
88 struct lnode {
89 quad_corner_t desired_wire_id;
90
91 // Copy wires are forced to be distinct from wires in the
92 // original dag, in order to avoid ambiguity in renaming.
93 //
94 // Copy wires are always of the form 1*op, which doesn't
95 // normally appear in the dag because the algebraic simplifier
96 // reduces it to op. However, one can in theory create such
97 // a node by judicious use of linear(). Rather than
98 // trying to figure out which circuits one is not allowed
99 // to write, it seems simpler to just handle this case
100 // uniformly.
101 bool is_copy_wire;
102
103 std::vector<lterm> lterms;
104
105 lnode(quad_corner_t desired_wire_id, bool is_copy_wire,
106 const std::vector<lterm>& lterms)
107 : desired_wire_id(desired_wire_id),
108 is_copy_wire(is_copy_wire),
109 lterms(lterms) {}
110 };
111
112 quad_corner_t lop_of_op_at_depth(
113 const std::vector<std::vector<quad_corner_t>>& lop, size_t op,
114 size_t d) const {
115 const node& n = nodes_.at(op);
116 return lop.at(op).at(d - n.info.depth);
117 }
118
119 // Convert the DAG of nodes into a layered dag of lnodes.
120 std::vector<std::vector<lnode>> order_by_layer(
121 const std::vector<Elt>& constants, size_t depth_ub) {
122 // The source DAG is indexed by NODES_[OP].
123 // The destination dag uses a two-dimensional indexing
124 // scheme LNODES[D][LOP], where D is the depth.
125
126 // A single value NODES_[OP] may be replicated multiple times in
127 // LNODES. The mapping is maintained in array LOPS such that
128 // LOPS[OP][D - D0] contains the LOP index of node OP at depth D.
129 // D0 is the depth at which NODES_[OP] is first computed, and
130 // there is no point in storing LOPS[OP] for D < D0.
131
132 std::vector<std::vector<lnode>> lnodes(depth_ub);
133 std::vector<std::vector<quad_corner_t>> lops(nodes_.size());
134
135 nwires_overhead_ = 0;
136
137 for (size_t op = 0; op < nodes_.size(); ++op) {
138 const auto& n = nodes_[op];
139 const nodeinfo& nfo = n.info;
140 if (nfo.is_needed && !n.zero()) {
141 size_t d = nfo.depth;
142
143 // Allocate the LOP at depth D
144 quad_corner_t lop = quad_corner_t(lnodes.at(d).size());
145 lops.at(op).push_back(lop);
146
147 // create a LOPS entry for depth D
148 /*scope*/ {
149 std::vector<lterm> lterms;
150 for (const auto& t : n.terms) {
151 lterm lt = {
152 .k = constants.at(t.ki),
153 .lop0 = lop_of_op_at_depth(lops, t.op0, d - 1),
154 .lop1 = lop_of_op_at_depth(lops, t.op1, d - 1),
155 };
156 lterms.push_back(lt);
157 }
158 lnodes.at(d).push_back(lnode(nfo.desired_wire_id(d, depth_ub),
159 /*is_copy_wire=*/false, lterms));
160 }
161
162 // create copy wires
163 for (d = nfo.depth + 1; d < nfo.max_needed_depth; ++d) {
164 quad_corner_t lop_dm1 = lop;
165
166 // allocate the LOP at depth D
167 lop = quad_corner_t(lnodes.at(d).size());
168 lops.at(op).push_back(lop);
169
170 std::vector<lterm> lterms;
171
172 // Insert a multiplication by one of the layer
173 // at the previous layer.
174 lterm lt = {
175 .k = f_.one(),
176 .lop0 = quad_corner_t(0),
177 .lop1 = lop_dm1,
178 };
179 lterms.push_back(lt);
180 lnodes.at(d).push_back(lnode(nfo.desired_wire_id(d, depth_ub),
181 /*is_copy_wire=*/true, lterms));
182 ++nwires_overhead_;
183 } // for copy wires
184 } // if needed
185 } // for OP
186
187 return lnodes;
188 }
189
190 //------------------------------------------------------------
191 // canonical assignment of wire ids
192 //------------------------------------------------------------
193 //
194 // The canonicalization order is a matter of convention.
195 // We make some arbitrary choices that appear to interact
196 // better with ZSTD compression. The label [ARBITRARY CHOICE]
197 // denotes all places in the code where this occurs.
198 //
199 class renamed_lterm {
200 public:
201 Elt k_;
202 quad_corner_t rlop0_, rlop1_;
203
204 // [ARBITRARY CHOICE] Consistent with corner::canonicalize() in
205 // sumcheck/quad.h
206 renamed_lterm(const Elt& k, quad_corner_t rlop0, quad_corner_t rlop1)
207 : k_(k),
208 rlop0_(std::min<quad_corner_t>(rlop0, rlop1)),
209 rlop1_(std::max<quad_corner_t>(rlop0, rlop1)) {}
210
211 bool operator==(const renamed_lterm& y) const {
212 return rlop0_ == y.rlop0_ && rlop1_ == y.rlop1_ && k_ == y.k_;
213 }
214
215 // canonical order
216 static bool compare(const renamed_lterm& a, const renamed_lterm& b,
217 const Field& F) {
218 if (a.rlop0_ < b.rlop0_) return true;
219 if (a.rlop0_ > b.rlop0_) return false;
220 if (a.rlop1_ < b.rlop1_) return true;
221 if (a.rlop1_ > b.rlop1_) return false;
222 return elt_less_than(a.k_, b.k_, F);
223 }
224 };
225
226 class renamed_lnode {
227 public:
228 quad_corner_t desired_wire_id_;
229 quad_corner_t original_wire_index_;
230 bool is_copy_wire_;
231 std::vector<renamed_lterm> rlterms_;
232
233 renamed_lnode(quad_corner_t desired_wire_id,
234 quad_corner_t original_wire_index, bool is_copy_wire,
235 const std::vector<renamed_lterm>& rlterms)
236 : desired_wire_id_(desired_wire_id),
237 original_wire_index_(original_wire_index),
238 is_copy_wire_(is_copy_wire),
239 rlterms_(rlterms) {}
240
241 bool operator==(const renamed_lnode& y) const {
242 if (is_copy_wire_ != y.is_copy_wire_) return false;
243 if (rlterms_.size() != y.rlterms_.size()) return false;
244 size_t l = rlterms_.size();
245 for (size_t i = 0; i < l; ++i) {
246 if (!(rlterms_[i] == y.rlterms_[i])) return false;
247 }
248 return true;
249 }
250
251 // canonical order
252 static bool compare(const renamed_lnode& ra, const renamed_lnode& rb,
253 const Field& F) {
254 // Defined before undefined. This choice is mandated by the
255 // fact that the range of defined wire id's starts at 0.
256 if (ra.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
257 if (rb.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
258 return ra.desired_wire_id_ < rb.desired_wire_id_;
259 } else {
260 return true;
261 }
262 } else {
263 if (rb.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
264 return false;
265 }
266 // else both undefined
267 }
268
269 // [ARBITRARY CHOICE] Lexicographic order on the reverse of the
270 // terms array. This seems to compress much better than
271 // the normal lexicographic order.
272 for (size_t ia = ra.rlterms_.size(), ib = rb.rlterms_.size();
273 ia-- > 0 && ib-- > 0;) {
274 const renamed_lterm& rlta = ra.rlterms_[ia];
275 const renamed_lterm& rltb = rb.rlterms_[ib];
276 if (renamed_lterm::compare(rlta, rltb, F)) return true;
277 if (renamed_lterm::compare(rltb, rlta, F)) return false;
278 }
279
280 // [ARBITRARY CHOICE] If the common suffixes are the same, the
281 // shorter terms come first.
282 if (ra.rlterms_.size() < rb.rlterms_.size()) return true;
283 if (ra.rlterms_.size() > rb.rlterms_.size()) return false;
284
285 // Nodes that were in the original dag come first.
286 if (!ra.is_copy_wire_ && rb.is_copy_wire_) return true;
287 if (!rb.is_copy_wire_ && ra.is_copy_wire_) return false;
288
289 // equal, i.e., not less-than
290 return false;
291 }
292 };
293
294 template <class T>
295 bool uniq(const std::vector<T>& sorted) {
296 for (size_t i = 0; i + 1 < sorted.size(); ++i) {
297 if (sorted[i] == sorted[i + 1]) return false;
298 }
299 return true;
300 }
301
302 void assign_wire_ids(std::vector<std::vector<lnode>>& lnodes) {
303 // all inputs are expected to be defined already
304 assert_all_desired_wire_id_defined(lnodes.at(0));
305
306 for (size_t d = 1; d < lnodes.size(); ++d) {
307 const std::vector<lnode>& lnodes_at_dm1 = lnodes.at(d - 1);
308 const std::vector<lnode>& lnodes_at_d = lnodes.at(d);
309
310 // Create a renamed clone of LNODES_AT_D, in which all
311 // the LOP's are mapped to their desired wire id's
312 // at the previous layer. We use different types
313 // to avoid any possibility of confusion.
314 std::vector<renamed_lnode> renamed_at_d;
315
316 quad_corner_t original_wire_index(0);
317 for (const lnode& ln : lnodes_at_d) {
318 std::vector<renamed_lterm> rlterms;
319
320 // rename all terms
321 rlterms.reserve(ln.lterms.size());
322 for (const lterm& lt : ln.lterms) {
323 rlterms.push_back(renamed_lterm(
324 lt.k,
325 lnodes_at_dm1.at(static_cast<size_t>(lt.lop0)).desired_wire_id,
326 lnodes_at_dm1.at(static_cast<size_t>(lt.lop1)).desired_wire_id));
327 }
328
329 // canonicalize the terms order
330 std::sort(rlterms.begin(), rlterms.end(),
331 [&](const renamed_lterm& a, const renamed_lterm& b) {
332 return renamed_lterm::compare(a, b, f_);
333 });
334
335 // Terms must be unique, otherwise the canonicalization is
336 // ill-defined. Uniqueness is guaranteed by the algebraic
337 // simplifier, but assert it for good measure.
338 check(uniq(rlterms), "rlterms not unique");
339
340 renamed_at_d.push_back(renamed_lnode(
341 ln.desired_wire_id, original_wire_index, ln.is_copy_wire, rlterms));
342 ++original_wire_index;
343 }
344
345 check(renamed_at_d.size() == lnodes_at_d.size(),
346 "renamed_at_d.size() == lnodes_at_d.size()");
347
348 std::sort(renamed_at_d.begin(), renamed_at_d.end(),
349 [&](const renamed_lnode& a, const renamed_lnode& b) {
350 return renamed_lnode::compare(a, b, f_);
351 });
352
353 // Nodes must be unique, otherwise the canonicalization is
354 // ill-defined.
355 check(uniq(renamed_at_d), "renamed_at_d not unique");
356
357 quad_corner_t wid(0);
358 std::vector<lnode>& wlnodes_at_d = lnodes.at(d);
359
360 for (const renamed_lnode& ln : renamed_at_d) {
361 lnode& lnpi =
362 wlnodes_at_d.at(static_cast<size_t>(ln.original_wire_index_));
363 if (lnpi.desired_wire_id != nodeinfo::kWireIdUndefined) {
364 // We must have computed the same wire id
365 check(wid == lnpi.desired_wire_id, "wid == lnpi.desired_wire_id");
366 } else {
367 lnpi.desired_wire_id = wid;
368 }
369 wid++;
370 }
371 }
372 }
373
374 void assert_all_desired_wire_id_defined(const std::vector<lnode>& layer) {
375 for (const auto& ln : layer) {
376 check(ln.desired_wire_id != nodeinfo::kWireIdUndefined,
377 "ln.desired_wire_id != kWireIdUndefined");
378 }
379 }
380
381 void fill_layers(Circuit<Field>* c, size_t depth_ub,
382 const std::vector<std::vector<lnode>>& lnodes) {
383 check(depth_ub == lnodes.size(), "depth_ub == lnodes.size()");
384
385 corner_t nv = corner_t(lnodes.at(depth_ub - 1).size());
386
387 nwires_ = nv;
388 c->nv = nv;
389 c->logv = lg(nv);
390
391 // d-- > 1 (not 0) because depth 0 denotes input nodes, not a layer.
392 // Sumcheck counts layers starting from the output, hence the loop
393 // counts downwards.
394 for (size_t d = depth_ub; d-- > 1;) {
395 corner_t nw =
396 corner_t(lnodes.at(d - 1).size()); // inputs[d] == outputs[d-1]
397 nwires_ += nw;
398 c->l.push_back(
399 Layer<Field>{.nw = nw,
400 .logw = lg(nw),
401 .quad = mkquad(lnodes.at(d), lnodes.at(d - 1))});
402 }
403 }
404
405 std::unique_ptr<const Quad<Field>> mkquad(
406 const std::vector<lnode>& lnodes0, // wires at this layer
407 const std::vector<lnode>& lnodes1 // wires at the previous layer
408 ) {
409 size_t nterms0 = 0;
410 for (const auto& ln0 : lnodes0) {
411 nterms0 += ln0.lterms.size();
412 }
413 nquad_terms_ += nterms0;
414
415 auto S = std::make_unique<Quad<Field>>(nterms0);
416 size_t i = 0;
417 for (const auto& ln0 : lnodes0) {
418 for (const auto& lt : ln0.lterms) {
419 S->c_[i++] = typename Quad<Field>::corner{
420 .g = ln0.desired_wire_id,
421 .h = {lnodes1.at(static_cast<size_t>(lt.lop0)).desired_wire_id,
422 lnodes1.at(static_cast<size_t>(lt.lop1)).desired_wire_id},
423 .v = lt.k};
424 }
425 }
426 S->canonicalize(f_);
427 return S;
428 }
429};
430
431} // namespace proofs
432
433#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_COMPILER_SCHEDULE_H_
Definition quad.h:37
Definition circuit.h:45
Definition gf2_128.h:63
Definition circuit.h:30
Definition node.h:115
Definition node.h:76
Definition quad.h:51