diff --git a/README.md b/README.md index c27da45..d01024e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ ac-library-python is a Python port of [AtCoder Library (ACL)](https://atcoder.jp + [Fenwick Tree](https://github.com/atcoder/ac-library/blob/master/document_en/fenwicktree.md) + segtree ++ lazysegtree #### Math @@ -34,7 +35,6 @@ ac-library-python is a Python port of [AtCoder Library (ACL)](https://atcoder.jp #### Data structure -+ lazysegtree + string #### Math diff --git a/README_ja.md b/README_ja.md index 8f7c3ff..8187e20 100644 --- a/README_ja.md +++ b/README_ja.md @@ -14,6 +14,7 @@ ac-library-pythonは、[AtCoder Library (ACL)](https://atcoder.jp/posts/517)のP + [Fenwick Tree](https://github.com/atcoder/ac-library/blob/master/document_ja/fenwicktree.md) + segtree ++ lazysegtree #### 数学 @@ -32,7 +33,6 @@ ac-library-pythonは、[AtCoder Library (ACL)](https://atcoder.jp/posts/517)のP #### データ構造 -+ lazysegtree + string #### 数学 diff --git a/atcoder/lazysegtree.py b/atcoder/lazysegtree.py new file mode 100644 index 0000000..2427131 --- /dev/null +++ b/atcoder/lazysegtree.py @@ -0,0 +1,204 @@ +import typing + +import atcoder._bit + + +class LazySegTree: + def __init__( + self, + op: typing.Callable[[typing.Any, typing.Any], typing.Any], + e: typing.Any, + mapping: typing.Callable[[typing.Any, typing.Any], typing.Any], + composition: typing.Callable[[typing.Any, typing.Any], typing.Any], + id_: typing.Any, + v: typing.Union[int, typing.List[typing.Any]]) -> None: + self._op = op + self._e = e + self._mapping = mapping + self._composition = composition + self._id = id_ + + if isinstance(v, int): + v = [e] * v + + self._n = len(v) + self._log = atcoder._bit._ceil_pow2(self._n) + self._size = 1 << self._log + self._d = [e] * (2 * self._size) + self._lz = [self._id] * self._size + for i in range(self._n): + self._d[self._size + i] = v[i] + for i in range(self._size - 1, 0, -1): + self._update(i) + + def set(self, p: int, x: typing.Any) -> None: + assert 0 <= p < self._n + + p += self._size + for i in range(self._log, 0, -1): + self._push(p >> i) + self._d[p] = x + for i in range(1, self._log + 1): + self._update(p >> i) + + def get(self, p: int) -> typing.Any: + assert 0 <= p < self._n + + p += self._size + for i in range(self._log, 0, -1): + self._push(p >> i) + return self._d[p] + + def prod(self, left: int, right: int) -> typing.Any: + assert 0 <= left <= right <= self._n + + if left == right: + return self._e + + left += self._size + right += self._size + + for i in range(self._log, 0, -1): + if ((left >> i) << i) != left: + self._push(left >> i) + if ((right >> i) << i) != right: + self._push(right >> i) + + sml = self._e + smr = self._e + while left < right: + if left & 1: + sml = self._op(sml, self._d[left]) + left += 1 + if right & 1: + right -= 1 + smr = self._op(self._d[right], smr) + left >>= 1 + right >>= 1 + + return self._op(sml, smr) + + def all_prod(self) -> typing.Any: + return self._d[1] + + def apply(self, left: int, right: typing.Optional[int] = None, + f: typing.Optional[typing.Any] = None): + assert f is not None + + if right is None: + p = left + assert 0 <= left < self._n + + p += self._size + for i in range(self._log, 0, -1): + self._push(p >> i) + self._d[p] = self._mapping(f, self._d[p]) + for i in range(1, self._log + 1): + self._update(p >> i) + else: + assert 0 <= left <= right <= self._n + if left == right: + return + + left += self._size + right += self._size + + for i in range(self._log, 0, -1): + if ((left >> i) << i) != left: + self._push(left >> i) + if ((right >> i) << i) != right: + self._push((right - 1) >> i) + + l2 = left + r2 = right + while left < right: + if left & 1: + self._all_apply(left, f) + left += 1 + if right & 1: + right -= 1 + self._all_apply(right, f) + left >>= 1 + right >>= 1 + left = l2 + right = r2 + + for i in range(1, self._log + 1): + if ((left >> i) << i) != left: + self._update(left >> i) + if ((right >> i) << i) != right: + self._update((right - 1) >> i) + + def max_right( + self, left: int, g: typing.Callable[[typing.Any], bool]) -> int: + assert 0 <= left <= self._n + assert g(self._e) + + if left == self._n: + return self._n + + left += self._size + for i in range(self._log, 0, -1): + self._push(left >> i) + + sm = self._e + first = True + while first or (left & -left) != left: + first = False + while left % 2 == 0: + left >>= 1 + if not g(self._op(sm, self._d[left])): + while left < self._size: + self._push(left) + left *= 2 + if g(self._op(sm, self._d[left])): + sm = self._op(sm, self._d[left]) + left += 1 + return left - self._size + sm = self._op(sm, self._d[left]) + left += 1 + + return self._n + + def min_left(self, right: int, g: typing.Any) -> int: + assert 0 <= right <= self._n + assert g(self._e) + + if right == 0: + return 0 + + right += self._size + for i in range(self._log, 0, -1): + self._push((right - 1) >> i) + + sm = self._e + first = True + while first or (right & -right) != right: + first = False + right -= 1 + while right > 1 and right % 2: + right >>= 1 + if not g(self._op(self._d[right], sm)): + while right < self._size: + self._push(right) + right = 2 * right + 1 + if g(self._op(self._d[right], sm)): + sm = self._op(self._d[right], sm) + right -= 1 + return right + 1 - self._size + sm = self._op(self._d[right], sm) + + return 0 + + def _update(self, k: int) -> None: + self._d[k] = self._op(self._d[2 * k], self._d[2 * k + 1]) + + def _all_apply(self, k: int, f: typing.Any) -> None: + self._d[k] = self._mapping(f, self._d[k]) + if k < self._size: + self._lz[k] = self._composition(f, self._lz[k]) + + def _push(self, k: int) -> None: + self._all_apply(2 * k, self._lz[k]) + self._all_apply(2 * k + 1, self._lz[k]) + self._lz[k] = self._id diff --git a/atcoder/modint.py b/atcoder/modint.py index fd039d4..f91e746 100644 --- a/atcoder/modint.py +++ b/atcoder/modint.py @@ -1,5 +1,4 @@ from __future__ import annotations -import copy import typing import atcoder._math @@ -37,27 +36,36 @@ def val(self) -> int: return self._v def __iadd__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - self._v += rhs._v + if isinstance(rhs, Modint): + self._v += rhs._v + else: + self._v += rhs if self._v >= self._mod: self._v -= self._mod return self def __isub__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - self._v -= rhs._v + if isinstance(rhs, Modint): + self._v -= rhs._v + else: + self._v -= rhs if self._v < 0: self._v += self._mod return self def __imul__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - self._v = self._v * rhs._v % self._mod + if isinstance(rhs, Modint): + self._v = self._v * rhs._v % self._mod + else: + self._v = self._v * rhs % self._mod return self def __ifloordiv__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - self *= rhs.inv() + if isinstance(rhs, Modint): + inv = rhs.inv()._v + else: + inv = atcoder._math._inv_gcd(rhs, self._mod)[1] + self._v = self._v * inv % self._mod return self def __pos__(self) -> Modint: @@ -79,42 +87,47 @@ def inv(self) -> Modint: return Modint(eg[1]) def __add__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - result = copy.deepcopy(self) - result += rhs - return result + if isinstance(rhs, Modint): + result = self._v + rhs._v + if result >= self._mod: + result -= self._mod + return raw(result) + else: + return Modint(self._v + rhs) def __sub__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - result = copy.deepcopy(self) - result -= rhs - return result + if isinstance(rhs, Modint): + result = self._v - rhs._v + if result < 0: + result += self._mod + return raw(result) + else: + return Modint(self._v - rhs) def __mul__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - result = copy.deepcopy(self) - result *= rhs - return result + if isinstance(rhs, Modint): + return Modint(self._v * rhs._v) + else: + return Modint(self._v * rhs) def __floordiv__(self, rhs: typing.Union[Modint, int]) -> Modint: - rhs = self._asmodint(rhs) - result = copy.deepcopy(self) - result //= rhs - return result + if isinstance(rhs, Modint): + inv = rhs.inv()._v + else: + inv = atcoder._math._inv_gcd(rhs, self._mod)[1] + return Modint(self._v * inv) def __eq__(self, rhs: typing.Union[Modint, int]) -> bool: - rhs = self._asmodint(rhs) - return self._v == rhs._v + if isinstance(rhs, Modint): + return self._v == rhs._v + else: + return self._v == rhs def __ne__(self, rhs: typing.Union[Modint, int]) -> bool: - rhs = self._asmodint(rhs) - return self._v != rhs._v - - def _asmodint(self, rhs: typing.Union[Modint, int]) -> Modint: if isinstance(rhs, Modint): - return rhs + return self._v != rhs._v else: - return Modint(rhs) + return self._v != rhs def raw(v: int) -> Modint: diff --git a/example/lazysegtree_practice_k.py b/example/lazysegtree_practice_k.py new file mode 100644 index 0000000..34fc55c --- /dev/null +++ b/example/lazysegtree_practice_k.py @@ -0,0 +1,41 @@ +# https://atcoder.jp/contests/practice2/tasks/practice2_k + +import sys + +from atcoder.lazysegtree import LazySegTree +from atcoder.modint import ModContext, Modint + + +def main() -> None: + with ModContext(998244353): + n, q = map(int, sys.stdin.readline().split()) + a = [(Modint(ai), 1) for ai in map(int, sys.stdin.readline().split())] + + def op(x: (Modint, int), y: (Modint, int)) -> (Modint, int): + return x[0] + y[0], x[1] + y[1] + + e = Modint(0), 0 + + def mapping(x: (Modint, Modint), y: (Modint, int)) -> (Modint, int): + return x[0] * y[0] + x[1] * y[1], y[1] + + def composition(x: (Modint, Modint), + y: (Modint, Modint)) -> (Modint, Modint): + return x[0] * y[0], x[0] * y[1] + x[1] + + id_ = Modint(1), Modint(0) + + lazy_segtree = LazySegTree(op, e, mapping, composition, id_, a) + + for _ in range(q): + t, *inputs = map(int, sys.stdin.readline().split()) + if t == 0: + l, r, b, c = inputs + lazy_segtree.apply(l, r, (Modint(b), Modint(c))) + else: + l, r = inputs + print(lazy_segtree.prod(l, r)[0].val()) + + +if __name__ == '__main__': + main() diff --git a/example/lazysegtree_practice_k_wo_modint.py b/example/lazysegtree_practice_k_wo_modint.py new file mode 100644 index 0000000..bc75e8c --- /dev/null +++ b/example/lazysegtree_practice_k_wo_modint.py @@ -0,0 +1,40 @@ +# https://atcoder.jp/contests/practice2/tasks/practice2_k + +import sys + +from atcoder.lazysegtree import LazySegTree + + +def main() -> None: + mod = 998244353 + + n, q = map(int, sys.stdin.readline().split()) + a = [(ai, 1) for ai in map(int, sys.stdin.readline().split())] + + def op(x: (int, int), y: (int, int)) -> (int, int): + return (x[0] + y[0]) % mod, x[1] + y[1] + + e = 0, 0 + + def mapping(x: (int, int), y: (int, int)) -> (int, int): + return (x[0] * y[0] + x[1] * y[1]) % mod, y[1] + + def composition(x: (int, int), y: (int, int)) -> (int, int): + return (x[0] * y[0]) % mod, (x[0] * y[1] + x[1]) % mod + + id_ = 1, 0 + + lazy_segtree = LazySegTree(op, e, mapping, composition, id_, a) + + for _ in range(q): + t, *inputs = map(int, sys.stdin.readline().split()) + if t == 0: + l, r, b, c = inputs + lazy_segtree.apply(l, r, (b, c)) + else: + l, r = inputs + print(lazy_segtree.prod(l, r)[0]) + + +if __name__ == '__main__': + main() diff --git a/example/lazysegtree_practice_l.py b/example/lazysegtree_practice_l.py new file mode 100644 index 0000000..6c17418 --- /dev/null +++ b/example/lazysegtree_practice_l.py @@ -0,0 +1,43 @@ +# https://atcoder.jp/contests/practice2/tasks/practice2_l + +import sys + +from atcoder.lazysegtree import LazySegTree + + +def main() -> None: + n, q = map(int, sys.stdin.readline().split()) + a = [] + for x in map(int, sys.stdin.readline().split()): + if x == 0: + a.append((1, 0, 0)) + else: + a.append((0, 1, 0)) + + def op(x: (int, int, int), y: (int, int, int)) -> (int, int, int): + return x[0] + y[0], x[1] + y[1], x[2] + y[2] + x[1] * y[0] + + e = (0, 0, 0) + + def mapping(x: bool, y: (int, int, int)) -> (int, int, int): + if not x: + return y + return y[1], y[0], y[0] * y[1] - y[2] + + def composition(x: bool, y: bool) -> bool: + return (x and not y) or (not x and y) + + id_ = False + + lazy_segtree = LazySegTree(op, e, mapping, composition, id_, a) + for _ in range(q): + t, left, right = map(int, sys.stdin.readline().split()) + left -= 1 + if t == 1: + lazy_segtree.apply(left, right, True) + else: + print(lazy_segtree.prod(left, right)[2]) + + +if __name__ == '__main__': + main()