Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
jwt_witness.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_JWT_JWT_WITNESS_H_
16#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_JWT_JWT_WITNESS_H_
17
18#include <cstddef>
19#include <cstdint>
20#include <cstdio>
21#include <string>
22#include <vector>
23
24#include "arrays/dense.h"
25#include "circuits/base64/decode_util.h"
26#include "circuits/ecdsa/verify_witness.h"
27#include "circuits/jwt/jwt_constants.h"
28#include "circuits/logic/bit_plucker_encoder.h"
29#include "circuits/sha/flatsha256_witness.h"
30#include "util/crypto.h"
31#include "util/log.h"
32
33namespace proofs {
34
35/* This struct allows a verifier to express which attribute and value the prover
36 * must claim. */
38 uint8_t id[32];
39 uint8_t value[64];
40 size_t id_len, value_len;
41};
42
43template <class Field>
44bool fill_attribute(DenseFiller<Field>& filler, const OpenedAttribute& attr,
45 const Field& F, size_t version) {
46 std::vector<uint8_t> vbuf;
47 vbuf.push_back('"');
48 vbuf.insert(vbuf.end(), attr.id, attr.id + attr.id_len);
49 vbuf.push_back('"');
50 vbuf.push_back(':');
51 vbuf.push_back('"');
52 vbuf.insert(vbuf.end(), attr.value, attr.value + attr.value_len);
53 vbuf.push_back('"');
54 for (size_t i = 0; i < 128; ++i) {
55 if (i < vbuf.size()) {
56 filler.push_back(vbuf[i], 8, F);
57 } else {
58 filler.push_back(0, 8, F);
59 }
60 }
61 filler.push_back(vbuf.size(), 8, F);
62 return true;
63}
64
65
66template <class EC, class ScalarField, size_t SHABlocks>
67class JWTWitness {
68 constexpr static size_t kMaxSHABlocks = SHABlocks;
69 using Field = typename EC::Field;
70 using Elt = typename Field::Elt;
71 using Nat = typename Field::N;
72 using EcdsaWitness = VerifyWitness3<EC, ScalarField>;
73 const EC& ec_;
74
75 public:
76 Elt e_, dpkx_, dpky_;
77 EcdsaWitness sig_;
78 EcdsaWitness kb_sig_;
79
80 uint8_t preimage_[64 * kMaxSHABlocks];
81 uint8_t e_bits_[256];
82 FlatSHA256Witness::BlockWitness sha_bw_[kMaxSHABlocks];
83 uint8_t numb_; /* Number of the correct sha block. */
84 uint8_t na_; /* Number of attributes. */
85 size_t payload_ind_, payload_len_;
86 std::vector<size_t> attr_ind_;
87
88 struct Jws {
89 std::string msg;
90 std::string payload;
91 size_t payload_len, payload_ind;
92 Nat ne, nr, ns;
93 Elt e, r, s;
94 };
95
96 bool parse_jws(std::string jwt, Jws& jws) {
97 size_t dot = jwt.find_first_of('.');
98 if (dot == std::string::npos) {
99 log(ERROR, "JWT is not well-formed");
100 return false;
101 }
102 size_t dot2 = jwt.find_first_of('.', dot + 1);
103 if (dot2 == std::string::npos) {
104 log(ERROR, "JWT is not in the format of header.payload.signature");
105 return false;
106 }
107 auto hdr = jwt.substr(0, dot);
108 auto pld = jwt.substr(dot + 1, dot2 - dot - 1);
109 auto sig = jwt.substr(dot2 + 1);
110 jws.msg = jwt.substr(0, dot2);
111 jws.payload = pld;
112 jws.payload_ind = dot + 1;
113 jws.payload_len = pld.size();
114
115 uint8_t hash[kSHA256DigestSize];
116 SHA256 sha;
117 sha.Update((const uint8_t*)jws.msg.data(), dot2);
118 sha.DigestData(hash);
119 jws.ne = nat_from_be(hash);
120
121 std::vector<uint8_t> sigb;
122 sigb.reserve(ec_.f_.kBytes * 2);
123 if (!base64_decode_url(sig, sigb) || sigb.size() < ec_.f_.kBytes * 2) {
124 log(ERROR, "signature is not in the format of base64url");
125 return false;
126 }
127 jws.nr = nat_from_be(&sigb[0]);
128 jws.ns = nat_from_be(&sigb[ec_.f_.kBytes]);
129
130 jws.e = ec_.f_.to_montgomery(jws.ne);
131 jws.r = ec_.f_.to_montgomery(jws.nr);
132 jws.s = ec_.f_.to_montgomery(jws.ns);
133
134 return true;
135 }
136
137 explicit JWTWitness(const EC& ec, const ScalarField& Fn)
138 : ec_(ec), sig_(Fn, ec), kb_sig_(Fn, ec) {}
139
140 void fill_witness(DenseFiller<Field>& filler) const {
141 filler.push_back(e_);
142 filler.push_back(dpkx_);
143 filler.push_back(dpky_);
144 sig_.fill_witness(filler);
145 kb_sig_.fill_witness(filler);
146
147 // Write the message.
148 for (size_t i = 0; i < 64 * kMaxSHABlocks; ++i) {
149 filler.push_back(preimage_[i], 8, ec_.f_);
150 }
151
152 for (size_t i = 0; i < 256; ++i) {
153 filler.push_back(e_bits_[i], 1, ec_.f_);
154 }
155
156 for (size_t j = 0; j < kMaxSHABlocks; ++j) {
157 fill_sha(filler, sha_bw_[j]);
158 }
159
160 filler.push_back(numb_, 8, ec_.f_);
161
162 for (size_t i = 0; i < na_; ++i) {
163 filler.push_back(attr_ind_[i], kJWTIndexBits, ec_.f_);
164 }
165
166 filler.push_back(payload_ind_, kJWTIndexBits, ec_.f_);
167 filler.push_back(payload_len_, kJWTIndexBits, ec_.f_);
168 }
169
170 void fill_sha(DenseFiller<Field>& filler,
171 const FlatSHA256Witness::BlockWitness& bw) const {
172 BitPluckerEncoder<Field, kSHAJWTPluckerBits> BPENC(ec_.f_);
173 for (size_t k = 0; k < 48; ++k) {
174 filler.push_back(BPENC.mkpacked_v32(bw.outw[k]));
175 }
176 for (size_t k = 0; k < 64; ++k) {
177 filler.push_back(BPENC.mkpacked_v32(bw.oute[k]));
178 filler.push_back(BPENC.mkpacked_v32(bw.outa[k]));
179 }
180 for (size_t k = 0; k < 8; ++k) {
181 filler.push_back(BPENC.mkpacked_v32(bw.h1[k]));
182 }
183 }
184
185 // Transform from u32 be (i.e., be[0] is the most significant nibble)
186 // into nat form, which requires first converting to le byte order.
187 Nat nat_from_u32(const uint32_t be[]) const {
188 uint8_t tmp[Nat::kBytes];
189 const size_t top = Nat::kBytes / 4;
190 for (size_t i = 0; i < Nat::kBytes; ++i) {
191 tmp[i] = (be[top - i / 4 - 1] >> ((i % 4) * 8)) & 0xff;
192 }
193 return Nat::of_bytes(tmp);
194 }
195
196 // Transform from u8 be (i.e., be[31] is the most significant byte) into
197 // nat form, which requires first converting to le byte order.
198 Nat nat_from_be(const uint8_t be[/* Nat::kBytes */]) {
199 uint8_t tmp[Nat::kBytes];
200 // Transform into byte-wise le representation.
201 for (size_t i = 0; i < Nat::kBytes; ++i) {
202 tmp[i] = be[Nat::kBytes - i - 1];
203 }
204 return Nat::of_bytes(tmp);
205 }
206
207 bool compute_witness(std::string jwt, Elt pkX, Elt pkY,
208 std::vector<OpenedAttribute> attrs) {
209 size_t tilde = jwt.find_first_of('~');
210 if (tilde == std::string::npos) {
211 log(ERROR, "JWT is not in the format of header.payload.signature~kb");
212 return false;
213 }
214 auto id = jwt.substr(0, tilde);
215 auto kb = jwt.substr(tilde + 1);
216 Jws id_jws;
217 if (!parse_jws(id, id_jws)) {
218 return false;
219 }
220
221 if (id_jws.msg.size() > kMaxSHABlocks * 64 - 9) {
222 log(INFO, "JWT payload bytes is too large");
223 return false;
224 }
225
226 FlatSHA256Witness::transform_and_witness_message(
227 id_jws.msg.size(), reinterpret_cast<const uint8_t*>(id_jws.msg.data()),
228 kMaxSHABlocks, numb_, preimage_, sha_bw_);
229
230 e_ = id_jws.e;
231 payload_ind_ = id_jws.payload_ind;
232 payload_len_ = id_jws.payload_len;
233 if (!sig_.compute_witness(pkX, pkY, id_jws.ne, id_jws.nr, id_jws.ns)) {
234 log(ERROR, "signature verification failed");
235 return false;
236 }
237
238 for (size_t i = 0; i < 256; ++i) {
239 e_bits_[i] = id_jws.ne.bit(i);
240 }
241
242 // Find the positions of each of the attributes.
243 na_ = attrs.size();
244 std::vector<uint8_t> payload;
245 payload.reserve(id_jws.payload.size());
246 if (!base64_decode_url(id_jws.payload, payload)) {
247 log(ERROR, "JWT payload is not in the format of base64url");
248 return false;
249 }
250 std::string str((const char*)payload.data(), payload.size());
251 for (size_t i = 0; i < na_; ++i) {
252 std::string idm =
253 "\"" + std::string((const char*)attrs[i].id, attrs[i].id_len) +
254 "\":\"" +
255 std::string((const char*)attrs[i].value, attrs[i].value_len) + "\"";
256 size_t ind = str.find(idm, 0);
257 if (ind == std::string::npos) {
258 log(ERROR, "Could not find attribute %s", idm.c_str());
259 return false;
260 }
261 attr_ind_.push_back(ind);
262 }
263
264 // Find device public key in payload.
265 std::string cnf_prefix =
266 "\"cnf\":{\"jwk\":{\"kty\":\"EC\",\"crv\":\"P-256\",\"x\":\"";
267 size_t x_ind = str.find(cnf_prefix.data(), 0, cnf_prefix.size());
268 if (x_ind == std::string::npos) {
269 log(ERROR, "Could not find device public key in payload");
270 return false;
271 }
272 size_t y_ind = str.find("\",\"y\":\"", x_ind + cnf_prefix.size());
273 if (y_ind == std::string::npos) {
274 log(ERROR, "Could not find device public key in payload");
275 return false;
276 }
277 std::string x = str.substr(x_ind + cnf_prefix.size(), 43);
278 std::string y = str.substr(y_ind + 7, 43);
279 std::vector<uint8_t> dpkx, dpky;
280 dpkx.reserve(65); dpky.reserve(65);
281 if (!base64_decode_url(x, dpkx)) {
282 log(ERROR, "CNF:dpkx payload is not in the format of base64url");
283 return false;
284 }
285 if (!base64_decode_url(y, dpky)) {
286 log(ERROR, "CNF:dpky payload is not in the format of base64url");
287 return false;
288 }
289 Nat nx = nat_from_be(dpkx.data());
290 Nat ny = nat_from_be(dpky.data());
291 dpkx_ = ec_.f_.to_montgomery(nx);
292 dpky_ = ec_.f_.to_montgomery(ny);
293
294 // Process the key binding portion
295 if (kb.empty()) {
296 log(ERROR, "kb portion is missing");
297 return false;
298 }
299 Jws kb_jws;
300 if (!parse_jws(kb, kb_jws)) {
301 log(ERROR, "kb jws parsing failed");
302 return false;
303 }
304 if (!kb_sig_.compute_witness(dpkx_, dpky_, kb_jws.ne, kb_jws.nr,
305 kb_jws.ns)) {
306 log(ERROR, "kb signature verification failed");
307 return false;
308 }
309 return true;
310 }
311};
312
313} // namespace proofs
314
315#endif // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_JWT_JWT_WITNESS_H_
Definition dense.h:153
Definition verify_witness.h:30
Definition flatsha256_witness.h:27
Definition gf2_128.h:63
Definition jwt_witness.h:88
Definition jwt_witness.h:37