Skip to content

Commit ac86f91

Browse files
authored
perf(common): improve equality caching by explicitly invalidating the entry on __del__ (ibis-project#8708)
Previously we used a rather complicated mechanism to implement global equality cache for operation nodes involving tricky weak reference tracking registering callbacks to invalidate cache entries. While this has greatly improved the overall performance of ibis internals we can have a simpler and more lightweight implementation by storing the equality comparison results in a `dict[dict[object_id, bool]]` data structure which allows us quick lookups and quick deletions. The caching is also specialized to a pair of objects in contrary to the previous `WeakCache` implementation which supported arbitrary number of key elements requiring multiple iterations over the key tuple.
1 parent 3d52904 commit ac86f91

9 files changed

Lines changed: 93 additions & 104 deletions

File tree

ibis/common/bases.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from typing import TYPE_CHECKING, Any
66
from weakref import WeakValueDictionary
77

8-
from ibis.common.caching import WeakCache
9-
108
if TYPE_CHECKING:
119
from collections.abc import Mapping
1210

@@ -141,41 +139,38 @@ class Comparable(Abstract):
141139
142140
Since the class holds a global cache of comparison results, it is important
143141
to make sure that the instances are not kept alive longer than necessary.
144-
This is done automatically by using weak references for the compared objects.
145142
"""
146143

147-
__cache__ = WeakCache()
148-
149-
def __eq__(self, other) -> bool:
150-
try:
151-
return self.__cached_equals__(other)
152-
except TypeError:
153-
return NotImplemented
144+
__cache__ = {}
154145

155146
@abstractmethod
156147
def __equals__(self, other) -> bool: ...
157148

158-
def __cached_equals__(self, other) -> bool:
149+
def __eq__(self, other) -> bool:
159150
if self is other:
160151
return True
161152

162153
# type comparison should be cheap
163154
if type(self) is not type(other):
164155
return False
165156

166-
# reduce space required for commutative operation
167-
if id(self) < id(other):
168-
key = (self, other)
169-
else:
170-
key = (other, self)
171-
157+
id1 = id(self)
158+
id2 = id(other)
172159
try:
173-
result = self.__cache__[key]
160+
return self.__cache__[id1][id2]
174161
except KeyError:
175162
result = self.__equals__(other)
176-
self.__cache__[key] = result
177-
178-
return result
163+
self.__cache__.setdefault(id1, {})[id2] = result
164+
self.__cache__.setdefault(id2, {})[id1] = result
165+
return result
166+
167+
def __del__(self):
168+
id1 = id(self)
169+
for id2 in self.__cache__.pop(id1, ()):
170+
eqs2 = self.__cache__[id2]
171+
del eqs2[id1]
172+
if not eqs2:
173+
del self.__cache__[id2]
179174

180175

181176
class SlottedMeta(AbstractMeta):

ibis/common/caching.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
from __future__ import annotations
22

33
import functools
4-
import weakref
54
from collections import Counter, defaultdict
6-
from collections.abc import MutableMapping
7-
from typing import TYPE_CHECKING, Any, Callable
5+
from typing import Any, Callable
86

97
from bidict import bidict
108

119
from ibis.common.exceptions import IbisError
1210

13-
if TYPE_CHECKING:
14-
from collections.abc import Iterator
15-
1611

1712
def memoize(func: Callable) -> Callable:
1813
"""Memoize a function."""
@@ -31,51 +26,6 @@ def wrapper(*args, **kwargs):
3126
return wrapper
3227

3328

34-
class WeakCache(MutableMapping):
35-
__slots__ = ("_data",)
36-
_data: dict
37-
38-
def __init__(self):
39-
object.__setattr__(self, "_data", {})
40-
41-
def __setattr__(self, name, value):
42-
raise TypeError(f"can't set {name}")
43-
44-
def __len__(self) -> int:
45-
return len(self._data)
46-
47-
def __iter__(self) -> Iterator[Any]:
48-
return iter(self._data)
49-
50-
def __setitem__(self, key, value) -> None:
51-
# construct an alternative representation of the key using the id()
52-
# of the key's components, this prevents infinite recursions
53-
identifiers = tuple(id(item) for item in key)
54-
55-
# create a function which removes the key from the cache
56-
def callback(ref_):
57-
return self._data.pop(identifiers, None)
58-
59-
# create weak references for the key's components with the callback
60-
# to remove the cache entry if any of the key's components gets
61-
# garbage collected
62-
refs = tuple(weakref.ref(item, callback) for item in key)
63-
64-
self._data[identifiers] = (value, refs)
65-
66-
def __getitem__(self, key):
67-
identifiers = tuple(id(item) for item in key)
68-
value, _ = self._data[identifiers]
69-
return value
70-
71-
def __delitem__(self, key):
72-
identifiers = tuple(id(item) for item in key)
73-
del self._data[identifiers]
74-
75-
def __repr__(self):
76-
return f"{self.__class__.__name__}({self._data})"
77-
78-
7929
class RefCountedCache:
8030
"""A cache with reference-counted keys.
8131

ibis/common/grounds.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AbstractMeta,
2323
Comparable,
2424
Final,
25+
Hashable,
2526
Immutable,
2627
Singleton,
2728
)

ibis/common/tests/test_bases.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
Singleton,
1818
Slotted,
1919
)
20-
from ibis.common.caching import WeakCache
2120

2221

2322
def test_classes_are_based_on_abstract():
@@ -79,9 +78,19 @@ def __init__(self, a, b):
7978
assert copy.deepcopy(foo) is foo
8079

8180

81+
class Cache(dict):
82+
def setpair(self, a, b, value):
83+
a, b = id(a), id(b)
84+
self.setdefault(a, {})[b] = value
85+
self.setdefault(b, {})[a] = value
86+
87+
def getpair(self, a, b):
88+
return self.get(id(a), {}).get(id(b))
89+
90+
8291
class Node(Comparable):
8392
# override the default cache object
84-
__cache__ = WeakCache()
93+
__cache__ = Cache()
8594
__slots__ = ("name",)
8695
num_equal_calls = 0
8796

@@ -107,14 +116,6 @@ def cache():
107116
assert not cache
108117

109118

110-
def pair(a, b):
111-
# for same ordering with comparable
112-
if id(a) < id(b):
113-
return (a, b)
114-
else:
115-
return (b, a)
116-
117-
118119
def test_comparable_basic(cache):
119120
a = Node(name="a")
120121
b = Node(name="a")
@@ -133,28 +134,48 @@ def test_comparable_caching(cache):
133134
d = Node(name="d")
134135
e = Node(name="e")
135136

136-
cache[pair(a, b)] = True
137-
cache[pair(a, c)] = False
138-
cache[pair(c, d)] = True
139-
cache[pair(b, d)] = False
140-
assert len(cache) == 4
137+
cache.setpair(a, b, True)
138+
cache.setpair(a, c, False)
139+
cache.setpair(c, d, True)
140+
cache.setpair(b, d, False)
141+
expected = {
142+
id(a): {id(b): True, id(c): False},
143+
id(b): {id(a): True, id(d): False},
144+
id(c): {id(a): False, id(d): True},
145+
id(d): {id(c): True, id(b): False},
146+
}
147+
assert cache == expected
141148

142149
assert a == b
150+
assert b == a
143151
assert a != c
152+
assert c != a
144153
assert c == d
154+
assert d == c
145155
assert b != d
156+
assert d != b
146157
assert Node.num_equal_calls == 0
158+
assert cache == expected
147159

148160
# no cache hit
149-
assert pair(a, e) not in cache
161+
assert cache.getpair(a, e) is None
150162
assert a != e
163+
assert cache.getpair(a, e) is False
151164
assert Node.num_equal_calls == 1
152-
assert len(cache) == 5
165+
expected = {
166+
id(a): {id(b): True, id(c): False, id(e): False},
167+
id(b): {id(a): True, id(d): False},
168+
id(c): {id(a): False, id(d): True},
169+
id(d): {id(c): True, id(b): False},
170+
id(e): {id(a): False},
171+
}
172+
assert cache == expected
153173

154174
# run only once
155175
assert e != a
156176
assert Node.num_equal_calls == 1
157-
assert pair(a, e) in cache
177+
assert cache.getpair(a, e) is False
178+
assert cache == expected
158179

159180

160181
def test_comparable_garbage_collection(cache):
@@ -163,16 +184,29 @@ def test_comparable_garbage_collection(cache):
163184
c = Node(name="c")
164185
d = Node(name="d")
165186

166-
cache[pair(a, b)] = True
167-
cache[pair(a, c)] = False
168-
cache[pair(c, d)] = True
169-
cache[pair(b, d)] = False
187+
cache.setpair(a, b, True)
188+
cache.setpair(a, c, False)
189+
cache.setpair(c, d, True)
190+
cache.setpair(b, d, False)
170191

171-
assert weakref.getweakrefcount(a) == 2
192+
assert cache.getpair(a, c) is False
193+
assert cache.getpair(c, d) is True
172194
del c
173-
assert weakref.getweakrefcount(a) == 1
195+
assert cache == {
196+
id(a): {id(b): True},
197+
id(b): {id(a): True, id(d): False},
198+
id(d): {id(b): False},
199+
}
200+
201+
assert cache.getpair(a, b) is True
202+
assert cache.getpair(b, d) is False
174203
del b
175-
assert weakref.getweakrefcount(a) == 0
204+
assert cache == {}
205+
206+
assert a != d
207+
assert cache == {id(a): {id(d): False}, id(d): {id(a): False}}
208+
del a
209+
assert cache == {}
176210

177211

178212
def test_comparable_cache_reuse(cache):

ibis/common/tests/test_graph_benchmarks.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional
3+
from typing import Any, Optional
44

55
import pytest
66
from typing_extensions import Self
@@ -19,9 +19,10 @@ class MyNode(Concrete, Node):
1919
d: frozendict[str, int]
2020
e: Optional[Self] = None
2121
f: tuple[Self, ...] = ()
22+
g: Any = None
2223

2324

24-
def generate_node(depth):
25+
def generate_node(depth, g=None):
2526
# generate a nested node object with the given depth
2627
if depth == 0:
2728
return MyNode(10, "20", c=(30, 40), d=frozendict(e=50, f=60))
@@ -32,6 +33,7 @@ def generate_node(depth):
3233
d=frozendict(e=5, f=6),
3334
e=generate_node(0),
3435
f=(generate_node(depth - 1), generate_node(0)),
36+
g=g,
3537
)
3638

3739

@@ -62,3 +64,12 @@ def test_replace_mapping(benchmark):
6264
node = generate_node(500)
6365
subs = {generate_node(1): generate_node(0)}
6466
benchmark(node.replace, subs)
67+
68+
69+
def test_equality_caching(benchmark):
70+
node = generate_node(150)
71+
other = generate_node(150)
72+
assert node == other
73+
assert other == node
74+
assert node is not other
75+
benchmark.pedantic(node.__eq__, args=[other], iterations=100, rounds=200)

ibis/common/tests/test_grounds.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,7 @@ def __equals__(self, other):
417417
assert a != c
418418
assert c != a
419419
assert a.__equals__(b)
420-
assert a.__cached_equals__(b)
421420
assert not a.__equals__(c)
422-
assert not a.__cached_equals__(c)
423421

424422

425423
def test_maintain_definition_order():

ibis/expr/datatypes/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def equals(self, other):
140140
raise TypeError(
141141
f"invalid equality comparison between DataType and {type(other)}"
142142
)
143-
return super().__cached_equals__(other)
143+
return self == other
144144

145145
def cast(self, other, **kwargs):
146146
# TODO(kszucs): remove it or deprecate it?

ibis/expr/operations/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def equals(self, other) -> bool:
2323
raise TypeError(
2424
f"invalid equality comparison between Node and {type(other)}"
2525
)
26-
return self.__cached_equals__(other)
26+
return self == other
2727

2828
# Avoid custom repr for performance reasons
2929
__repr__ = object.__repr__

ibis/expr/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def equals(self, other: Schema) -> bool:
8989
raise TypeError(
9090
f"invalid equality comparison between Schema and {type(other)}"
9191
)
92-
return self.__cached_equals__(other)
92+
return self == other
9393

9494
@classmethod
9595
def from_tuples(

0 commit comments

Comments
 (0)