Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
logic.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
17
18#include <stddef.h>
19
20#include <array>
21#include <cstdint>
22#include <functional>
23#include <vector>
24
25#include "algebra/fp_generic.h"
26#include "gf2k/gf2_128.h"
27#include "util/panic.h"
28
29namespace proofs {
30/*
31 Arithmetization of boolean logic in a field.
32 This class builds logical and arithmetic operations such as add, sub, mul,
33 and, or, xor, etc. over bits in the arithmetic circuit model.
34 The class utilizes several optimizations, including changing from the {0,1}
35 basis to the {-1,1} basis for representing bits.
36 */
37template <typename Field_, class Backend>
38class Logic {
39 public:
40 using Field = Field_; /* this class export Field, Elt, and EltW */
41 using Elt = typename Field::Elt;
42 // an "Elt Wire", a wire carrying an Elt.
43 using EltW = typename Backend::V;
44
45 const Field& f_;
46
47 explicit Logic(const Backend* bk, const Field& F) : f_(F), bk_(bk) {}
48
49 //------------------------------------------------------------
50 // Arithmetic.
51
52 //
53 // Re-export field operations
54 Elt addf(const Elt& a, const Elt& b) const { return f_.addf(a, b); }
55 Elt mulf(const Elt& a, const Elt& b) const { return f_.mulf(a, b); }
56 Elt invertf(const Elt& a) const { return f_.invertf(a); }
57 Elt negf(const Elt& a) const { return f_.negf(a); }
58 Elt zero() const { return f_.zero(); }
59 Elt one() const { return f_.one(); }
60 Elt mone() const { return f_.mone(); }
61 Elt elt(uint64_t a) const { return f_.of_scalar(a); }
62
63 template <size_t N>
64 Elt elt(const char (&s)[N]) const {
65 return f_.of_string(s);
66 }
67
68 // To ensure deterministic behavior, the order of function calls that produce
69 // circuit wires must be well-defined at compile time.
70 // The C spec leaves certain order of operations unspecified in expressions.
71 // One such ambiguity arises in the order of function calls in an argument
72 // list. For example, the expression f(creates_wire(x), creates_wire(y))
73 // results in an ambiguous order.
74 // To help prevent this, all function calls that create wires can have at most
75 // one argument that is itself a function. To enforce this property, we
76 // require that all but the last argument to a function be a const pointer.
77
78 // Re-export backend operations
79 EltW assert0(const EltW& a) const { return bk_->assert0(a); }
80 EltW add(const EltW* a, const EltW& b) const { return bk_->add(*a, b); }
81
82 EltW sub(const EltW* a, const EltW& b) const { return bk_->sub(*a, b); }
83
84 EltW mul(const EltW* a, const EltW& b) const { return bk_->mul(*a, b); }
85 EltW mul(const Elt& k, const EltW& b) const { return bk_->mul(k, b); }
86 EltW mul(const Elt& k, const EltW* a, const EltW& b) const {
87 return bk_->mul(k, a, b);
88 }
89
90 EltW ax(const Elt& a, const EltW& x) const { return bk_->ax(a, x); }
91 EltW axy(const Elt& a, const EltW* x, const EltW& y) const {
92 return bk_->axy(a, *x, y);
93 }
94 EltW axpy(const EltW* y, const Elt& a, const EltW& x) const {
95 return bk_->axpy(*y, a, x);
96 }
97 EltW apy(const EltW& y, const Elt& a) const { return bk_->apy(y, a); }
98
99 EltW konst(const Elt& a) const { return bk_->konst(a); }
100 EltW konst(uint64_t a) const { return konst(elt(a)); }
101
102 template <size_t N>
103 std::array<EltW, N> konst(const std::array<Elt, N>& a) const {
104 std::array<EltW, N> r;
105 for (size_t i = 0; i < N; ++i) {
106 r[i] = konst(a[i]);
107 }
108 return r;
109 }
110
111 //------------------------------------------------------------
112 // Boolean logic.
113 //
114 // We map TRUE to one() and FALSE to zero(). We call this convention
115 // the "standard basis".
116 //
117 // However, actual values on wires may use different conventions,
118 // e.g. -1 for TRUE and 1 for FALSE. To keep track of these changes,
119 // we represent boolean values as (c0, c1, x) where c0+c1*x is
120 // the value in the standard basis. c0 and c1
121 // are compile-time constants that can be manipulated, and
122 // x is a runtime value not known in advance.
123 //
124 // For example, let the "xor basis" denote the mapping FALSE -> 1,
125 // TRUE -> -1. In the xor basis, xor(a,b)=a*b. The output of the
126 // xor gate in the standard basis would be represented as 1/2 + (-1/2)*x
127 // where x=a*b is the wire value in the xor basis.
128
129 // a "bit Wire", a wire carrying a bit
130 struct BitW {
131 Elt c0, c1;
132 EltW x;
133 BitW() = default;
134
135 // constructor in the standard basis
136 explicit BitW(const EltW& bv, const Field& F)
137 : BitW(F.zero(), F.one(), bv) {}
138
139 BitW(Elt c0_, Elt c1_, const EltW& x_) : c0(c0_), c1(c1_), x(x_) {}
140 };
141
142 // vectors of N bits
143 template <size_t N>
144 class bitvec : public std::array<BitW, N> {};
145
146 // Common sizes, publicly exported for convenience. The type names are
147 // intentionally lower-case to capture the spirit of basic "intx_t" types.
148 using v1 = bitvec<1>;
149 using v4 = bitvec<4>;
150 using v8 = bitvec<8>;
151 using v16 = bitvec<16>;
152 using v32 = bitvec<32>;
153 using v64 = bitvec<64>;
154 using v128 = bitvec<128>;
155 using v129 = bitvec<129>;
156 using v256 = bitvec<256>;
157
158 // Let v(x)=c0+c1*x. Return a representation of
159 // d0+d1*v(x)=(d0+d1*c0)+(d1*c1)*x without changing x.
160 // Does not involve the backend at all.
161 BitW rebase(const Elt& d0, const Elt& d1, const BitW& v) const {
162 return BitW(addf(d0, mulf(d1, v.c0)), mulf(d1, v.c1), v.x);
163 }
164
165 EltW eval(const BitW& v) const {
166 EltW r = ax(v.c1, v.x);
167 if (v.c0 != zero()) {
168 auto c0 = konst(v.c0);
169 r = add(&c0, r);
170 }
171 return r;
172 }
173
174 // compute in the circuit what F.of_scalar(sum_i v[i] 2^i) would compute
175 // outside the circuit
176 template <size_t N>
177 EltW as_scalar(const bitvec<N>& v) const {
178 EltW r = konst(zero());
179 for (size_t i = 0; i < N; ++i) {
180 auto vi = eval(v[i]);
181 r = axpy(&r, f_.beta(i), vi);
182 }
183 return r;
184 }
185
186 // return an EltW which is 0 iff v is 0
187 EltW assert0(const BitW& v) const {
188 auto e = eval(v);
189 return assert0(e);
190 }
191 // return an EltW which is 0 iff v is 1
192 EltW assert1(const BitW& v) const {
193 auto e = lnot(v);
194 return assert0(e);
195 }
196
197 // 0 iff a==b
198 EltW assert_eq(const EltW* a, const EltW& b) const {
199 return assert0(sub(a, b));
200 }
201 EltW assert_eq(const BitW* a, const BitW& b) const {
202 return assert0(lxor(a, b));
203 }
204 EltW assert_implies(const BitW* a, const BitW& b) const {
205 return assert1(limplies(a, b));
206 }
207
208 // special test for asserting that b \in {0,1} (i.e.,
209 // not some other field element).
210 EltW assert_is_bit(const BitW& b) const {
211 // b - b*b
212 // Seems to work better than b*(1-b)
213 // Equivalent to land(b,lnot(b)) but does not rely
214 // on the specific arithmetization.
215 auto eb = eval(b);
216 return assert_is_bit(eb);
217 }
218 EltW assert_is_bit(const EltW& v) const {
219 auto vvmv = sub(&v, mul(&v, v));
220 return assert0(vvmv);
221 }
222
223 // bits in their own basis b + 0*1, to allow for some
224 // compile-time constant folding
225 BitW bit(size_t b) const {
226 return BitW((b == 0) ? zero() : one(), zero(), konst(one()));
227 }
228
229 void bits(size_t n, BitW a[/*n*/], uint64_t x) const {
230 for (size_t i = 0; i < n; ++i) {
231 a[i] = bit((x >> i) & 1u);
232 }
233 }
234
235 // gates
236 BitW lnot(const BitW& x) const {
237 // lnot() is a pure representation change that does not
238 // involve actual circuit gates
239
240 // 1 - x in the standard basis
241 return rebase(one(), mone(), x);
242 }
243
244 BitW land(const BitW* a, const BitW& b) const {
245 // a * b in the standard basis
246 return mulv(a, b);
247 }
248
249 // special case of product of a logic value by a field
250 // element
251 EltW lmul(const BitW* a, const EltW& b) const {
252 // a * b in the standard basis
253 auto ab = mulv(a, BitW(b, f_));
254 return eval(ab);
255 }
256 EltW lmul(const EltW* b, const BitW& a) const { return lmul(&a, *b); }
257
258 BitW lxor(const BitW* a, const BitW& b) const {
259 return lxor_aux(*a, b, typename Field::TypeTag());
260 }
261 BitW lxor(const BitW* a, const BitW* b) const {
262 return lxor_aux(*a, *b, typename Field::TypeTag());
263 }
264
265 BitW lor(const BitW* a, const BitW& b) const {
266 auto na = lnot(*a);
267 auto nab = land(&na, lnot(b));
268 return lnot(nab);
269 }
270
271 // a => b
272 BitW limplies(const BitW* a, const BitW& b) const {
273 auto na = lnot(*a);
274 return lor(&na, b);
275 }
276
277 // OR of two quantities known to be mutually exclusive
278 BitW lor_exclusive(const BitW* a, const BitW& b) const { return addv(*a, b); }
279
280 BitW lxor3(const BitW* a, const BitW* b, const BitW& c) const {
281 BitW p = lxor(a, b);
282 return lxor(&p, c);
283 }
284
285 // sha256 Ch(): (x & y) ^ (~x & z);
286 BitW lCh(const BitW* x, const BitW* y, const BitW& z) const {
287 auto xy = land(x, *y);
288 auto nx = lnot(*x);
289 return lor_exclusive(&xy, land(&nx, z));
290 }
291
292 // sha256 Maj(): (x & y) ^ (x & z) ^ (y & z);
293 BitW lMaj(const BitW* x, const BitW* y, const BitW& z) const {
294 // Interpret as x + y + z >= 2 and compute the carry
295 // for an adder in the (p, g) basis
296 BitW p = lxor(x, *y);
297 BitW g = land(x, *y);
298 return lor_exclusive(&g, land(&p, z));
299 }
300
301 // mux over logic values
302 BitW mux(const BitW* control, const BitW* iftrue, const BitW& iffalse) const {
303 auto cif = land(control, *iftrue);
304 auto nc = lnot(*control);
305 auto ciff = land(&nc, *iffalse);
306 return lor_exclusive(&cif, ciff);
307 }
308
309 // mux over backend values
310 EltW mux(const BitW* control, const EltW* iftrue, const EltW& iffalse) const {
311 auto cif = lmul(control, *iftrue);
312 auto nc = lnot(*control);
313 auto ciff = lmul(&nc, iffalse);
314 return add(&cif, ciff);
315 }
316
317 // sum_{i0 <= i < i1} f(i)
318 EltW add(size_t i0, size_t i1, const std::function<EltW(size_t)>& f) const {
319 if (i1 <= i0) {
320 return konst(0);
321 } else if (i1 == i0 + 1) {
322 return f(i0);
323 } else {
324 size_t im = i0 + (i1 - i0) / 2;
325 auto lh = add(i0, im, f);
326 auto rh = add(im, i1, f);
327 return add(&lh, rh);
328 }
329 }
330
331 // lor_exclusive_{i0 <= i < i1} f(i)
332 BitW lor_exclusive(size_t i0, size_t i1,
333 const std::function<BitW(size_t)>& f) const {
334 if (i1 <= i0) {
335 return bit(0);
336 } else if (i1 == i0 + 1) {
337 return f(i0);
338 } else {
339 size_t im = i0 + (i1 - i0) / 2;
340 auto lh = lor_exclusive(i0, im, f);
341 auto rh = lor_exclusive(im, i1, f);
342 return lor_exclusive(&lh, rh);
343 }
344 }
345
346 // and_{i0 <= i < i1} f(i)
347 BitW land(size_t i0, size_t i1, const std::function<BitW(size_t)>& f) const {
348 if (i1 <= i0) {
349 return bit(1);
350 } else if (i1 == i0 + 1) {
351 return f(i0);
352 } else {
353 size_t im = i0 + (i1 - i0) / 2;
354 auto lh = land(i0, im, f);
355 auto rh = land(im, i1, f);
356 return land(&lh, rh);
357 }
358 }
359
360 // or_{i0 <= i < i1} f(i)
361 BitW lor(size_t i0, size_t i1, const std::function<BitW(size_t)>& f) const {
362 if (i1 <= i0) {
363 return bit(0);
364 } else if (i1 == i0 + 1) {
365 return f(i0);
366 } else {
367 size_t im = i0 + (i1 - i0) / 2;
368 auto lh = lor(i0, im, f);
369 auto rh = lor(im, i1, f);
370 return lor(&lh, rh);
371 }
372 }
373
374 BitW or_of_and(std::vector<std::vector<BitW>> clauses_of_ands) const {
375 std::vector<BitW> ands(clauses_of_ands.size());
376 for (size_t i = 0; i < clauses_of_ands.size(); ++i) {
377 auto ai = clauses_of_ands[i];
378 BitW res = land(0, ai.size(), [&](size_t i) { return ai[i]; });
379 ands[i] = res;
380 }
381 return lor(0, ands.size(), [&](size_t i) { return ands[i]; });
382 }
383
384 // prod_{i0 <= i < i1} f(i)
385 EltW mul(size_t i0, size_t i1, const std::function<EltW(size_t)>& f) const {
386 if (i1 <= i0) {
387 return konst(1);
388 } else if (i1 == i0 + 1) {
389 return f(i0);
390 } else {
391 size_t im = i0 + (i1 - i0) / 2;
392 auto lh = mul(i0, im, f);
393 auto rh = mul(im, i1, f);
394 return mul(&lh, rh);
395 }
396 }
397
398 // assert that a + b = c in constant depth
399 void assert_sum(size_t w, const BitW c[/*w*/], const BitW a[/*w*/],
400 const BitW b[/*w*/]) const {
401 // first step of generic_gp_add(): change the basis from
402 // (a, b) to (g, p):
403 std::vector<BitW> g(w), p(w), cy(w);
404 for (size_t i = 0; i < w; ++i) {
405 g[i] = land(&a[i], b[i]);
406 p[i] = lxor(&a[i], &b[i]);
407 }
408
409 // invert the last step of generic_gp_add(): derive
410 // cy[i - 1] (called g[i - 1] there) from
411 // c[i] and p[i].
412 assert_eq(&c[0], p[0]);
413 for (size_t i = 1; i < w; ++i) {
414 cy[i - 1] = lxor(&c[i], p[i]);
415 }
416
417 // Verify that applying ripple_scan to g[] produces cy[].
418 // Note that ripple_scan() operates in-place on g[]. Here, however, g[] is
419 // the input to ripple_scan(), and cy[] is the output.
420 assert_eq(&cy[0], g[0]);
421 for (size_t i = 1; i + 1 < w; ++i) {
422 auto cyp = land(&cy[i - 1], p[i]);
423 auto g_cyp = lor_exclusive(&g[i], cyp);
424 assert_eq(&cy[i], g_cyp);
425 }
426 }
427
428 // (carry, c) = a + b, returning the carry.
429 BitW ripple_carry_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
430 const BitW b[/*w*/]) const {
431 return generic_gp_add(w, c, a, b, &Logic::ripple_scan);
432 }
433
434 // (carry, c) = a - b, returning the carry.
435 BitW ripple_carry_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
436 const BitW b[/*w*/]) const {
437 return generic_gp_sub(w, c, a, b, &Logic::ripple_scan);
438 }
439
440 BitW parallel_prefix_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
441 const BitW b[/*w*/]) const {
442 return generic_gp_add(w, c, a, b, &Logic::sklansky_scan);
443 }
444
445 BitW parallel_prefix_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
446 const BitW b[/*w*/]) const {
447 return generic_gp_sub(w, c, a, b, &Logic::sklansky_scan);
448 }
449
450 // w x w -> 2w-bit multiplier c = a * b
451 void multiplier(size_t w, BitW c[/*2*w*/], const BitW a[/*w*/],
452 const BitW b[/*w*/]) const {
453 std::vector<BitW> t(w);
454 for (size_t i = 0; i < w; ++i) {
455 if (i == 0) {
456 for (size_t j = 0; j < w; ++j) {
457 c[j] = land(&a[0], b[j]);
458 }
459 c[w] = bit(0);
460 } else {
461 for (size_t j = 0; j < w; ++j) {
462 t[j] = land(&a[i], b[j]);
463 }
464 BitW carry = ripple_carry_add(w, c + i, t.data(), c + i);
465 c[i + w] = carry;
466 }
467 }
468 }
469
470 // w x w -> 2w-bit polynomial multiplier over gf2. c(x) = a(x) * b(x)
471 void gf2_polynomial_multiplier(size_t w, BitW c[/*2*w*/], const BitW a[/*w*/],
472 const BitW b[/*w*/]) const {
473 std::vector<BitW> t(w);
474 for (size_t k = 0; k < 2 * w; ++k) {
475 size_t n = 0;
476 for (size_t i = 0; i < w; ++i) {
477 if (k >= i && k - i < w) {
478 t[n++] = land(&a[i], b[k - i]);
479 }
480 }
481 c[k] = parity(0, n, t.data());
482 }
483 }
484
485 // w x w -> 2w-bit polynomial multiplier over gf2. c(x) = a(x) * b(x)
486 // via the Karatsuba recurrence. Only works for w = 2^k.
487 void gf2_polynomial_multiplier_karat(size_t w, BitW c[/*2*w*/],
488 const BitW a[/*w*/],
489 const BitW b[/*w*/]) const {
490 check(w == 128 || w == 64 || w < 64, "input length is not a power of 2");
491 if (w < 64) {
492 gf2_polynomial_multiplier(w, c, a, b);
493 return;
494 } else {
495 // We only run this look on w=128 bits. To support odd w,
496 std::vector<BitW> a01(w / 2); /* a0 plus a1 */
497 std::vector<BitW> b01(w / 2); /* b0 plus b1 */
498 std::vector<BitW> ab01(w);
499 std::vector<BitW> a0b0(w);
500 std::vector<BitW> a1b1(w);
501
502 for (size_t i = 0; i < w / 2; ++i) {
503 a01[i] = lxor(&a[i], a[i + w / 2]);
504 b01[i] = lxor(&b[i], b[i + w / 2]);
505 }
506
507 gf2_polynomial_multiplier_karat(w / 2, &ab01[0], &a01[0], &b01[0]);
508 gf2_polynomial_multiplier_karat(w / 2, &a0b0[0], a, b);
509 gf2_polynomial_multiplier_karat(w / 2, &a1b1[0], a + w / 2, b + w / 2);
510
511 for (size_t i = 0; i < w; ++i) {
512 ab01[i] = lxor3(&ab01[i], &a0b0[i], a1b1[i]);
513 }
514
515 for (size_t i = 0; i < w / 2; ++i) {
516 c[i] = a0b0[i];
517 c[i + w / 2] = lxor(&a0b0[i + w / 2], ab01[i]);
518 c[i + w] = lxor(&ab01[i + w / 2], a1b1[i]);
519 c[i + 3 * w / 2] = a1b1[i + w / 2];
520 }
521 }
522 }
523
524 // Performs field multiplication in GF2^128 defined by the irreducible
525 // x^128 + x^7 + x^2 + x + 1. This routine is generated in a sage script that
526 // computes a sparse matrix-vector mult via the powers of x^k mod p(x).
527 //
528 // def make_mulmod(F, n):
529 // r = F(1)
530 // gen = F.gen()
531 // nl = [[] for _ in range(n)]
532 // terms = 0
533 // for i in range(0, 2*n-1):
534 // for j, var in enumerate(r.polynomial().list()):
535 // if var == 1:
536 // nl[j].append(i)
537 // r = r * gen
538 // print(nl)
539 void gf2_128_mul(v128& c, const v128 a, const v128 b) const {
540 const std::vector<uint16_t> taps[128] = {
541 {0, 128, 249, 254},
542 {1, 128, 129, 249, 250, 254},
543 {2, 128, 129, 130, 249, 250, 251, 254},
544 {3, 129, 130, 131, 250, 251, 252},
545 {4, 130, 131, 132, 251, 252, 253},
546 {5, 131, 132, 133, 252, 253, 254},
547 {6, 132, 133, 134, 253, 254},
548 {7, 128, 133, 134, 135, 249},
549 {8, 129, 134, 135, 136, 250},
550 {9, 130, 135, 136, 137, 251},
551 {10, 131, 136, 137, 138, 252},
552 {11, 132, 137, 138, 139, 253},
553 {12, 133, 138, 139, 140, 254},
554 {13, 134, 139, 140, 141},
555 {14, 135, 140, 141, 142},
556 {15, 136, 141, 142, 143},
557 {16, 137, 142, 143, 144},
558 {17, 138, 143, 144, 145},
559 {18, 139, 144, 145, 146},
560 {19, 140, 145, 146, 147},
561 {20, 141, 146, 147, 148},
562 {21, 142, 147, 148, 149},
563 {22, 143, 148, 149, 150},
564 {23, 144, 149, 150, 151},
565 {24, 145, 150, 151, 152},
566 {25, 146, 151, 152, 153},
567 {26, 147, 152, 153, 154},
568 {27, 148, 153, 154, 155},
569 {28, 149, 154, 155, 156},
570 {29, 150, 155, 156, 157},
571 {30, 151, 156, 157, 158},
572 {31, 152, 157, 158, 159},
573 {32, 153, 158, 159, 160},
574 {33, 154, 159, 160, 161},
575 {34, 155, 160, 161, 162},
576 {35, 156, 161, 162, 163},
577 {36, 157, 162, 163, 164},
578 {37, 158, 163, 164, 165},
579 {38, 159, 164, 165, 166},
580 {39, 160, 165, 166, 167},
581 {40, 161, 166, 167, 168},
582 {41, 162, 167, 168, 169},
583 {42, 163, 168, 169, 170},
584 {43, 164, 169, 170, 171},
585 {44, 165, 170, 171, 172},
586 {45, 166, 171, 172, 173},
587 {46, 167, 172, 173, 174},
588 {47, 168, 173, 174, 175},
589 {48, 169, 174, 175, 176},
590 {49, 170, 175, 176, 177},
591 {50, 171, 176, 177, 178},
592 {51, 172, 177, 178, 179},
593 {52, 173, 178, 179, 180},
594 {53, 174, 179, 180, 181},
595 {54, 175, 180, 181, 182},
596 {55, 176, 181, 182, 183},
597 {56, 177, 182, 183, 184},
598 {57, 178, 183, 184, 185},
599 {58, 179, 184, 185, 186},
600 {59, 180, 185, 186, 187},
601 {60, 181, 186, 187, 188},
602 {61, 182, 187, 188, 189},
603 {62, 183, 188, 189, 190},
604 {63, 184, 189, 190, 191},
605 {64, 185, 190, 191, 192},
606 {65, 186, 191, 192, 193},
607 {66, 187, 192, 193, 194},
608 {67, 188, 193, 194, 195},
609 {68, 189, 194, 195, 196},
610 {69, 190, 195, 196, 197},
611 {70, 191, 196, 197, 198},
612 {71, 192, 197, 198, 199},
613 {72, 193, 198, 199, 200},
614 {73, 194, 199, 200, 201},
615 {74, 195, 200, 201, 202},
616 {75, 196, 201, 202, 203},
617 {76, 197, 202, 203, 204},
618 {77, 198, 203, 204, 205},
619 {78, 199, 204, 205, 206},
620 {79, 200, 205, 206, 207},
621 {80, 201, 206, 207, 208},
622 {81, 202, 207, 208, 209},
623 {82, 203, 208, 209, 210},
624 {83, 204, 209, 210, 211},
625 {84, 205, 210, 211, 212},
626 {85, 206, 211, 212, 213},
627 {86, 207, 212, 213, 214},
628 {87, 208, 213, 214, 215},
629 {88, 209, 214, 215, 216},
630 {89, 210, 215, 216, 217},
631 {90, 211, 216, 217, 218},
632 {91, 212, 217, 218, 219},
633 {92, 213, 218, 219, 220},
634 {93, 214, 219, 220, 221},
635 {94, 215, 220, 221, 222},
636 {95, 216, 221, 222, 223},
637 {96, 217, 222, 223, 224},
638 {97, 218, 223, 224, 225},
639 {98, 219, 224, 225, 226},
640 {99, 220, 225, 226, 227},
641 {100, 221, 226, 227, 228},
642 {101, 222, 227, 228, 229},
643 {102, 223, 228, 229, 230},
644 {103, 224, 229, 230, 231},
645 {104, 225, 230, 231, 232},
646 {105, 226, 231, 232, 233},
647 {106, 227, 232, 233, 234},
648 {107, 228, 233, 234, 235},
649 {108, 229, 234, 235, 236},
650 {109, 230, 235, 236, 237},
651 {110, 231, 236, 237, 238},
652 {111, 232, 237, 238, 239},
653 {112, 233, 238, 239, 240},
654 {113, 234, 239, 240, 241},
655 {114, 235, 240, 241, 242},
656 {115, 236, 241, 242, 243},
657 {116, 237, 242, 243, 244},
658 {117, 238, 243, 244, 245},
659 {118, 239, 244, 245, 246},
660 {119, 240, 245, 246, 247},
661 {120, 241, 246, 247, 248},
662 {121, 242, 247, 248, 249},
663 {122, 243, 248, 249, 250},
664 {123, 244, 249, 250, 251},
665 {124, 245, 250, 251, 252},
666 {125, 246, 251, 252, 253},
667 {126, 247, 252, 253, 254},
668 {127, 248, 253, 254},
669 };
670 gf2k_mul(c.data(), a.data(), b.data(), taps, 128);
671 }
672
673 // Performs field multiplication in GF2^k using a sparse matrix datastructure.
674 void gf2k_mul(BitW c[/*w*/], const BitW a[/*w*/], const BitW b[/*w*/],
675 const std::vector<uint16_t> M[], size_t w) const {
676 std::vector<BitW> t(w * 2);
677 gf2_polynomial_multiplier_karat(w, t.data(), a, b);
678
679 std::vector<BitW> tmp(w);
680 for (size_t i = 0; i < w; ++i) {
681 size_t n = 0;
682 for (auto ti : M[i]) {
683 tmp[n++] = t[ti];
684 }
685 c[i] = parity(0, n, tmp.data());
686 }
687 }
688
689 // a == 0
690 BitW eq0(size_t w, const BitW a[/*w*/]) const { return eq0(0, w, a); }
691
692 // a == b
693 BitW eq(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
694 return eq_reduce(0, w, a, b);
695 }
696
697 // a < b.
698 // Specialization of the subtractor for the case (a - b) < 0
699 BitW lt(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
700 if (w == 0) {
701 return bit(0);
702 } else {
703 BitW xeq, xlt;
704 lt_reduce(0, w, &xeq, &xlt, a, b);
705 return xlt;
706 }
707 }
708
709 // a <= b
710 BitW leq(size_t w, const BitW a[/*w*/], const BitW b[/*w*/]) const {
711 auto blt = lt(w, b, a);
712 return lnot(blt);
713 }
714
715 // Parallel prefix of various kinds
716 template <class T>
717 void scan(const std::function<void(T*, const T&, const T&)>& op, T x[],
718 size_t i0, size_t i1, bool backward = false) const {
719 // generic Sklansky scan
720 if (i1 - i0 > 1) {
721 size_t im = i0 + (i1 - i0) / 2;
722 scan(op, x, i0, im, backward);
723 scan(op, x, im, i1, backward);
724 if (backward) {
725 for (size_t i = i0; i < im; ++i) {
726 op(&x[i], x[i], x[im]);
727 }
728 } else {
729 for (size_t i = im; i < i1; ++i) {
730 op(&x[i], x[im - 1], x[i]);
731 }
732 }
733 }
734 }
735
736 void scan_and(BitW x[], size_t i0, size_t i1, bool backward = false) const {
737 scan<BitW>(
738 [&](BitW* out, const BitW& l, const BitW& r) { *out = land(&l, r); }, x,
739 i0, i1, backward);
740 }
741
742 void scan_or(BitW x[], size_t i0, size_t i1, bool backward = false) const {
743 scan<BitW>(
744 [&](BitW* out, const BitW& l, const BitW& r) { *out = lor(&l, r); }, x,
745 i0, i1, backward);
746 }
747
748 void scan_xor(BitW x[], size_t i0, size_t i1, bool backward = false) const {
749 scan<BitW>(
750 [&](BitW* out, const BitW& l, const BitW& r) { *out = lxor(&l, r); }, x,
751 i0, i1, backward);
752 }
753
754 template <size_t I0, size_t I1, size_t N>
755 bitvec<I1 - I0> slice(const bitvec<N>& a) const {
756 bitvec<I1 - I0> r;
757 for (size_t i = I0; i < I1; ++i) {
758 r[i - I0] = a[i];
759 }
760 return r;
761 }
762
763 // Little-endian append of A and B. A[0] is the LSB, B starts at
764 // position [NA].
765 template <size_t NA, size_t NB>
766 bitvec<NA + NB> vappend(const bitvec<NA>& a, const bitvec<NB>& b) const {
768 for (size_t i = 0; i < NA; ++i) {
769 r[i] = a[i];
770 }
771 for (size_t i = 0; i < NB; ++i) {
772 r[i + NA] = b[i];
773 }
774 return r;
775 }
776
777 template <size_t N>
778 bool vequal(const bitvec<N>* a, const bitvec<N>& b) const {
779 for (size_t i = 0; i < N; ++i) {
780 auto eai = eval((*a)[i]);
781 auto ebi = eval(b[i]);
782 if (eai != ebi) return false;
783 }
784 return true;
785 }
786
787 template <size_t N>
788 bitvec<N> vbit(uint64_t x) const {
789 bitvec<N> r;
790 bits(N, r.data(), x);
791 return r;
792 }
793
794 // shorthands for the silly "template" notation
795 v8 vbit8(uint64_t x) const { return vbit<8>(x); }
796 v32 vbit32(uint64_t x) const { return vbit<32>(x); }
797
798 template <size_t N>
799 bitvec<N> vnot(const bitvec<N>& x) const {
800 bitvec<N> r;
801 for (size_t i = 0; i < N; ++i) {
802 r[i] = lnot(x[i]);
803 }
804 return r;
805 }
806
807 template <size_t N>
808 bitvec<N> vand(const bitvec<N>* a, const bitvec<N>& b) const {
809 bitvec<N> r;
810 for (size_t i = 0; i < N; ++i) {
811 r[i] = land(&(*a)[i], b[i]);
812 }
813 return r;
814 }
815
816 template <size_t N>
817 bitvec<N> vand(const BitW* a, const bitvec<N>& b) const {
818 bitvec<N> r;
819 for (size_t i = 0; i < N; ++i) {
820 r[i] = land(a, b[i]);
821 }
822 return r;
823 }
824
825 template <size_t N>
826 bitvec<N> vor(const bitvec<N>* a, const bitvec<N>& b) const {
827 bitvec<N> r;
828 for (size_t i = 0; i < N; ++i) {
829 r[i] = lor(&(*a)[i], b[i]);
830 }
831 return r;
832 }
833 template <size_t N>
834 bitvec<N> vor_exclusive(const bitvec<N>* a, const bitvec<N>& b) const {
835 bitvec<N> r;
836 for (size_t i = 0; i < N; ++i) {
837 r[i] = lor_exclusive(&(*a)[i], b[i]);
838 }
839 return r;
840 }
841 template <size_t N>
842 bitvec<N> vxor(const bitvec<N>* a, const bitvec<N>& b) const {
843 bitvec<N> r;
844 for (size_t i = 0; i < N; ++i) {
845 r[i] = lxor(&(*a)[i], b[i]);
846 }
847 return r;
848 }
849
850 template <size_t N>
851 bitvec<N> vCh(const bitvec<N>* x, const bitvec<N>* y,
852 const bitvec<N>& z) const {
853 bitvec<N> r;
854 for (size_t i = 0; i < N; ++i) {
855 r[i] = lCh(&(*x)[i], &(*y)[i], z[i]);
856 }
857 return r;
858 }
859 template <size_t N>
860 bitvec<N> vMaj(const bitvec<N>* x, const bitvec<N>* y,
861 const bitvec<N>& z) const {
862 bitvec<N> r;
863 for (size_t i = 0; i < N; ++i) {
864 r[i] = lMaj(&(*x)[i], &(*y)[i], z[i]);
865 }
866 return r;
867 }
868
869 template <size_t N>
870 bitvec<N> vxor3(const bitvec<N>* x, const bitvec<N>* y,
871 const bitvec<N>& z) const {
872 bitvec<N> r;
873 for (size_t i = 0; i < N; ++i) {
874 r[i] = lxor3(&(*x)[i], &(*y)[i], z[i]);
875 }
876 return r;
877 }
878
879 template <size_t N>
880 bitvec<N> vshr(const bitvec<N>& a, size_t shift, size_t b = 0) const {
881 bitvec<N> r;
882 for (size_t i = 0; i < N; ++i) {
883 if (i + shift < N) {
884 r[i] = a[i + shift];
885 } else {
886 r[i] = bit(b);
887 }
888 }
889 return r;
890 }
891
892 template <size_t N>
893 bitvec<N> vshl(const bitvec<N>& a, size_t shift, size_t b = 0) const {
894 bitvec<N> r;
895 for (size_t i = 0; i < N; ++i) {
896 if (i >= shift) {
897 r[i] = a[i - shift];
898 } else {
899 r[i] = bit(b);
900 }
901 }
902 return r;
903 }
904
905 template <size_t N>
906 bitvec<N> vrotr(const bitvec<N>& a, size_t b) const {
907 bitvec<N> r;
908 for (size_t i = 0; i < N; ++i) {
909 r[i] = a[(i + b) % N];
910 }
911 return r;
912 }
913
914 template <size_t N>
915 bitvec<N> vrotl(const bitvec<N>& a, size_t b) const {
916 bitvec<N> r;
917 for (size_t i = 0; i < N; ++i) {
918 r[(i + b) % N] = a[i];
919 }
920 return r;
921 }
922
923 template <size_t N>
924 bitvec<N> vadd(const bitvec<N>& a, const bitvec<N>& b) const {
925 bitvec<N> r;
926 (void)parallel_prefix_add(N, &r[0], &a[0], &b[0]);
927 return r;
928 }
929 template <size_t N>
930 bitvec<N> vadd(const bitvec<N>& a, uint64_t val) const {
931 return vadd(a, vbit<N>(val));
932 }
933
934 template <size_t N>
935 BitW veq(const bitvec<N>& a, const bitvec<N>& b) const {
936 return eq(N, a.data(), b.data());
937 }
938 template <size_t N>
939 BitW veq(const bitvec<N>& a, uint64_t val) const {
940 auto v = vbit<N>(val);
941 return veq(a, v);
942 }
943 template <size_t N>
944 BitW vlt(const bitvec<N>* a, const bitvec<N>& b) const {
945 return lt(N, (*a).data(), b.data());
946 }
947 template <size_t N>
948 BitW vlt(const bitvec<N>& a, uint64_t val) const {
949 auto v = vbit<N>(val);
950 return vlt(&a, v);
951 }
952 template <size_t N>
953 BitW vlt(uint64_t a, const bitvec<N>& b) const {
954 auto va = vbit<N>(a);
955 return vlt(&va, b);
956 }
957 template <size_t N>
958 BitW vleq(const bitvec<N>* a, const bitvec<N>& b) const {
959 return leq(N, (*a).data(), b.data());
960 }
961 template <size_t N>
962 BitW vleq(const bitvec<N>& a, uint64_t val) const {
963 auto v = vbit<N>(val);
964 return vleq(&a, v);
965 }
966
967 // (a ^ val) & mask == 0
968 template <size_t N>
969 BitW veqmask(const bitvec<N>* a, uint64_t mask, const bitvec<N>& val) const {
970 auto r = vxor(a, val);
971 size_t n = pack(mask, N, &r[0]);
972 return eq0(0, n, &r[0]);
973 }
974
975 template <size_t N>
976 BitW veqmask(const bitvec<N>& a, uint64_t mask, uint64_t val) const {
977 auto v = vbit<N>(val);
978 return veqmask(&a, mask, v);
979 }
980
981 // I/O. This is a hack which only works if the backend supports
982 // bk_->{input,output}. Because C++ templates are lazily expanded,
983 // this class compiles even with backends that do not support I/O,
984 // as long as you don't expand vinput(), voutput().
985 EltW eltw_input() const { return bk_->input(); }
986 BitW input() const { return BitW(bk_->input(), f_); }
987 void output(const BitW& x, size_t i) const { bk_->output(eval(x), i); }
988 size_t wire_id(const BitW& v) const { return bk_->wire_id(v.x); }
989 size_t wire_id(const EltW& x) const { return bk_->wire_id(x); }
990
991 template <size_t N>
992 bitvec<N> vinput() const {
993 bitvec<N> r;
994 for (size_t i = 0; i < N; ++i) {
995 r[i] = input();
996 }
997 return r;
998 }
999
1000 template <size_t N>
1001 void voutput(const bitvec<N>& x, size_t i0) const {
1002 for (size_t i = 0; i < N; ++i) {
1003 output(x[i], i + i0);
1004 }
1005 }
1006
1007 template <size_t N>
1008 void vassert0(const bitvec<N>& x) const {
1009 for (size_t i = 0; i < N; ++i) {
1010 (void)assert0(x[i]);
1011 }
1012 }
1013
1014 template <size_t N>
1015 void vassert_eq(const bitvec<N>* x, const bitvec<N>& y) const {
1016 for (size_t i = 0; i < N; ++i) {
1017 (void)assert_eq(&(*x)[i], y[i]);
1018 }
1019 }
1020
1021 template <size_t N>
1022 void vassert_eq(const bitvec<N>& x, uint64_t y) const {
1023 auto v = vbit<N>(y);
1024 vassert_eq(&x, v);
1025 }
1026
1027 template <size_t N>
1028 void vassert_is_bit(const bitvec<N>& a) const {
1029 for (size_t i = 0; i < N; ++i) {
1030 (void)assert_is_bit(a[i]);
1031 }
1032 }
1033
1034 private:
1035 // return one quad gate for the product eval(a)*eval(b),
1036 // optimizing some "obvious" cases.
1037 BitW mulv(const BitW* a, const BitW& b) const {
1038 if (a->c1 == zero()) {
1039 return rebase(zero(), a->c0, b);
1040 } else if (b.c1 == zero()) {
1041 return mulv(&b, *a);
1042 } else {
1043 // Avoid creating the intermediate term 1 * a.x * b.x which is
1044 // likely a useless node. Moreover, two nodes (k1 * a.x * b.x)
1045 // and (k2 * a.x * b.x) will detect the common subexpression
1046 // (a.x * b.x), which will confusingly increment the
1047 // common-subexpression counter.
1048 EltW x = axy(mulf(a->c1, b.c1), &a->x, b.x);
1049 x = axpy(&x, mulf(a->c0, b.c1), b.x);
1050 x = axpy(&x, mulf(a->c1, b.c0), a->x);
1051 x = apy(x, mulf(a->c0, b.c0));
1052 return BitW(x, f_);
1053 }
1054 }
1055
1056 BitW addv(const BitW& a, const BitW& b) const {
1057 if (a.c1 == zero()) {
1058 return BitW(addf(a.c0, b.c0), b.c1, b.x);
1059 } else if (b.c1 == zero()) {
1060 return addv(b, a);
1061 } else {
1062 EltW x = ax(a.c1, a.x);
1063 auto axb = ax(b.c1, b.x);
1064 x = add(&x, axb);
1065 x = apy(x, addf(a.c0, b.c0));
1066 return BitW(x, f_);
1067 }
1068 }
1069
1070 BitW lxor_aux(const BitW& a, const BitW& b, PrimeFieldTypeTag tt) const {
1071 // a * b in the xor basis TRUE -> -1, FALSE -> 1
1072 // map a, b from standard basis to xor basis
1073 Elt mtwo = f_.negf(f_.two());
1074 Elt half = f_.half();
1075 Elt mhalf = f_.negf(half);
1076
1077 BitW a1 = rebase(one(), mtwo, a);
1078 BitW b1 = rebase(one(), mtwo, b);
1079 BitW p = mulv(&a1, b1);
1080 return rebase(half, mhalf, p);
1081 }
1082 BitW lxor_aux(const BitW& a, const BitW& b, BinaryFieldTypeTag tt) const {
1083 return addv(a, b);
1084 }
1085
1086 size_t pack(uint64_t mask, size_t n, BitW a[/*n*/]) const {
1087 size_t j = 0;
1088 for (size_t i = 0; i < n; ++i) {
1089 if (mask & 1) {
1090 a[j++] = a[i];
1091 }
1092 mask >>= 1;
1093 }
1094 return j;
1095 }
1096
1097 // carry-propagation equations
1098 // (g0, p0) + (g1, p1) = (g1 | (g0 & p1), p0 & p1)
1099 // Accumulate in-place into (g1, p1).
1100 //
1101 // We use the property that g1 and p1 are mutually exclusive (g1&p1
1102 // is false), and therefore g1 and (g0 & p1) are also mutually
1103 // exclusive.
1104 void gp_reduce(const BitW& g0, const BitW& p0, BitW* g1, BitW* p1) const {
1105 auto g0p1 = land(&g0, *p1);
1106 *g1 = lor_exclusive(g1, g0p1);
1107 *p1 = land(&p0, *p1);
1108 }
1109
1110 // ripple carry propagation
1111 void ripple_scan(std::vector<BitW>& g, std::vector<BitW>& p, size_t i0,
1112 size_t i1) const {
1113 for (size_t i = i0 + 1; i < i1; ++i) {
1114 gp_reduce(g[i - 1], p[i - 1], &g[i], &p[i]);
1115 }
1116 }
1117
1118 // parallel-prefix carry propagation, Sklansky-style [1960]
1119 void sklansky_scan(std::vector<BitW>& g, std::vector<BitW>& p, size_t i0,
1120 size_t i1) const {
1121 if (i1 - i0 > 1) {
1122 size_t im = i0 + (i1 - i0) / 2;
1123 sklansky_scan(g, p, i0, im);
1124 sklansky_scan(g, p, im, i1);
1125 for (size_t i = im; i < i1; ++i) {
1126 gp_reduce(g[im - 1], p[im - 1], &g[i], &p[i]);
1127 }
1128 }
1129 }
1130
1131 // generic add in generate/propagate form, parametrized
1132 // by the scan primitive.
1133 //
1134 // (carry, c) = a + b
1135 BitW generic_gp_add(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
1136 const BitW b[/*w*/],
1137 void (Logic::*scan)(std::vector<BitW>& /*g*/,
1138 std::vector<BitW>& /*p*/,
1139 size_t /*i0*/, size_t /*i1*/)
1140 const) const {
1141 if (w == 0) {
1142 return bit(0);
1143 } else {
1144 std::vector<BitW> g(w), p(w);
1145 for (size_t i = 0; i < w; ++i) {
1146 g[i] = land(&a[i], b[i]);
1147 p[i] = lxor(&a[i], b[i]);
1148 c[i] = p[i];
1149 }
1150 (this->*scan)(g, p, 0, w);
1151 for (size_t i = 1; i < w; ++i) {
1152 c[i] = lxor(&c[i], g[i - 1]);
1153 }
1154 return g[w - 1];
1155 }
1156 }
1157
1158 BitW generic_gp_sub(size_t w, BitW c[/*w*/], const BitW a[/*w*/],
1159 const BitW b[/*w*/],
1160 void (Logic::*scan)(std::vector<BitW>& /*g*/,
1161 std::vector<BitW>& /*p*/,
1162 size_t /*i0*/, size_t /*i1*/)
1163 const) const {
1164 // implement as ~(~a + b)
1165 std::vector<BitW> t(w);
1166 for (size_t j = 0; j < w; ++j) {
1167 t[j] = lnot(a[j]);
1168 }
1169 BitW carry = generic_gp_add(w, c, t.data(), b, scan);
1170 for (size_t j = 0; j < w; ++j) {
1171 c[j] = lnot(c[j]);
1172 }
1173 return carry;
1174 }
1175
1176 // Recursion for the a < b comparison.
1177 // Let a = (a1, a0) and b = (b1, b0). Then:
1178 //
1179 // a == b iff a1 == b1 && a0 == b0
1180 // a < b iff a1 < b1 || (a1 == b1 && a0 < b0)
1181 void lt_reduce(size_t i0, size_t i1, BitW* xeq, BitW* xlt,
1182 const BitW a[/*w*/], const BitW b[/*w*/]) const {
1183 if (i1 - i0 > 1) {
1184 BitW eq0, eq1, lt0, lt1;
1185 size_t im = i0 + (i1 - i0) / 2;
1186 lt_reduce(i0, im, &eq0, &lt0, a, b);
1187 lt_reduce(im, i1, &eq1, &lt1, a, b);
1188 *xeq = land(&eq1, eq0);
1189 auto lt0_and_eq1 = land(&eq1, lt0);
1190 *xlt = lor_exclusive(&lt1, lt0_and_eq1);
1191 } else {
1192 auto axb = lxor(&a[i0], b[i0]);
1193 *xeq = lnot(axb);
1194 auto na = lnot(a[i0]);
1195 *xlt = land(&na, b[i0]);
1196 }
1197 }
1198
1199 BitW parity(size_t i0, size_t i1, const BitW a[]) const {
1200 if (i1 <= i0) {
1201 return bit(0);
1202 } else if (i1 == i0 + 1) {
1203 return a[i0];
1204 } else {
1205 size_t im = i0 + (i1 - i0) / 2;
1206 auto lp = parity(i0, im, a);
1207 auto rp = parity(im, i1, a);
1208 return lxor(&lp, rp);
1209 }
1210 }
1211
1212 BitW eq0(size_t i0, size_t i1, const BitW a[]) const {
1213 if (i1 <= i0) {
1214 return bit(1);
1215 } else if (i1 == i0 + 1) {
1216 return lnot(a[i0]);
1217 } else {
1218 size_t im = i0 + (i1 - i0) / 2;
1219 auto le = eq0(i0, im, a);
1220 auto re = eq0(im, i1, a);
1221 return land(&le, re);
1222 }
1223 }
1224
1225 BitW eq_reduce(size_t i0, size_t i1, const BitW a[], const BitW b[]) const {
1226 if (i1 <= i0) {
1227 return bit(1);
1228 } else if (i1 == i0 + 1) {
1229 return lnot(lxor(&a[i0], b[i0]));
1230 } else {
1231 size_t im = i0 + (i1 - i0) / 2;
1232 auto le = eq_reduce(i0, im, a, b);
1233 auto re = eq_reduce(im, i1, a, b);
1234 return land(&le, re);
1235 }
1236 }
1237
1238 const Backend* bk_;
1239};
1240
1241} // namespace proofs
1242
1243#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_LOGIC_LOGIC_H_
Definition logic.h:144
Definition gf2_128.h:63
Definition logic.h:130