1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from synapse .storage .database import make_tuple_comparison_clause
15+ from typing import Callable , Tuple
16+ from unittest .mock import Mock , call
17+
18+ from twisted .test .proto_helpers import MemoryReactor
19+
20+ from synapse .server import HomeServer
21+ from synapse .storage .database import (
22+ DatabasePool ,
23+ LoggingTransaction ,
24+ make_tuple_comparison_clause ,
25+ )
26+ from synapse .util import Clock
1627
1728from tests import unittest
1829
@@ -22,3 +33,94 @@ def test_native_tuple_comparison(self):
2233 clause , args = make_tuple_comparison_clause ([("a" , 1 ), ("b" , 2 )])
2334 self .assertEqual (clause , "(a,b) > (?,?)" )
2435 self .assertEqual (args , [1 , 2 ])
36+
37+
38+ class CallbacksTestCase (unittest .HomeserverTestCase ):
39+ """Tests for transaction callbacks."""
40+
41+ def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) -> None :
42+ self .store = hs .get_datastores ().main
43+ self .db_pool : DatabasePool = self .store .db_pool
44+
45+ def _run_interaction (
46+ self , func : Callable [[LoggingTransaction ], object ]
47+ ) -> Tuple [Mock , Mock ]:
48+ """Run the given function in a database transaction, with callbacks registered.
49+
50+ Args:
51+ func: The function to be run in a transaction. The transaction will be
52+ retried if `func` raises an `OperationalError`.
53+
54+ Returns:
55+ Two mocks, which were registered as an `after_callback` and an
56+ `exception_callback` respectively, on every transaction attempt.
57+ """
58+ after_callback = Mock ()
59+ exception_callback = Mock ()
60+
61+ def _test_txn (txn : LoggingTransaction ) -> None :
62+ txn .call_after (after_callback , 123 , 456 , extra = 789 )
63+ txn .call_on_exception (exception_callback , 987 , 654 , extra = 321 )
64+ func (txn )
65+
66+ try :
67+ self .get_success_or_raise (
68+ self .db_pool .runInteraction ("test_transaction" , _test_txn )
69+ )
70+ except Exception :
71+ pass
72+
73+ return after_callback , exception_callback
74+
75+ def test_after_callback (self ) -> None :
76+ """Test that the after callback is called when a transaction succeeds."""
77+ after_callback , exception_callback = self ._run_interaction (lambda txn : None )
78+
79+ after_callback .assert_called_once_with (123 , 456 , extra = 789 )
80+ exception_callback .assert_not_called ()
81+
82+ def test_exception_callback (self ) -> None :
83+ """Test that the exception callback is called when a transaction fails."""
84+ _test_txn = Mock (side_effect = ZeroDivisionError )
85+ after_callback , exception_callback = self ._run_interaction (_test_txn )
86+
87+ after_callback .assert_not_called ()
88+ exception_callback .assert_called_once_with (987 , 654 , extra = 321 )
89+
90+ def test_failed_retry (self ) -> None :
91+ """Test that the exception callback is called for every failed attempt."""
92+ # Always raise an `OperationalError`.
93+ _test_txn = Mock (side_effect = self .db_pool .engine .module .OperationalError )
94+ after_callback , exception_callback = self ._run_interaction (_test_txn )
95+
96+ after_callback .assert_not_called ()
97+ exception_callback .assert_has_calls (
98+ [
99+ call (987 , 654 , extra = 321 ),
100+ call (987 , 654 , extra = 321 ),
101+ call (987 , 654 , extra = 321 ),
102+ call (987 , 654 , extra = 321 ),
103+ call (987 , 654 , extra = 321 ),
104+ call (987 , 654 , extra = 321 ),
105+ ]
106+ )
107+ self .assertEqual (exception_callback .call_count , 6 ) # no additional calls
108+
109+ def test_successful_retry (self ) -> None :
110+ """Test callbacks for a failed transaction followed by a successful attempt."""
111+ # Raise an `OperationalError` on the first attempt only.
112+ _test_txn = Mock (
113+ side_effect = [self .db_pool .engine .module .OperationalError , None ]
114+ )
115+ after_callback , exception_callback = self ._run_interaction (_test_txn )
116+
117+ # Calling both `after_callback`s when the first attempt failed is rather
118+ # surprising (#12184). Let's document the behaviour in a test.
119+ after_callback .assert_has_calls (
120+ [
121+ call (123 , 456 , extra = 789 ),
122+ call (123 , 456 , extra = 789 ),
123+ ]
124+ )
125+ self .assertEqual (after_callback .call_count , 2 ) # no additional calls
126+ exception_callback .assert_not_called ()
0 commit comments