Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
logic.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
17
18#include <stddef.h>
19
20#include <array>
21#include <cstdint>
22#include <functional>
23#include <vector>
24
25#include "algebra/fp_generic.h"
26#include "gf2k/gf2_128.h"
27#include "util/panic.h"
28
29namespace proofs {
30/*
31 Arithmetization of boolean logic in a field.
32 This class builds logical and arithmetic operations such as add, sub, mul,
33 and, or, xor, etc. over bits in the arithmetic circuit model.
34 The class utilizes several optimizations, including changing from the {0,1}
35 basis to the {-1,1} basis for representing bits.
36 */
37template <typename Field_, class Backend>
38class Logic {
39 public:
40 using Field = Field_; /* this class export Field, Elt, and EltW */
41 using Elt = typename Field::Elt;
42 // an "Elt Wire", a wire carrying an Elt.
43 using EltW = typename Backend::V;
44
45 const Field& f_;
46
47 explicit Logic(const Backend* bk, const Field& F) : f_(F), bk_(bk) {}
48
49 //------------------------------------------------------------
50 // Arithmetic.
51
52 //
53 // Re-export field operations
54 Elt addf(const Elt& a, const Elt& b) const { return f_.addf(a, b); }
55 Elt mulf(const Elt& a, const Elt& b) const { return f_.mulf(a, b); }
56 Elt invertf(const Elt& a) const { return f_.invertf(a); }
57 Elt negf(const Elt& a) const { return f_.negf(a); }
58 Elt zero() const { return f_.zero(); }
59 Elt one() const { return f_.one(); }
60 Elt mone() const { return f_.mone(); }
61 Elt elt(uint64_t a) const { return f_.of_scalar(a); }
62
63 template <size_t N>
64 Elt elt(const char (&s)[N]) const {
65 return f_.of_string(s);
66 }
67
68 // To ensure deterministic behavior, the order of function calls that produce
69 // circuit wires must be well-defined at compile time.
70 // The C spec leaves certain order of operations unspecified in expressions.
71 // One such ambiguity arises in the order of function calls in an argument
72 // list. For example, the expression f(creates_wire(x), creates_wire(y))
73 // results in an ambiguous order.
74 // To help prevent this, all function calls that create wires can have at most
75 // one argument that is itself a function. To enforce this property, we
76 // require that all but the last argument to a function be a const pointer.
77
78 // Re-export backend operations
79 EltW assert0(const EltW& a) const { return bk_->assert0(a); }
80 EltW add(const EltW* a, const EltW& b) const { return bk_->add(*a, b); }
81
82 EltW sub(const EltW* a, const EltW& b) const { return bk_->sub(*a, b); }
83
84 EltW mul(const EltW* a, const EltW& b) const { return bk_->mul(*a, b); }
85 EltW mul(const Elt& k, const EltW& b) const { return bk_->mul(k, b); }
86 EltW mul(const Elt& k, const EltW* a, const EltW& b) const {
87 return bk_->mul(k, a, b);
88 }
89
90 EltW ax(const Elt& a, const EltW& x) const { return bk_->ax(a, x); }
91 EltW axy(const Elt& a, const EltW* x, const EltW& y) const {
92 return bk_->axy(a, *x, y);
93 }
94 EltW axpy(const EltW* y, const Elt& a, const EltW& x) const {
95 return bk_->axpy(*y, a, x);
96 }
97 EltW apy(const EltW& y, const Elt& a) const { return bk_->apy(y, a); }
98
99 EltW konst(const Elt& a) const { return bk_->konst(a); }
100 EltW konst(uint64_t a) const { return konst(elt(a)); }
101
102 template <size_t N>
103 std::array<EltW, N> konst(const std::array<Elt, N>& a) const {
104 std::array<EltW, N> r;
105 for (size_t i = 0; i < N; ++i) {
106 r[i] = konst(a[i]);
107 }
108 return r;
109 }
110
111 //------------------------------------------------------------
112 // Boolean logic.
113 //
114 // We map TRUE to one() and FALSE to zero(). We call this convention
115 // the "standard basis".
116 //
117 // However, actual values on wires may use different conventions,
118 // e.g. -1 for TRUE and 1 for FALSE. To keep track of these changes,
119 // we represent boolean values as (c0, c1, x) where c0+c1*x is
120 // the value in the standard basis. c0 and c1
121 // are compile-time constants that can be manipulated, and
122 // x is a runtime value not known in advance.
123 //
124 // For example, let the "xor basis" denote the mapping FALSE -> 1,
125 // TRUE -> -1. In the xor basis, xor(a,b)=a*b. The output of the
126 // xor gate in the standard basis would be represented as 1/2 + (-1/2)*x
127 // where x=a*b is the wire value in the xor basis.
128
129 // a "bit Wire", a wire carrying a bit
130 struct BitW {
131 Elt c0, c1;
132 EltW x;
133 BitW() = default;
134
135 // constructor in the standard basis
136 explicit BitW(const EltW& bv, const Field& F)
137 : BitW(F.zero(), F.one(), bv) {}
138
139 BitW(Elt c0_, Elt c1_, const EltW& x_) : c0(c0_), c1(c1_), x(x_) {}
140 };
141
142 // vectors of N bits
143 template <size_t N>
144 class bitvec : public std::array<BitW, N> {};
145
146 // Common sizes, publicly exported for convenience. The type names are
147 // intentionally lower-case to capture the spirit of basic "intx_t" types.
148 using v1 = bitvec<1>;
149 using v4 = bitvec<4>;
150 using v8 = bitvec<8>;
151 using v16 = bitvec<16>;
152 using v32 = bitvec<32>;
153 using v64 = bitvec<64>;
154 using v128 = bitvec<128>;
155 using v129 = bitvec<129>;
156 using v256 = bitvec<256>;
157
158 // Let v(x)=c0+c1*x. Return a representation of
159 // d0+d1*v(x)=(d0+d1*c0)+(d1*c1)*x without changing x.
160 // Does not involve the backend at all.
161 BitW rebase(const Elt& d0, const Elt& d1, const BitW& v) const {
162 return BitW(addf(d0, mulf(d1, v.c0)), mulf(d1, v.c1), v.x);
163 }
164
165 EltW eval(const BitW& v) const {
166 EltW r = ax(v.c1, v.x);
167 if (v.c0 != zero()) {
168 auto c0 = konst(v.c0);
169 r = add(&c0, r);
170 }
171 return r;
172 }
173
174 // return an EltW which is 0 iff v is 0
175 EltW assert0(const BitW& v) const {
176 auto e = eval(v);
177 return assert0(e);
178 }
179 // return an EltW which is 0 iff v is 1
180 EltW assert1(const BitW& v) const {
181 auto e = lnot(v);
182 return assert0(e);
183 }
184
185 // 0 iff a==b
186 EltW assert_eq(const EltW* a, const EltW& b) const {
187 return assert0(sub(a, b));
188 }
189 EltW assert_eq(const BitW* a, const BitW& b) const {
190 return assert0(lxor(a, b));
191 }
192 EltW assert_implies(const BitW* a, const BitW& b) const {
193 return assert1(limplies(a, b));
194 }
195
196 // special test for asserting that b \in {0,1} (i.e.,
197 // not some other field element).
198 EltW assert_is_bit(const BitW& b) const {
199 // b - b*b
200 // Seems to work better than b*(1-b)
201 // Equivalent to land(b,lnot(b)) but does not rely
202 // on the specific arithmetization.
203 auto eb = eval(b);
204 return assert_is_bit(eb);
205 }
206 EltW assert_is_bit(const EltW& v) const {
207 auto vvmv = sub(&v, mul(&v, v));
208 return assert0(vvmv);
209 }
210
211 // bits in their own basis b + 0*1, to allow for some
212 // compile-time constant folding
213 BitW bit(size_t b) const {
214 return BitW((b == 0) ? zero() : one(), zero(), konst(one()));
215 }
216
217 void bits(size_t n, BitW a[/*n*/], uint64_t x) const {
218 for (size_t i = 0; i < n; ++i) {
219 a[i] = bit((x >> i) & 1u);
220 }
221 }
222
223 // gates
224 BitW lnot(const BitW& x) const {
225 // lnot() is a pure representation change that does not
226 // involve actual circuit gates
227
228 // 1 - x in the standard basis
229 return rebase(one(), mone(), x);
230 }
231
232 BitW land(const BitW* a, const BitW& b) const {
233 // a * b in the standard basis
234 return mulv(a, b);
235 }
236
237 // special case of product of a logic value by a field
238 // element
239 EltW lmul(const BitW* a, const EltW& b) const {
240 // a * b in the standard basis
241 auto ab = mulv(a, BitW(b, f_));
242 return eval(ab);
243 }
244 EltW lmul(const EltW* b, const BitW& a) const { return lmul(&a, *b); }
245
246 BitW lxor(const BitW* a, const BitW& b) const {
247 return lxor_aux(*a, b, typename Field::TypeTag());
248 }
249 BitW lxor(const BitW* a, const BitW* b) const {
250 return lxor_aux(*a, *b, typename Field::TypeTag());
251 }
252
253 BitW lor(const BitW* a, const BitW& b) const {
254 auto na = lnot(*a);
255 auto nab = land(&na, lnot(b));
256 return lnot(nab);
257 }
258
259 // a => b
260 BitW limplies(const BitW* a, const BitW& b) const {
261 auto na = lnot(*a);
262 return lor(&na, b);
263 }
264
265 // OR of two quantities known to be mutually exclusive
266 BitW lor_exclusive(const BitW* a, const BitW& b) const { return addv(*a, b); }
267
268 BitW lxor3(const BitW* a, const BitW* b, const BitW& c) const {
269 BitW p = lxor(a, b);
270 return lxor(&p, c);
271 }
272
273 // sha256 Ch(): (x & y) ^ (~x & z);
274 BitW lCh(const BitW* x, const BitW* y, const BitW& z) const {
275 auto xy = land(x, *y);
276 auto nx = lnot(*x);
277 return lor_exclusive(&xy, land(&nx, z));
278 }
279
280 // sha256 Maj(): (x & y) ^ (x & z) ^ (y & z);
281 BitW lMaj(const BitW* x, const BitW* y, const BitW& z) const {
282 // Interpret as x + y + z >= 2 and compute the carry
283 // for an adder in the (p, g) basis
284 BitW p = lxor(x, *y);
285 BitW g = land(x, *y);
286 return lor_exclusive(&g, land(&p, z));
287 }
288
289 // mux over logic values
290 BitW mux(const BitW* control, const BitW* iftrue, const BitW& iffalse) const {
291 auto cif = land(control, *iftrue);
292 auto nc = lnot(*control);
293 auto ciff = land(&nc, *iffalse);
294 return lor_exclusive(&cif, ciff);
295 }
296
297 // mux over backend values
298 EltW mux(const BitW* control, const EltW* iftrue, const EltW& iffalse) const {
299 auto cif = lmul(control, *iftrue);
300 auto nc = lnot(*control);
301 auto ciff = lmul(&nc, iffalse);
302 return add(&cif, ciff);
303 }
304
305 // sum_{i0 <= i < i1} f(i)
306 EltW add(size_t i0, size_t i1, const std::function<EltW(size_t)>& f) const {
307 if (i1 <= i0) {
308 return konst(0);
309 } else if (i1 == i0 + 1) {
310 return f(i0);
311 } else {
312 size_t im = i0 + (i1 - i0) / 2;
313 auto lh = add(i0, im, f);
314 auto rh = add(im, i1, f);
315 return add(&lh, rh);
316 }
317 }
318
319 // lor_exclusive_{i0 <= i < i1} f(i)
320 BitW lor_exclusive(size_t i0, size_t i1,
321 const std::function<BitW(size_t)>& f) const {
322 if (i1 <= i0) {
323 return bit(0);
324 } else if (i1 == i0 + 1) {
325 return f(i0);
326 } else {
327 size_t im = i0 + (i1 - i0) / 2;
328 auto lh = lor_exclusive(i0, im, f);
329 auto rh = lor_exclusive(im, i1, f);
330 return lor_exclusive(&lh, rh);
331 }
332 }
333
334 // and_{i0 <= i < i1} f(i)
335 BitW land(size_t i0, size_t i1, const std::function<BitW(size_t)>& f) const {
336 if (i1 <= i0) {
337 return bit(1);
338 } else if (i1 == i0 + 1) {
339 return f(i0);
340 } else {
341 size_t im = i0 + (i1 - i0) / 2;
342 auto lh = land(i0, im, f);
343 auto rh = land(im, i1, f);
344 return land(&lh, rh);
345 }
346 }
347
348 // or_{i0 <= i < i1} f(i)
349 BitW lor(size_t i0, size_t i1, const std::function<BitW(size_t)>& f) const {
350 if (i1 <= i0) {
351 return bit(0);
352 } else if (i1 == i0 + 1) {
353 return f(i0);
354 } else {
355 size_t im = i0 + (i1 - i0) / 2;
356 auto lh = lor(i0, im, f);
357 auto rh = lor(im, i1, f);
358 return lor(&lh, rh);
359 }
360 }
361
362 BitW or_of_and(std::vector<std::vector<BitW>> clauses_of_ands) const {
363 std::vector<BitW> ands(clauses_of_ands.size());
364 for (size_t i = 0; i < clauses_of_ands.size(); ++i) {
365 auto ai = clauses_of_ands[i];
366 BitW res = land(0, ai.size(), [&](size_t i) { return ai[i]; });
367 ands[i] = res;
368 }
369 return lor(0, ands.size(), [&](size_t i) { return ands[i]; });
370 }
371
372 // prod_{i0 <= i < i1} f(i)
373 EltW mul(size_t i0, size_t i1, const std::function<EltW(size_t)>& f) const {
374 if (i1 <= i0) {
375 return konst(1);
376 } else if (i1 == i0 + 1) {
377 return f(i0);
378 } else {
379 size_t im = i0 + (i1 - i0) / 2;
380 auto lh = mul(i0, im, f);
381 auto rh = mul(im, i1, f);
382 return mul(&lh, rh);
383 }
384 }
385
386 // assert that a + b = c in constant depth
387 void assert_sum(size_t w, const BitW c[/*w*/], const BitW a[/*w*/],
388 const BitW b[/*w*/]) const {
389 // first step of generic_gp_add(): change the basis from
390 // (a, b) to (g, p):
391 std::vector<BitW> g(w), p(w), cy(w);
392 for (size_t i = 0; i < w; ++i) {
393 g[i] = land(&a[i], b[i]);
394 p[i] = lxor(&a[i], &b[i]);
395 }
396
397 // invert the last step of generic_gp_add(): derive
398 // cy[i - 1] (called g[i - 1] there) from
399 // c[i] and p[i].
400 assert_eq(&c[0], p[0]);
401 for (size_t i = 1; i < w; ++i) {
402 cy[i - 1] = lxor(&c[i], p[i]);
403 }
404
405 // Verify that applying ripple_scan to g[] produces cy[].
406 // Note that ripple_scan() operates in-place on g[]. Here, however, g[] is
407 // the input to ripple_scan(), and cy[] is the output.
408 assert_eq(&cy[0], g[0]);
409 for (size_t i = 1; i + 1 < w; ++i) {
410 auto cyp = land(&cy[i - 1], p[i]);
411 auto g_cyp = lor_exclusive(&g[i], cyp);
412 assert_eq(&cy[i], g_cyp);
413 }
414 }
415
416 // (carry, c) = a + b, returning the carry.
417 BitW ripple_carry_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
418 const BitW b[/*w*/]) const {
419 return generic_gp_add(w, c, a, b, &Logic::ripple_scan);
420 }
421
422 // (carry, c) = a - b, returning the carry.
423 BitW ripple_carry_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
424 const BitW b[/*w*/]) const {
425 return generic_gp_sub(w, c, a, b, &Logic::ripple_scan);
426 }
427
428 BitW parallel_prefix_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
429 const BitW b[/*w*/]) const {
430 return generic_gp_add(w, c, a, b, &Logic::sklansky_scan);
431 }
432
433 BitW parallel_prefix_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
434 const BitW b[/*w*/]) const {
435 return generic_gp_sub(w, c, a, b, &Logic::sklansky_scan);
436 }
437
438 // w x w -> 2w-bit multiplier c = a * b
439 void multiplier(size_t w, BitW c[/*2*w*/], const BitW a[/*w*/],
440 const BitW b[/*w*/]) const {
441 std::vector<BitW> t(w);
442 for (size_t i = 0; i < w; ++i) {
443 if (i == 0) {
444 for (size_t j = 0; j < w; ++j) {
445 c[j] = land(&a[0], b[j]);
446 }
447 c[w] = bit(0);
448 } else {
449 for (size_t j = 0; j < w; ++j) {
450 t[j] = land(&a[i], b[j]);
451 }
452 BitW carry = ripple_carry_add(w, c + i, t.data(), c + i);
453 c[i + w] = carry;
454 }
455 }
456 }
457
458 // w x w -> 2w-bit polynomial multiplier over gf2. c(x) = a(x) * b(x)
459 void gf2_polynomial_multiplier(size_t w, BitW c[/*2*w*/], const BitW a[/*w*/],
460 const BitW b[/*w*/]) const {
461 std::vector<BitW> t(w);
462 for (size_t k = 0; k < 2 * w; ++k) {
463 size_t n = 0;
464 for (size_t i = 0; i < w; ++i) {
465 if (k >= i && k - i < w) {
466 t[n++] = land(&a[i], b[k - i]);
467 }
468 }
469 c[k] = parity(0, n, t.data());
470 }
471 }
472
473 // w x w -> 2w-bit polynomial multiplier over gf2. c(x) = a(x) * b(x)
474 // via the Karatsuba recurrence. Only works for w = 2^k.
475 void gf2_polynomial_multiplier_karat(size_t w, BitW c[/*2*w*/],
476 const BitW a[/*w*/],
477 const BitW b[/*w*/]) const {
478 check(w == 128 || w == 64 || w < 64, "input length is not a power of 2");
479 if (w < 64) {
480 gf2_polynomial_multiplier(w, c, a, b);
481 return;
482 } else {
483 // We only run this look on w=128 bits. To support odd w,
484 std::vector<BitW> a01(w / 2); /* a0 plus a1 */
485 std::vector<BitW> b01(w / 2); /* b0 plus b1 */
486 std::vector<BitW> ab01(w);
487 std::vector<BitW> a0b0(w);
488 std::vector<BitW> a1b1(w);
489
490 for (size_t i = 0; i < w / 2; ++i) {
491 a01[i] = lxor(&a[i], a[i + w / 2]);
492 b01[i] = lxor(&b[i], b[i + w / 2]);
493 }
494
495 gf2_polynomial_multiplier_karat(w / 2, &ab01[0], &a01[0], &b01[0]);
496 gf2_polynomial_multiplier_karat(w / 2, &a0b0[0], a, b);
497 gf2_polynomial_multiplier_karat(w / 2, &a1b1[0], a + w / 2, b + w / 2);
498
499 for (size_t i = 0; i < w; ++i) {
500 ab01[i] = lxor3(&ab01[i], &a0b0[i], a1b1[i]);
501 }
502
503 for (size_t i = 0; i < w/2; ++i) {
504 c[i] = a0b0[i];
505 c[i + w/2] = lxor(&a0b0[i + w/2], ab01[i]);
506 c[i + w] = lxor(&ab01[i + w/2], a1b1[i]);
507 c[i + 3*w/2] = a1b1[i + w/2];
508 }
509 }
510 }
511
512 // Performs field multiplication in GF2^128 defined by the irreducible
513 // x^128 + x^7 + x^2 + x + 1. This routine is generated in a sage script that
514 // computes a sparse matrix-vector mult via the powers of x^k mod p(x).
515 //
516 // def make_mulmod(F, n):
517 // r = F(1)
518 // gen = F.gen()
519 // nl = [[] for _ in range(n)]
520 // terms = 0
521 // for i in range(0, 2*n-1):
522 // for j, var in enumerate(r.polynomial().list()):
523 // if var == 1:
524 // nl[j].append(i)
525 // r = r * gen
526 // print(nl)
527 void gf2_128_mul(v128& c, const v128 a, const v128 b) const {
528 const std::vector<uint16_t> taps[128] = {
529 {0, 128, 249, 254},
530 {1, 128, 129, 249, 250, 254},
531 {2, 128, 129, 130, 249, 250, 251, 254},
532 {3, 129, 130, 131, 250, 251, 252},
533 {4, 130, 131, 132, 251, 252, 253},
534 {5, 131, 132, 133, 252, 253, 254},
535 {6, 132, 133, 134, 253, 254},
536 {7, 128, 133, 134, 135, 249},
537 {8, 129, 134, 135, 136, 250},
538 {9, 130, 135, 136, 137, 251},
539 {10, 131, 136, 137, 138, 252},
540 {11, 132, 137, 138, 139, 253},
541 {12, 133, 138, 139, 140, 254},
542 {13, 134, 139, 140, 141},
543 {14, 135, 140, 141, 142},
544 {15, 136, 141, 142, 143},
545 {16, 137, 142, 143, 144},
546 {17, 138, 143, 144, 145},
547 {18, 139, 144, 145, 146},
548 {19, 140, 145, 146, 147},
549 {20, 141, 146, 147, 148},
550 {21, 142, 147, 148, 149},
551 {22, 143, 148, 149, 150},
552 {23, 144, 149, 150, 151},
553 {24, 145, 150, 151, 152},
554 {25, 146, 151, 152, 153},
555 {26, 147, 152, 153, 154},
556 {27, 148, 153, 154, 155},
557 {28, 149, 154, 155, 156},
558 {29, 150, 155, 156, 157},
559 {30, 151, 156, 157, 158},
560 {31, 152, 157, 158, 159},
561 {32, 153, 158, 159, 160},
562 {33, 154, 159, 160, 161},
563 {34, 155, 160, 161, 162},
564 {35, 156, 161, 162, 163},
565 {36, 157, 162, 163, 164},
566 {37, 158, 163, 164, 165},
567 {38, 159, 164, 165, 166},
568 {39, 160, 165, 166, 167},
569 {40, 161, 166, 167, 168},
570 {41, 162, 167, 168, 169},
571 {42, 163, 168, 169, 170},
572 {43, 164, 169, 170, 171},
573 {44, 165, 170, 171, 172},
574 {45, 166, 171, 172, 173},
575 {46, 167, 172, 173, 174},
576 {47, 168, 173, 174, 175},
577 {48, 169, 174, 175, 176},
578 {49, 170, 175, 176, 177},
579 {50, 171, 176, 177, 178},
580 {51, 172, 177, 178, 179},
581 {52, 173, 178, 179, 180},
582 {53, 174, 179, 180, 181},
583 {54, 175, 180, 181, 182},
584 {55, 176, 181, 182, 183},
585 {56, 177, 182, 183, 184},
586 {57, 178, 183, 184, 185},
587 {58, 179, 184, 185, 186},
588 {59, 180, 185, 186, 187},
589 {60, 181, 186, 187, 188},
590 {61, 182, 187, 188, 189},
591 {62, 183, 188, 189, 190},
592 {63, 184, 189, 190, 191},
593 {64, 185, 190, 191, 192},
594 {65, 186, 191, 192, 193},
595 {66, 187, 192, 193, 194},
596 {67, 188, 193, 194, 195},
597 {68, 189, 194, 195, 196},
598 {69, 190, 195, 196, 197},
599 {70, 191, 196, 197, 198},
600 {71, 192, 197, 198, 199},
601 {72, 193, 198, 199, 200},
602 {73, 194, 199, 200, 201},
603 {74, 195, 200, 201, 202},
604 {75, 196, 201, 202, 203},
605 {76, 197, 202, 203, 204},
606 {77, 198, 203, 204, 205},
607 {78, 199, 204, 205, 206},
608 {79, 200, 205, 206, 207},
609 {80, 201, 206, 207, 208},
610 {81, 202, 207, 208, 209},
611 {82, 203, 208, 209, 210},
612 {83, 204, 209, 210, 211},
613 {84, 205, 210, 211, 212},
614 {85, 206, 211, 212, 213},
615 {86, 207, 212, 213, 214},
616 {87, 208, 213, 214, 215},
617 {88, 209, 214, 215, 216},
618 {89, 210, 215, 216, 217},
619 {90, 211, 216, 217, 218},
620 {91, 212, 217, 218, 219},
621 {92, 213, 218, 219, 220},
622 {93, 214, 219, 220, 221},
623 {94, 215, 220, 221, 222},
624 {95, 216, 221, 222, 223},
625 {96, 217, 222, 223, 224},
626 {97, 218, 223, 224, 225},
627 {98, 219, 224, 225, 226},
628 {99, 220, 225, 226, 227},
629 {100, 221, 226, 227, 228},
630 {101, 222, 227, 228, 229},
631 {102, 223, 228, 229, 230},
632 {103, 224, 229, 230, 231},
633 {104, 225, 230, 231, 232},
634 {105, 226, 231, 232, 233},
635 {106, 227, 232, 233, 234},
636 {107, 228, 233, 234, 235},
637 {108, 229, 234, 235, 236},
638 {109, 230, 235, 236, 237},
639 {110, 231, 236, 237, 238},
640 {111, 232, 237, 238, 239},
641 {112, 233, 238, 239, 240},
642 {113, 234, 239, 240, 241},
643 {114, 235, 240, 241, 242},
644 {115, 236, 241, 242, 243},
645 {116, 237, 242, 243, 244},
646 {117, 238, 243, 244, 245},
647 {118, 239, 244, 245, 246},
648 {119, 240, 245, 246, 247},
649 {120, 241, 246, 247, 248},
650 {121, 242, 247, 248, 249},
651 {122, 243, 248, 249, 250},
652 {123, 244, 249, 250, 251},
653 {124, 245, 250, 251, 252},
654 {125, 246, 251, 252, 253},
655 {126, 247, 252, 253, 254},
656 {127, 248, 253, 254},
657 };
658 gf2k_mul(c.data(), a.data(), b.data(), taps, 128);
659 }
660
661 // Performs field multiplication in GF2^k using a sparse matrix datastructure.
662 void gf2k_mul(BitW c[/*w*/], const BitW a[/*w*/], const BitW b[/*w*/],
663 const std::vector<uint16_t> M[], size_t w) const {
664 std::vector<BitW> t(w * 2);
665 gf2_polynomial_multiplier_karat(w, t.data(), a, b);
666
667 std::vector<BitW> tmp(w);
668 for (size_t i = 0; i < w; ++i) {
669 size_t n = 0;
670 for (auto ti : M[i]) {
671 tmp[n++] = t[ti];
672 }
673 c[i] = parity(0, n, tmp.data());
674 }
675 }
676
677 // a == 0
678 BitW eq0(size_t w, const BitW a[/*w*/]) const { return eq0(0, w, a); }
679
680 // a == b
681 BitW eq(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
682 return eq_reduce(0, w, a, b);
683 }
684
685 // a < b.
686 // Specialization of the subtractor for the case (a - b) < 0
687 BitW lt(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
688 if (w == 0) {
689 return bit(0);
690 } else {
691 BitW xeq, xlt;
692 lt_reduce(0, w, &xeq, &xlt, a, b);
693 return xlt;
694 }
695 }
696
697 // a <= b
698 BitW leq(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
699 auto blt = lt(w, b, a);
700 return lnot(blt);
701 }
702
703 // Parallel prefix of various kinds
704 template <class T>
705 void scan(const std::function<void(T*, const T&, const T&)>& op, T x[],
706 size_t i0, size_t i1, bool backward = false) const {
707 // generic Sklansky scan
708 if (i1 - i0 > 1) {
709 size_t im = i0 + (i1 - i0) / 2;
710 scan(op, x, i0, im, backward);
711 scan(op, x, im, i1, backward);
712 if (backward) {
713 for (size_t i = i0; i < im; ++i) {
714 op(&x[i], x[i], x[im]);
715 }
716 } else {
717 for (size_t i = im; i < i1; ++i) {
718 op(&x[i], x[im - 1], x[i]);
719 }
720 }
721 }
722 }
723
724 void scan_and(BitW x[], size_t i0, size_t i1, bool backward = false) const {
725 scan<BitW>(
726 [&](BitW* out, const BitW& l, const BitW& r) { *out = land(&l, r); }, x,
727 i0, i1, backward);
728 }
729
730 void scan_or(BitW x[], size_t i0, size_t i1, bool backward = false) const {
731 scan<BitW>(
732 [&](BitW* out, const BitW& l, const BitW& r) { *out = lor(&l, r); }, x,
733 i0, i1, backward);
734 }
735
736 void scan_xor(BitW x[], size_t i0, size_t i1, bool backward = false) const {
737 scan<BitW>(
738 [&](BitW* out, const BitW& l, const BitW& r) { *out = lxor(&l, r); }, x,
739 i0, i1, backward);
740 }
741
742 template <size_t I0, size_t I1, size_t N>
743 bitvec<I1 - I0> slice(const bitvec<N>& a) const {
744 bitvec<I1 - I0> r;
745 for (size_t i = I0; i < I1; ++i) {
746 r[i - I0] = a[i];
747 }
748 return r;
749 }
750
751 // Little-endian append of A and B. A[0] is the LSB, B starts at
752 // position [NA].
753 template <size_t NA, size_t NB>
754 bitvec<NA + NB> vappend(const bitvec<NA>& a, const bitvec<NB>& b) const {
756 for (size_t i = 0; i < NA; ++i) {
757 r[i] = a[i];
758 }
759 for (size_t i = 0; i < NB; ++i) {
760 r[i + NA] = b[i];
761 }
762 return r;
763 }
764
765 template <size_t N>
766 bool vequal(const bitvec<N>* a, const bitvec<N>& b) const {
767 for (size_t i = 0; i < N; ++i) {
768 auto eai = eval((*a)[i]);
769 auto ebi = eval(b[i]);
770 if (eai != ebi) return false;
771 }
772 return true;
773 }
774
775 template <size_t N>
776 bitvec<N> vbit(uint64_t x) const {
777 bitvec<N> r;
778 bits(N, r.data(), x);
779 return r;
780 }
781
782 // shorthands for the silly "template" notation
783 v8 vbit8(uint64_t x) const { return vbit<8>(x); }
784 v32 vbit32(uint64_t x) const { return vbit<32>(x); }
785
786 template <size_t N>
787 bitvec<N> vnot(const bitvec<N>& x) const {
788 bitvec<N> r;
789 for (size_t i = 0; i < N; ++i) {
790 r[i] = lnot(x[i]);
791 }
792 return r;
793 }
794
795 template <size_t N>
796 bitvec<N> vand(const bitvec<N>* a, const bitvec<N>& b) const {
797 bitvec<N> r;
798 for (size_t i = 0; i < N; ++i) {
799 r[i] = land(&(*a)[i], b[i]);
800 }
801 return r;
802 }
803
804 template <size_t N>
805 bitvec<N> vand(const BitW* a, const bitvec<N>& b) const {
806 bitvec<N> r;
807 for (size_t i = 0; i < N; ++i) {
808 r[i] = land(a, b[i]);
809 }
810 return r;
811 }
812
813 template <size_t N>
814 bitvec<N> vor(const bitvec<N>* a, const bitvec<N>& b) const {
815 bitvec<N> r;
816 for (size_t i = 0; i < N; ++i) {
817 r[i] = lor(&(*a)[i], b[i]);
818 }
819 return r;
820 }
821 template <size_t N>
822 bitvec<N> vor_exclusive(const bitvec<N>* a, const bitvec<N>& b) const {
823 bitvec<N> r;
824 for (size_t i = 0; i < N; ++i) {
825 r[i] = lor_exclusive(&(*a)[i], b[i]);
826 }
827 return r;
828 }
829 template <size_t N>
830 bitvec<N> vxor(const bitvec<N>* a, const bitvec<N>& b) const {
831 bitvec<N> r;
832 for (size_t i = 0; i < N; ++i) {
833 r[i] = lxor(&(*a)[i], b[i]);
834 }
835 return r;
836 }
837
838 template <size_t N>
839 bitvec<N> vCh(const bitvec<N>* x, const bitvec<N>* y,
840 const bitvec<N>& z) const {
841 bitvec<N> r;
842 for (size_t i = 0; i < N; ++i) {
843 r[i] = lCh(&(*x)[i], &(*y)[i], z[i]);
844 }
845 return r;
846 }
847 template <size_t N>
848 bitvec<N> vMaj(const bitvec<N>* x, const bitvec<N>* y,
849 const bitvec<N>& z) const {
850 bitvec<N> r;
851 for (size_t i = 0; i < N; ++i) {
852 r[i] = lMaj(&(*x)[i], &(*y)[i], z[i]);
853 }
854 return r;
855 }
856
857 template <size_t N>
858 bitvec<N> vxor3(const bitvec<N>* x, const bitvec<N>* y,
859 const bitvec<N>& z) const {
860 bitvec<N> r;
861 for (size_t i = 0; i < N; ++i) {
862 r[i] = lxor3(&(*x)[i], &(*y)[i], z[i]);
863 }
864 return r;
865 }
866
867 template <size_t N>
868 bitvec<N> vshr(const bitvec<N>& a, size_t shift, size_t b = 0) const {
869 bitvec<N> r;
870 for (size_t i = 0; i < N; ++i) {
871 if (i + shift < N) {
872 r[i] = a[i + shift];
873 } else {
874 r[i] = bit(b);
875 }
876 }
877 return r;
878 }
879
880 template <size_t N>
881 bitvec<N> vshl(const bitvec<N>& a, size_t shift, size_t b = 0) const {
882 bitvec<N> r;
883 for (size_t i = 0; i < N; ++i) {
884 if (i >= shift) {
885 r[i] = a[i - shift];
886 } else {
887 r[i] = bit(b);
888 }
889 }
890 return r;
891 }
892
893 template <size_t N>
894 bitvec<N> vrotr(const bitvec<N>& a, size_t b) const {
895 bitvec<N> r;
896 for (size_t i = 0; i < N; ++i) {
897 r[i] = a[(i + b) % N];
898 }
899 return r;
900 }
901
902 template <size_t N>
903 bitvec<N> vrotl(const bitvec<N>& a, size_t b) const {
904 bitvec<N> r;
905 for (size_t i = 0; i < N; ++i) {
906 r[(i + b) % N] = a[i];
907 }
908 return r;
909 }
910
911 template <size_t N>
912 bitvec<N> vadd(const bitvec<N>& a, const bitvec<N>& b) const {
913 bitvec<N> r;
914 (void)parallel_prefix_add(N, &r[0], &a[0], &b[0]);
915 return r;
916 }
917 template <size_t N>
918 bitvec<N> vadd(const bitvec<N>& a, uint64_t val) const {
919 return vadd(a, vbit<N>(val));
920 }
921
922 template <size_t N>
923 BitW veq(const bitvec<N>& a, const bitvec<N>& b) const {
924 return eq(N, a.data(), b.data());
925 }
926 template <size_t N>
927 BitW veq(const bitvec<N>& a, uint64_t val) const {
928 auto v = vbit<N>(val);
929 return veq(a, v);
930 }
931 template <size_t N>
932 BitW vlt(const bitvec<N>* a, const bitvec<N>& b) const {
933 return lt(N, (*a).data(), b.data());
934 }
935 template <size_t N>
936 BitW vlt(const bitvec<N>& a, uint64_t val) const {
937 auto v = vbit<N>(val);
938 return vlt(&a, v);
939 }
940 template <size_t N>
941 BitW vlt(uint64_t a, const bitvec<N>& b) const {
942 auto va = vbit<N>(a);
943 return vlt(&va, b);
944 }
945 template <size_t N>
946 BitW vleq(const bitvec<N>* a, const bitvec<N>& b) const {
947 return leq(N, (*a).data(), b.data());
948 }
949 template <size_t N>
950 BitW vleq(const bitvec<N>& a, uint64_t val) const {
951 auto v = vbit<N>(val);
952 return vleq(&a, v);
953 }
954
955 // (a ^ val) & mask == 0
956 template <size_t N>
957 BitW veqmask(const bitvec<N>* a, uint64_t mask, const bitvec<N>& val) const {
958 auto r = vxor(a, val);
959 size_t n = pack(mask, N, &r[0]);
960 return eq0(0, n, &r[0]);
961 }
962
963 template <size_t N>
964 BitW veqmask(const bitvec<N>& a, uint64_t mask, uint64_t val) const {
965 auto v = vbit<N>(val);
966 return veqmask(&a, mask, v);
967 }
968
969 // I/O. This is a hack which only works if the backend supports
970 // bk_->{input,output}. Because C++ templates are lazily expanded,
971 // this class compiles even with backends that do not support I/O,
972 // as long as you don't expand vinput(), voutput().
973 BitW input() const { return BitW(bk_->input(), f_); }
974 void output(const BitW& x, size_t i) const { bk_->output(eval(x), i); }
975 size_t wire_id(const BitW& v) const { return bk_->wire_id(v.x); }
976 size_t wire_id(const EltW& x) const { return bk_->wire_id(x); }
977
978 template <size_t N>
979 bitvec<N> vinput() const {
980 bitvec<N> r;
981 for (size_t i = 0; i < N; ++i) {
982 r[i] = input();
983 }
984 return r;
985 }
986
987 template <size_t N>
988 void voutput(const bitvec<N>& x, size_t i0) const {
989 for (size_t i = 0; i < N; ++i) {
990 output(x[i], i + i0);
991 }
992 }
993
994 template <size_t N>
995 void vassert0(const bitvec<N>& x) const {
996 for (size_t i = 0; i < N; ++i) {
997 (void)assert0(x[i]);
998 }
999 }
1000
1001 template <size_t N>
1002 void vassert_eq(const bitvec<N>* x, const bitvec<N>& y) const {
1003 for (size_t i = 0; i < N; ++i) {
1004 (void)assert_eq(&(*x)[i], y[i]);
1005 }
1006 }
1007
1008 template <size_t N>
1009 void vassert_eq(const bitvec<N>& x, uint64_t y) const {
1010 auto v = vbit<N>(y);
1011 vassert_eq(&x, v);
1012 }
1013
1014 template <size_t N>
1015 void vassert_is_bit(const bitvec<N>& a) const {
1016 for (size_t i = 0; i < N; ++i) {
1017 (void)assert_is_bit(a[i]);
1018 }
1019 }
1020
1021 private:
1022 // return one quad gate for the product eval(a)*eval(b),
1023 // optimizing some "obvious" cases.
1024 BitW mulv(const BitW* a, const BitW& b) const {
1025 if (a->c1 == zero()) {
1026 return rebase(zero(), a->c0, b);
1027 } else if (b.c1 == zero()) {
1028 return mulv(&b, *a);
1029 } else {
1030 // Avoid creating the intermediate term 1 * a.x * b.x which is
1031 // likely a useless node. Moreover, two nodes (k1 * a.x * b.x)
1032 // and (k2 * a.x * b.x) will detect the common subexpression
1033 // (a.x * b.x), which will confusingly increment the
1034 // common-subexpression counter.
1035 EltW x = axy(mulf(a->c1, b.c1), &a->x, b.x);
1036 x = axpy(&x, mulf(a->c0, b.c1), b.x);
1037 x = axpy(&x, mulf(a->c1, b.c0), a->x);
1038 x = apy(x, mulf(a->c0, b.c0));
1039 return BitW(x, f_);
1040 }
1041 }
1042
1043 BitW addv(const BitW& a, const BitW& b) const {
1044 if (a.c1 == zero()) {
1045 return BitW(addf(a.c0, b.c0), b.c1, b.x);
1046 } else if (b.c1 == zero()) {
1047 return addv(b, a);
1048 } else {
1049 EltW x = ax(a.c1, a.x);
1050 auto axb = ax(b.c1, b.x);
1051 x = add(&x, axb);
1052 x = apy(x, addf(a.c0, b.c0));
1053 return BitW(x, f_);
1054 }
1055 }
1056
1057 BitW lxor_aux(const BitW& a, const BitW& b, PrimeFieldTypeTag tt) const {
1058 // a * b in the xor basis TRUE -> -1, FALSE -> 1
1059 // map a, b from standard basis to xor basis
1060 Elt mtwo = f_.negf(f_.two());
1061 Elt half = f_.half();
1062 Elt mhalf = f_.negf(half);
1063
1064 BitW a1 = rebase(one(), mtwo, a);
1065 BitW b1 = rebase(one(), mtwo, b);
1066 BitW p = mulv(&a1, b1);
1067 return rebase(half, mhalf, p);
1068 }
1069 BitW lxor_aux(const BitW& a, const BitW& b, BinaryFieldTypeTag tt) const {
1070 return addv(a, b);
1071 }
1072
1073
1074 size_t pack(uint64_t mask, size_t n, BitW a[/*n*/]) const {
1075 size_t j = 0;
1076 for (size_t i = 0; i < n; ++i) {
1077 if (mask & 1) {
1078 a[j++] = a[i];
1079 }
1080 mask >>= 1;
1081 }
1082 return j;
1083 }
1084
1085 // carry-propagation equations
1086 // (g0, p0) + (g1, p1) = (g1 | (g0 & p1), p0 & p1)
1087 // Accumulate in-place into (g1, p1).
1088 //
1089 // We use the property that g1 and p1 are mutually exclusive (g1&p1
1090 // is false), and therefore g1 and (g0 & p1) are also mutually
1091 // exclusive.
1092 void gp_reduce(const BitW& g0, const BitW& p0, BitW* g1, BitW* p1) const {
1093 auto g0p1 = land(&g0, *p1);
1094 *g1 = lor_exclusive(g1, g0p1);
1095 *p1 = land(&p0, *p1);
1096 }
1097
1098 // ripple carry propagation
1099 void ripple_scan(std::vector<BitW>& g, std::vector<BitW>& p, size_t i0,
1100 size_t i1) const {
1101 for (size_t i = i0 + 1; i < i1; ++i) {
1102 gp_reduce(g[i - 1], p[i - 1], &g[i], &p[i]);
1103 }
1104 }
1105
1106 // parallel-prefix carry propagation, Sklansky-style [1960]
1107 void sklansky_scan(std::vector<BitW>& g, std::vector<BitW>& p, size_t i0,
1108 size_t i1) const {
1109 if (i1 - i0 > 1) {
1110 size_t im = i0 + (i1 - i0) / 2;
1111 sklansky_scan(g, p, i0, im);
1112 sklansky_scan(g, p, im, i1);
1113 for (size_t i = im; i < i1; ++i) {
1114 gp_reduce(g[im - 1], p[im - 1], &g[i], &p[i]);
1115 }
1116 }
1117 }
1118
1119 // generic add in generate/propagate form, parametrized
1120 // by the scan primitive.
1121 //
1122 // (carry, c) = a + b
1123 BitW generic_gp_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
1124 const BitW b[/*w*/],
1125 void (Logic::*scan)(std::vector<BitW>& /*g*/,
1126 std::vector<BitW>& /*p*/,
1127 size_t /*i0*/, size_t /*i1*/)
1128 const) const {
1129 if (w == 0) {
1130 return bit(0);
1131 } else {
1132 std::vector<BitW> g(w), p(w);
1133 for (size_t i = 0; i < w; ++i) {
1134 g[i] = land(&a[i], b[i]);
1135 p[i] = lxor(&a[i], b[i]);
1136 c[i] = p[i];
1137 }
1138 (this->*scan)(g, p, 0, w);
1139 for (size_t i = 1; i < w; ++i) {
1140 c[i] = lxor(&c[i], g[i - 1]);
1141 }
1142 return g[w - 1];
1143 }
1144 }
1145
1146 BitW generic_gp_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
1147 const BitW b[/*w*/],
1148 void (Logic::*scan)(std::vector<BitW>& /*g*/,
1149 std::vector<BitW>& /*p*/,
1150 size_t /*i0*/, size_t /*i1*/)
1151 const) const {
1152 // implement as ~(~a + b)
1153 std::vector<BitW> t(w);
1154 for (size_t j = 0; j < w; ++j) {
1155 t[j] = lnot(a[j]);
1156 }
1157 BitW carry = generic_gp_add(w, c, t.data(), b, scan);
1158 for (size_t j = 0; j < w; ++j) {
1159 c[j] = lnot(c[j]);
1160 }
1161 return carry;
1162 }
1163
1164 // Recursion for the a < b comparison.
1165 // Let a = (a1, a0) and b = (b1, b0). Then:
1166 //
1167 // a == b iff a1 == b1 && a0 == b0
1168 // a < b iff a1 < b1 || (a1 == b1 && a0 < b0)
1169 void lt_reduce(size_t i0, size_t i1, BitW* xeq, BitW* xlt,
1170 const BitW a[/*w*/], const BitW b[/*w*/]) const {
1171 if (i1 - i0 > 1) {
1172 BitW eq0, eq1, lt0, lt1;
1173 size_t im = i0 + (i1 - i0) / 2;
1174 lt_reduce(i0, im, &eq0, &lt0, a, b);
1175 lt_reduce(im, i1, &eq1, &lt1, a, b);
1176 *xeq = land(&eq1, eq0);
1177 auto lt0_and_eq1 = land(&eq1, lt0);
1178 *xlt = lor_exclusive(&lt1, lt0_and_eq1);
1179 } else {
1180 auto axb = lxor(&a[i0], b[i0]);
1181 *xeq = lnot(axb);
1182 auto na = lnot(a[i0]);
1183 *xlt = land(&na, b[i0]);
1184 }
1185 }
1186
1187 BitW parity(size_t i0, size_t i1, const BitW a[]) const {
1188 if (i1 <= i0) {
1189 return bit(0);
1190 } else if (i1 == i0 + 1) {
1191 return a[i0];
1192 } else {
1193 size_t im = i0 + (i1 - i0) / 2;
1194 auto lp = parity(i0, im, a);
1195 auto rp = parity(im, i1, a);
1196 return lxor(&lp, rp);
1197 }
1198 }
1199
1200 BitW eq0(size_t i0, size_t i1, const BitW a[]) const {
1201 if (i1 <= i0) {
1202 return bit(1);
1203 } else if (i1 == i0 + 1) {
1204 return lnot(a[i0]);
1205 } else {
1206 size_t im = i0 + (i1 - i0) / 2;
1207 auto le = eq0(i0, im, a);
1208 auto re = eq0(im, i1, a);
1209 return land(&le, re);
1210 }
1211 }
1212
1213 BitW eq_reduce(size_t i0, size_t i1, const BitW a[], const BitW b[]) const {
1214 if (i1 <= i0) {
1215 return bit(1);
1216 } else if (i1 == i0 + 1) {
1217 return lnot(lxor(&a[i0], b[i0]));
1218 } else {
1219 size_t im = i0 + (i1 - i0) / 2;
1220 auto le = eq_reduce(i0, im, a, b);
1221 auto re = eq_reduce(im, i1, a, b);
1222 return land(&le, re);
1223 }
1224 }
1225
1226 const Backend* bk_;
1227};
1228
1229} // namespace proofs
1230
1231#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
Definition logic.h:144
Definition gf2_128.h:63
Definition logic.h:130