Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
flatsha256_circuit.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_SHA_FLATSHA256_CIRCUIT_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_SHA_FLATSHA256_CIRCUIT_H_
17
18#include <stddef.h>
19
20#include <cstdint>
21#include <vector>
22
23#include "circuits/compiler/compiler.h"
24#include "circuits/logic/bit_adder.h"
25#include "circuits/sha/sha256_constants.h"
26
27namespace proofs {
28// FlatSHA256Circuit
29//
30// Implements SHA256 hash function as an arithmetic circuit over the field F.
31// The circuit is flattened, meaning that the SHA round function has been
32// repeated in parallel instead of sequentially. As a result, the prover must
33// provide the intermediate round values as witnesses.
34//
35// This package does not have any external dependencies on a SHA256 library.
36//
37// There are two versions of this function, one with standard bit inputs, and
38// another with packed bit inputs. The later reduces the number of inputs at
39// the cost of increasing the depth and number of wires. For example, the
40// following shows the difference with pack parameter 2.
41//
42// FlatSHA256_Circuit.assert_transform_block
43// depth: 7 wires: 38029 in: 6657 out:128 use:30897 ovh:7132 t:166468 cse:9703
44// notn:113744
45//
46// FlatSHA256_Circuit.assert_transform_block_packed
47// depth: 9 wires: 65735 in: 3585 out:128 use:55486 ovh:10249 t:214653
48// cse:28135 notn:151504
49//
50//
51template <class Logic, class BitPlucker>
52class FlatSHA256Circuit {
53 public:
54 using v8 = typename Logic::v8;
55 using v256 = typename Logic::v256;
56 using v32 = typename Logic::v32;
57 using EltW = typename Logic::EltW;
58 using Field = typename Logic::Field;
59 using packed_v32 = typename BitPlucker::packed_v32;
60
61 const Logic& l_;
62 BitPlucker bp_; /* public, so caller can encode input */
63
64 struct BlockWitness {
65 packed_v32 outw[48];
66 packed_v32 oute[64];
67 packed_v32 outa[64];
68 packed_v32 h1[8];
69
70 static packed_v32 packed_input(QuadCircuit<typename Logic::Field>& Q) {
71 packed_v32 r;
72 for (size_t i = 0; i < r.size(); ++i) {
73 r[i] = Q.input();
74 }
75 return r;
76 }
77
79 for (size_t k = 0; k < 48; ++k) {
80 outw[k] = packed_input(Q);
81 }
82 for (size_t k = 0; k < 64; ++k) {
83 oute[k] = packed_input(Q);
84 outa[k] = packed_input(Q);
85 }
86 for (size_t k = 0; k < 8; ++k) {
87 h1[k] = packed_input(Q);
88 }
89 }
90 };
91
92 explicit FlatSHA256Circuit(const Logic& l) : l_(l), bp_(l_) {}
93
94 static packed_v32 packed_input(QuadCircuit<Field>& Q) {
95 packed_v32 r;
96 for (size_t i = 0; i < r.size(); ++i) {
97 r[i] = Q.input();
98 }
99 return r;
100 }
101
102 void assert_transform_block(const v32 in[16], const v32 H0[8],
103 const v32 outw[48], const v32 oute[64],
104 const v32 outa[64], const v32 H1[8]) const {
105 const Logic& L = l_; // shorthand
106 BitAdder<Logic, 32> BA(L);
107
108 std::vector<v32> w(64);
109 for (size_t i = 0; i < 16; ++i) {
110 w[i] = in[i];
111 }
112
113 for (size_t i = 16; i < 64; ++i) {
114 auto sw2 = sigma1(w[i - 2]);
115 auto sw15 = sigma0(w[i - 15]);
116 std::vector<v32> terms = {sw2, w[i - 7], sw15, w[i - 16]};
117 w[i] = outw[i - 16];
118 BA.assert_eqmod(w[i], BA.add(terms), 4);
119 }
120
121 v32 a = H0[0];
122 v32 b = H0[1];
123 v32 c = H0[2];
124 v32 d = H0[3];
125 v32 e = H0[4];
126 v32 f = H0[5];
127 v32 g = H0[6];
128 v32 h = H0[7];
129
130 for (size_t t = 0; t < 64; ++t) {
131 auto s1e = Sigma1(e);
132 auto ch = L.vCh(&e, &f, g);
133 auto rt = L.vbit32(kSha256Round[t]);
134 std::vector<v32> t1_terms = {h, s1e, ch, rt, w[t]};
135 EltW t1 = BA.add(t1_terms);
136 EltW sigma0 = BA.as_field_element(Sigma0(a));
137 EltW vmaj = BA.as_field_element(L.vMaj(&a, &b, c));
138 EltW t2 = BA.add(&sigma0, vmaj);
139
140 h = g;
141 g = f;
142 f = e;
143 e = oute[t];
144 EltW ed = BA.as_field_element(d);
145 BA.assert_eqmod(e, BA.add(&t1, ed), 6);
146 d = c;
147 c = b;
148 b = a;
149 a = outa[t];
150 BA.assert_eqmod(a, BA.add(&t1, t2), 7);
151 }
152
153 BA.assert_eqmod(H1[0], BA.add(H0[0], a), 2);
154 BA.assert_eqmod(H1[1], BA.add(H0[1], b), 2);
155 BA.assert_eqmod(H1[2], BA.add(H0[2], c), 2);
156 BA.assert_eqmod(H1[3], BA.add(H0[3], d), 2);
157 BA.assert_eqmod(H1[4], BA.add(H0[4], e), 2);
158 BA.assert_eqmod(H1[5], BA.add(H0[5], f), 2);
159 BA.assert_eqmod(H1[6], BA.add(H0[6], g), 2);
160 BA.assert_eqmod(H1[7], BA.add(H0[7], h), 2);
161 }
162
163 // Packed API.
164 // H0 not packed, all others packed
165 void assert_transform_block(const v32 in[16], const v32 H0[8],
166 const packed_v32 poutw[48],
167 const packed_v32 poute[64],
168 const packed_v32 pouta[64],
169 const packed_v32 pH1[8]) const {
170 std::vector<v32> H1(8);
171 std::vector<v32> outw(48);
172 std::vector<v32> oute(64), outa(64);
173 for (size_t i = 0; i < 8; ++i) {
174 H1[i] = bp_.unpack_v32(pH1[i]);
175 }
176 for (size_t i = 0; i < 48; ++i) {
177 outw[i] = bp_.unpack_v32(poutw[i]);
178 }
179 for (size_t i = 0; i < 64; ++i) {
180 oute[i] = bp_.unpack_v32(poute[i]);
181 outa[i] = bp_.unpack_v32(pouta[i]);
182 }
183 assert_transform_block(in, H0, outw.data(), oute.data(), outa.data(),
184 H1.data());
185 }
186
187 // all packed
188 void assert_transform_block(const v32 in[16], const packed_v32 pH0[8],
189 const packed_v32 poutw[48],
190 const packed_v32 poute[64],
191 const packed_v32 pouta[64],
192 const packed_v32 pH1[8]) const {
193 std::vector<v32> H0(8);
194 for (size_t i = 0; i < 8; ++i) {
195 H0[i] = bp_.unpack_v32(pH0[i]);
196 }
197 assert_transform_block(in, H0.data(), poutw, poute, pouta, pH1);
198 }
199
200 /* This method checks that the block witness corresponds to the iterated
201 computation of the sha block transform on the input.
202 */
203 void assert_message(size_t max, const v8& nb, const v8 in[/* 64*max */],
204 const BlockWitness bw[/*max*/]) const {
205 const Logic& L = l_; // shorthand
206 const packed_v32* H = nullptr;
207 std::vector<v32> tmp(16);
208
209 for (size_t b = 0; b < max; ++b) {
210 const v8* inb = &in[64 * b];
211 for (size_t i = 0; i < 16; ++i) {
212 // big-endian mapping of v8[4] into v32. The first
213 // argument of vappend() is the LSB, and thus +3 is
214 // the LSB and +0 is the MSB, hence big-endian.
215 tmp[i] = L.vappend(L.vappend(inb[4 * i + 3], inb[4 * i + 2]),
216 L.vappend(inb[4 * i + 1], inb[4 * i + 0]));
217 }
218 if (b == 0) {
219 v32 H0[8];
220 initial_context(H0);
221 assert_transform_block(tmp.data(), H0, bw[b].outw, bw[b].oute,
222 bw[b].outa, bw[b].h1);
223 } else {
224 assert_transform_block(tmp.data(), H, bw[b].outw, bw[b].oute,
225 bw[b].outa, bw[b].h1);
226 }
227 H = bw[b].h1;
228 }
229 }
230
231 /* This method checks that the block witness corresponds to the iterated
232 computation of the sha block transform on the prefix || input.
233 */
234 void assert_message_with_prefix(size_t max, const v8& nb,
235 const v8 in[/* < 64*max */],
236 const uint8_t prefix[/* len */], size_t len,
237 const BlockWitness bw[/*max*/]) const {
238 const Logic& L = l_; // shorthand
239 std::vector<v32> tmp(16);
240
241 std::vector<v8> bbuf(64 * max);
242 for (size_t i = 0; i < len; ++i) {
243 L.bits(8, bbuf[i].data(), prefix[i]);
244 }
245 for (size_t i = 0; i + len < 64 * max; ++i) {
246 bbuf[i + len] = in[i];
247 }
248
249 assert_message(max, nb, bbuf.data(), bw);
250 }
251
252 /* This method checks if H(in) == target. The method requires that in[]
253 contains exactly nb*64 bytes and has been padded according to the SHA256
254 specification.
255 */
256 void assert_message_hash(size_t max, const v8& nb, const v8 in[/* 64*max */],
257 const v256& target,
258 const BlockWitness bw[/*max*/]) const {
259 assert_message(max, nb, in, bw);
260 assert_hash(max, target, nb, bw);
261 }
262
263 // This method checks if H(prefix || in) == target.
264 // Since the prefix is hardcoded, the compiler can propagate constants
265 // and produce smaller circuits. As above, the method requires that in[]
266 // contains exactly nb*64 bytes and has been padded according to the SHA256
267 // specification. To use this method, compute the block_witness for the
268 // entire message as usual.
269 void assert_message_hash_with_prefix(size_t max, const v8& nb,
270 const v8 in[/* 64*max */],
271 const uint8_t prefix[/* len */],
272 size_t len, const v256& target,
273 const BlockWitness bw[/*max*/]) const {
274 assert_message_with_prefix(max, nb, in, prefix, len, bw);
275 assert_hash(max, target, nb, bw);
276 }
277
278 // Verifies that the nb_th element of the block witness is equal to e.
279 // The block witness keeps track of the intermediate output of each
280 // block transform. Therefore, this method can be used to verify that the
281 // prover knows a preimage that hashes to the desired e.
282 void assert_hash(size_t max, const v256& e, const v8& nb,
283 const BlockWitness bw[/*max*/]) const {
284 packed_v32 x[8];
285 for (size_t b = 0; b < max; ++b) {
286 auto bt = l_.veq(nb, b + 1); /* b is zero-indexed */
287 auto ebt = l_.eval(bt);
288 for (size_t i = 0; i < 8; ++i) {
289 for (size_t k = 0; k < bp_.kNv32Elts; ++k) {
290 if (b == 0) {
291 x[i][k] = l_.mul(&ebt, bw[b].h1[i][k]);
292 } else {
293 auto maybe_sha = l_.mul(&ebt, bw[b].h1[i][k]);
294 x[i][k] = l_.add(&x[i][k], maybe_sha);
295 }
296 }
297 }
298 }
299
300 // Unpack the hash into a v256 in reverse byte-order.
301 v256 mm;
302 for (size_t j = 0; j < 8; ++j) {
303 auto hj = bp_.unpack_v32(x[j]);
304 for (size_t k = 0; k < 32; ++k) {
305 mm[((7 - j) * 32 + k)] = hj[k];
306 }
307 }
308 l_.vassert_eq(&mm, e);
309 }
310
311 private:
312 void initial_context(v32 H[8]) const {
313 static const uint64_t initial[8] = {0x6a09e667u, 0xbb67ae85u, 0x3c6ef372u,
314 0xa54ff53au, 0x510e527fu, 0x9b05688cu,
315 0x1f83d9abu, 0x5be0cd19u};
316 for (size_t i = 0; i < 8; i++) {
317 H[i] = l_.template vbit<32>(initial[i]);
318 }
319 }
320
321 v32 Sigma0(const v32& x) const {
322 auto x2 = l_.vrotr(x, 2);
323 auto x13 = l_.vrotr(x, 13);
324 return l_.vxor3(&x2, &x13, l_.vrotr(x, 22));
325 }
326
327 v32 Sigma1(const v32& x) const {
328 auto x6 = l_.vrotr(x, 6);
329 auto x11 = l_.vrotr(x, 11);
330 return l_.vxor3(&x6, &x11, l_.vrotr(x, 25));
331 }
332
333 v32 sigma0(const v32& x) const {
334 auto x7 = l_.vrotr(x, 7);
335 auto x18 = l_.vrotr(x, 18);
336 return l_.vxor3(&x7, &x18, l_.vshr(x, 3));
337 }
338
339 v32 sigma1(const v32& x) const {
340 auto x17 = l_.vrotr(x, 17);
341 auto x19 = l_.vrotr(x, 19);
342 return l_.vxor3(&x17, &x19, l_.vshr(x, 10));
343 }
344};
345
346} // namespace proofs
347
348#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_SHA_FLATSHA256_CIRCUIT_H_
Definition bit_plucker.h:86
Definition logic.h:38
Definition compiler.h:50
Definition flatsha256_circuit.h:64