From e8fd0d8934e1741e441ce31a19015dd34457d8ad Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Fri, 11 Sep 2020 08:50:17 +0900 Subject: [PATCH 1/8] tmp --- .gitignore | 2 + README.md | 1 + atcoder/_bit.py | 15 +++ atcoder/_math.py | 142 ++++++++++++++++++++++++++++ atcoder/_queue.py | 35 +++++++ atcoder/_scc.py | 98 +++++++++++++++++++ atcoder/convolution.py | 210 +++++++++++++++++++++++++++++++++++++++++ atcoder/fenwicktree.py | 30 ++++++ atcoder/math.py | 99 +++++++++++++++++++ atcoder/scc.py | 17 ++++ atcoder/twosat.py | 35 +++++++ 11 files changed, 684 insertions(+) create mode 100644 README.md create mode 100644 atcoder/_bit.py create mode 100644 atcoder/_math.py create mode 100644 atcoder/_queue.py create mode 100644 atcoder/_scc.py create mode 100644 atcoder/convolution.py create mode 100644 atcoder/fenwicktree.py create mode 100644 atcoder/math.py create mode 100644 atcoder/scc.py create mode 100644 atcoder/twosat.py diff --git a/.gitignore b/.gitignore index b6e4761..5ef91b0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +*.hpp + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md new file mode 100644 index 0000000..cb390ad --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +作業用ブランチです。一切テストしてないのでバグだらけだと思います。使わないでください。 diff --git a/atcoder/_bit.py b/atcoder/_bit.py new file mode 100644 index 0000000..9d1c939 --- /dev/null +++ b/atcoder/_bit.py @@ -0,0 +1,15 @@ +def _ceil_pow2(n: int) -> int: + x = 0 + while (1 << x) < n: + x += 1 + + return x + + +def _bsf(n: int) -> int: + x = 0 + while n % 2 == 0: + x += 1 + n //= 2 + + return x diff --git a/atcoder/_math.py b/atcoder/_math.py new file mode 100644 index 0000000..ab34fb3 --- /dev/null +++ b/atcoder/_math.py @@ -0,0 +1,142 @@ +import typing + + +class Barrett: + ''' + Fast moduler by barrett reduction + Reference: https://en.wikipedia.org/wiki/Barrett_reduction + NOTE: reconsider after Ice Lake + ''' + + def __init__(self, m: int) -> None: + self._m = m + self._im = ((1 << 64) - 1) / m + 1 + + def umod(self) -> int: + return self._m + + def mul(self, a: int, b: int) -> int: + ''' + [1] m = 1 + a = b = im = 0, so okay + + [2] m >= 2 + im = ceil(2^64 / m) + -> im * m = 2^64 + r (0 <= r < m) + let z = a*b = c*m + d (0 <= c, d < m) + a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im + c*r + d*im < m*m + m*im < m*m + 2^64 + m <= 2^64 + m*(m+1) < 2^64 * 2 + ((ab * im) >> 64) == c or c + 1 + ''' + + z = a * b + x = (z * self._im) >> 64 + v = z - x * self._m + if self._m <= v: + v += self._m + return v + + +def _is_prime(n: int) -> bool: + ''' + Reference: + M. Forisek and J. Jancina, + Fast Primality Testing for Integers That Fit into a Machine Word + ''' + + if n <= 1: + return False + if n == 2 or n == 7 or n == 61: + return True + if n % 2 == 0: + return False + + d = n - 1 + while d % 2 == 0: + d /= 2 + + for a in (2, 7, 61): + t = d + y = pow(a, t, n) + while t != n - 1 and y != 1 and y != n - 1: + y = y * y % n + t <<= 1 + if y != n - 1 and t % 2 == 0: + return False + return True + + +def _inv_gcd(a: int, b: int) -> typing.Tuple[int, int]: + a %= b + if a == 0: + return (b, 0) + + # Contracts: + # [1] s - m0 * a = 0 (mod b) + # [2] t - m1 * a = 0 (mod b) + # [3] s * |m1| + t * |m0| <= b + s = b + t = a + m0 = 0 + m1 = 1 + + while t: + u = s // t + s -= t * u + m0 -= m1 * u # |m1 * u| <= |m1| * s <= b + + # [3]: + # (s - t * u) * |m1| + t * |m0 - m1 * u| + # <= s * |m1| - t * u * |m1| + t * (|m0| + |m1| * u) + # = s * |m1| + t * |m0| <= b + + s, t = t, s + m0, m1 = m1, m0 + + # by [3]: |m0| <= b/g + # by g != b: |m0| < b/g + if m0 < 0: + m0 += b // s + + return (s, m0) + + +def _primitive_root(m: int) -> int: + if m == 2: + return 1 + if m == 167772161: + return 3 + if m == 469762049: + return 3 + if m == 754974721: + return 11 + if m == 998244353: + return 3 + + divs = [2] + [0] * 19 + cnt = 1 + x = (m - 1) // 2 + while x % 2 == 0: + x //= 2 + + i = 3 + while i * i <= x: + if x % i == 0: + divs[cnt] = i + cnt += 1 + while x % i == 0: + x //= i + i += 2 + + if x > 1: + divs[cnt] = x + cnt += 1 + + g = 2 + while True: + for i in range(cnt): + if pow(g, (m - 1) // divs[i], m) == 1: + break + else: + return g + g += 1 diff --git a/atcoder/_queue.py b/atcoder/_queue.py new file mode 100644 index 0000000..61e9e2e --- /dev/null +++ b/atcoder/_queue.py @@ -0,0 +1,35 @@ +import typing + + +class SimpleQueue: + def __init__(self): + self._payload = [] + self._pos = 0 + self._size = 0 + + def reserve(self, n: int) -> None: + self._payload += [None] * (n - len(self._payload)) + + def __len__(self) -> int: + return self._size + + def empty(self) -> bool: + return self._size == 0 + + def push(self, t: typing.Any) -> None: + i = self._pos + self._size + if len(self._payload) <= i: + self.reserve(i + 1) + self._payload[i] = t + + def front(self) -> typing.Any: + return self._payload[self._pos] + + def clear(self) -> None: + self._payload = [] + self._pos = 0 + self._size = 0 + + def pop(self) -> None: + self._pos += 1 + self._size -= 1 diff --git a/atcoder/_scc.py b/atcoder/_scc.py new file mode 100644 index 0000000..5ab8e11 --- /dev/null +++ b/atcoder/_scc.py @@ -0,0 +1,98 @@ +import copy +import typing + + +class CSR: + def __init__( + self, n: int, edges: typing.List[typing.Tuple[int, int]]) -> None: + self.start = [0] * (n + 1) + self.elist = [0] * len(edges) + + for e in edges: + self.start[e[0] + 1] += 1 + + for i in range(1, n + 1): + self.start[i] += self.start[i - 1] + + counter = copy.deepcopy(self.start) + for e in edges: + self.elist[counter[e[0]]] = e[1] + counter[e[0]] += 1 + + +class SCCGraph: + ''' + Reference: + R. Tarjan, + Depth-First Search and Linear Graph Algorithms + ''' + + def __init__(self, n: int) -> None: + self._n = n + self._edges = [] + + def num_vertices(self) -> int: + return self._n + + def add_edge(self, from_vertex: int, to_vertex: int) -> None: + self._edges.append((from_vertex, to_vertex)) + + def scc_ids(self) -> typing.Tuple[int, typing.List[int]]: + g = CSR(self._n, self._edges) + now_ord = 0 + group_num = 0 + visited = [] + low = [0] * self._n + order = [-1] * self._n + ids = [0] * self._n + + def dfs(v: int) -> None: + nonlocal now_ord + nonlocal group_num + nonlocal visited + nonlocal low + nonlocal order + nonlocal ids + + low[v] = now_ord + order[v] = now_ord + now_ord += 1 + visited.append(v) + for i in range(g.start[v], g.start[v + 1]): + to = g.elist[i] + if order[to] == -1: + dfs(to) + low[v] = min(low[v], low[to]) + else: + low[v] = min(low[v], order[to]) + + if low[v] == order[v]: + while True: + u = visited[-1] + visited.pop() + order[u] = self._n + ids[u] = group_num + if u == v: + break + group_num += 1 + + for i in range(self._n): + if order[i] == -1: + dfs(i) + + for i in range(self._n): + ids[i] = group_num - 1 - ids[i] + + return (group_num, ids) + + def scc(self) -> typing.List[typing.List[int]]: + ids = self.scc_ids() + group_num = ids[0] + counts = [0] * group_num + for x in ids[1]: + counts[x] += 1 + groups = [[] for _ in range(group_num)] + for i in range(self._n): + groups[ids[1][i]].append(i) + + return groups diff --git a/atcoder/convolution.py b/atcoder/convolution.py new file mode 100644 index 0000000..b559a5e --- /dev/null +++ b/atcoder/convolution.py @@ -0,0 +1,210 @@ +import typing + +import atcoder._bit +import atcoder._math +import atcoder.modint +from atcoder.modint import modint + + +_sum_e = {} # _sum_e[i] = ies[0] * ... * ies[i - 1] * es[i] + + +def _butterfly(a: typing.List[modint]) -> None: + g = atcoder._math._primitive_root(a[0].mod()) + n = len(a) + h = atcoder._bit._ceil_pow2(n) + + if a[0].mod() not in _sum_e: + es = [0] * 30 # es[i]^(2^(2+i)) == 1 + ies = [0] * 30 + cnt2 = atcoder._bit._bsf(a[0].mod() - 1) + e = modint(g).pow((a[0].mod() - 1) >> cnt2) + ie = e.inv() + for i in range(cnt2, 1, -1): + # e^(2^i) == 1 + es[i - 2] = e + ies[i - 2] = ie + e *= e + ie *= ie + sum_e = [0] * 30 + now = modint(1) + for i in range(cnt2 - 2): + sum_e[i] = es[i] * now + now *= ies[i] + _sum_e[a[0].mod()] = sum_e + else: + sum_e = _sum_e[a[0].mod()] + + for ph in (1, h + 1): + w = 1 << (ph - 1) + p = 1 << (h - ph) + now = modint(1) + for s in range(w): + offset = s << (h - ph + 1) + for i in range(p): + left = a[i + offset] + right = a[i + offset + p] * now + a[i + offset] = left + right + a[i + offset + p] = left - right + now *= sum_e[atcoder._bit._bsf(~s)] + + +_sum_ie = {} # _sum_ie[i] = es[0] * ... * es[i - 1] * ies[i] + + +def _butterfly_inv(a: typing.List[modint]) -> None: + g = atcoder._math._primitive_root(a[0].mod()) + n = len(a) + h = atcoder._bit.ceil_pow2(n) + + if a[0].mod() not in _sum_ie: + es = [0] * 30 # es[i]^(2^(2+i)) == 1 + ies = [0] * 30 + cnt2 = atcoder._bit._bsf(a[0].mod() - 1) + e = modint(g).pow((a[0].mod() - 1) >> cnt2) + ie = e.inv() + for i in range(cnt2, 1, -1): + # e^(2^i) == 1 + es[i - 2] = e + ies[i - 2] = ie + e *= e + ie *= ie + sum_ie = [0] * 30 + now = modint(1) + for i in range(cnt2 - 2): + sum_ie[i] = ies[i] * now + now *= es[i] + _sum_ie[a[0].mod()] = sum_ie + else: + sum_ie = _sum_ie[a[0].mod()] + + for ph in range(h, 0, -1): + w = 1 << (ph - 1) + p = 1 << (h - ph) + inow = modint(1) + for s in range(w): + offset = s << (h - ph + 1) + for i in range(p): + left = a[i + offset] + right = a[i + offset + p] + a[i + offset] = left + right + a[i + offset + p] = (a[0].mod() + + left.val() - right.val()) * inow.val() + inow *= sum_ie[atcoder._bit._bsf(~s)] + + +def convolution_mod(a: typing.List[modint], + b: typing.List[modint]) -> typing.List[modint]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + if min(n, m) <= 60: + if n < m: + n, m = m, n + a, b = b, a + ans = [0] * (n + m - 1) + for i in range(n): + for j in range(m): + ans[i + j] += a[i] * b[j] + return ans + + z = 1 << atcoder._bit._ceil_pow2(n + m - 1) + + while len(a) < z: + a.append(modint(0)) + _butterfly(a) + + while len(b) < z: + b.append(modint(0)) + _butterfly(b) + + for i in range(z): + a[i] *= b[i] + _butterfly_inv(a) + a = a[:n + m - 1] + + iz = modint(z).inv() + for i in range(n + m - 1): + a[i] *= iz + + return a + + +def convolution(mod: int, a: typing.List[typing.Any], + b: typing.List[typing.Any]) -> typing.List[typing.Any]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + atcoder.modint.set_mod(mod) + + a2 = map(modint, a) + b2 = map(modint, b) + + return list(map(lambda c: c.val(), convolution_mod(a2, b2))) + + +def convolution_ll( + a: typing.List[int], b: typing.List[int]) -> typing.List[int]: + n = len(a) + m = len(b) + + if n == 0 or m == 0: + return [] + + mod1 = 754974721 # 2^24 + mod2 = 167772161 # 2^25 + mod3 = 469762049 # 2^26 + m2m3 = mod2 * mod3 + m1m3 = mod1 * mod3 + m1m2 = mod1 * mod2 + m1m2m3 = mod1 * mod2 * mod3 + + i1 = atcoder._math._inv_gcd(mod2 * mod3, mod1)[1] + i2 = atcoder._math._inv_gcd(mod1 * mod3, mod2)[1] + i3 = atcoder._math._inv_gcd(mod1 * mod2, mod3)[1] + + c1 = convolution(mod1, a, b) + c2 = convolution(mod2, a, b) + c3 = convolution(mod3, a, b) + + c = [0] * (n + m - 1) + for i in range(n + m - 1): + x = 0 + x += (c1[i] * i1) % mod1 * m2m3 + x += (c2[i] * i2) % mod2 * m1m3 + x += (c3[i] * i3) % mod3 * m1m2 + + ''' + B = 2^63, -B <= x, r(real value) < B + (x, x - M, x - 2M, or x - 3M) = r (mod 2B) + r = c1[i] (mod MOD1) + focus on MOD1 + r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B) + r = x, + x - M' + (0 or 2B), + x - 2M' + (0, 2B or 4B), + x - 3M' + (0, 2B, 4B or 6B) (without mod!) + (r - x) = 0, (0) + - M' + (0 or 2B), (1) + -2M' + (0 or 2B or 4B), (2) + -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1) + we checked that + ((1) mod MOD1) mod 5 = 2 + ((2) mod MOD1) mod 5 = 3 + ((3) mod MOD1) mod 5 = 4 + ''' + + diff = c1[i] - x % mod1 + if diff < 0: + diff += mod1 + offset = [0, 0, m1m2m3, 2 * m1m2m3, 3 * m1m2m3] + x -= offset[diff % 5] + c[i] = x + + return c diff --git a/atcoder/fenwicktree.py b/atcoder/fenwicktree.py new file mode 100644 index 0000000..a6e1549 --- /dev/null +++ b/atcoder/fenwicktree.py @@ -0,0 +1,30 @@ +import typing + + +class FenwickTree: + '''Reference: https://en.wikipedia.org/wiki/Fenwick_tree''' + + def __init__(self, n: int = 0) -> None: + self._n = 0 + self.data = [0] * n + + def add(self, p: int, x: typing.Any) -> None: + assert 0 <= p < self._n + + p += 1 + while p <= self._n: + self.data[p - 1] += x + p += p & -p + + def sum(self, left: int, right: int) -> typing.Any: + assert 0 <= left <= right <= self._n + + return sum(right) - sum(left) + + def _sum(self, r: int) -> typing.Any: + s = 0 + while r > 0: + s += self.data[r - 1] + r -= r & -r + + return s diff --git a/atcoder/math.py b/atcoder/math.py new file mode 100644 index 0000000..e61a1b3 --- /dev/null +++ b/atcoder/math.py @@ -0,0 +1,99 @@ +import typing + +import atcoder._math + + +def pow_mod(x: int, n: int, m: int) -> int: + assert 0 <= n and 1 <= m + + return pow(x, n, m) + + +def inv_mod(x: int, m: int) -> int: + assert 1 <= m + + z = atcoder._math._inv_gcd(x, m) + + assert z[0] == 1 + + return z[1] + + +def crt(r: typing.List[int], m: typing.List[int]) -> typing.Tuple[int, int]: + assert len(r) == len(m) + + n = len(r) + + # Contracts: 0 <= r0 < m0 + r0 = 0 + m0 = 1 + for i in range(n): + assert 1 <= m[i] + r1 = r[i] % m[i] + m1 = m[i] + if m0 < m1: + r0, r1 = r1, r0 + m0, m1 = m1, m0 + if m0 % m1 == 0: + if r0 % m1 != r1: + return (0, 0) + continue + + # assume: m0 > m1, lcm(m0, m1) >= 2 * max(m0, m1) + + ''' + (r0, m0), (r1, m1) -> (r2, m2 = lcm(m0, m1)); + r2 % m0 = r0 + r2 % m1 = r1 + -> (r0 + x*m0) % m1 = r1 + -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1) + -> x = (r1 - r0) / g * inv(u0) (mod u1) + ''' + + # im = inv(u0) (mod u1) (0 <= im < u1) + g, im = atcoder._math._inv_gcd(m0, m1) + + u1 = m1 // g + # |r1 - r0| < (m0 + m1) <= lcm(m0, m1) + if (r1 - r0) % g: + return (0, 0) + + # u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1) + x = (r1 - r0) // g % u1 * im % u1 + + ''' + |r0| + |m0 * x| + < m0 + m0 * (u1 - 1) + = m0 + m0 * m1 / g - m0 + = lcm(m0, m1) + ''' + + r0 += x * m0 + m0 *= u1 # -> lcm(m0, m1) + if r0 < 0: + r0 += m0 + + return (r0, m0) + + +def floor_sum(n: int, m: int, a: int, b: int) -> int: + ans = 0 + + if a >= m: + ans += (n - 1) * n * (a // m) // 2 + a %= m + + if b >= m: + ans += n * (b // m) + b %= m + + y_max = (a * n + b) // m + x_max = y_max * m - b + + if y_max == 0: + return ans + + ans += (n - (x_max + a - 1) // a) * y_max + ans += floor_sum(y_max, a, m, (a - x_max % a) % a) + + return ans diff --git a/atcoder/scc.py b/atcoder/scc.py new file mode 100644 index 0000000..9ff67ac --- /dev/null +++ b/atcoder/scc.py @@ -0,0 +1,17 @@ +import typing + +import atcoder._scc + + +class SCCGraph: + def __init__(self, n: int) -> None: + self._internal = atcoder._scc.SCCGraph(n) + + def add_edge(self, from_vertex: int, to_vertex: int) -> None: + n = self._internal.num_vertices() + assert 0 <= from_vertex < n + assert 0 <= to_vertex < n + self._internal.add_edge(from_vertex, to_vertex) + + def scc(self) -> typing.List[typing.List[int]]: + return self._internal.scc() diff --git a/atcoder/twosat.py b/atcoder/twosat.py new file mode 100644 index 0000000..b93a6b4 --- /dev/null +++ b/atcoder/twosat.py @@ -0,0 +1,35 @@ +import typing + +import atcoder._scc + + +class TwoSAT: + ''' + Reference: + B. Aspvall, M. Plass, and R. Tarjan, + A Linear-Time Algorithm for Testing the Truth of Certain Quantified Boolean + Formulas + ''' + + def __init__(self, n: int = 0) -> None: + self._n = n + self._answer = [False] * n + self._scc = atcoder._scc.scc_graph(2 * n) + + def add_clause(self, i: int, f: bool, j: int, g: bool) -> None: + assert 0 <= i < self._n + assert 0 <= j < self._n + + self._scc.add_edge(2 * i + (0 if f else 1), 2 * j + (1 if g else 0)) + self._scc.add_edge(2 * j + (0 if g else 1), 2 * i + (1 if f else 0)) + + def satisfiable(self) -> bool: + scc_id = self._scc.scc_ids()[1] + for i in range(self._n): + if scc_id[2 * i] == scc_id[2 * i + 1]: + return False + self._answer[i] = scc_id[2 * i] < scc_id[2 * i + 1] + return True + + def answer(self) -> typing.List[bool]: + return self._answer From 4c2252e65770ce15f993c83c5d1ac549d320394b Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Fri, 11 Sep 2020 08:51:37 +0900 Subject: [PATCH 2/8] private variable --- atcoder/dsu.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/atcoder/dsu.py b/atcoder/dsu.py index ad73361..f2693ac 100644 --- a/atcoder/dsu.py +++ b/atcoder/dsu.py @@ -11,7 +11,7 @@ class DSU: def __init__(self, n: int = 0): self._n = n - self.parent_or_size = [-1] * n + self._parent_or_size = [-1] * n def merge(self, a: int, b: int) -> int: assert 0 <= a < self._n @@ -23,11 +23,11 @@ def merge(self, a: int, b: int) -> int: if x == y: return x - if -self.parent_or_size[x] < -self.parent_or_size[y]: + if -self._parent_or_size[x] < -self._parent_or_size[y]: x, y = y, x - self.parent_or_size[x] += self.parent_or_size[y] - self.parent_or_size[y] = x + self._parent_or_size[x] += self._parent_or_size[y] + self._parent_or_size[y] = x return x @@ -40,16 +40,16 @@ def same(self, a: int, b: int) -> bool: def leader(self, a: int) -> int: assert 0 <= a < self._n - if self.parent_or_size[a] < 0: + if self._parent_or_size[a] < 0: return a - self.parent_or_size[a] = self.leader(self.parent_or_size[a]) - return self.parent_or_size[a] + self._parent_or_size[a] = self.leader(self._parent_or_size[a]) + return self._parent_or_size[a] def size(self, a: int) -> int: assert 0 <= a < self._n - return -self.parent_or_size[self.leader(a)] + return -self._parent_or_size[self.leader(a)] def groups(self) -> typing.List[typing.List[int]]: leader_buf = [self.leader(i) for i in range(self._n)] From f633f7ef1fddd295c9efaf377b4d41599ed00262 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 12 Sep 2020 11:03:14 +0900 Subject: [PATCH 3/8] modint --- atcoder/modint.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 atcoder/modint.py diff --git a/atcoder/modint.py b/atcoder/modint.py new file mode 100644 index 0000000..49e5657 --- /dev/null +++ b/atcoder/modint.py @@ -0,0 +1,119 @@ +from __future__ import annotations +import copy +import typing + +import atcoder._math + + +class ModContext: + context = [] + + def __init__(self, mod: int) -> None: + assert 1 <= mod + + self.mod = mod + + def __enter__(self) -> None: + self.context.append(atcoder._math.barrett(self.mod)) + + def __exit__(self, exc_type: typing.Any, exc_value: typing.Any, + traceback: typing.Any) -> None: + self.context.pop() + + @classmethod + def get_bt(cls) -> int: + return cls.context[-1] + + +class Modint: + def __init__(self, v: int = 0) -> None: + self._bt = ModContext.get_bt() + if v == 0: + self._v = 0 + else: + self._v = v % self.mod() + + def mod(self) -> int: + return self._bt.umod() + + def val(self) -> int: + return self._v + + def __iadd__(self, rhs: Modint) -> Modint: + self._v += rhs._v + if self._v >= self.mod(): + self._v -= self.mod() + return self + + def __isub__(self, rhs: Modint) -> Modint: + self._v -= rhs._v + if self._v < 0: + self._v += self.mod() + return self + + def __imul__(self, rhs: Modint) -> Modint: + self._v = self._bt.mul(self._v, rhs._v) + return self + + def __ifloordiv__(self, rhs: Modint) -> Modint: + self *= rhs.inv() + return self + + def __pos__(self) -> Modint: + return self + + def __neg__(self) -> Modint: + return Modint() - self + + def __pow__(self, n: int) -> Modint: + assert 0 <= n + + x = self + r = 1 + + while n: + if n & 1: + r *= x + x = x * x + n >>= 1 + + return r + + def inv(self) -> Modint: + eg = atcoder._math._inv_gcd(self._v, self.mod()) + + assert eg[0] == 1 + + return eg[1] + + def __add__(self, rhs: Modint) -> Modint: + result = copy.deepcopy(self) + result += rhs + return result + + def __sub__(self, rhs: Modint) -> Modint: + result = copy.deepcopy(self) + result -= rhs + return result + + def __mul__(self, rhs: Modint) -> Modint: + result = copy.deepcopy(self) + result *= rhs + return result + + def __floordiv__(self, rhs: Modint) -> Modint: + result = copy.deepcopy(self) + result //= rhs + return result + + def __eq__(self, rhs: Modint) -> bool: + return self._v == rhs._v + + def __ne__(self, rhs: Modint) -> bool: + return self._v != rhs._v + + +def raw(v: int) -> Modint: + x = Modint() + x._v = v + return x From 3044e7af2efadaa60fb37ff83b77ab70b04368a4 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 12 Sep 2020 14:02:51 +0900 Subject: [PATCH 4/8] string --- atcoder/string.py | 250 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 atcoder/string.py diff --git a/atcoder/string.py b/atcoder/string.py new file mode 100644 index 0000000..07b74f5 --- /dev/null +++ b/atcoder/string.py @@ -0,0 +1,250 @@ +import copy +import functools +import typing + + +def _sa_naive(s: typing.List[int]) -> typing.List[int]: + sa = list(range(len(s))) + return sorted(sa, key=lambda i: s[i:]) + + +def _sa_doubling(s: typing.List[int]) -> typing.List[int]: + n = len(s) + sa = range(n) + rnk = copy.deepcopy(s) + tmp = [0] * n + k = 1 + while k < n: + def cmp(x: int, y: int) -> bool: + if rnk[x] != rnk[y]: + return rnk[x] < rnk[y] + rx = rnk[x + k] if x + k < n else -1 + ry = rnk[y + k] if y + k < n else -1 + return rx < ry + sa.sort(key=functools.cmp_to_key(cmp)) + tmp[sa[0]] = 0 + for i in range(1, n): + tmp[sa[i]] = tmp[sa[i - 1]] + (1 if cmp(sa[i - 1], sa[i]) else 0) + tmp, rnk = rnk, tmp + k *= 2 + return sa + + +def _sa_is(s: typing.List[int], upper: int) -> typing.List[int]: + ''' + SA-IS, linear-time suffix array construction + Reference: + G. Nong, S. Zhang, and W. H. Chan, + Two Efficient Algorithms for Linear Time Suffix Array Construction + ''' + + threshold_naive = 10 + threshold_doubling = 40 + + n = len(s) + + if n == 0: + return [] + if n == 1: + return [0] + if n == 2: + if s[0] < s[1]: + return [0, 1] + else: + return [1, 0] + + if n < threshold_naive: + return _sa_naive(s) + if n < threshold_doubling: + return _sa_doubling(s) + + sa = [0] * n + ls = [False] * n + for i in range(n - 2, -1, -1): + if s[i] == s[i + 1]: + ls[i] = ls[i + 1] + else: + ls[i] = s[i] < s[i + 1] + + sum_l = [0] * (upper + 1) + sum_s = [0] * (upper + 1) + for i in range(n): + if not ls[i]: + sum_s[s[i]] += 1 + else: + sum_l[s[i] + 1] += 1 + for i in range(upper + 1): + sum_s[i] += sum_l[i] + if i < upper: + sum_l[i + 1] += sum_s[i] + + def induce(lms: typing.List[int]) -> None: + nonlocal sa + sa = [-1] * n + + buf = copy.deepcopy(sum_s) + for d in lms: + if d == n: + continue + sa[buf[s[d]]] = d + buf[s[d]] += 1 + + buf = copy.deepcopy(sum_l) + sa[buf[s[n - 1]]] = n - 1 + buf[s[n - 1]] += 1 + for i in range(n): + v = sa[i] + if v >= 1 and not ls[v - 1]: + sa[buf[s[v - 1]]] = v - 1 + buf[s[v - 1]] += 1 + + buf = copy.deepcopy(sum_l) + for i in range(n - 1, -1, -1): + v = sa[i] + if v >= 1 and ls[v - 1]: + buf[s[v - 1] + 1] -= 1 + sa[buf[s[v - 1] + 1]] = v - 1 + + lms_map = [-1] * (n + 1) + m = 0 + for i in range(1, n): + if not ls[i - 1] and ls[i]: + lms_map[i] = m + m += 1 + lms = [] + for i in range(1, n): + if not ls[i - 1] and ls[i]: + lms.append(i) + + induce(lms) + + if m: + sorted_lms = [] + for v in sa: + if lms_map[v] != -1: + sorted_lms.append(v) + rec_s = [0] * m + rec_upper = 0 + rec_s[lms_map[sorted_lms[0]]] = 0 + for i in range(1, m): + left = sorted_lms[i - 1] + right = sorted_lms[i] + if lms_map[left] + 1 < m: + end_l = lms[lms_map[left] + 1] + else: + end_l = n + if lms_map[right] + 1 < m: + end_r = lms[lms_map[right] + 1] + else: + end_r = n + + same = True + if end_l - left != end_r - right: + same = False + else: + while left < end_l: + if s[left] != s[right]: + break + left += 1 + right += 1 + if left == n or s[left] != s[right]: + same = False + + if not same: + rec_upper += 1 + rec_s[lms_map[sorted_lms[i]]] = rec_upper + + rec_sa = _sa_is(rec_s, rec_upper) + + for i in range(m): + sorted_lms[i] = lms[rec_sa[i]] + induce(sorted_lms) + + return sa + + +def suffix_array(s: typing.Union[str, typing.List[int]], + upper: typing.Optional[int] = None) -> typing.List[int]: + if isinstance(s, str): + return _sa_is([ord(c) for c in s], 255) + elif upper is None: + n = len(s) + idx = list(range(n)) + idx.sort(key=functools.cmp_to_key(lambda l, r: s[l] < s[r])) + s2 = [0] * n + now = 0 + for i in range(n): + if i and s[idx[i - 1]] != s[idx[i]]: + now += 1 + s2[idx[i]] = now + return _sa_is(s2, now) + else: + assert 0 <= upper + for d in s: + assert 0 <= d <= upper + + return _sa_is(s, upper) + + +def lcp_array(s: typing.Union[str, typing.List[int]], + sa: typing.List[int]) -> typing.List[int]: + ''' + Reference: + T. Kasai, G. Lee, H. Arimura, S. Arikawa, and K. Park, + Linear-Time Longest-Common-Prefix Computation in Suffix Arrays and Its + Applications + ''' + + if isinstance(s, str): + return lcp_array([ord(c) for c in s], sa) + + n = len(s) + assert n >= 1 + + rnk = [0] * n + for i in range(n): + rnk[sa[i]] = i + + lcp = [0] * (n - 1) + h = 0 + for i in range(n): + if h > 0: + h -= 1 + if rnk[i] == 0: + continue + j = sa[rnk[i] - 1] + while j + h < n and i + h < n: + if s[j + h] != s[i + h]: + break + h += 1 + lcp[rnk[i] - 1] = h + + return lcp + + +def z_algorithm(s: typing.Union[str, typing.List[int]]) -> typing.List[int]: + ''' + Reference: + D. Gusfield, + Algorithms on Strings, Trees, and Sequences: Computer Science and + Computational Biology + ''' + + if isinstance(s, str): + return z_algorithm([ord(c) for c in s]) + + n = len(s) + if n == 0: + return [] + + z = [0] * n + j = 0 + for i in range(1, n): + z[i] = 0 if j + z[j] <= i else min(j + z[j] - i, z[i - j]) + while i + z[i] < n and s[z[i]] == s[i + z[i]]: + z[i] += 1 + if j + z[j] < i + z[i]: + j = i + z[0] = n + + return z From 0a041677b21162bd1b639f9ebad712fdd4a242ac Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 12 Sep 2020 14:39:22 +0900 Subject: [PATCH 5/8] segtree --- atcoder/segtree.py | 118 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 atcoder/segtree.py diff --git a/atcoder/segtree.py b/atcoder/segtree.py new file mode 100644 index 0000000..a045f37 --- /dev/null +++ b/atcoder/segtree.py @@ -0,0 +1,118 @@ +import typing + +import atcoder._bit + + +class Segtree: + def __init__(self, + op: typing.Callable[[typing.Any, typing.Any], typing.Any], + e: typing.Any, + v: typing.Union[int, typing.List[typing.Any]]) -> None: + self._op = op + self._e = e + + if isinstance(v, int): + v = [e] * v + + self._n = len(v) + self._log = atcoder._bit._ceil_pow2(self._n) + self._size = 1 << self._log + self._d = [e for _ in range(2 * self._size)] + + for i in range(self._n): + self._d[self._size + i] = v[i] + for i in range(self._size - 1, 0, -1): + self._update(i) + + def set(self, p: int, x: typing.Any) -> None: + assert 0 <= p < self._n + + p += self._size + self._d[p] = x + for i in range(1, self._log + 1): + self._update(p >> i) + + def get(self, p: int) -> typing.Any: + assert 0 <= p < self._n + + return self._d[p + self._size] + + def prod(self, left: int, right: int) -> typing.Any: + assert 0 <= left <= right <= self._n + sml = self._e + smr = self._e + left += self._size + right += self._size + + while left < right: + if left & 1: + sml = self._op(sml, self._d[left]) + left += 1 + if right & 1: + right -= 1 + smr = self._op(self._d[right], smr) + left >>= 1 + right >>= 1 + + return self._op(sml, smr) + + def all_prod(self) -> typing.Any: + return self._d[1] + + def max_right(self, left: int, + f: typing.Callable[[typing.Any], bool]) -> int: + assert 0 <= left <= self._n + assert f(self._e) + + if left == self._n: + return self._n + + left += self._size + sm = self._e + + first = True + while first or (left & -left) != left: + first = False + while left % 2 == 0: + left >>= 1 + if not f(self._op(sm, self._d[left])): + while left < self._size: + left = 2 * left + if f(self._op(sm, self._d[left])): + sm = self._op(sm, self._d[left]) + left += 1 + return left - self._size + sm = self._op(sm, self._d[left]) + left += 1 + + return self._n + + def min_left(self, right: int, + f: typing.Callable[[typing.Any], bool]) -> int: + assert 0 <= right <= self._n + assert f(self._e) + + if right == 0: + return 0 + + right += self._size + sm = self._e + + first = True + while first or (right & -right) != right: + right -= 1 + while right > 1 and right % 2: + right >>= 1 + if not f(self._op(self._d[right], sm)): + while right < self._size: + right = 2 * right + 1 + if f(self._op(self._d[right], sm)): + sm = self._op(self._d[right], sm) + right -= 1 + return right + 1 - self._size + sm = self._op(self._d[right], sm) + + return 0 + + def _update(self, k: int) -> None: + self._d[k] = self._op(self._d[2 * k], self._d[2 * k + 1]) From a4376a67b9fe22195fb5753bc34ffdd04282d51d Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 12 Sep 2020 15:00:04 +0900 Subject: [PATCH 6/8] update --- atcoder/string.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atcoder/string.py b/atcoder/string.py index 07b74f5..93567d7 100644 --- a/atcoder/string.py +++ b/atcoder/string.py @@ -196,7 +196,7 @@ def lcp_array(s: typing.Union[str, typing.List[int]], ''' if isinstance(s, str): - return lcp_array([ord(c) for c in s], sa) + s = [ord(c) for c in s] n = len(s) assert n >= 1 @@ -231,7 +231,7 @@ def z_algorithm(s: typing.Union[str, typing.List[int]]) -> typing.List[int]: ''' if isinstance(s, str): - return z_algorithm([ord(c) for c in s]) + s = [ord(c) for c in s] n = len(s) if n == 0: From 0ca2cbc6f9a07bdcf66234071e77b96338dd3b9a Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Sat, 12 Sep 2020 15:08:51 +0900 Subject: [PATCH 7/8] update --- atcoder/segtree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atcoder/segtree.py b/atcoder/segtree.py index a045f37..1443f90 100644 --- a/atcoder/segtree.py +++ b/atcoder/segtree.py @@ -17,7 +17,7 @@ def __init__(self, self._n = len(v) self._log = atcoder._bit._ceil_pow2(self._n) self._size = 1 << self._log - self._d = [e for _ in range(2 * self._size)] + self._d = [e] * range(2 * self._size) for i in range(self._n): self._d[self._size + i] = v[i] From 9e000eb0b8f493183faad04eb53853410b38494d Mon Sep 17 00:00:00 2001 From: Kato Hiroki Date: Sat, 12 Sep 2020 19:44:12 +0900 Subject: [PATCH 8/8] Add pytest and tests for dsu (#5) --- requirements.txt | 8 +++++ tests/__init__.py | 0 tests/conftest.py | 8 +++++ tests/test_dsu.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+) create mode 100644 requirements.txt create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_dsu.py diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7d3a2bc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +autopep8 +flake8 +numpy == 1.18.2 +pep8-naming +pyflakes +pytest +pytest-watch +scipy == 1.4.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8395053 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import os +import sys + +# Add path. +# See: +# https://www.magata.net/memo/index.php?pytest%C6%FE%CC%E7#y2046859 +path = os.path.dirname(os.path.abspath(__file__)) + '/../atcoder/' +sys.path.append(os.path.abspath(path)) diff --git a/tests/test_dsu.py b/tests/test_dsu.py new file mode 100644 index 0000000..68cc029 --- /dev/null +++ b/tests/test_dsu.py @@ -0,0 +1,79 @@ +import pytest + +from atcoder.dsu import DSU + + +@pytest.fixture +def dsu(): + return DSU(5) + + +class TestDsu(object): + + def test_merge(self, dsu): + ''' + dsu.merge(vertex a, vertex b) is expected to be in the same group. + + GIVEN an initialized dsu object + WHEN vertex 0 and 1 are merged + THEN vertex 0 and 1 are the same group. + ''' + + is_same = dsu.same(0, 1) + assert is_same is False + + dsu.merge(0, 1) + is_same = dsu.same(0, 1) + assert is_same is True + + def test_size(self, dsu): + ''' + dsu.size(vertex a) is expected to get size of vertex a. + + GIVEN an initialized dsu object + WHEN vertex 0, 1 and 2 are merged + THEN size of vertex 0 is 3. + ''' + + dsu.merge(0, 1) + dsu.merge(0, 2) + assert dsu.size(0) == 3 + + is_same = dsu.same(0, 3) + assert is_same is False + + is_same = dsu.same(0, 4) + assert is_same is False + + def test_leader(self, dsu): + ''' + dsu.leader(vertex a) is expected to return the representative of the + connected component that contains the vertex a. + + GIVEN an initialized dsu object + WHEN vertex 0, 1 and 2 are merged + THEN vertex 1 and 2 belong to vertex 0. + ''' + + dsu.merge(0, 1) + dsu.merge(0, 2) + + assert dsu.leader(1) == 0 + assert dsu.leader(2) == 0 + assert dsu.leader(3) != 0 + assert dsu.leader(4) != 0 + + def test_groups(self, dsu): + ''' + dsu.groups() is expected to return the list of the graph that divided + into connected components. + + GIVEN an initialized dsu object + WHEN vertex 0, 1 and 2 are merged + THEN returns [[0, 1, 2], [3], [4]] + ''' + + dsu.merge(0, 1) + dsu.merge(0, 2) + + assert dsu.groups() == [[0, 1, 2], [3], [4]]