Longfellow ZK 0290cb32
Loading...
Searching...
No Matches
convolution.h
1// Copyright 2025 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef PRIVACY_PROOFS_ZK_LIB_ALGEBRA_CONVOLUTION_H_
16#define PRIVACY_PROOFS_ZK_LIB_ALGEBRA_CONVOLUTION_H_
17
18#include <stddef.h>
19
20#include <cstdint>
21#include <memory>
22#include <vector>
23
24#include "algebra/blas.h"
25#include "algebra/fft.h"
26#include "algebra/rfft.h"
27
28/*
29All of the classes in this package compute convolutions.
30That is, given inputs arrays of field elements x, y, with |x|=n, |y|=m,
31these methods compute the first m entries of
32
33 z[k] = \sum_{i=0}^{n-1} x[i] y[k-i]
34
35SlowConvolution uses an O(n*m) method for testing validation.
36
37FFTConvolution and FFTExtConvolution first pad y to length n and use advanced
38FFT algorithms to compute the same in O(nlogn) time.
39
40The const Field& objects that are passed have lifetimes that exceed the call
41durations and can be safely passed by const reference.
42*/
43
44namespace proofs {
45
46// Returns the smallest power of 2 that is at least n.
47static size_t choose_padding(const size_t n) {
48 size_t p = 1;
49 while (p < n) {
50 p *= 2;
51 }
52 return p;
53}
54
55template <class Field>
56class FFTConvolution {
57 using Elt = typename Field::Elt;
58
59 public:
60 FFTConvolution(size_t n, size_t m, const Field& f, const Elt omega,
61 uint64_t omega_order, const Elt y[/*m*/])
62 : f_(f),
63 omega_(omega),
64 omega_order_(omega_order),
65 n_(n),
66 m_(m),
67 padding_(choose_padding(m)),
68 y_fft_(padding_, f_.zero()) {
69 Blas<Field>::copy(m, &y_fft_[0], 1, y, 1);
70 FFT<Field>::fftf(&y_fft_[0], padding_, omega_, omega_order_, f_);
71
72 // Pre-scale Y by 1/N to compensate for the scaling in FFTB(FFTF(.))
73 Blas<Field>::scale(padding_, &y_fft_[0], 1,
74 f_.invertf(f_.of_scalar(padding_)), f_);
75 }
76
77 // Computes (first m entries of) convolution of x with y, outputs in z:
78 // z[k] = \sum_{i=0}^{n-1} x[i] y[k-i].
79 // Note that y has already been FFT'd and divided by padding_ in constructor
80 void convolution(const Elt x[/*n_*/], Elt z[/*m_*/]) const {
81 std::vector<Elt> x_fft(padding_, f_.zero());
82 Blas<Field>::copy(n_, &x_fft[0], 1, x, 1);
83 FFT<Field>::fftf(&x_fft[0], padding_, omega_, omega_order_, f_);
84 // Pointwise multiplication.
85 for (size_t i = 0; i < padding_; ++i) {
86 f_.mul(x_fft[i], y_fft_[i]);
87 }
88 // Backward fft.
89 FFT<Field>::fftb(&x_fft[0], padding_, omega_, omega_order_, f_);
90 Blas<Field>::copy(m_, z, 1, &x_fft[0], 1);
91 }
92
93 private:
94 const Field& f_;
95 const Elt omega_;
96 const uint64_t omega_order_;
97
98 // n is the number of points input
99 size_t n_;
100 size_t m_; // total number of points output (points in + new points out)
101 size_t padding_;
102
103 // fft(y[i]) / padding
104 // padded with zeroes to the next power of 2 at least m.
105 std::vector<Elt> y_fft_;
106};
107
108template <class Field>
109class FFTConvolutionFactory {
110 using Elt = typename Field::Elt;
111
112 public:
113 using Convolver = FFTConvolution<Field>;
114 FFTConvolutionFactory(const Field& f, const Elt omega, uint64_t omega_order)
115 : f_(f), omega_(omega), omega_order_(omega_order) {}
116
117 std::unique_ptr<const Convolver> make(size_t n, size_t m,
118 const Elt y[/*m*/]) const {
119 return std::make_unique<const Convolver>(n, m, f_, omega_, omega_order_, y);
120 }
121
122 private:
123 const Field& f_;
124 const Elt omega_;
125 const uint64_t omega_order_;
126};
127
128template <class Field, class FieldExt>
129class FFTExtConvolution {
130 using Elt = typename Field::Elt;
131 using EltExt = typename FieldExt::Elt;
132
133 public:
134 FFTExtConvolution(size_t n, size_t m, const Field& f, const FieldExt& f_ext,
135 const EltExt omega, uint64_t omega_order,
136 const Elt y[/*m*/])
137 : f_(f),
138 f_ext_(f_ext),
139 omega_(omega),
140 omega_order_(omega_order),
141 n_(n),
142 m_(m),
143 padding_(choose_padding(m)),
144 y_fft_(padding_, f_.zero()) {
145 Blas<Field>::copy(m, &y_fft_[0], 1, y, 1);
146 RFFT<FieldExt>::r2hc(&y_fft_[0], padding_, omega_, omega_order_, f_ext_);
147
148 // Pre-scale Y by 1/N to compensate for the scaling in HC2R(R2HC(.))
149 Blas<Field>::scale(padding_, &y_fft_[0], 1,
150 f_.invertf(f_.of_scalar(padding_)), f_);
151 }
152
153 // Computes (first m entries of) convolution of x with y, stores in z:
154 // z[k] = \sum_{i=0}^{n-1} x[i] y[k-i].
155 // Note that y has already been FFT'd and divided by padding_ in constructor
156 void convolution(const Elt x[/*n_*/], Elt z[/*m_*/]) const {
157 std::vector<Elt> x_fft(padding_, f_.zero());
158 Blas<Field>::copy(n_, &x_fft[0], 1, x, 1);
159 RFFT<FieldExt>::r2hc(&x_fft[0], padding_, omega_, omega_order_, f_ext_);
160
161 // Pointwise multiplication
162 {
163 size_t i;
164 f_.mul(x_fft[0], y_fft_[0]); // DC is real
165 for (i = 1; i + i < padding_; ++i) {
166 RFFT<FieldExt>::cmul(&x_fft[i], &x_fft[padding_ - i], x_fft[i],
167 x_fft[padding_ - i], y_fft_[i],
168 y_fft_[padding_ - i], f_);
169 }
170 f_.mul(x_fft[i], y_fft_[i]); // Nyquist is real
171 }
172
173 // Backward FFT.
174 RFFT<FieldExt>::hc2r(&x_fft[0], padding_, omega_, omega_order_, f_ext_);
175 Blas<Field>::copy(m_, z, 1, &x_fft[0], 1);
176 }
177
178 private:
179 const Field& f_;
180 const FieldExt& f_ext_;
181 const EltExt omega_;
182 const uint64_t omega_order_;
183
184 // n is the number of points input in x
185 size_t n_;
186 size_t m_; // total number of points output in convolution
187 size_t padding_;
188
189 // fft(y[i]) / padding
190 // padded with zeroes to the next power of 2 at least m.
191 std::vector<Elt> y_fft_;
192};
193
194template <class Field, class FieldExt>
195class FFTExtConvolutionFactory {
196 using Elt = typename Field::Elt;
197 using EltExt = typename FieldExt::Elt;
198
199 public:
200 using Convolver = FFTExtConvolution<Field, FieldExt>;
201
202 FFTExtConvolutionFactory(const Field& f, const FieldExt& f_ext,
203 const EltExt omega, uint64_t omega_order)
204 : f_(f), f_ext_(f_ext), omega_(omega), omega_order_(omega_order) {}
205
206 std::unique_ptr<const Convolver> make(size_t n, size_t m,
207 const Elt y[/*m*/]) const {
208 return std::make_unique<const Convolver>(n, m, f_, f_ext_, omega_,
209 omega_order_, y);
210 }
211
212 private:
213 const Field& f_;
214 const FieldExt& f_ext_;
215 const EltExt omega_;
216 const uint64_t omega_order_;
217};
218} // namespace proofs
219
220#endif // PRIVACY_PROOFS_ZK_LIB_ALGEBRA_CONVOLUTION_H_
Definition convolution.h:56
Definition convolution.h:129
Definition gf2_128.h:63