Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
prover_layers.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
16#define PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
17
18#include <stddef.h>
19
20#include <memory>
21#include <vector>
22
23#include "arrays/affine.h"
24#include "arrays/dense.h"
25#include "arrays/eqs.h"
26#include "sumcheck/circuit.h"
27#include "sumcheck/quad.h"
28#include "sumcheck/transcript_sumcheck.h"
29#include "util/panic.h"
30
31namespace proofs {
32
33// A high level idea is partially described in chapter 4.6.7 "Leveraging Data
34// Parallelism for Further Speedups" in the book "Proofs, Arguments, and
35// Zero-Knowledge" by Justin Thaler.
36template <class Field>
37class ProverLayers {
38 using Elt = typename Field::Elt;
39
40 public:
41 using inputs = std::vector<std::unique_ptr<Dense<Field>>>;
42
43 explicit ProverLayers(const Field& f) : f_(f) {}
44
45 // Evaluate CIRCUIT on input wires W0. This function stores the
46 // input wires of each layer L into IN->at(L), and returns the
47 // final output. This asymmetry reflects the fact that for L
48 // layers there are L+1 meaningful sets of wires, and that the
49 // prover needs IN while the verifier needs the final output.
50 std::unique_ptr<Dense<Field>> eval_circuit(inputs* in,
51 const Circuit<Field>* circ,
52 std::unique_ptr<Dense<Field>> W0,
53 const Field& F) {
54 if (in == nullptr || circ == nullptr || W0 == nullptr) return nullptr;
55
56 std::unique_ptr<Dense<Field>> finalV;
57 size_t nl = circ->nl, nc = circ->nc;
58 check(nl >= 1, "nl >= 1");
59 check(nc >= 1, "nc >= 1");
60
61 Dense<Field>* W = W0.get();
62
63 in->resize(nl);
64 in->at(nl - 1).swap(W0);
65
66 // Allocate memory and evaluate layer on input W and output V
67 for (size_t l = nl; l-- > 0;) {
68 Dense<Field>* V;
69 if (l > 0) {
70 // input of layer l-1 = output of layer l
71 in->at(l - 1) = std::make_unique<Dense<Field>>(nc, circ->l[l - 1].nw);
72 V = in->at(l - 1).get();
73 } else {
74 // final output = output of layer 0
75 finalV = std::make_unique<Dense<Field>>(nc, circ->nv);
76 V = finalV.get();
77 }
78
79 bool ok = eval_quad(circ->l[l].quad.get(), V, W, F);
80 if (!ok) {
81 // Early exit in case of assertion failure.
82 // In this case IN is only partially allocated.
83 // To avoid ambiguities, free all memory that we may have allocated.
84 for (size_t i = 0; i < nl; ++i) {
85 in->at(i) = nullptr;
86 }
87 finalV = nullptr;
88
89 return /*finalV=*/nullptr;
90 }
91
92 W = V;
93 }
94
95 return finalV;
96 }
97
98 protected:
99 const Field& f_;
100
101 // A struct that collects the bindings generated while proving one
102 // layer, to serve as initial bindings for the next layer.
103 // This protected class must be defined before the public section.
104 struct bindings {
105 size_t logv;
106 Elt q[Proof<Field>::kMaxBindings];
107 Elt g[2][Proof<Field>::kMaxBindings];
108 };
109
110 // Generate proof for circuit, as a protected member, the caller must
111 // ensure that input parameters are valid.
112 void prove(Proof<Field>* pr, const Proof<Field>* pad,
113 const Circuit<Field>* circ, const inputs& in, ProofAux<Field>* aux,
114 bindings& bnd, TranscriptSumcheck<Field>& ts, const Field& F) {
115 size_t logc = circ->logc;
116 corner_t nc = circ->nc;
117
118 check(circ->logv <= Proof<Field>::kMaxBindings,
119 "CIRCUIT->logv <= kMaxBindings");
120 bnd.logv = circ->logv;
121
122 // obtain the initial Q and G[0] bindings from the verifier
123 ts.begin_circuit(bnd.q, bnd.g[0]);
124
125 // Duplicate the g[0] binding.
126 // In general, the prover step takes two claims G[0], G[1] on the output
127 // wires and reduces them to one claim on G[0] + alpha * G[1] for random
128 // alpha. However, in the first step, there is only one claim, so we
129 // need to make up G[1]. The code sets G[1] = G[0] and it doesn't affect
130 // soundness.
131 for (size_t i = 0; i < bnd.logv; ++i) {
132 bnd.g[1][i] = bnd.g[0][i];
133 }
134
135 for (size_t ly = 0; ly < circ->nl; ++ly) {
136 auto clr = &circ->l.at(ly);
137 Elt alpha, beta;
138 ts.begin_layer(alpha, beta, ly);
139 Eqs<Field> EQ(logc, nc, bnd.q, F);
140 auto QUAD = clr->quad->clone();
141 QUAD->bind_g(bnd.logv, bnd.g[0], bnd.g[1], alpha, beta, F);
142
143 layer(pr, pad, ts, bnd, ly, logc, clr->logw, &EQ, QUAD.get(),
144 in.at(ly).get(), F);
145
146 if (aux != nullptr) {
147 aux->bound_quad[ly] = QUAD->scalar();
148 }
149 }
150 }
151
152 private:
153 using index_t = typename Quad<Field>::index_t;
154 using CPoly = typename LayerProof<Field>::CPoly;
155 using WPoly = typename LayerProof<Field>::WPoly;
156
157 /*
158 Engage in single-layer sumcheck on
159
160 EQ[|c] QUAD[|r,l] W[r,c] W[l,c]
161
162 Bind c to C, r to R, and l to L (in that order). Store claims
163 W[R,C] and W[L,C] in the proof, and set BND to the new bindings for
164 the next layer.
165
166 logw: number of sumcheck rounds in r, l
167 logc: number of sumcheck rounds in c
168 */
169 void layer(Proof<Field>* pr, const Proof<Field>* pad,
170 TranscriptSumcheck<Field>& ts, bindings& bnd, size_t layer,
171 size_t logc, size_t logw, Eqs<Field>* EQ, Quad<Field>* QUAD,
172 Dense<Field>* W, const Field& F) {
173 check(EQ->n() == W->n0_, "EQ->n() == W->n0_");
174
175 check(logw <= Proof<Field>::kMaxBindings, "logw <= kMaxBindings");
176 bnd.logv = logw;
177
178 // Bind the C variables to Q.
179 // Note that binding C variables takes O(number_of_copies * circuit_size)
180 // while binding R, L takes O(circuit_size * log(circuit_size)). In most
181 // cases number_of_copies > log(circuit_size), so we don't have to
182 // optimize binding R, L.
183 for (size_t round = 0; round < logc; ++round) {
184 CPoly sum{};
185
186 // sum over r,l: QUAD[|r,l] EQ[|c] W[r,c] W[l,c]
187 for (index_t i = 0; i < QUAD->n_; i++) {
188 corner_t r(QUAD->c_[i].h[0]);
189 corner_t l(QUAD->c_[i].h[1]);
190
191 // sum over c: EQ[|c] W[r,c] W[l,c]
192 CPoly sumc{};
193
194 // n0_ is the copy dimension, n1_ is the wire dimension.
195 for (corner_t c = 0; c < W->n0_; c += 2) {
196 CPoly poly = cpoly_at_dense(EQ, c, 0, F)
197 .mul(cpoly_at_dense(W, c, r, F), F)
198 .mul(cpoly_at_dense(W, c, l, F), F);
199 sumc.add(poly, F);
200 }
201
202 sumc.mul_scalar(QUAD->c_[i].v, F);
203 sum.add(sumc, F);
204 }
205
206 Elt rnd = round_c(pr, pad, ts, layer, round, sum, F);
207 bnd.q[round] = rnd;
208
209 // bind the c variable in both EQ and W
210 EQ->bind(rnd, F);
211 W->bind(rnd, F);
212 }
213
214 Elt eq0 = EQ->scalar();
215
216 W->reshape(W->n1_);
217 check(W->n1_ == 1, "W->n1_ == 1");
218
219 auto Wclone = W->clone(); // keep alive until function end
220 Dense<Field>* WH[2] = {W, Wclone.get()}; // reuse W
221
222 for (size_t round = 0; round < logw; ++round) {
223 for (size_t hand = 0; hand < 2; hand++) {
224 // In SUM_{l,r} Q[l,r] W[l] W[r], first precompute QW[l] =
225 // SUM_{r} Q[l,r] W[r] as a dense array, and then compute
226 // SUM_{l} QW[l] W[l].
227 Dense<Field> QW(WH[hand]->n0_, 1);
228 QW.clear(F);
229 size_t ohand = 1 - hand;
230
231 // QW[l] = SUM_{r} Q[l,r] W[r]
232 for (index_t i = 0; i < QUAD->n_; ++i) {
233 corner_t p0(QUAD->c_[i].h[hand]);
234 corner_t p1(QUAD->c_[i].h[ohand]);
235 F.add(QW.v_[p0], F.mulf(QUAD->c_[i].v, WH[ohand]->v_[p1]));
236 }
237
238 // SUM_{l} QW[l] W[l].
239 WPoly sum{};
240 for (corner_t l = 0; l < QW.n0_; l += 2) {
241 WPoly poly = wpoly_at_dense(WH[hand], l, 0, F)
242 .mul(wpoly_at_dense(&QW, l, 0, F), F);
243 sum.add(poly, F);
244 }
245
246 sum.mul_scalar(eq0, F);
247 Elt rnd = round_h(pr, pad, ts, layer, hand, round, sum, F);
248 bnd.g[hand][round] = rnd;
249
250 // bind the r variable in W[hand] and QUAD
251 WH[hand]->bind(rnd, F);
252 QUAD->bind_h(rnd, hand, F);
253 }
254 }
255
256 QUAD->scalar(); // for the side effect of assertions
257 Elt WC[2] = {WH[0]->scalar(), WH[1]->scalar()};
258 end_layer(pr, pad, ts, layer, WC, F);
259 }
260
261 // Evaluate the quadratic form
262 //
263 // V[g,c] = QUAD[g|r,l] W[r,c] W[l,c]
264 //
265 // Returns false in the case the quad is an assert0 check that fails.
266 bool eval_quad(const Quad<Field>* quad, Dense<Field>* V,
267 const Dense<Field>* W, const Field& F) {
268 check(V->n0_ == W->n0_, "V->n0_ == W->n0_");
269 corner_t n0 = V->n0_;
270
271 V->clear(F);
272 for (index_t i = 0; i < quad->n_; i++) {
273 corner_t g(quad->c_[i].g);
274 corner_t r(quad->c_[i].h[0]);
275 corner_t l(quad->c_[i].h[1]);
276 for (corner_t c = 0; c < n0; ++c) {
277 auto x = quad->c_[i].v;
278 if (x == F.zero()) {
279 // assert that the computed W[l]W[r] is zero.
280 auto y = W->v_[n0 * l + c];
281 F.mul(y, W->v_[n0 * r + c]);
282 if (y != F.zero()) {
283 return false;
284 }
285 } else {
286 F.mul(x, W->v_[n0 * l + c]);
287 F.mul(x, W->v_[n0 * r + c]);
288 F.add(V->v_[n0 * g + c], x);
289 }
290 }
291 }
292 return true;
293 }
294
295 Elt /*R*/ round_c(Proof<Field>* pr, const Proof<Field>* pad,
296 TranscriptSumcheck<Field>& ts, size_t layer, size_t round,
297 CPoly poly, const Field& F) {
298 check(round <= Proof<Field>::kMaxBindings, "round <= kMaxBindings");
299
300 if (pad) {
301 poly.sub(pad->l[layer].cp[round], F);
302 }
303
304 pr->l[layer].cp[round] = poly;
305 return ts.round(poly);
306 }
307
308 Elt /*R*/ round_h(Proof<Field>* pr, const Proof<Field>* pad,
309 TranscriptSumcheck<Field>& ts, size_t layer, size_t hand,
310 size_t round, WPoly poly, const Field& F) {
311 check(round <= Proof<Field>::kMaxBindings, "round <= kMaxBindings");
312 if (pad) {
313 poly.sub(pad->l[layer].hp[hand][round], F);
314 }
315 pr->l[layer].hp[hand][round] = poly;
316 return ts.round(poly);
317 }
318
319 void end_layer(Proof<Field>* pr, const Proof<Field>* pad,
320 TranscriptSumcheck<Field>& ts, size_t layer, const Elt wc[2],
321 const Field& F) {
322 Elt tt[2] = {wc[0], wc[1]};
323 if (pad) {
324 F.sub(tt[0], pad->l[layer].wc[0]);
325 F.sub(tt[1], pad->l[layer].wc[1]);
326 }
327
328 pr->l[layer].wc[0] = tt[0];
329 pr->l[layer].wc[1] = tt[1];
330
331 ts.write(tt, 1, 2);
332 }
333
334 CPoly cpoly_at_dense(const Dense<Field>* D, corner_t p0, corner_t p1,
335 const Field& F) {
336 return CPoly::extend(D->t2_at_corners(p0, p1, F), F);
337 }
338
339 WPoly wpoly_at_dense(const Dense<Field>* D, corner_t p0, corner_t p1,
340 const Field& F) {
341 return WPoly::extend(D->t2_at_corners(p0, p1, F), F);
342 }
343};
344} // namespace proofs
345
346#endif // PRIVACY_PROOFS_ZK_LIB_SUMCHECK_PROVER_LAYERS_H_
Definition dense.h:37
Definition eqs.h:32
Definition transcript_sumcheck.h:32
Definition circuit.h:45
Definition gf2_128.h:63
Definition circuit.h:149
Definition circuit.h:130
Definition prover_layers.h:104