Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 171 additions & 92 deletions Convolution/Convolution.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/**
* Convolution.
*
* @verified https://atcoder.jp/contests/practice2/tasks/practice2_f
* @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
* @verified https://atcoder.jp/contests/practice2/submissions/24449847
* @verified https://judge.yosupo.jp/submission/53841
*/
class Convolution {
/**
Expand Down Expand Up @@ -78,6 +78,74 @@ private static int ceilPow2(int n) {
return x;
}

private static class FftInfo {
private static int bsfConstexpr(int n) {
int x = 0;
while ((n & (1 << x)) == 0) x++;
return x;
}

private static long inv(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}

private final int rank2;
public final long[] root;
public final long[] iroot;
public final long[] rate2;
public final long[] irate2;
public final long[] rate3;
public final long[] irate3;

public FftInfo(int g, int mod) {
rank2 = bsfConstexpr(mod - 1);
root = new long[rank2 + 1];
iroot = new long[rank2 + 1];
rate2 = new long[Math.max(0, rank2 - 2 + 1)];
irate2 = new long[Math.max(0, rank2 - 2 + 1)];
rate3 = new long[Math.max(0, rank2 - 3 + 1)];
irate3 = new long[Math.max(0, rank2 - 3 + 1)];

root[rank2] = pow(g, (mod - 1) >> rank2, mod);
iroot[rank2] = inv(root[rank2], mod);
for (int i = rank2 - 1; i >= 0; i--) {
root[i] = root[i + 1] * root[i + 1] % mod;
iroot[i] = iroot[i + 1] * iroot[i + 1] % mod;
}

{
long prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 2; i++) {
rate2[i] = root[i + 2] * prod % mod;
irate2[i] = iroot[i + 2] * iprod % mod;
prod = prod * iroot[i + 2] % mod;
iprod = iprod * root[i + 2] % mod;
}
}
{
long prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 3; i++) {
rate3[i] = root[i + 3] * prod % mod;
irate3[i] = iroot[i + 3] * iprod % mod;
prod = prod * iroot[i + 3] % mod;
iprod = iprod * root[i + 3] % mod;
}
}
}
};

/**
* Garner's algorithm.
*
Expand All @@ -104,87 +172,65 @@ private static long garner(long[] c, int[] mods) {
return cnst[n - 1];
}

/**
* Pre-calculation for NTT.
*
* @param mod NTT Prime.
* @param g Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumE(int mod, int g) {
long[] sum_e = new long[30];
long[] es = new long[30];
long[] ies = new long[30];
int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_e[i] = es[i] * now % mod;
now = now * ies[i] % mod;
}
return sum_e;
}

/**
* Pre-calculation for inverse NTT.
*
* @param mod Mod.
* @param g Primitive root of mod.
* @return Pre-calculation table.
*/
private static long[] sumIE(int mod, int g) {
long[] sum_ie = new long[30];
long[] es = new long[30];
long[] ies = new long[30];

int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
long e = pow(g, (mod - 1) >> cnt2, mod);
long ie = pow(e, mod - 2, mod);
for (int i = cnt2; i >= 2; i--) {
es[i - 2] = e;
ies[i - 2] = ie;
e = e * e % mod;
ie = ie * ie % mod;
}
long now = 1;
for (int i = 0; i < cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now % mod;
now = now * es[i] % mod;
}
return sum_ie;
}

/**
* Inverse NTT.
*
* @param a Target array.
* @param sumIE Pre-calculation table.
* @param g Primitive root of mod.
* @param mod NTT Prime.
*/
private static void butterflyInv(long[] a, long[] sumIE, int mod) {
private static void butterflyInv(long[] a, int g, int mod) {
int n = a.length;
int h = ceilPow2(n);

for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p];
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (mod + l - r) * inow % mod;
FftInfo info = new FftInfo(g, mod);

int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len > 0) {
if (len == 1) {
int p = 1 << (h - len);
long irot = 1;
for (int s = 0; s < (1 << (len - 1)); s++) {
int offset = s << (h - len + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p];
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (mod + l - r) % mod * irot % mod;
}
if (s + 1 != (1 << (len - 1))) {
irot *= info.irate2[Integer.numberOfTrailingZeros(~s)];
irot %= mod;
}
}
int x = Integer.numberOfTrailingZeros(~s);
inow = inow * sumIE[x] % mod;
len--;
} else {
// 4-base
int p = 1 << (h - len);
long irot = 1, iimag = info.iroot[2];
for (int s = 0; s < (1 << (len - 2)); s++) {
long irot2 = irot * irot % mod;
long irot3 = irot2 * irot % mod;
int offset = s << (h - len + 2);
for (int i = 0; i < p; i++) {
long a0 = 1L * a[i + offset + 0 * p];
long a1 = 1L * a[i + offset + 1 * p];
long a2 = 1L * a[i + offset + 2 * p];
long a3 = 1L * a[i + offset + 3 * p];

long a2na3iimag = 1L * (mod + a2 - a3) % mod * iimag % mod;

a[i + offset] = (a0 + a1 + a2 + a3) % mod;
a[i + offset + 1 * p] = (a0 + (mod - a1) + a2na3iimag) % mod * irot % mod;
a[i + offset + 2 * p] = (a0 + a1 + (mod - a2) + (mod - a3)) % mod * irot2 % mod;
a[i + offset + 3 * p] = (a0 + (mod - a1) + (mod - a2na3iimag)) % mod * irot3 % mod;
}
if (s + 1 != (1 << (len - 2))) {
irot *= info.irate3[Integer.numberOfTrailingZeros(~s)];
irot %= mod;
}
}
len -= 2;
}
}
}
Expand All @@ -193,26 +239,61 @@ private static void butterflyInv(long[] a, long[] sumIE, int mod) {
* Inverse NTT.
*
* @param a Target array.
* @param sumE Pre-calculation table.
* @param g Primitive root of mod.
* @param mod NTT Prime.
*/
private static void butterfly(long[] a, long[] sumE, int mod) {
private static void butterfly(long[] a, int g, int mod) {
int n = a.length;
int h = ceilPow2(n);

for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
long now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p] * now % mod;
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (l - r + mod) % mod;
FftInfo info = new FftInfo(g, mod);

int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len < h) {
if (h - len == 1) {
int p = 1 << (h - len - 1);
long rot = 1;
for (int s = 0; s < (1 << len); s++) {
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
long l = a[i + offset];
long r = a[i + offset + p] * rot % mod;
a[i + offset] = (l + r) % mod;
a[i + offset + p] = (l + mod - r) % mod;
}
if (s + 1 != (1 << len)) {
rot *= info.rate2[Integer.numberOfTrailingZeros(~s)];
rot %= mod;
}
}
len++;
} else {
// 4-base
int p = 1 << (h - len - 2);
long rot = 1, imag = info.root[2];
for (int s = 0; s < (1 << len); s++) {
long rot2 = rot * rot % mod;
long rot3 = rot2 * rot % mod;
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
long mod2 = 1L * mod * mod;
long a0 = 1L * a[i + offset];
long a1 = 1L * a[i + offset + p] * rot % mod;
long a2 = 1L * a[i + offset + 2 * p] * rot2 % mod;
long a3 = 1L * a[i + offset + 3 * p] * rot3 % mod;
long a1na3imag = 1L * (a1 + mod2 - a3) % mod * imag % mod;
long na2 = mod2 - a2;
a[i + offset] = (a0 + a2 + a1 + a3) % mod;
a[i + offset + 1 * p] = (a0 + a2 + (2 * mod2 - (a1 + a3))) % mod;
a[i + offset + 2 * p] = (a0 + na2 + a1na3imag) % mod;
a[i + offset + 3 * p] = (a0 + na2 + (mod2 - a1na3imag)) % mod;
}
if (s + 1 != (1 << len)) {
rot *= info.rate3[Integer.numberOfTrailingZeros(~s)];
rot %= mod;
}
}
int x = Integer.numberOfTrailingZeros(~s);
now = now * sumE[x] % mod;
len += 2;
}
}
}
Expand Down Expand Up @@ -241,15 +322,13 @@ public static long[] convolution(long[] a, long[] b, int mod) {
}

int g = primitiveRoot(mod);
long[] sume = sumE(mod, g);
long[] sumie = sumIE(mod, g);

butterfly(a, sume, mod);
butterfly(b, sume, mod);
butterfly(a, g, mod);
butterfly(b, g, mod);
for (int i = 0; i < z; i++) {
a[i] = a[i] * b[i] % mod;
}
butterflyInv(a, sumie, mod);
butterflyInv(a, g, mod);
a = java.util.Arrays.copyOf(a, n + m - 1);

long iz = pow(z, mod - 2, mod);
Expand Down