1515
1616import logging
1717from collections import OrderedDict
18+ from typing import Any , Generic , Optional , TypeVar , Union , overload
19+
20+ import attr
21+ from typing_extensions import Literal
1822
1923from synapse .config import cache as cache_config
2024from synapse .metrics .background_process_metrics import run_as_background_process
25+ from synapse .util import Clock
2126from synapse .util .caches import register_cache
2227
2328logger = logging .getLogger (__name__ )
2429
2530
26- SENTINEL = object ()
31+ SENTINEL = object () # type: Any
32+
2733
34+ T = TypeVar ("T" )
35+ KT = TypeVar ("KT" )
36+ VT = TypeVar ("VT" )
2837
29- class ExpiringCache :
38+
39+ class ExpiringCache (Generic [KT , VT ]):
3040 def __init__ (
3141 self ,
32- cache_name ,
33- clock ,
34- max_len = 0 ,
35- expiry_ms = 0 ,
36- reset_expiry_on_get = False ,
37- iterable = False ,
42+ cache_name : str ,
43+ clock : Clock ,
44+ max_len : int = 0 ,
45+ expiry_ms : int = 0 ,
46+ reset_expiry_on_get : bool = False ,
47+ iterable : bool = False ,
3848 ):
3949 """
4050 Args:
41- cache_name (str) : Name of this cache, used for logging.
42- clock (Clock)
43- max_len (int) : Max size of dict. If the dict grows larger than this
51+ cache_name: Name of this cache, used for logging.
52+ clock
53+ max_len: Max size of dict. If the dict grows larger than this
4454 then the oldest items get automatically evicted. Default is 0,
4555 which indicates there is no max limit.
46- expiry_ms (int) : How long before an item is evicted from the cache
56+ expiry_ms: How long before an item is evicted from the cache
4757 in milliseconds. Default is 0, indicating items never get
4858 evicted based on time.
49- reset_expiry_on_get (bool) : If true, will reset the expiry time for
59+ reset_expiry_on_get: If true, will reset the expiry time for
5060 an item on access. Defaults to False.
51- iterable (bool) : If true, the size is calculated by summing the
61+ iterable: If true, the size is calculated by summing the
5262 sizes of all entries, rather than the number of entries.
5363 """
5464 self ._cache_name = cache_name
@@ -62,7 +72,7 @@ def __init__(
6272 self ._expiry_ms = expiry_ms
6373 self ._reset_expiry_on_get = reset_expiry_on_get
6474
65- self ._cache = OrderedDict ()
75+ self ._cache = OrderedDict () # type: OrderedDict[KT, _CacheEntry]
6676
6777 self .iterable = iterable
6878
@@ -79,12 +89,12 @@ def f():
7989
8090 self ._clock .looping_call (f , self ._expiry_ms / 2 )
8191
82- def __setitem__ (self , key , value ) :
92+ def __setitem__ (self , key : KT , value : VT ) -> None :
8393 now = self ._clock .time_msec ()
8494 self ._cache [key ] = _CacheEntry (now , value )
8595 self .evict ()
8696
87- def evict (self ):
97+ def evict (self ) -> None :
8898 # Evict if there are now too many items
8999 while self ._max_size and len (self ) > self ._max_size :
90100 _key , value = self ._cache .popitem (last = False )
@@ -93,7 +103,7 @@ def evict(self):
93103 else :
94104 self .metrics .inc_evictions ()
95105
96- def __getitem__ (self , key ) :
106+ def __getitem__ (self , key : KT ) -> VT :
97107 try :
98108 entry = self ._cache [key ]
99109 self .metrics .inc_hits ()
@@ -106,7 +116,7 @@ def __getitem__(self, key):
106116
107117 return entry .value
108118
109- def pop (self , key , default = SENTINEL ):
119+ def pop (self , key : KT , default : T = SENTINEL ) -> Union [ VT , T ] :
110120 """Removes and returns the value with the given key from the cache.
111121
112122 If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ def pop(self, key, default=SENTINEL):
115125 Identical functionality to `dict.pop(..)`.
116126 """
117127
118- value = self ._cache .pop (key , default )
128+ value = self ._cache .pop (key , SENTINEL )
129+ # The key was not found.
119130 if value is SENTINEL :
120- raise KeyError (key )
131+ if default is SENTINEL :
132+ raise KeyError (key )
133+ return default
121134
122- return value
135+ return value . value
123136
124- def __contains__ (self , key ) :
137+ def __contains__ (self , key : KT ) -> bool :
125138 return key in self ._cache
126139
127- def get (self , key , default = None ):
140+ @overload
141+ def get (self , key : KT , default : Literal [None ] = None ) -> Optional [VT ]:
142+ ...
143+
144+ @overload
145+ def get (self , key : KT , default : T ) -> Union [VT , T ]:
146+ ...
147+
148+ def get (self , key : KT , default : Optional [T ] = None ) -> Union [VT , Optional [T ]]:
128149 try :
129150 return self [key ]
130151 except KeyError :
131152 return default
132153
133- def setdefault (self , key , value ) :
154+ def setdefault (self , key : KT , value : VT ) -> VT :
134155 try :
135156 return self [key ]
136157 except KeyError :
137158 self [key ] = value
138159 return value
139160
140- def _prune_cache (self ):
161+ def _prune_cache (self ) -> None :
141162 if not self ._expiry_ms :
142163 # zero expiry time means don't expire. This should never get called
143164 # since we have this check in start too.
@@ -166,7 +187,7 @@ def _prune_cache(self):
166187 len (self ),
167188 )
168189
169- def __len__ (self ):
190+ def __len__ (self ) -> int :
170191 if self .iterable :
171192 return sum (len (entry .value ) for entry in self ._cache .values ())
172193 else :
@@ -190,9 +211,7 @@ def set_cache_factor(self, factor: float) -> bool:
190211 return False
191212
192213
214+ @attr .s (slots = True )
193215class _CacheEntry :
194- __slots__ = ["time" , "value" ]
195-
196- def __init__ (self , time , value ):
197- self .time = time
198- self .value = value
216+ time = attr .ib (type = int )
217+ value = attr .ib ()
0 commit comments