From c98d6209fa4f0457ca699f2ebcf1b333a46bf48f Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sun, 20 Sep 2020 16:18:07 +0900 Subject: [PATCH] Port convolution --- README.md | 5 +- README_ja.md | 5 +- atcoder/convolution.py | 181 ++++++++++++++++++++++++++++ atcoder/modint.py | 3 + example/convolution_practice.py | 19 +++ example/convolution_practice_int.py | 19 +++ 6 files changed, 224 insertions(+), 8 deletions(-) create mode 100644 atcoder/convolution.py create mode 100644 example/convolution_practice.py create mode 100644 example/convolution_practice_int.py diff --git a/README.md b/README.md index d01024e..d87473b 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ ac-library-python is a Python port of [AtCoder Library (ACL)](https://atcoder.jp #### Math ++ convolution + math + modint @@ -37,10 +38,6 @@ ac-library-python is a Python port of [AtCoder Library (ACL)](https://atcoder.jp + string -#### Math - -+ convolution - ## Install ``` diff --git a/README_ja.md b/README_ja.md index 8187e20..6069e95 100644 --- a/README_ja.md +++ b/README_ja.md @@ -18,6 +18,7 @@ ac-library-pythonは、[AtCoder Library (ACL)](https://atcoder.jp/posts/517)のP #### 数学 ++ convolution + math + modint @@ -35,10 +36,6 @@ ac-library-pythonは、[AtCoder Library (ACL)](https://atcoder.jp/posts/517)のP + string -#### 数学 - -+ convolution - ## インストール ``` diff --git a/atcoder/convolution.py b/atcoder/convolution.py new file mode 100644 index 0000000..87f6856 --- /dev/null +++ b/atcoder/convolution.py @@ -0,0 +1,181 @@ +import typing + +import atcoder._bit +import atcoder._math +from atcoder.modint import ModContext, Modint + + +_sum_e = {} # _sum_e[i] = ies[0] * ... * ies[i - 1] * es[i] + + +def _butterfly(a: typing.List[Modint]) -> None: + g = atcoder._math._primitive_root(a[0].mod()) + n = len(a) + h = atcoder._bit._ceil_pow2(n) + + if a[0].mod() not in _sum_e: + es = [Modint(0)] * 30 # es[i]^(2^(2+i)) == 1 + ies = [Modint(0)] * 30 + cnt2 = atcoder._bit._bsf(a[0].mod() - 1) + e = Modint(g) ** ((a[0].mod() - 1) >> cnt2) + ie = e.inv() + for i in range(cnt2, 1, -1): + # e^(2^i) == 1 + es[i - 2] = e + ies[i - 2] = ie + e = e * e + ie = ie * ie + sum_e = [Modint(0)] * 30 + now = Modint(1) + for i in range(cnt2 - 2): + sum_e[i] = es[i] * now + now *= ies[i] + _sum_e[a[0].mod()] = sum_e + else: + sum_e = _sum_e[a[0].mod()] + + for ph in range(1, h + 1): + w = 1 << (ph - 1) + p = 1 << (h - ph) + now = Modint(1) + for s in range(w): + offset = s << (h - ph + 1) + for i in range(p): + left = a[i + offset] + right = a[i + offset + p] * now + a[i + offset] = left + right + a[i + offset + p] = left - right + now *= sum_e[atcoder._bit._bsf(~s)] + + +_sum_ie = {} # _sum_ie[i] = es[0] * ... * es[i - 1] * ies[i] + + +def _butterfly_inv(a: typing.List[Modint]) -> None: + g = atcoder._math._primitive_root(a[0].mod()) + n = len(a) + h = atcoder._bit._ceil_pow2(n) + + if a[0].mod() not in _sum_ie: + es = [0] * 30 # es[i]^(2^(2+i)) == 1 + ies = [0] * 30 + cnt2 = atcoder._bit._bsf(a[0].mod() - 1) + e = Modint(g) ** ((a[0].mod() - 1) >> cnt2) + ie = e.inv() + for i in range(cnt2, 1, -1): + # e^(2^i) == 1 + es[i - 2] = e + ies[i - 2] = ie + e = e * e + ie = ie * ie + sum_ie = [0] * 30 + now = Modint(1) + for i in range(cnt2 - 2): + sum_ie[i] = ies[i] * now + now *= es[i] + _sum_ie[a[0].mod()] = sum_ie + else: + sum_ie = _sum_ie[a[0].mod()] + + for ph in range(h, 0, -1): + w = 1 << (ph - 1) + p = 1 << (h - ph) + inow = Modint(1) + for s in range(w): + offset = s << (h - ph + 1) + for i in range(p): + left = a[i + offset] + right = a[i + offset + p] + a[i + offset] = left + right + a[i + offset + p] = Modint( + (a[0].mod() + left.val() - right.val()) * inow.val()) + inow *= sum_ie[atcoder._bit._bsf(~s)] + + +def convolution_mod(a: typing.List[Modint], + b: typing.List[Modint]) -> typing.List[Modint]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + if min(n, m) <= 60: + if n < m: + n, m = m, n + a, b = b, a + ans = [Modint(0) for _ in range(n + m - 1)] + for i in range(n): + for j in range(m): + ans[i + j] += a[i] * b[j] + return ans + + z = 1 << atcoder._bit._ceil_pow2(n + m - 1) + + while len(a) < z: + a.append(Modint(0)) + _butterfly(a) + + while len(b) < z: + b.append(Modint(0)) + _butterfly(b) + + for i in range(z): + a[i] *= b[i] + _butterfly_inv(a) + a = a[:n + m - 1] + + iz = Modint(z).inv() + for i in range(n + m - 1): + a[i] *= iz + + return a + + +def convolution(mod: int, a: typing.List[typing.Any], + b: typing.List[typing.Any]) -> typing.List[typing.Any]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + with ModContext(mod): + a2 = list(map(Modint, a)) + b2 = list(map(Modint, b)) + + return list(map(lambda c: c.val(), convolution_mod(a2, b2))) + + +def convolution_int( + a: typing.List[int], b: typing.List[int]) -> typing.List[int]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + mod1 = 754974721 # 2^24 + mod2 = 167772161 # 2^25 + mod3 = 469762049 # 2^26 + m2m3 = mod2 * mod3 + m1m3 = mod1 * mod3 + m1m2 = mod1 * mod2 + m1m2m3 = mod1 * mod2 * mod3 + + i1 = atcoder._math._inv_gcd(mod2 * mod3, mod1)[1] + i2 = atcoder._math._inv_gcd(mod1 * mod3, mod2)[1] + i3 = atcoder._math._inv_gcd(mod1 * mod2, mod3)[1] + + c1 = convolution(mod1, a, b) + c2 = convolution(mod2, a, b) + c3 = convolution(mod3, a, b) + + c = [0] * (n + m - 1) + for i in range(n + m - 1): + c[i] += (c1[i] * i1) % mod1 * m2m3 + c[i] += (c2[i] * i2) % mod2 * m1m3 + c[i] += (c3[i] * i3) % mod3 * m1m2 + c[i] %= m1m2m3 + + return c diff --git a/atcoder/modint.py b/atcoder/modint.py index f91e746..5ac8545 100644 --- a/atcoder/modint.py +++ b/atcoder/modint.py @@ -32,6 +32,9 @@ def __init__(self, v: int = 0) -> None: else: self._v = v % self._mod + def mod(self) -> int: + return self._mod + def val(self) -> int: return self._v diff --git a/example/convolution_practice.py b/example/convolution_practice.py new file mode 100644 index 0000000..9ade528 --- /dev/null +++ b/example/convolution_practice.py @@ -0,0 +1,19 @@ +# https://atcoder.jp/contests/practice2/tasks/practice2_f + +import sys + +from atcoder.convolution import convolution + + +def main() -> None: + n, m = map(int, sys.stdin.readline().split()) + a = list(map(int, sys.stdin.readline().split())) + b = list(map(int, sys.stdin.readline().split())) + + c = convolution(998244353, a, b) + + print(' '.join(map(str, c))) + + +if __name__ == '__main__': + main() diff --git a/example/convolution_practice_int.py b/example/convolution_practice_int.py new file mode 100644 index 0000000..e0985a2 --- /dev/null +++ b/example/convolution_practice_int.py @@ -0,0 +1,19 @@ +# https://atcoder.jp/contests/practice2/tasks/practice2_f + +import sys + +from atcoder.convolution import convolution_int + + +def main() -> None: + n, m = map(int, sys.stdin.readline().split()) + a = list(map(int, sys.stdin.readline().split())) + b = list(map(int, sys.stdin.readline().split())) + + c = convolution_int(a, b) + + print(' '.join([str(ci % 998244353) for ci in c])) + + +if __name__ == '__main__': + main()