44 using Field =
typename EC::Field;
46 using Nat =
typename Field::N;
54 uint8_t preimage_[64 * kMaxJWTSHABlocks];
59 size_t payload_ind_, payload_len_;
60 std::vector<size_t> attr_ind_;
61 std::vector<size_t> attr_id_len_;
62 std::vector<size_t> attr_value_len_;
64 explicit JWTWitness(
const EC& ec,
const ScalarField& Fn)
65 : ec_(ec), sig_(Fn, ec) {}
71 sig_.fill_witness(filler);
74 for (
size_t i = 0; i < 64 * kMaxJWTSHABlocks; ++i) {
75 filler.push_back(preimage_[i], 8, ec_.f_);
78 for (
size_t i = 0; i < 256; ++i) {
79 filler.push_back(e_bits_[i], 1, ec_.f_);
82 for (
size_t j = 0; j < kMaxJWTSHABlocks; ++j) {
83 fill_sha(filler, sha_bw_[j]);
86 filler.push_back(numb_, 8, ec_.f_);
88 for (
size_t i = 0; i < na_; ++i) {
89 filler.push_back(attr_ind_[i], kJWTIndexBits, ec_.f_);
90 filler.push_back(attr_id_len_[i], 8, ec_.f_);
91 filler.push_back(attr_value_len_[i], 8, ec_.f_);
94 filler.push_back(payload_ind_, kJWTIndexBits, ec_.f_);
95 filler.push_back(payload_len_, kJWTIndexBits, ec_.f_);
101 for (
size_t k = 0; k < 48; ++k) {
102 filler.push_back(BPENC.mkpacked_v32(bw.outw[k]));
104 for (
size_t k = 0; k < 64; ++k) {
105 filler.push_back(BPENC.mkpacked_v32(bw.oute[k]));
106 filler.push_back(BPENC.mkpacked_v32(bw.outa[k]));
108 for (
size_t k = 0; k < 8; ++k) {
109 filler.push_back(BPENC.mkpacked_v32(bw.h1[k]));
115 Nat nat_from_u32(
const uint32_t be[])
const {
116 uint8_t tmp[Nat::kBytes];
117 const size_t top = Nat::kBytes / 4;
118 for (
size_t i = 0; i < Nat::kBytes; ++i) {
119 tmp[i] = (be[top - i / 4 - 1] >> ((i % 4) * 8)) & 0xff;
121 return Nat::of_bytes(tmp);
126 Nat nat_from_be(
const uint8_t be[]) {
127 uint8_t tmp[Nat::kBytes];
129 for (
size_t i = 0; i < Nat::kBytes; ++i) {
130 tmp[i] = be[Nat::kBytes - i - 1];
132 return Nat::of_bytes(tmp);
135 bool compute_witness(std::string jwt, Elt pkX, Elt pkY,
136 std::vector<OpenedAttribute> attrs) {
137 size_t dot = jwt.find_first_of(
'.');
138 size_t dot2 = jwt.find_first_of(
'.', dot + 1);
139 if (dot == std::string::npos || dot2 == std::string::npos) {
140 log(ERROR,
"JWT is not in the format of header.payload.signature");
143 auto hdr = jwt.substr(0, dot);
144 auto pld = jwt.substr(dot + 1, dot2 - dot - 1);
145 auto rest = jwt.substr(dot2 + 1);
146 auto msg = jwt.substr(0, dot2);
147 payload_len_ = pld.size();
148 payload_ind_ = dot + 1;
150 if (payload_len_ > kMaxJWTSHABlocks * 64) {
151 log(ERROR,
"JWT payload is too large");
155 size_t tilde = rest.find_first_of(
'~');
156 if (tilde == std::string::npos) {
157 log(ERROR,
"JWT is not in the format of header.payload.signature~epoch");
160 auto sig = rest.substr(0, tilde);
161 auto claims = rest.substr(tilde + 1);
163 FlatSHA256Witness::transform_and_witness_message(
164 msg.size(),
reinterpret_cast<const uint8_t*
>(msg.data()),
165 kMaxJWTSHABlocks, numb_, preimage_, sha_bw_);
167 Nat ne = nat_from_u32(sha_bw_[numb_ - 1].h1);
168 e_ = ec_.f_.to_montgomery(ne);
170 std::vector<uint8_t> sigb;
171 sigb.reserve(ec_.f_.kBytes * 2);
172 if (!base64_decode_url(sig, sigb) || sigb.size() < ec_.f_.kBytes * 2) {
173 log(ERROR,
"signature is not in the format of base64url");
176 Nat nr = nat_from_be(&sigb[0]);
177 Nat ns = nat_from_be(&sigb[ec_.f_.kBytes]);
179 r_ = ec_.f_.to_montgomery(nr);
180 s_ = ec_.f_.to_montgomery(ns);
181 if (!sig_.compute_witness(pkX, pkY, ne, nr, ns)) {
182 log(ERROR,
"signature verification failed");
186 for (
size_t i = 0; i < 256; ++i) {
187 e_bits_[i] = ne.bit(i);
192 std::vector<uint8_t> payload;
193 payload.reserve(pld.size());
194 if (!base64_decode_url(pld, payload)) {
195 log(ERROR,
"JWT payload is not in the format of base64url");
198 std::string str((
const char*)payload.data(), payload.size());
199 for (
size_t i = 0; i < na_; ++i) {
200 size_t ind = str.find((
const char*)attrs[i].
id, 0, attrs[i].id_len);
201 if (ind == std::string::npos) {
202 log(ERROR,
"Could not find attribute %.*s", attrs[i].id_len,
206 size_t vstart = ind + attrs[i].id_len + 3;
208 str.find((
const char*)attrs[i].value, vstart, attrs[i].value_len);
209 if (vind == std::string::npos || vind != vstart) {
210 log(ERROR,
"Could not find attribute value %.*s", attrs[i].value_len,
214 attr_ind_.push_back(ind);
215 attr_id_len_.push_back(attrs[i].id_len);
216 attr_value_len_.push_back(attrs[i].value_len);