11/**
22 * Convolution.
33 *
4- * @verified https://atcoder.jp/contests/practice2/tasks/practice2_f
5- * @verified https://judge.yosupo.jp/problem/convolution_mod_1000000007
4+ * @verified https://atcoder.jp/contests/practice2/submissions/24449847
5+ * @verified https://judge.yosupo.jp/submission/53841
66 */
77class Convolution {
88 /**
@@ -78,6 +78,74 @@ private static int ceilPow2(int n) {
7878 return x ;
7979 }
8080
81+ private static class FftInfo {
82+ private static int bsfConstexpr (int n ) {
83+ int x = 0 ;
84+ while ((n & (1 << x )) == 0 ) x ++;
85+ return x ;
86+ }
87+
88+ private static long inv (long a , long mod ) {
89+ long b = mod ;
90+ long p = 1 , q = 0 ;
91+ while (b > 0 ) {
92+ long c = a / b ;
93+ long d ;
94+ d = a ;
95+ a = b ;
96+ b = d % b ;
97+ d = p ;
98+ p = q ;
99+ q = d - c * q ;
100+ }
101+ return p < 0 ? p + mod : p ;
102+ }
103+
104+ private final int rank2 ;
105+ public final long [] root ;
106+ public final long [] iroot ;
107+ public final long [] rate2 ;
108+ public final long [] irate2 ;
109+ public final long [] rate3 ;
110+ public final long [] irate3 ;
111+
112+ public FftInfo (int g , int mod ) {
113+ rank2 = bsfConstexpr (mod - 1 );
114+ root = new long [rank2 + 1 ];
115+ iroot = new long [rank2 + 1 ];
116+ rate2 = new long [Math .max (0 , rank2 - 2 + 1 )];
117+ irate2 = new long [Math .max (0 , rank2 - 2 + 1 )];
118+ rate3 = new long [Math .max (0 , rank2 - 3 + 1 )];
119+ irate3 = new long [Math .max (0 , rank2 - 3 + 1 )];
120+
121+ root [rank2 ] = pow (g , (mod - 1 ) >> rank2 , mod );
122+ iroot [rank2 ] = inv (root [rank2 ], mod );
123+ for (int i = rank2 - 1 ; i >= 0 ; i --) {
124+ root [i ] = root [i + 1 ] * root [i + 1 ] % mod ;
125+ iroot [i ] = iroot [i + 1 ] * iroot [i + 1 ] % mod ;
126+ }
127+
128+ {
129+ long prod = 1 , iprod = 1 ;
130+ for (int i = 0 ; i <= rank2 - 2 ; i ++) {
131+ rate2 [i ] = root [i + 2 ] * prod % mod ;
132+ irate2 [i ] = iroot [i + 2 ] * iprod % mod ;
133+ prod = prod * iroot [i + 2 ] % mod ;
134+ iprod = iprod * root [i + 2 ] % mod ;
135+ }
136+ }
137+ {
138+ long prod = 1 , iprod = 1 ;
139+ for (int i = 0 ; i <= rank2 - 3 ; i ++) {
140+ rate3 [i ] = root [i + 3 ] * prod % mod ;
141+ irate3 [i ] = iroot [i + 3 ] * iprod % mod ;
142+ prod = prod * iroot [i + 3 ] % mod ;
143+ iprod = iprod * root [i + 3 ] % mod ;
144+ }
145+ }
146+ }
147+ };
148+
81149 /**
82150 * Garner's algorithm.
83151 *
@@ -104,87 +172,65 @@ private static long garner(long[] c, int[] mods) {
104172 return cnst [n - 1 ];
105173 }
106174
107- /**
108- * Pre-calculation for NTT.
109- *
110- * @param mod NTT Prime.
111- * @param g Primitive root of mod.
112- * @return Pre-calculation table.
113- */
114- private static long [] sumE (int mod , int g ) {
115- long [] sum_e = new long [30 ];
116- long [] es = new long [30 ];
117- long [] ies = new long [30 ];
118- int cnt2 = Integer .numberOfTrailingZeros (mod - 1 );
119- long e = pow (g , (mod - 1 ) >> cnt2 , mod );
120- long ie = pow (e , mod - 2 , mod );
121- for (int i = cnt2 ; i >= 2 ; i --) {
122- es [i - 2 ] = e ;
123- ies [i - 2 ] = ie ;
124- e = e * e % mod ;
125- ie = ie * ie % mod ;
126- }
127- long now = 1 ;
128- for (int i = 0 ; i <= cnt2 - 2 ; i ++) {
129- sum_e [i ] = es [i ] * now % mod ;
130- now = now * ies [i ] % mod ;
131- }
132- return sum_e ;
133- }
134-
135- /**
136- * Pre-calculation for inverse NTT.
137- *
138- * @param mod Mod.
139- * @param g Primitive root of mod.
140- * @return Pre-calculation table.
141- */
142- private static long [] sumIE (int mod , int g ) {
143- long [] sum_ie = new long [30 ];
144- long [] es = new long [30 ];
145- long [] ies = new long [30 ];
146-
147- int cnt2 = Integer .numberOfTrailingZeros (mod - 1 );
148- long e = pow (g , (mod - 1 ) >> cnt2 , mod );
149- long ie = pow (e , mod - 2 , mod );
150- for (int i = cnt2 ; i >= 2 ; i --) {
151- es [i - 2 ] = e ;
152- ies [i - 2 ] = ie ;
153- e = e * e % mod ;
154- ie = ie * ie % mod ;
155- }
156- long now = 1 ;
157- for (int i = 0 ; i <= cnt2 - 2 ; i ++) {
158- sum_ie [i ] = ies [i ] * now % mod ;
159- now = now * es [i ] % mod ;
160- }
161- return sum_ie ;
162- }
163-
164175 /**
165176 * Inverse NTT.
166177 *
167178 * @param a Target array.
168- * @param sumIE Pre-calculation table .
179+ * @param g Primitive root of mod .
169180 * @param mod NTT Prime.
170181 */
171- private static void butterflyInv (long [] a , long [] sumIE , int mod ) {
182+ private static void butterflyInv (long [] a , int g , int mod ) {
172183 int n = a .length ;
173184 int h = ceilPow2 (n );
174185
175- for (int ph = h ; ph >= 1 ; ph --) {
176- int w = 1 << (ph - 1 ), p = 1 << (h - ph );
177- long inow = 1 ;
178- for (int s = 0 ; s < w ; s ++) {
179- int offset = s << (h - ph + 1 );
180- for (int i = 0 ; i < p ; i ++) {
181- long l = a [i + offset ];
182- long r = a [i + offset + p ];
183- a [i + offset ] = (l + r ) % mod ;
184- a [i + offset + p ] = (mod + l - r ) * inow % mod ;
186+ FftInfo info = new FftInfo (g , mod );
187+
188+ int len = h ; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
189+ while (len > 0 ) {
190+ if (len == 1 ) {
191+ int p = 1 << (h - len );
192+ long irot = 1 ;
193+ for (int s = 0 ; s < (1 << (len - 1 )); s ++) {
194+ int offset = s << (h - len + 1 );
195+ for (int i = 0 ; i < p ; i ++) {
196+ long l = a [i + offset ];
197+ long r = a [i + offset + p ];
198+ a [i + offset ] = (l + r ) % mod ;
199+ a [i + offset + p ] = (mod + l - r ) % mod * irot % mod ;
200+ }
201+ if (s + 1 != (1 << (len - 1 ))) {
202+ irot *= info .irate2 [Integer .numberOfTrailingZeros (~s )];
203+ irot %= mod ;
204+ }
185205 }
186- int x = Integer .numberOfTrailingZeros (~s );
187- inow = inow * sumIE [x ] % mod ;
206+ len --;
207+ } else {
208+ // 4-base
209+ int p = 1 << (h - len );
210+ long irot = 1 , iimag = info .iroot [2 ];
211+ for (int s = 0 ; s < (1 << (len - 2 )); s ++) {
212+ long irot2 = irot * irot % mod ;
213+ long irot3 = irot2 * irot % mod ;
214+ int offset = s << (h - len + 2 );
215+ for (int i = 0 ; i < p ; i ++) {
216+ long a0 = 1L * a [i + offset + 0 * p ];
217+ long a1 = 1L * a [i + offset + 1 * p ];
218+ long a2 = 1L * a [i + offset + 2 * p ];
219+ long a3 = 1L * a [i + offset + 3 * p ];
220+
221+ long a2na3iimag = 1L * (mod + a2 - a3 ) % mod * iimag % mod ;
222+
223+ a [i + offset ] = (a0 + a1 + a2 + a3 ) % mod ;
224+ a [i + offset + 1 * p ] = (a0 + (mod - a1 ) + a2na3iimag ) % mod * irot % mod ;
225+ a [i + offset + 2 * p ] = (a0 + a1 + (mod - a2 ) + (mod - a3 )) % mod * irot2 % mod ;
226+ a [i + offset + 3 * p ] = (a0 + (mod - a1 ) + (mod - a2na3iimag )) % mod * irot3 % mod ;
227+ }
228+ if (s + 1 != (1 << (len - 2 ))) {
229+ irot *= info .irate3 [Integer .numberOfTrailingZeros (~s )];
230+ irot %= mod ;
231+ }
232+ }
233+ len -= 2 ;
188234 }
189235 }
190236 }
@@ -193,26 +239,61 @@ private static void butterflyInv(long[] a, long[] sumIE, int mod) {
193239 * Inverse NTT.
194240 *
195241 * @param a Target array.
196- * @param sumE Pre-calculation table .
242+ * @param g Primitive root of mod .
197243 * @param mod NTT Prime.
198244 */
199- private static void butterfly (long [] a , long [] sumE , int mod ) {
245+ private static void butterfly (long [] a , int g , int mod ) {
200246 int n = a .length ;
201247 int h = ceilPow2 (n );
202248
203- for (int ph = 1 ; ph <= h ; ph ++) {
204- int w = 1 << (ph - 1 ), p = 1 << (h - ph );
205- long now = 1 ;
206- for (int s = 0 ; s < w ; s ++) {
207- int offset = s << (h - ph + 1 );
208- for (int i = 0 ; i < p ; i ++) {
209- long l = a [i + offset ];
210- long r = a [i + offset + p ] * now % mod ;
211- a [i + offset ] = (l + r ) % mod ;
212- a [i + offset + p ] = (l - r + mod ) % mod ;
249+ FftInfo info = new FftInfo (g , mod );
250+
251+ int len = 0 ; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
252+ while (len < h ) {
253+ if (h - len == 1 ) {
254+ int p = 1 << (h - len - 1 );
255+ long rot = 1 ;
256+ for (int s = 0 ; s < (1 << len ); s ++) {
257+ int offset = s << (h - len );
258+ for (int i = 0 ; i < p ; i ++) {
259+ long l = a [i + offset ];
260+ long r = a [i + offset + p ] * rot % mod ;
261+ a [i + offset ] = (l + r ) % mod ;
262+ a [i + offset + p ] = (l + mod - r ) % mod ;
263+ }
264+ if (s + 1 != (1 << len )) {
265+ rot *= info .rate2 [Integer .numberOfTrailingZeros (~s )];
266+ rot %= mod ;
267+ }
268+ }
269+ len ++;
270+ } else {
271+ // 4-base
272+ int p = 1 << (h - len - 2 );
273+ long rot = 1 , imag = info .root [2 ];
274+ for (int s = 0 ; s < (1 << len ); s ++) {
275+ long rot2 = rot * rot % mod ;
276+ long rot3 = rot2 * rot % mod ;
277+ int offset = s << (h - len );
278+ for (int i = 0 ; i < p ; i ++) {
279+ long mod2 = 1L * mod * mod ;
280+ long a0 = 1L * a [i + offset ];
281+ long a1 = 1L * a [i + offset + p ] * rot % mod ;
282+ long a2 = 1L * a [i + offset + 2 * p ] * rot2 % mod ;
283+ long a3 = 1L * a [i + offset + 3 * p ] * rot3 % mod ;
284+ long a1na3imag = 1L * (a1 + mod2 - a3 ) % mod * imag % mod ;
285+ long na2 = mod2 - a2 ;
286+ a [i + offset ] = (a0 + a2 + a1 + a3 ) % mod ;
287+ a [i + offset + 1 * p ] = (a0 + a2 + (2 * mod2 - (a1 + a3 ))) % mod ;
288+ a [i + offset + 2 * p ] = (a0 + na2 + a1na3imag ) % mod ;
289+ a [i + offset + 3 * p ] = (a0 + na2 + (mod2 - a1na3imag )) % mod ;
290+ }
291+ if (s + 1 != (1 << len )) {
292+ rot *= info .rate3 [Integer .numberOfTrailingZeros (~s )];
293+ rot %= mod ;
294+ }
213295 }
214- int x = Integer .numberOfTrailingZeros (~s );
215- now = now * sumE [x ] % mod ;
296+ len += 2 ;
216297 }
217298 }
218299 }
@@ -241,15 +322,13 @@ public static long[] convolution(long[] a, long[] b, int mod) {
241322 }
242323
243324 int g = primitiveRoot (mod );
244- long [] sume = sumE (mod , g );
245- long [] sumie = sumIE (mod , g );
246325
247- butterfly (a , sume , mod );
248- butterfly (b , sume , mod );
326+ butterfly (a , g , mod );
327+ butterfly (b , g , mod );
249328 for (int i = 0 ; i < z ; i ++) {
250329 a [i ] = a [i ] * b [i ] % mod ;
251330 }
252- butterflyInv (a , sumie , mod );
331+ butterflyInv (a , g , mod );
253332 a = java .util .Arrays .copyOf (a , n + m - 1 );
254333
255334 long iz = pow (z , mod - 2 , mod );
0 commit comments