39 using size_t_for_storage = term::size_t_for_storage;
43 const std::vector<node>& nodes_;
48 size_t nwires_overhead_;
50 Scheduler(
const std::vector<node>& nodes,
const Field& f)
55 nwires_overhead_(0) {}
57 std::unique_ptr<Circuit<Field>> mkcircuit(
const std::vector<Elt>& constants,
58 size_t depth_ub,
size_t nc) {
59 std::unique_ptr<Circuit<Field>> c = std::make_unique<Circuit<Field>>();
66 auto lnodes = order_by_layer(constants, depth_ub);
76 assign_wire_ids(lnodes);
77 fill_layers(c.get(), depth_ub, lnodes);
86 quad_corner_t lop0, lop1;
89 quad_corner_t desired_wire_id;
103 std::vector<lterm> lterms;
105 lnode(quad_corner_t desired_wire_id,
bool is_copy_wire,
106 const std::vector<lterm>& lterms)
107 : desired_wire_id(desired_wire_id),
108 is_copy_wire(is_copy_wire),
112 quad_corner_t lop_of_op_at_depth(
113 const std::vector<std::vector<quad_corner_t>>& lop,
size_t op,
115 const node& n = nodes_.at(op);
116 return lop.at(op).at(d - n.info.depth);
120 std::vector<std::vector<lnode>> order_by_layer(
121 const std::vector<Elt>& constants,
size_t depth_ub) {
132 std::vector<std::vector<lnode>> lnodes(depth_ub);
133 std::vector<std::vector<quad_corner_t>> lops(nodes_.size());
135 nwires_overhead_ = 0;
137 for (
size_t op = 0; op < nodes_.size(); ++op) {
138 const auto& n = nodes_[op];
139 const nodeinfo& nfo = n.info;
140 if (nfo.is_needed && !n.zero()) {
141 size_t d = nfo.depth;
144 quad_corner_t lop = quad_corner_t(lnodes.at(d).size());
145 lops.at(op).push_back(lop);
149 std::vector<lterm> lterms;
150 for (
const auto& t : n.terms) {
152 .k = constants.at(t.ki),
153 .lop0 = lop_of_op_at_depth(lops, t.op0, d - 1),
154 .lop1 = lop_of_op_at_depth(lops, t.op1, d - 1),
156 lterms.push_back(lt);
158 lnodes.at(d).push_back(lnode(nfo.desired_wire_id(d, depth_ub),
163 for (d = nfo.depth + 1; d < nfo.max_needed_depth; ++d) {
164 quad_corner_t lop_dm1 = lop;
167 lop = quad_corner_t(lnodes.at(d).size());
168 lops.at(op).push_back(lop);
170 std::vector<lterm> lterms;
176 .lop0 = quad_corner_t(0),
179 lterms.push_back(lt);
180 lnodes.at(d).push_back(lnode(nfo.desired_wire_id(d, depth_ub),
199 class renamed_lterm {
202 quad_corner_t rlop0_, rlop1_;
206 renamed_lterm(
const Elt& k, quad_corner_t rlop0, quad_corner_t rlop1)
208 rlop0_(std::min<quad_corner_t>(rlop0, rlop1)),
209 rlop1_(std::max<quad_corner_t>(rlop0, rlop1)) {}
211 bool operator==(
const renamed_lterm& y)
const {
212 return rlop0_ == y.rlop0_ && rlop1_ == y.rlop1_ && k_ == y.k_;
216 static bool compare(
const renamed_lterm& a,
const renamed_lterm& b,
218 if (a.rlop0_ < b.rlop0_)
return true;
219 if (a.rlop0_ > b.rlop0_)
return false;
220 if (a.rlop1_ < b.rlop1_)
return true;
221 if (a.rlop1_ > b.rlop1_)
return false;
222 return elt_less_than(a.k_, b.k_, F);
226 class renamed_lnode {
228 quad_corner_t desired_wire_id_;
229 quad_corner_t original_wire_index_;
231 std::vector<renamed_lterm> rlterms_;
233 renamed_lnode(quad_corner_t desired_wire_id,
234 quad_corner_t original_wire_index,
bool is_copy_wire,
235 const std::vector<renamed_lterm>& rlterms)
236 : desired_wire_id_(desired_wire_id),
237 original_wire_index_(original_wire_index),
238 is_copy_wire_(is_copy_wire),
241 bool operator==(
const renamed_lnode& y)
const {
242 if (is_copy_wire_ != y.is_copy_wire_)
return false;
243 if (rlterms_.size() != y.rlterms_.size())
return false;
244 size_t l = rlterms_.size();
245 for (
size_t i = 0; i < l; ++i) {
246 if (!(rlterms_[i] == y.rlterms_[i]))
return false;
252 static bool compare(
const renamed_lnode& ra,
const renamed_lnode& rb,
256 if (ra.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
257 if (rb.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
258 return ra.desired_wire_id_ < rb.desired_wire_id_;
263 if (rb.desired_wire_id_ != nodeinfo::kWireIdUndefined) {
272 for (
size_t ia = ra.rlterms_.size(), ib = rb.rlterms_.size();
273 ia-- > 0 && ib-- > 0;) {
274 const renamed_lterm& rlta = ra.rlterms_[ia];
275 const renamed_lterm& rltb = rb.rlterms_[ib];
276 if (renamed_lterm::compare(rlta, rltb, F))
return true;
277 if (renamed_lterm::compare(rltb, rlta, F))
return false;
282 if (ra.rlterms_.size() < rb.rlterms_.size())
return true;
283 if (ra.rlterms_.size() > rb.rlterms_.size())
return false;
286 if (!ra.is_copy_wire_ && rb.is_copy_wire_)
return true;
287 if (!rb.is_copy_wire_ && ra.is_copy_wire_)
return false;
295 bool uniq(
const std::vector<T>& sorted) {
296 for (
size_t i = 0; i + 1 < sorted.size(); ++i) {
297 if (sorted[i] == sorted[i + 1])
return false;
302 void assign_wire_ids(std::vector<std::vector<lnode>>& lnodes) {
304 assert_all_desired_wire_id_defined(lnodes.at(0));
306 for (
size_t d = 1; d < lnodes.size(); ++d) {
307 const std::vector<lnode>& lnodes_at_dm1 = lnodes.at(d - 1);
308 const std::vector<lnode>& lnodes_at_d = lnodes.at(d);
314 std::vector<renamed_lnode> renamed_at_d;
316 quad_corner_t original_wire_index(0);
317 for (
const lnode& ln : lnodes_at_d) {
318 std::vector<renamed_lterm> rlterms;
321 rlterms.reserve(ln.lterms.size());
322 for (
const lterm& lt : ln.lterms) {
323 rlterms.push_back(renamed_lterm(
325 lnodes_at_dm1.at(
static_cast<size_t>(lt.lop0)).desired_wire_id,
326 lnodes_at_dm1.at(
static_cast<size_t>(lt.lop1)).desired_wire_id));
330 std::sort(rlterms.begin(), rlterms.end(),
331 [&](
const renamed_lterm& a,
const renamed_lterm& b) {
332 return renamed_lterm::compare(a, b, f_);
338 check(uniq(rlterms),
"rlterms not unique");
340 renamed_at_d.push_back(renamed_lnode(
341 ln.desired_wire_id, original_wire_index, ln.is_copy_wire, rlterms));
342 ++original_wire_index;
345 check(renamed_at_d.size() == lnodes_at_d.size(),
346 "renamed_at_d.size() == lnodes_at_d.size()");
348 std::sort(renamed_at_d.begin(), renamed_at_d.end(),
349 [&](
const renamed_lnode& a,
const renamed_lnode& b) {
350 return renamed_lnode::compare(a, b, f_);
355 check(uniq(renamed_at_d),
"renamed_at_d not unique");
357 quad_corner_t wid(0);
358 std::vector<lnode>& wlnodes_at_d = lnodes.at(d);
360 for (
const renamed_lnode& ln : renamed_at_d) {
362 wlnodes_at_d.at(
static_cast<size_t>(ln.original_wire_index_));
363 if (lnpi.desired_wire_id != nodeinfo::kWireIdUndefined) {
365 check(wid == lnpi.desired_wire_id,
"wid == lnpi.desired_wire_id");
367 lnpi.desired_wire_id = wid;
374 void assert_all_desired_wire_id_defined(
const std::vector<lnode>& layer) {
375 for (
const auto& ln : layer) {
376 check(ln.desired_wire_id != nodeinfo::kWireIdUndefined,
377 "ln.desired_wire_id != kWireIdUndefined");
382 const std::vector<std::vector<lnode>>& lnodes) {
383 check(depth_ub == lnodes.size(),
"depth_ub == lnodes.size()");
385 corner_t nv = corner_t(lnodes.at(depth_ub - 1).size());
394 for (
size_t d = depth_ub; d-- > 1;) {
396 corner_t(lnodes.at(d - 1).size());
401 .quad = mkquad(lnodes.at(d), lnodes.at(d - 1))});
405 std::unique_ptr<const Quad<Field>> mkquad(
406 const std::vector<lnode>& lnodes0,
407 const std::vector<lnode>& lnodes1
410 for (
const auto& ln0 : lnodes0) {
411 nterms0 += ln0.lterms.size();
413 nquad_terms_ += nterms0;
415 auto S = std::make_unique<Quad<Field>>(nterms0);
417 for (
const auto& ln0 : lnodes0) {
418 for (
const auto& lt : ln0.lterms) {
420 .g = ln0.desired_wire_id,
421 .h = {lnodes1.at(
static_cast<size_t>(lt.lop0)).desired_wire_id,
422 lnodes1.at(
static_cast<size_t>(lt.lop1)).desired_wire_id},