Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
prover_layers.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
16#define PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
17
18#include <stddef.h>
19
20#include <memory>
21#include <vector>
22
23#include "arrays/affine.h"
24#include "arrays/dense.h"
25#include "arrays/eqs.h"
26#include "sumcheck/circuit.h"
27#include "sumcheck/quad.h"
28#include "sumcheck/transcript_sumcheck.h"
29#include "util/panic.h"
30
31namespace proofs {
32
33// A high level idea is partially described in chapter 4.6.7 "Leveraging Data
34// Parallelism for Further Speedups" in the book "Proofs, Arguments, and
35// Zero-Knowledge" by Justin Thaler.
36template <class Field>
37class ProverLayers {
38 using Elt = typename Field::Elt;
39
40 public:
41 using inputs = std::vector<std::unique_ptr<Dense<Field>>>;
42
43 explicit ProverLayers(const Field& f) : f_(f) {}
44
45 // Evaluate CIRCUIT on input wires W0. This function stores the
46 // input wires of each layer L into IN->at(L), and returns the
47 // final output. This asymmetry reflects the fact that for L
48 // layers there are L+1 meaningful sets of wires, and that the
49 // prover needs IN while the verifier needs the final output.
50 std::unique_ptr<Dense<Field>> eval_circuit(inputs* in,
51 const Circuit<Field>* circ,
52 std::unique_ptr<Dense<Field>> W0,
53 const Field& F) {
54 if (in == nullptr || circ == nullptr || W0 == nullptr) return nullptr;
55
56 std::unique_ptr<Dense<Field>> finalV;
57 size_t nl = circ->nl, nc = circ->nc;
58 check(nl >= 1, "nl >= 1");
59 check(nc >= 1, "nc >= 1");
60
61 Dense<Field>* W = W0.get();
62
63 in->resize(nl);
64 in->at(nl - 1).swap(W0);
65
66 // Allocate memory and evaluate layer on input W and output V
67 for (size_t l = nl; l-- > 0;) {
68 Dense<Field>* V;
69 if (l > 0) {
70 // input of layer l-1 = output of layer l
71 in->at(l - 1) = std::make_unique<Dense<Field>>(nc, circ->l[l - 1].nw);
72 V = in->at(l - 1).get();
73 } else {
74 // final output = output of layer 0
75 finalV = std::make_unique<Dense<Field>>(nc, circ->nv);
76 V = finalV.get();
77 }
78
79 bool ok = eval_quad(circ->l[l].quad.get(), V, W, F);
80 if (!ok) {
81 // Early exit in case of assertion failure.
82 // In this case IN is only partially allocated.
83 // To avoid ambiguities, free all memory that we may have allocated.
84 for (size_t i = 0; i < nl; ++i) {
85 in->at(i) = nullptr;
86 }
87 finalV = nullptr;
88
89 return /*finalV=*/nullptr;
90 }
91
92 W = V;
93 }
94
95 return finalV;
96 }
97
98 protected:
99 const Field& f_;
100
101 // A struct that collects the bindings generated while proving one
102 // layer, to serve as initial bindings for the next layer.
103 // This protected class must be defined before the public section.
104 struct bindings {
105 size_t logv;
106 Elt q[Proof<Field>::kMaxBindings];
107 Elt g[2][Proof<Field>::kMaxBindings];
108 };
109
110 // Generate proof for circuit, as a protected member, the caller must
111 // ensure that input parameters are valid.
112 void prove(Proof<Field>* pr, const Proof<Field>* pad,
113 const Circuit<Field>* circ, const inputs& in, ProofAux<Field>* aux,
114 bindings& bnd, TranscriptSumcheck<Field>& ts, const Field& F) {
115 size_t logc = circ->logc;
116 corner_t nc = circ->nc;
117
118 check(circ->logv <= Proof<Field>::kMaxBindings,
119 "CIRCUIT->logv <= kMaxBindings");
120 bnd.logv = circ->logv;
121
122 // obtain the initial Q and G[0] bindings from the verifier
123 ts.begin_circuit(bnd.q, bnd.g[0]);
124
125 // Duplicate the g[0] binding.
126 // In general, the prover step takes two claims G[0], G[1] on the output
127 // wires and reduces them to one claim on G[0] + alpha * G[1] for random
128 // alpha. However, in the first step, there is only one claim, so we
129 // need to make up G[1]. The code sets G[1] = G[0] and it doesn't affect
130 // soundness.
131 for (size_t i = 0; i < bnd.logv; ++i) {
132 bnd.g[1][i] = bnd.g[0][i];
133 }
134
135 for (size_t ly = 0; ly < circ->nl; ++ly) {
136 auto clr = &circ->l.at(ly);
137 Elt alpha, beta;
138 ts.begin_layer(alpha, beta, ly);
139 Eqs<Field> EQ(logc, nc, bnd.q, F);
140 auto QUAD = clr->quad->clone();
141 QUAD->bind_g(bnd.logv, bnd.g[0], bnd.g[1], alpha, beta, F);
142
143 layer(pr, pad, ts, bnd, ly, logc, clr->logw, &EQ, QUAD.get(),
144 in.at(ly).get(), F);
145
146 if (aux != nullptr) {
147 aux->bound_quad[ly] = QUAD->scalar();
148 }
149 }
150 }
151
152 private:
153 using index_t = typename Quad<Field>::index_t;
154 using CPoly = typename LayerProof<Field>::CPoly;
155 using WPoly = typename LayerProof<Field>::WPoly;
156 using FCPoly = typename LayerProof<Field>::FCPoly;
157 using FWPoly = typename LayerProof<Field>::FWPoly;
158
159 /*
160 Engage in single-layer sumcheck on
161
162 EQ[|c] QUAD[|r,l] W[r,c] W[l,c]
163
164 Bind c to C, r to R, and l to L (in that order). Store claims
165 W[R,C] and W[L,C] in the proof, and set BND to the new bindings for
166 the next layer.
167
168 logw: number of sumcheck rounds in r, l
169 logc: number of sumcheck rounds in c
170 */
171 void layer(Proof<Field>* pr, const Proof<Field>* pad,
172 TranscriptSumcheck<Field>& ts, bindings& bnd, size_t layer,
173 size_t logc, size_t logw, Eqs<Field>* EQ, Quad<Field>* QUAD,
174 Dense<Field>* W, const Field& F) {
175 check(EQ->n() == W->n0_, "EQ->n() == W->n0_");
176
177 check(logw <= Proof<Field>::kMaxBindings, "logw <= kMaxBindings");
178 bnd.logv = logw;
179
180 // Bind the C variables to Q.
181 // Note that binding C variables takes O(number_of_copies * circuit_size)
182 // while binding R, L takes O(circuit_size * log(circuit_size)). In most
183 // cases number_of_copies > log(circuit_size), so we don't have to
184 // optimize binding R, L.
185 for (size_t round = 0; round < logc; ++round) {
186 CPoly sum{};
187
188 // sum over r,l: QUAD[|r,l] EQ[|c] W[r,c] W[l,c]
189 for (index_t i = 0; i < QUAD->n_; i++) {
190 corner_t r(QUAD->c_[i].h[0]);
191 corner_t l(QUAD->c_[i].h[1]);
192
193 // sum over c: EQ[|c] W[r,c] W[l,c]
194 CPoly sumc{};
195
196 // n0_ is the copy dimension, n1_ is the wire dimension.
197 for (corner_t c = 0; c < W->n0_; c += 2) {
198 CPoly poly = cpoly_at_dense(EQ, c, 0, F)
199 .mul(cpoly_at_dense(W, c, r, F), F)
200 .mul(cpoly_at_dense(W, c, l, F), F);
201 sumc.add(poly, F);
202 }
203
204 sumc.mul_scalar(QUAD->c_[i].v, F);
205 sum.add(sumc, F);
206 }
207
208 Elt rnd = round_c(pr, pad, ts, layer, round, sum, F);
209 bnd.q[round] = rnd;
210
211 // bind the c variable in both EQ and W
212 EQ->bind(rnd, F);
213 W->bind(rnd, F);
214 }
215
216 Elt eq0 = EQ->scalar();
217
218 W->reshape(W->n1_);
219 check(W->n1_ == 1, "W->n1_ == 1");
220
221 auto Wclone = W->clone(); // keep alive until function end
222 Dense<Field>* WH[2] = {W, Wclone.get()}; // reuse W
223
224 for (size_t round = 0; round < logw; ++round) {
225 for (size_t hand = 0; hand < 2; hand++) {
226 // In SUM_{l,r} Q[l,r] W[l] W[r], first precompute QW[l] =
227 // SUM_{r} Q[l,r] W[r] as a dense array, and then compute
228 // SUM_{l} QW[l] W[l].
229 Dense<Field> QW(WH[hand]->n0_, 1);
230 QW.clear(F);
231 size_t ohand = 1 - hand;
232
233 // QW[l] = SUM_{r} Q[l,r] W[r]
234 for (index_t i = 0; i < QUAD->n_; ++i) {
235 corner_t p0(QUAD->c_[i].h[hand]);
236 corner_t p1(QUAD->c_[i].h[ohand]);
237 F.add(QW.v_[p0], F.mulf(QUAD->c_[i].v, WH[ohand]->v_[p1]));
238 }
239
240 // SUM_{l} QW[l] W[l].
241 WPoly sum{};
242 for (corner_t l = 0; l < QW.n0_; l += 2) {
243 WPoly poly = wpoly_at_dense(WH[hand], l, 0, F)
244 .mul(wpoly_at_dense(&QW, l, 0, F), F);
245 sum.add(poly, F);
246 }
247
248 sum.mul_scalar(eq0, F);
249 Elt rnd = round_h(pr, pad, ts, layer, hand, round, sum, F);
250 bnd.g[hand][round] = rnd;
251
252 // bind the r variable in W[hand] and QUAD
253 WH[hand]->bind(rnd, F);
254 QUAD->bind_h(rnd, hand, F);
255 }
256 }
257
258 QUAD->scalar(); // for the side effect of assertions
259 Elt WC[2] = {WH[0]->scalar(), WH[1]->scalar()};
260 end_layer(pr, pad, ts, layer, WC, F);
261 }
262
263 // Evaluate the quadratic form
264 //
265 // V[g,c] = QUAD[g|r,l] W[r,c] W[l,c]
266 //
267 // Returns false in the case the quad is an assert0 check that fails.
268 bool eval_quad(const Quad<Field>* quad, Dense<Field>* V,
269 const Dense<Field>* W, const Field& F) {
270 check(V->n0_ == W->n0_, "V->n0_ == W->n0_");
271 corner_t n0 = V->n0_;
272
273 V->clear(F);
274 for (index_t i = 0; i < quad->n_; i++) {
275 corner_t g(quad->c_[i].g);
276 corner_t r(quad->c_[i].h[0]);
277 corner_t l(quad->c_[i].h[1]);
278 for (corner_t c = 0; c < n0; ++c) {
279 auto x = quad->c_[i].v;
280 if (x == F.zero()) {
281 // assert that the computed W[l]W[r] is zero.
282 auto y = W->v_[n0 * l + c];
283 F.mul(y, W->v_[n0 * r + c]);
284 if (y != F.zero()) {
285 return false;
286 }
287 } else {
288 F.mul(x, W->v_[n0 * l + c]);
289 F.mul(x, W->v_[n0 * r + c]);
290 F.add(V->v_[n0 * g + c], x);
291 }
292 }
293 }
294 return true;
295 }
296
297 Elt /*R*/ round_c(Proof<Field>* pr, const Proof<Field>* pad,
298 TranscriptSumcheck<Field>& ts, size_t layer, size_t round,
299 CPoly poly, const Field& F) {
300 check(round <= Proof<Field>::kMaxBindings, "round <= kMaxBindings");
301
302 if (pad) {
303 poly.sub(pad->l[layer].cp[round], F);
304 }
305
306 pr->l[layer].cp[round] = poly;
307 return ts.round(poly);
308 }
309
310 Elt /*R*/ round_h(Proof<Field>* pr, const Proof<Field>* pad,
311 TranscriptSumcheck<Field>& ts, size_t layer, size_t hand,
312 size_t round, WPoly poly, const Field& F) {
313 check(round <= Proof<Field>::kMaxBindings, "round <= kMaxBindings");
314 if (pad) {
315 poly.sub(pad->l[layer].hp[hand][round], F);
316 }
317 pr->l[layer].hp[hand][round] = poly;
318 return ts.round(poly);
319 }
320
321 void end_layer(Proof<Field>* pr, const Proof<Field>* pad,
322 TranscriptSumcheck<Field>& ts, size_t layer, const Elt wc[2],
323 const Field& F) {
324 Elt tt[2] = {wc[0], wc[1]};
325 if (pad) {
326 F.sub(tt[0], pad->l[layer].wc[0]);
327 F.sub(tt[1], pad->l[layer].wc[1]);
328 }
329
330 pr->l[layer].wc[0] = tt[0];
331 pr->l[layer].wc[1] = tt[1];
332
333 ts.write(tt, 1, 2);
334 }
335
336 CPoly cpoly_at_dense(const Dense<Field>* D, corner_t p0, corner_t p1,
337 const Field& F) {
338 auto tmp = FCPoly::extend(D->t2_at_corners(p0, p1, F), F);
339 return CPoly(tmp);
340 }
341
342 WPoly wpoly_at_dense(const Dense<Field>* D, corner_t p0, corner_t p1,
343 const Field& F) {
344 auto tmp = FWPoly::extend(D->t2_at_corners(p0, p1, F), F);
345 return WPoly(tmp);
346 }
347};
348} // namespace proofs
349
350#endif // PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
Definition dense.h:37
Definition eqs.h:32
Definition transcript_sumcheck.h:32
Definition circuit.h:45
Definition gf2_128.h:63
Definition circuit.h:151
Definition circuit.h:132
Definition prover_layers.h:104