Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
cbor.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_CBOR_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_CBOR_H_
17
18#include <stddef.h>
19#include <stdint.h>
20
21#include <array>
22#include <vector>
23
24#include "circuits/cbor_parser/cbor_byte_decoder.h"
25#include "circuits/cbor_parser/cbor_constants.h"
26#include "circuits/cbor_parser/cbor_pluck.h"
27#include "circuits/cbor_parser/scan.h"
28#include "circuits/logic/counter.h"
29#include "circuits/logic/memcmp.h"
30#include "circuits/logic/routing.h"
31#include "util/panic.h"
32
33namespace proofs {
34template <class Logic, size_t IndexBits = CborConstants::kIndexBits>
35class Cbor {
36 public:
37 using CounterL = Counter<Logic>;
38 using CborBD = CborByteDecoder<Logic>;
39 using Field = typename Logic::Field;
40 using EltW = typename Logic::EltW;
41 using CEltW = typename CounterL::CEltW;
42 using BitW = typename Logic::BitW;
43 using v8 = typename Logic::v8;
44 static constexpr size_t kIndexBits = IndexBits;
45 static constexpr size_t kNCounters = CborConstants::kNCounters;
46 using bv_counters = typename Logic::template bitvec<kNCounters>;
47
48 // a bitvector that contains an index into the input
49 // (byte) array.
50 using vindex = typename Logic::template bitvec<kIndexBits>;
51
52 explicit Cbor(const Logic& l) : l_(l), ctr_(l), bd_(l), bp_(l) {}
53
55 EltW invprod_decode; // inverse of a certain product, see assert_decode()
56 CEltW cc0_counter; // initial value of counter[0]
57 EltW invprod_parse; // inverse of a certain product, see assert_parse()
58 };
59
61 EltW encoded_sel_header;
62 };
63
64 //------------------------------------------------------------
65 // Decoder (lexer)
66 //------------------------------------------------------------
67 struct decode {
68 // wires generated by the byte decoder given the input.
69 typename CborBD::decode bd;
70
71 // wires generated by the lexer from witnesses
72 BitW header;
73 };
74
75 void assert_decode(size_t n, const decode ds[/*n*/],
76 const position_witness pw[/*n*/],
77 const global_witness& gw) const {
78 const Logic& L = l_; // shorthand
79 Scan<CounterL> SC(ctr_);
80
81 // -------------------------------------------------------------
82 // Decoder didn't fail
83 for (size_t i = 0; i < n; ++i) {
84 L.assert_implies(&ds[i].header, L.lnot(ds[i].bd.invalid));
85 }
86 // if LENGTH_PLUS_NEXT_V8 is TRUE in the last position,
87 // then the input is invalid.
88 L.assert_implies(&ds[n - 1].header,
89 L.lnot(ds[n - 1].bd.length_plus_next_v8));
90
91 // if COUNT_IS_NEXT_V8 is TRUE in the last position,
92 // then the input is invalid.
93 L.assert_implies(&ds[n - 1].header, L.lnot(ds[n - 1].bd.count_is_next_v8));
94
95 // -------------------------------------------------------------
96 // Headers are where they are supposed to be.
97 // First, compute the segmented scan
98 // slen[i] = header[i] ? length[i] : (slen[i-1] + mone[i])
99 std::vector<CEltW> mone(n);
100 std::vector<BitW> header(n);
101 std::vector<CEltW> length(n);
102 std::vector<CEltW> slen_next(n);
103
104 for (size_t i = 0; i + 1 < n; ++i) {
105 mone[i] = ctr_.mone();
106 header[i] = ds[i].header;
107 length[i] = ds[i].bd.length;
108 if (i + 1 < n) {
109 CEltW len_i =
110 ctr_.ite0(&ds[i].bd.length_plus_next_v8, ds[i + 1].bd.as_counter);
111 length[i] = ctr_.add(&length[i], len_i);
112 }
113 }
114
115 SC.add(n, slen_next.data(), header.data(), length.data(), mone.data());
116
117 // Now check the headers.
118 {
119 // "The first position is a header"
120 L.assert1(header[0]);
121 }
122
123 {
124 EltW one = L.konst(L.one());
125 CEltW mone_counter = ctr_.mone();
126
127 // "\A I : (SLEN_NEXT[I] == 1) IFF HEADER[I+1]"
128 {
129 // "\A I : HEADER[I+1] => (SLEN_NEXT[I] == 1)"
130 for (size_t i = 0; i + 1 < n; ++i) {
131 CEltW implies =
132 ctr_.ite0(&header[i + 1], ctr_.add(&slen_next[i], mone_counter));
133 ctr_.assert0(implies);
134 }
135 }
136 {
137 // "\A I : (SLEN_NEXT[I] == 1) => HEADER[i+1] "
138 // Verify via the invertibility of
139 //
140 // PROD_{I, L} HEADER[I+1] ? 1 : (SLEN_NEXT[I] - 1)
141 //
142 auto f = [&](size_t i) {
143 CEltW snm1 = ctr_.add(&slen_next[i], mone_counter);
144 return L.mux(&header[i + 1], &one, ctr_.znz_indicator(snm1));
145 };
146 EltW prod = L.mul(0, n - 1, f);
147 auto want_one = L.mul(&prod, gw.invprod_decode);
148 L.assert_eq(&want_one, one);
149 }
150 }
151 }
152
153 //------------------------------------------------------------
154 // Parser
155 //------------------------------------------------------------
156 using counters = std::array<CEltW, kNCounters>;
158 bv_counters sel;
159 counters c;
160 };
161
162 void parse(size_t n, parse_output ps[/*n*/], const decode ds[/*n*/],
163 const position_witness pw[/*n*/], const global_witness& gw) const {
164 std::vector<CEltW> ddss(n);
165 std::vector<BitW> SS(n);
166 std::vector<CEltW> AA(n);
167 std::vector<CEltW> BB(n);
168
169 const Logic& L = l_; // shorthand
170 Scan<CounterL> SC(ctr_);
171
172 for (size_t i = 0; i < n; ++i) {
173 ps[i].sel = bp_.pluckj(pw[i].encoded_sel_header);
174 }
175
176 CEltW mone = ctr_.mone();
177 for (size_t l = 0; l < kNCounters; ++l) {
178 for (size_t i = 0; i < n; ++i) {
179 // at the selected headers, decrement the level-L counter.
180 auto dp = L.land(&ds[i].header, ps[i].sel[l]);
181 ddss[i] = ctr_.ite0(&dp, mone);
182 }
183
184 if (l == 0) {
185 // do level-0 as an unsegmented parallel prefix
186 // on DDSS starting at CC0.
187 // One can achieve the same effect by using the segmented prefix
188 // after initializing SS and AA as follows:
189 //
190 // SS[0] = L.bit(1);
191 // AA[0] = gw.cc0_counter;
192 // for (size_t i = 1; i < n; ++i) {
193 // SS[i] = L.bit(0);
194 // AA[i] = L.konst(0);
195 // }
196 //
197 // The compiler is smart enough to constant-fold the segment
198 // SS[i] and produces the same circuit in both cases, but
199 // there is no point in wasting compiler time and the
200 // unsegmented prefix is more straightforward anyway.
201 //
202 // Note that AA, SS are uninitialized here. They will be initialized
203 // below for the next level.
204 ddss[0] = gw.cc0_counter;
205 SC.add(n, BB.data(), ddss.data());
206 } else {
207 SC.add(n, BB.data(), SS.data(), AA.data(), ddss.data());
208 }
209
210 // output the result of the parallel prefix
211 for (size_t i = 0; i < n; ++i) {
212 ps[i].c[l] = BB[i];
213 }
214
215 // prepare SS, AA for the next level
216 for (size_t i = 0; i < n; ++i) {
217 CEltW newc = ctr_.as_counter(ds[i].bd.tagp);
218 CEltW count = ds[i].bd.count_as_counter;
219 if (i + 1 < n) {
220 count = ctr_.mux(&ds[i].bd.count_is_next_v8, &ds[i + 1].bd.as_counter,
221 count);
222 }
223 newc = ctr_.add(&newc, ctr_.ite0(&ds[i].bd.itemsp, count));
224 newc = ctr_.add(&newc, ctr_.ite0(&ds[i].bd.mapp, count));
225 AA[i] = newc;
226
227 auto sel = L.land(&ps[i].sel[l], ds[i].header);
228 auto tag = L.lor(&ds[i].bd.tagp, ds[i].bd.itemsp);
229 SS[i] = L.land(&sel, tag);
230 }
231 }
232
233 // Assert that we don't want to start new segments at a level
234 // that does not exist.
235 for (size_t i = 0; i < n; ++i) {
236 L.assert0(SS[i]);
237 }
238 }
239
240 void assert_parse(size_t n, const decode ds[/*n*/],
241 const parse_output ps[/*n*/],
242 const global_witness& gw) const {
243 const Logic& L = l_; // shorthand
244
245 for (size_t i = 0; i < n; ++i) {
246 // "The SEL witnesses are mutually exclusive."
247 // The bit plucker guarantees that the SEL witnesses
248 // are bits, but in principle one could feed an
249 // out-of-domain input to the bit plucker that
250 // sets more than one bit.
251 // Another way to accomplish the same effect would
252 // be to range-check the input to the bit plucker.
253 for (size_t l = 0; l < kNCounters; ++l) {
254 for (size_t m = l + 1; m < kNCounters; ++m) {
255 L.assert0(L.land(&ps[i].sel[l], ps[i].sel[m]));
256 }
257 }
258
259 // "at a header, at least one SEL bit is set"
260 auto sum = L.bit(0);
261 for (size_t l = 0; l < kNCounters; ++l) {
262 // known to be exclusive by the test above
263 sum = L.lor_exclusive(&sum, ps[i].sel[l]);
264 }
265 L.assert_implies(&ds[i].header, sum);
266 }
267
268 // "All counters are zero at the end of the input"
269 // COUNTER[I][L] is the state of the parser at the end
270 // of position I, so COUNTER[N-1][L] is the final state.
271 for (size_t l = 0; l < kNCounters; ++l) {
272 ctr_.assert0(ps[n - 1].c[l]);
273 }
274
275 // SEL[0][0] is set. We implicitly define COUNTER[-1][L] to make
276 // this the correct choice.
277 L.assert1(ps[0].sel[0]);
278
279 for (size_t i = 0; i + 1 < n; ++i) {
280 // "If SEL[I+1][L] is set, then COUNTER[I][L] is the nonzero
281 // counter of maximal L. (COUNTER[I][L] contains the output
282 // counter of stage I, which affects SEL[I+1].) Here we check
283 // maximality: COUNTER[I][J]=0 for J>L. See below for
284 // SEL[I+1][L] => (COUNTER[I][L] != 0).
285 BitW b = ps[i + 1].sel[0];
286 for (size_t l = 1; l < kNCounters; ++l) {
287 // b => COUNTER[i][l] == 0
288 ctr_.assert0(ctr_.ite0(&b, ps[i].c[l]));
289 b = L.lor(&b, ps[i + 1].sel[l]);
290 }
291 }
292
293 // "SEL[I+1][L] => (COUNTER[I][L] != 0)"
294 // Check via the invertibility of
295 //
296 // PROD_{I, L} SEL[I+1][L] ? COUNTER[I][L] : 1
297 std::vector<EltW> prod(kNCounters);
298 auto one = L.konst(1);
299 for (size_t l = 0; l < kNCounters; ++l) {
300 auto f = [&](size_t i) {
301 EltW cc = ctr_.znz_indicator(ps[i].c[l]);
302 return L.mux(&ps[i + 1].sel[l], &cc, one);
303 };
304 prod[l] = L.mul(0, n - 1, f);
305 }
306
307 EltW p = L.mul(0, kNCounters, [&](size_t l) { return prod[l]; });
308 auto want_one = L.mul(&p, gw.invprod_parse);
309 L.assert_eq(&want_one, one);
310 }
311
312 //------------------------------------------------------------
313 // "J is the header of a string of length LEN containing BYTES"
314 //------------------------------------------------------------
315 void assert_text_at(size_t n, const vindex& j, size_t len,
316 const uint8_t bytes[/*len*/],
317 const decode ds[/*n*/]) const {
318 const Logic& L = l_; // shorthand
319 const Routing<Logic> R(L);
320
321 // we don't handle long strings
322 proofs::check(len < 24, "len < 24");
323
324 assert_header(n, j, ds);
325
326 std::vector<EltW> A(n);
327 for (size_t i = 0; i < n; ++i) {
328 A[i] = ds[i].bd.as_scalar;
329 }
330
331 // shift len+1 bytes, including the header.
332 std::vector<EltW> B(len + 1);
333 const EltW defaultA = L.konst(256); // a constant that cannot appear in A[]
334 R.shift(j, len + 1, B.data(), n, A.data(), defaultA, /*unroll=*/3);
335
336 size_t expected_header = (3 << 5) + len;
337 L.assert_eq(&B[0], L.konst(expected_header));
338 for (size_t i = 0; i < len; ++i) {
339 auto bi = L.konst(bytes[i]);
340 L.assert_eq(&B[i + 1], bi);
341 }
342 }
343
344 //------------------------------------------------------------
345 // "J is a header containing unsigned U."
346 //------------------------------------------------------------
347 void assert_unsigned_at(size_t n, const vindex& j, uint64_t u,
348 const decode ds[/*n*/]) const {
349 // only small u for now
350 proofs::check(u < 24, "u < 24");
351
352 size_t expected = (0 << 5) + u;
353 assert_atom_at(n, j, l_.konst(expected), ds);
354 }
355
356 //------------------------------------------------------------
357 // "J is a header containing negative U." (U >= 0, and
358 // CBOR distinguishes 0 from -0 apparently)
359 //------------------------------------------------------------
360 void assert_negative_at(size_t n, const vindex& j, uint64_t u,
361 const decode ds[/*n*/]) const {
362 // only small u for now
363 proofs::check(u < 24, "u < 24");
364
365 size_t expected = (1 << 5) + u;
366 assert_atom_at(n, j, l_.konst(expected), ds);
367 }
368
369 //------------------------------------------------------------
370 // "J is a header containing a boolean primitive (0xF4 or 0xF5)."
371 //
372 //------------------------------------------------------------
373 void assert_bool_at(size_t n, const vindex& j, bool val,
374 const decode ds[/*n*/]) const {
375 size_t expected = (7 << 5) + (val ? 21 : 20);
376 assert_atom_at(n, j, l_.konst(expected), ds);
377 }
378
379 // Helps assemble the checks for date assertions.
380 void date_helper(size_t n, const vindex& j, const decode ds[/*n*/],
381 std::vector<v8>& B /* size 22 */) const {
382 const Logic& L = l_; // shorthand
383 const Routing<Logic> R(L);
384 assert_header(n, j, ds);
385
386 std::vector<v8> A(n);
387 for (size_t i = 0; i < n; ++i) {
388 A[i] = ds[i].bd.as_bits;
389 }
390
391 const v8 defaultA =
392 L.template vbit<8>(0); // a constant that cannot appear in A[]
393 R.shift(j, 20 + 2, B.data(), n, A.data(), defaultA, /*unroll=*/3);
394
395 // Check for tag: date/time string.
396 L.vassert_eq(&B[0], L.template vbit<8>(0xc0));
397
398 // Check for string(20)
399 L.vassert_eq(&B[1], L.template vbit<8>(0x74));
400 }
401
402 //------------------------------------------------------------
403 // "J is a header containing date d < now." now is 20 bytes
404 // in the format 2023-11-01T09:00:00Z
405 //------------------------------------------------------------
406 void assert_date_before_at(size_t n, const vindex& j, const v8 now[/* 20 */],
407 const decode ds[/*n*/]) const {
408 const Logic& L = l_; // shorthand
409 const Memcmp<Logic> CMP(L);
410 std::vector<v8> B(20 + 2);
411 date_helper(n, j, ds, B);
412 auto lt = CMP.lt(20, &B[2], now);
413 L.assert1(lt);
414 }
415
416 //------------------------------------------------------------
417 // "J is a header containing date d > now." now is 20 bytes in the
418 // format 2023-11-01T09:00:00Z
419 // ------------------------------------------------------------
420 void assert_date_after_at(size_t n, const vindex& j, const v8 now[/* 20 */],
421 const decode ds[/*n*/]) const {
422 const Logic& L = l_; // shorthand
423 const Memcmp<Logic> CMP(L);
424 std::vector<v8> B(20 + 2);
425 date_helper(n, j, ds, B);
426 auto lt = CMP.lt(20, now, &B[2]);
427 L.assert1(lt);
428 }
429
430 //------------------------------------------------------------
431 // "J is a header containing represented by the byte EXPECTED in the
432 // input."
433 //------------------------------------------------------------
434 void assert_atom_at(size_t n, const vindex& j, const EltW& expected,
435 const decode ds[/*n*/]) const {
436 const Logic& L = l_; // shorthand
437 const Routing<Logic> R(L);
438
439 assert_header(n, j, ds);
440
441 std::vector<EltW> A(n);
442 for (size_t i = 0; i < n; ++i) {
443 A[i] = ds[i].bd.as_scalar;
444 }
445
446 EltW B[1];
447 size_t unroll = 3;
448 R.shift(j, 1, B, n, A.data(), L.konst(256), unroll);
449 L.assert_eq(&B[0], expected);
450 }
451
452 //------------------------------------------------------------
453 // "Position j contains a header"
454 //------------------------------------------------------------
455 void assert_header(size_t n, const vindex& j, const decode ds[/*n*/]) const {
456 const Logic& L = l_; // shorthand
457
458 L.vassert_is_bit(j);
459
460 // giant dot product since the veq(j, .) terms are mutually exclusive.
461 auto f = [&](size_t i) { return L.land(&ds[i].header, L.veq(j, i)); };
462 L.assert1(L.lor_exclusive(0, n, f));
463 }
464
465 //------------------------------------------------------------
466 // "A map starts at position j"
467 //------------------------------------------------------------
468 void assert_map_header(size_t n, const vindex& j,
469 const decode ds[/*n*/]) const {
470 const Logic& L = l_; // shorthand
471
472 L.vassert_is_bit(j);
473
474 // giant dot product since the veq(j, .) terms are mutually exclusive.
475 auto f = [&](size_t i) {
476 auto eq_ji = L.veq(j, i);
477 auto dsi = L.land(&ds[i].bd.mapp, ds[i].header);
478 return L.land(&eq_ji, dsi);
479 };
480 L.assert1(L.lor_exclusive(0, n, f));
481 }
482
483 //------------------------------------------------------------
484 // "Position M starts a map of level LEVEL. (K, V) are headers
485 // representing the J-th pair in that map"
486 //------------------------------------------------------------
487 void assert_map_entry(size_t n, const vindex& m, size_t level,
488 const vindex& k, const vindex& v, const vindex& j,
489 const decode ds[/*n*/],
490 const parse_output ps[/*n*/]) const {
491 const Logic& L = l_; // shorthand
492 const Routing<Logic> R(L);
493
494 assert_map_header(n, m, ds);
495 assert_header(n, k, ds);
496 assert_header(n, v, ds);
497
498 for (size_t l = 0; l < kNCounters; ++l) {
499 // Hack: temporarily treat CEltW as EltW so as to reuse
500 // the shifter.
501 std::vector<EltW> A(n);
502 for (size_t i = 0; i < n; ++i) {
503 A[i] = ps[i].c[l].e;
504 }
505
506 // Select counters[m], counters[k], and counters[v].
507 CEltW cm, ck, cv;
508
509 const size_t unroll = 3;
510 R.shift(m, 1, &cm.e, n, A.data(), L.konst(0), unroll);
511 R.shift(k, 1, &ck.e, n, A.data(), L.konst(0), unroll);
512 R.shift(v, 1, &cv.e, n, A.data(), L.konst(0), unroll);
513
514 if (l <= level) {
515 // Counters[L] must agree at the key, value, and root
516 // of the map.
517 ctr_.assert_eq(&cm, ck);
518 ctr_.assert_eq(&cm, cv);
519 } else if (l == level + 1) {
520 CEltW one = ctr_.as_counter(1);
521 CEltW two = ctr_.as_counter(2);
522 // LEVEL+1 counters must have the right number of decrements.
523 // Specifically, if the counter at the map is N, then the j-th
524 // key has N-(2*j+1) and the j-th value has N-(2*j+2)
525 CEltW jctr = ctr_.as_counter(j);
526 CEltW twoj = ctr_.add(&jctr, jctr);
527 ctr_.assert_eq(&cm, ctr_.add(&ck, ctr_.add(&twoj, one)));
528 ctr_.assert_eq(&cm, ctr_.add(&cv, ctr_.add(&twoj, two)));
529 } else {
530 // not sure if this is necessary, but all other counters
531 // of CM are supposed to be zero.
532 ctr_.assert0(cm);
533 }
534 }
535 }
536
537 //------------------------------------------------------------
538 // "JROOT is the first byte of the actual (unpadded) input and
539 // all previous bytes are 0"
540 //------------------------------------------------------------
541 void assert_input_starts_at(size_t n, const vindex& jroot,
542 const vindex& input_len,
543 const decode ds[/*n*/]) const {
544 const Logic& L = l_; // shorthand
545
546 L.assert1(L.vleq(input_len, n));
547 L.assert1(L.vlt(jroot, n));
548 auto tot = L.vadd(jroot, input_len);
549 L.vassert_eq(tot, n);
550
551 for (size_t i = 0; i < n; ++i) {
552 L.assert0(L.lmul(&ds[i].bd.as_scalar, L.vlt(i, jroot)));
553 }
554 }
555
556 //------------------------------------------------------------
557 // Utilities
558 //------------------------------------------------------------
559 // The circuit accepts up to N input positions, of which
560 // INPUT_LEN are actual input and the rest are ignored.
561 void decode_all(size_t n, decode ds[/*n*/], const v8 in[/*n*/],
562 const position_witness pw[/*n*/]) const {
563 for (size_t i = 0; i < n; ++i) {
564 ds[i].bd = bd_.decode_one_v8(in[i]);
565 ds[i].header = bp_.pluckb(pw[i].encoded_sel_header);
566 }
567 }
568
569 void decode_and_assert_decode(size_t n, decode ds[/*n*/], const v8 in[/*n*/],
570 const position_witness pw[/*n*/],
571 const global_witness& gw) const {
572 decode_all(n, ds, in, pw);
573 assert_decode(n, ds, pw, gw);
574 }
575
576 void decode_and_assert_decode_and_parse(size_t n, decode ds[/*n*/],
577 parse_output ps[/*n*/],
578 const v8 in[/*n*/],
579 const position_witness pw[/*n*/],
580 const global_witness& gw) const {
581 decode_and_assert_decode(n, ds, in, pw, gw);
582 parse(n, ps, ds, pw, gw);
583 assert_parse(n, ds, ps, gw);
584 }
585
586 private:
587 const Logic& l_;
588 const CounterL ctr_;
589 const CborBD bd_;
590 const CborPlucker<Logic, kNCounters> bp_;
591};
592} // namespace proofs
593
594#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_CBOR_H_
Definition cbor_byte_decoder.h:25
Definition logic.h:38
Definition scan.h:24
Definition cbor.h:67
Definition cbor.h:54
Definition cbor.h:157
Definition cbor.h:60
Definition cbor_byte_decoder.h:39
Definition logic.h:130