diff --git a/atcoder/dsu.py b/atcoder/dsu.py index 46998f8..d6cbcc9 100644 --- a/atcoder/dsu.py +++ b/atcoder/dsu.py @@ -3,13 +3,14 @@ class DSU: ''' - Implement (union by size) + (path compression) + Implement (union by size) + (path halving) + Reference: Zvi Galil and Giuseppe F. Italiano, Data structures and algorithms for disjoint set union problems ''' - def __init__(self, n: int = 0): + def __init__(self, n: int = 0) -> None: self._n = n self.parent_or_size = [-1] * n @@ -40,11 +41,17 @@ def same(self, a: int, b: int) -> bool: def leader(self, a: int) -> int: assert 0 <= a < self._n - if self.parent_or_size[a] < 0: - return a - - self.parent_or_size[a] = self.leader(self.parent_or_size[a]) - return self.parent_or_size[a] + parent = self.parent_or_size[a] + while parent >= 0: + if self.parent_or_size[parent] < 0: + return parent + self.parent_or_size[a], a, parent = ( + self.parent_or_size[parent], + self.parent_or_size[parent], + self.parent_or_size[self.parent_or_size[parent]] + ) + + return a def size(self, a: int) -> int: assert 0 <= a < self._n diff --git a/setup.cfg b/setup.cfg index fb00f9a..68bba96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,7 +12,9 @@ lint = flake8 pep8-naming mypy -test = pytest +test = + pytest + pytest-benchmark docs = sphinx sphinx_rtd_theme diff --git a/tests/test_dsu.py b/tests/test_dsu.py index 77ec59f..ba83c3a 100644 --- a/tests/test_dsu.py +++ b/tests/test_dsu.py @@ -1,22 +1,19 @@ from itertools import combinations import pytest +import random from typing import List, Tuple +from pytest_benchmark.fixture import BenchmarkFixture + from atcoder.dsu import DSU class TestDsu: - def _dsu(self, n: int = 5) -> DSU: - return DSU(n) - - def _get_all_pairs(self, n: int = 5) -> List[Tuple[int, ...]]: - return list(combinations(range(n), 2)) - def test_initial_status(self) -> None: - dsu = self._dsu() + dsu = DSU(5) - for i, j in self._get_all_pairs(): + for i, j in combinations(range(5), 2): assert not dsu.same(i, j) for index in range(5): @@ -26,7 +23,7 @@ def test_initial_status(self) -> None: assert dsu.groups() == [[0], [1], [2], [3], [4]] def test_merge(self) -> None: - dsu = self._dsu() + dsu = DSU(5) assert not dsu.same(0, 1) @@ -41,7 +38,7 @@ def test_merge(self) -> None: dsu.merge(i, j) def test_merge_elements_of_same_group(self) -> None: - dsu = self._dsu() + dsu = DSU(5) assert not dsu.same(0, 1) @@ -57,7 +54,7 @@ def test_merge_elements_of_same_group(self) -> None: dsu.same(i, j) def test_size(self) -> None: - dsu = self._dsu() + dsu = DSU(5) dsu.merge(0, 1) assert dsu.size(0) == 2 @@ -70,7 +67,7 @@ def test_size(self) -> None: dsu.size(i) def test_leader(self) -> None: - dsu = self._dsu() + dsu = DSU(5) dsu.merge(0, 1) dsu.merge(0, 2) @@ -87,9 +84,24 @@ def test_leader(self) -> None: dsu.leader(i) def test_groups(self) -> None: - dsu = self._dsu() + dsu = DSU(5) dsu.merge(0, 1) dsu.merge(0, 2) assert dsu.groups() == [[0, 1, 2], [3], [4]] + + def _merge_benchmark(self, dsu: DSU, pairs: List[Tuple[int, int]]) -> None: + for i, j in pairs: + dsu.merge(i, j) + + def test_benchmark(self, benchmark: BenchmarkFixture) -> None: + random.seed(0) + n = 100000 + + dsu = DSU(n) + pairs = [] + for _ in range(1000000): + pairs.append((random.randrange(0, n), random.randrange(0, n))) + + benchmark(self._merge_benchmark, dsu, pairs)