Skip to content

Commit 239c2b8

Browse files
committed
4-base NTT
1 parent a649af5 commit 239c2b8

File tree

1 file changed

+171
-92
lines changed

1 file changed

+171
-92
lines changed

Convolution/Convolution.java

Lines changed: 171 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/**
22
* Convolution.
33
*
4-
* @verified https://atcoder.jp/contests/practice2/tasks/practice2_f
5-
* @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
4+
* @verified https://atcoder.jp/contests/practice2/submissions/24449847
5+
* @verified https://judge.yosupo.jp/submission/53841
66
*/
77
class Convolution {
88
/**
@@ -78,6 +78,74 @@ private static int ceilPow2(int n) {
7878
return x;
7979
}
8080

81+
private static class FftInfo {
82+
private static int bsfConstexpr(int n) {
83+
int x = 0;
84+
while ((n & (1 << x)) == 0) x++;
85+
return x;
86+
}
87+
88+
private static long inv(long a, long mod) {
89+
long b = mod;
90+
long p = 1, q = 0;
91+
while (b > 0) {
92+
long c = a / b;
93+
long d;
94+
d = a;
95+
a = b;
96+
b = d % b;
97+
d = p;
98+
p = q;
99+
q = d - c * q;
100+
}
101+
return p < 0 ? p + mod : p;
102+
}
103+
104+
private final int rank2;
105+
public final long[] root;
106+
public final long[] iroot;
107+
public final long[] rate2;
108+
public final long[] irate2;
109+
public final long[] rate3;
110+
public final long[] irate3;
111+
112+
public FftInfo(int g, int mod) {
113+
rank2 = bsfConstexpr(mod - 1);
114+
root = new long[rank2 + 1];
115+
iroot = new long[rank2 + 1];
116+
rate2 = new long[Math.max(0, rank2 - 2 + 1)];
117+
irate2 = new long[Math.max(0, rank2 - 2 + 1)];
118+
rate3 = new long[Math.max(0, rank2 - 3 + 1)];
119+
irate3 = new long[Math.max(0, rank2 - 3 + 1)];
120+
121+
root[rank2] = pow(g, (mod - 1) >> rank2, mod);
122+
iroot[rank2] = inv(root[rank2], mod);
123+
for (int i = rank2 - 1; i >= 0; i--) {
124+
root[i] = root[i + 1] * root[i + 1] % mod;
125+
iroot[i] = iroot[i + 1] * iroot[i + 1] % mod;
126+
}
127+
128+
{
129+
long prod = 1, iprod = 1;
130+
for (int i = 0; i <= rank2 - 2; i++) {
131+
rate2[i] = root[i + 2] * prod % mod;
132+
irate2[i] = iroot[i + 2] * iprod % mod;
133+
prod = prod * iroot[i + 2] % mod;
134+
iprod = iprod * root[i + 2] % mod;
135+
}
136+
}
137+
{
138+
long prod = 1, iprod = 1;
139+
for (int i = 0; i <= rank2 - 3; i++) {
140+
rate3[i] = root[i + 3] * prod % mod;
141+
irate3[i] = iroot[i + 3] * iprod % mod;
142+
prod = prod * iroot[i + 3] % mod;
143+
iprod = iprod * root[i + 3] % mod;
144+
}
145+
}
146+
}
147+
};
148+
81149
/**
82150
* Garner's algorithm.
83151
*
@@ -104,87 +172,65 @@ private static long garner(long[] c, int[] mods) {
104172
return cnst[n - 1];
105173
}
106174

107-
/**
108-
* Pre-calculation for NTT.
109-
*
110-
* @param mod NTT Prime.
111-
* @param g Primitive root of mod.
112-
* @return Pre-calculation table.
113-
*/
114-
private static long[] sumE(int mod, int g) {
115-
long[] sum_e = new long[30];
116-
long[] es = new long[30];
117-
long[] ies = new long[30];
118-
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
119-
long e = pow(g, (mod - 1) >> cnt2, mod);
120-
long ie = pow(e, mod - 2, mod);
121-
for (int i = cnt2; i >= 2; i--) {
122-
es[i - 2] = e;
123-
ies[i - 2] = ie;
124-
e = e * e % mod;
125-
ie = ie * ie % mod;
126-
}
127-
long now = 1;
128-
for (int i = 0; i <= cnt2 - 2; i++) {
129-
sum_e[i] = es[i] * now % mod;
130-
now = now * ies[i] % mod;
131-
}
132-
return sum_e;
133-
}
134-
135-
/**
136-
* Pre-calculation for inverse NTT.
137-
*
138-
* @param mod Mod.
139-
* @param g Primitive root of mod.
140-
* @return Pre-calculation table.
141-
*/
142-
private static long[] sumIE(int mod, int g) {
143-
long[] sum_ie = new long[30];
144-
long[] es = new long[30];
145-
long[] ies = new long[30];
146-
147-
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
148-
long e = pow(g, (mod - 1) >> cnt2, mod);
149-
long ie = pow(e, mod - 2, mod);
150-
for (int i = cnt2; i >= 2; i--) {
151-
es[i - 2] = e;
152-
ies[i - 2] = ie;
153-
e = e * e % mod;
154-
ie = ie * ie % mod;
155-
}
156-
long now = 1;
157-
for (int i = 0; i <= cnt2 - 2; i++) {
158-
sum_ie[i] = ies[i] * now % mod;
159-
now = now * es[i] % mod;
160-
}
161-
return sum_ie;
162-
}
163-
164175
/**
165176
* Inverse NTT.
166177
*
167178
* @param a Target array.
168-
* @param sumIE Pre-calculation table.
179+
* @param g Primitive root of mod.
169180
* @param mod NTT Prime.
170181
*/
171-
private static void butterflyInv(long[] a, long[] sumIE, int mod) {
182+
private static void butterflyInv(long[] a, int g, int mod) {
172183
int n = a.length;
173184
int h = ceilPow2(n);
174185

175-
for (int ph = h; ph >= 1; ph--) {
176-
int w = 1 << (ph - 1), p = 1 << (h - ph);
177-
long inow = 1;
178-
for (int s = 0; s < w; s++) {
179-
int offset = s << (h - ph + 1);
180-
for (int i = 0; i < p; i++) {
181-
long l = a[i + offset];
182-
long r = a[i + offset + p];
183-
a[i + offset] = (l + r) % mod;
184-
a[i + offset + p] = (mod + l - r) * inow % mod;
186+
FftInfo info = new FftInfo(g, mod);
187+
188+
int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
189+
while (len > 0) {
190+
if (len == 1) {
191+
int p = 1 << (h - len);
192+
long irot = 1;
193+
for (int s = 0; s < (1 << (len - 1)); s++) {
194+
int offset = s << (h - len + 1);
195+
for (int i = 0; i < p; i++) {
196+
long l = a[i + offset];
197+
long r = a[i + offset + p];
198+
a[i + offset] = (l + r) % mod;
199+
a[i + offset + p] = (mod + l - r) % mod * irot % mod;
200+
}
201+
if (s + 1 != (1 << (len - 1))) {
202+
irot *= info.irate2[Integer.numberOfTrailingZeros(~s)];
203+
irot %= mod;
204+
}
185205
}
186-
int x = Integer.numberOfTrailingZeros(~s);
187-
inow = inow * sumIE[x] % mod;
206+
len--;
207+
} else {
208+
// 4-base
209+
int p = 1 << (h - len);
210+
long irot = 1, iimag = info.iroot[2];
211+
for (int s = 0; s < (1 << (len - 2)); s++) {
212+
long irot2 = irot * irot % mod;
213+
long irot3 = irot2 * irot % mod;
214+
int offset = s << (h - len + 2);
215+
for (int i = 0; i < p; i++) {
216+
long a0 = 1L * a[i + offset + 0 * p];
217+
long a1 = 1L * a[i + offset + 1 * p];
218+
long a2 = 1L * a[i + offset + 2 * p];
219+
long a3 = 1L * a[i + offset + 3 * p];
220+
221+
long a2na3iimag = 1L * (mod + a2 - a3) % mod * iimag % mod;
222+
223+
a[i + offset] = (a0 + a1 + a2 + a3) % mod;
224+
a[i + offset + 1 * p] = (a0 + (mod - a1) + a2na3iimag) % mod * irot % mod;
225+
a[i + offset + 2 * p] = (a0 + a1 + (mod - a2) + (mod - a3)) % mod * irot2 % mod;
226+
a[i + offset + 3 * p] = (a0 + (mod - a1) + (mod - a2na3iimag)) % mod * irot3 % mod;
227+
}
228+
if (s + 1 != (1 << (len - 2))) {
229+
irot *= info.irate3[Integer.numberOfTrailingZeros(~s)];
230+
irot %= mod;
231+
}
232+
}
233+
len -= 2;
188234
}
189235
}
190236
}
@@ -193,26 +239,61 @@ private static void butterflyInv(long[] a, long[] sumIE, int mod) {
193239
* Inverse NTT.
194240
*
195241
* @param a Target array.
196-
* @param sumE Pre-calculation table.
242+
* @param g Primitive root of mod.
197243
* @param mod NTT Prime.
198244
*/
199-
private static void butterfly(long[] a, long[] sumE, int mod) {
245+
private static void butterfly(long[] a, int g, int mod) {
200246
int n = a.length;
201247
int h = ceilPow2(n);
202248

203-
for (int ph = 1; ph <= h; ph++) {
204-
int w = 1 << (ph - 1), p = 1 << (h - ph);
205-
long now = 1;
206-
for (int s = 0; s < w; s++) {
207-
int offset = s << (h - ph + 1);
208-
for (int i = 0; i < p; i++) {
209-
long l = a[i + offset];
210-
long r = a[i + offset + p] * now % mod;
211-
a[i + offset] = (l + r) % mod;
212-
a[i + offset + p] = (l - r + mod) % mod;
249+
FftInfo info = new FftInfo(g, mod);
250+
251+
int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
252+
while (len < h) {
253+
if (h - len == 1) {
254+
int p = 1 << (h - len - 1);
255+
long rot = 1;
256+
for (int s = 0; s < (1 << len); s++) {
257+
int offset = s << (h - len);
258+
for (int i = 0; i < p; i++) {
259+
long l = a[i + offset];
260+
long r = a[i + offset + p] * rot % mod;
261+
a[i + offset] = (l + r) % mod;
262+
a[i + offset + p] = (l + mod - r) % mod;
263+
}
264+
if (s + 1 != (1 << len)) {
265+
rot *= info.rate2[Integer.numberOfTrailingZeros(~s)];
266+
rot %= mod;
267+
}
268+
}
269+
len++;
270+
} else {
271+
// 4-base
272+
int p = 1 << (h - len - 2);
273+
long rot = 1, imag = info.root[2];
274+
for (int s = 0; s < (1 << len); s++) {
275+
long rot2 = rot * rot % mod;
276+
long rot3 = rot2 * rot % mod;
277+
int offset = s << (h - len);
278+
for (int i = 0; i < p; i++) {
279+
long mod2 = 1L * mod * mod;
280+
long a0 = 1L * a[i + offset];
281+
long a1 = 1L * a[i + offset + p] * rot % mod;
282+
long a2 = 1L * a[i + offset + 2 * p] * rot2 % mod;
283+
long a3 = 1L * a[i + offset + 3 * p] * rot3 % mod;
284+
long a1na3imag = 1L * (a1 + mod2 - a3) % mod * imag % mod;
285+
long na2 = mod2 - a2;
286+
a[i + offset] = (a0 + a2 + a1 + a3) % mod;
287+
a[i + offset + 1 * p] = (a0 + a2 + (2 * mod2 - (a1 + a3))) % mod;
288+
a[i + offset + 2 * p] = (a0 + na2 + a1na3imag) % mod;
289+
a[i + offset + 3 * p] = (a0 + na2 + (mod2 - a1na3imag)) % mod;
290+
}
291+
if (s + 1 != (1 << len)) {
292+
rot *= info.rate3[Integer.numberOfTrailingZeros(~s)];
293+
rot %= mod;
294+
}
213295
}
214-
int x = Integer.numberOfTrailingZeros(~s);
215-
now = now * sumE[x] % mod;
296+
len += 2;
216297
}
217298
}
218299
}
@@ -241,15 +322,13 @@ public static long[] convolution(long[] a, long[] b, int mod) {
241322
}
242323

243324
int g = primitiveRoot(mod);
244-
long[] sume = sumE(mod, g);
245-
long[] sumie = sumIE(mod, g);
246325

247-
butterfly(a, sume, mod);
248-
butterfly(b, sume, mod);
326+
butterfly(a, g, mod);
327+
butterfly(b, g, mod);
249328
for (int i = 0; i < z; i++) {
250329
a[i] = a[i] * b[i] % mod;
251330
}
252-
butterflyInv(a, sumie, mod);
331+
butterflyInv(a, g, mod);
253332
a = java.util.Arrays.copyOf(a, n + m - 1);
254333

255334
long iz = pow(z, mod - 2, mod);

0 commit comments

Comments
 (0)