Skip to content

Commit e4dc4c5

Browse files
authored
refactor: improve hand divider performance a bit (#185)
* refactor: improve hand divider performance a bit * fix linter * address pr comments * address pr comments
1 parent a10fc80 commit e4dc4c5

File tree

3 files changed

+46
-39
lines changed

3 files changed

+46
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ The following methods are now available as static methods:
9393
- Yakuhai detection (hatsu, haku, chun, winds) now uses `has_pon_or_kan_of()` instead of counting triplets. Behavior changes for invalid hands with two or more identical triplets of the same tile.
9494
- Fixed an issue where `KokushiMusou.is_condition_met()` would return `None` if the condition was not met. It now consistently returns a `bool` value. Remove any `None` checks in the code that relied on the previous behavior.
9595
- `Shanten.calculate_shanten()` and `Shanten.calculate_shanten_for_regular_hand()` now raises `ValueError` instead of `assert` when the number of tiles is 15 or more.
96+
- `HandDivider.divide_hand()` now determines block type from `Meld.type` instead of inferring it from `Meld.tiles`. Behavior may differ for invalid `Meld.tiles` or inconsistent `Meld.type` and `Meld.tiles` combinations.
9697

9798
## What's Changed
9899
- Placeholder. It would be filled on release automatically

mahjong/hand_calculating/divider.py

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,37 @@
11
from collections.abc import Collection, Sequence
22
from dataclasses import dataclass
3-
from enum import Enum
4-
from functools import lru_cache, total_ordering
3+
from enum import IntEnum
4+
from functools import lru_cache
55
from typing import Literal
66

77
from mahjong.meld import Meld
8-
from mahjong.utils import is_chi, is_kan, is_pon
98

109

11-
class _BlockType(Enum):
10+
class _BlockType(IntEnum):
1211
QUAD = 0
1312
TRIPLET = 1
1413
PAIR = 2
1514
SEQUENCE = 3
1615

1716

18-
@total_ordering
19-
@dataclass(frozen=True)
20-
class _Block: # noqa: PLW1641 __hash__ is automatically implemented
21-
ty: _BlockType
17+
@dataclass(frozen=True, order=True)
18+
class _Block:
2219
tile_34: int
23-
24-
def __eq__(self, other: object) -> bool:
25-
if not isinstance(other, _Block):
26-
return NotImplemented
27-
return (self.tile_34, self.ty.value) == (other.tile_34, other.ty.value)
28-
29-
def __lt__(self, other: object) -> bool:
30-
if not isinstance(other, _Block):
31-
return NotImplemented
32-
return (self.tile_34, self.ty.value) < (other.tile_34, other.ty.value)
20+
ty: _BlockType
3321

3422
@classmethod
3523
def from_meld(cls, meld: Meld) -> "_Block":
36-
tiles_34 = meld.tiles_34
37-
if is_chi(tiles_34):
38-
return cls(_BlockType.SEQUENCE, tiles_34[0])
39-
if is_pon(tiles_34):
40-
return cls(_BlockType.TRIPLET, tiles_34[0])
41-
if is_kan(tiles_34):
42-
return cls(_BlockType.QUAD, tiles_34[0])
43-
msg = f"invalid meld type: {meld.type}, tiles: {tiles_34}"
44-
raise RuntimeError(msg)
24+
tile_34 = meld.tiles_34[0]
25+
match meld.type:
26+
case Meld.CHI:
27+
return cls(tile_34, _BlockType.SEQUENCE)
28+
case Meld.PON:
29+
return cls(tile_34, _BlockType.TRIPLET)
30+
case Meld.KAN | Meld.SHOUMINKAN:
31+
return cls(tile_34, _BlockType.QUAD)
32+
case _:
33+
msg = f"invalid meld type: {meld.type}, tiles: {meld.tiles_34}"
34+
raise RuntimeError(msg)
4535

4636
@property
4737
def tiles_34(self) -> list[int]:
@@ -126,21 +116,28 @@ def _divide_hand_impl(pure_hand: tuple[int, ...], melds: tuple[_Block, ...]) ->
126116

127117
@staticmethod
128118
def _decompose_chiitoitsu(pure_hand: list[int]) -> list[_Block]:
129-
blocks = [_Block(_BlockType.PAIR, i) for i, count in enumerate(pure_hand) if count == 2]
119+
blocks = [_Block(i, _BlockType.PAIR) for i, count in enumerate(pure_hand) if count == 2]
130120
return blocks if len(blocks) == 7 else []
131121

132122
@staticmethod
133123
def _decompose_single_color_hand(single_color_hand: list[int], suit: Literal[0, 9, 18]) -> list[list[_Block]]:
134-
combinations = HandDivider._decompose_single_color_hand_without_pair(single_color_hand, [], 0, suit)
124+
remaining = sum(single_color_hand)
125+
combinations = HandDivider._decompose_single_color_hand_without_pair(single_color_hand, [], 0, suit, remaining)
135126

136127
if not combinations:
137128
for pair in range(9):
138129
if single_color_hand[pair] < 2:
139130
continue
140131

141132
single_color_hand[pair] -= 2
142-
blocks = [_Block(_BlockType.PAIR, suit + pair)]
143-
comb = HandDivider._decompose_single_color_hand_without_pair(single_color_hand, blocks, 0, suit)
133+
blocks = [_Block(suit + pair, _BlockType.PAIR)]
134+
comb = HandDivider._decompose_single_color_hand_without_pair(
135+
single_color_hand,
136+
blocks,
137+
0,
138+
suit,
139+
remaining - 2,
140+
)
144141
single_color_hand[pair] += 2
145142

146143
if not comb:
@@ -156,25 +153,33 @@ def _decompose_single_color_hand_without_pair(
156153
blocks: list[_Block],
157154
i: int,
158155
suit: Literal[0, 9, 18],
156+
remaining: int,
159157
) -> list[list[_Block]]:
160158
if i == 9:
161-
return [blocks] if sum(single_color_hand) == 0 else []
159+
return [blocks] if remaining == 0 else []
162160

163161
if single_color_hand[i] == 0:
164-
return HandDivider._decompose_single_color_hand_without_pair(single_color_hand, blocks, i + 1, suit)
162+
return HandDivider._decompose_single_color_hand_without_pair(
163+
single_color_hand,
164+
blocks,
165+
i + 1,
166+
suit,
167+
remaining,
168+
)
165169

166170
combinations: list[list[_Block]] = []
167171

168172
if i < 7 and single_color_hand[i] >= 1 and single_color_hand[i + 1] >= 1 and single_color_hand[i + 2] >= 1:
169173
single_color_hand[i] -= 1
170174
single_color_hand[i + 1] -= 1
171175
single_color_hand[i + 2] -= 1
172-
new_blocks = [*blocks, _Block(_BlockType.SEQUENCE, suit + i)]
176+
new_blocks = [*blocks, _Block(suit + i, _BlockType.SEQUENCE)]
173177
new_combination = HandDivider._decompose_single_color_hand_without_pair(
174178
single_color_hand,
175179
new_blocks,
176180
i,
177181
suit,
182+
remaining - 3,
178183
)
179184
combinations.extend(new_combination)
180185
single_color_hand[i + 2] += 1
@@ -183,12 +188,13 @@ def _decompose_single_color_hand_without_pair(
183188

184189
if single_color_hand[i] >= 3:
185190
single_color_hand[i] -= 3
186-
new_blocks = [*blocks, _Block(_BlockType.TRIPLET, suit + i)]
191+
new_blocks = [*blocks, _Block(suit + i, _BlockType.TRIPLET)]
187192
new_combination = HandDivider._decompose_single_color_hand_without_pair(
188193
single_color_hand,
189194
new_blocks,
190195
i + 1,
191196
suit,
197+
remaining - 3,
192198
)
193199
combinations.extend(new_combination)
194200
single_color_hand[i] += 3
@@ -206,10 +212,10 @@ def _decompose_honors_hand(honors_hand: list[int]) -> list[_Block]:
206212
case 2:
207213
if has_pair:
208214
return []
209-
blocks.append(_Block(_BlockType.PAIR, 27 + i))
215+
blocks.append(_Block(27 + i, _BlockType.PAIR))
210216
has_pair = True
211217
case 3:
212-
blocks.append(_Block(_BlockType.TRIPLET, 27 + i))
218+
blocks.append(_Block(27 + i, _BlockType.TRIPLET))
213219
case _:
214220
return []
215221

tests/hand_calculating/tests_hand_dividing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_fix_not_correct_kan_handling() -> None:
102102
ids=["int", "str", "float", "None"],
103103
)
104104
def test_block_eq_with_non_block_returns_not_implemented(other: object) -> None:
105-
block = _Block(_BlockType.TRIPLET, _string_to_34_tile(man="1"))
105+
block = _Block(tile_34=_string_to_34_tile(man="1"), ty=_BlockType.TRIPLET)
106106
assert block.__eq__(other) is NotImplemented
107107

108108

@@ -112,7 +112,7 @@ def test_block_eq_with_non_block_returns_not_implemented(other: object) -> None:
112112
ids=["int", "str", "float", "None"],
113113
)
114114
def test_block_lt_with_non_block_returns_not_implemented(other: object) -> None:
115-
block = _Block(_BlockType.TRIPLET, _string_to_34_tile(man="1"))
115+
block = _Block(tile_34=_string_to_34_tile(man="1"), ty=_BlockType.TRIPLET)
116116
assert block.__lt__(other) is NotImplemented
117117

118118

0 commit comments

Comments
 (0)