Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit dea5779

Browse files
authored
Add tests for database transaction callbacks (#12198)
Signed-off-by: Sean Quah <seanq@element.io>
1 parent 5dd949b commit dea5779

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

changelog.d/12198.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add tests for database transaction callbacks.

tests/storage/test_database.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from synapse.storage.database import make_tuple_comparison_clause
15+
from typing import Callable, Tuple
16+
from unittest.mock import Mock, call
17+
18+
from twisted.test.proto_helpers import MemoryReactor
19+
20+
from synapse.server import HomeServer
21+
from synapse.storage.database import (
22+
DatabasePool,
23+
LoggingTransaction,
24+
make_tuple_comparison_clause,
25+
)
26+
from synapse.util import Clock
1627

1728
from tests import unittest
1829

@@ -22,3 +33,94 @@ def test_native_tuple_comparison(self):
2233
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
2334
self.assertEqual(clause, "(a,b) > (?,?)")
2435
self.assertEqual(args, [1, 2])
36+
37+
38+
class CallbacksTestCase(unittest.HomeserverTestCase):
39+
"""Tests for transaction callbacks."""
40+
41+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
42+
self.store = hs.get_datastores().main
43+
self.db_pool: DatabasePool = self.store.db_pool
44+
45+
def _run_interaction(
46+
self, func: Callable[[LoggingTransaction], object]
47+
) -> Tuple[Mock, Mock]:
48+
"""Run the given function in a database transaction, with callbacks registered.
49+
50+
Args:
51+
func: The function to be run in a transaction. The transaction will be
52+
retried if `func` raises an `OperationalError`.
53+
54+
Returns:
55+
Two mocks, which were registered as an `after_callback` and an
56+
`exception_callback` respectively, on every transaction attempt.
57+
"""
58+
after_callback = Mock()
59+
exception_callback = Mock()
60+
61+
def _test_txn(txn: LoggingTransaction) -> None:
62+
txn.call_after(after_callback, 123, 456, extra=789)
63+
txn.call_on_exception(exception_callback, 987, 654, extra=321)
64+
func(txn)
65+
66+
try:
67+
self.get_success_or_raise(
68+
self.db_pool.runInteraction("test_transaction", _test_txn)
69+
)
70+
except Exception:
71+
pass
72+
73+
return after_callback, exception_callback
74+
75+
def test_after_callback(self) -> None:
76+
"""Test that the after callback is called when a transaction succeeds."""
77+
after_callback, exception_callback = self._run_interaction(lambda txn: None)
78+
79+
after_callback.assert_called_once_with(123, 456, extra=789)
80+
exception_callback.assert_not_called()
81+
82+
def test_exception_callback(self) -> None:
83+
"""Test that the exception callback is called when a transaction fails."""
84+
_test_txn = Mock(side_effect=ZeroDivisionError)
85+
after_callback, exception_callback = self._run_interaction(_test_txn)
86+
87+
after_callback.assert_not_called()
88+
exception_callback.assert_called_once_with(987, 654, extra=321)
89+
90+
def test_failed_retry(self) -> None:
91+
"""Test that the exception callback is called for every failed attempt."""
92+
# Always raise an `OperationalError`.
93+
_test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
94+
after_callback, exception_callback = self._run_interaction(_test_txn)
95+
96+
after_callback.assert_not_called()
97+
exception_callback.assert_has_calls(
98+
[
99+
call(987, 654, extra=321),
100+
call(987, 654, extra=321),
101+
call(987, 654, extra=321),
102+
call(987, 654, extra=321),
103+
call(987, 654, extra=321),
104+
call(987, 654, extra=321),
105+
]
106+
)
107+
self.assertEqual(exception_callback.call_count, 6) # no additional calls
108+
109+
def test_successful_retry(self) -> None:
110+
"""Test callbacks for a failed transaction followed by a successful attempt."""
111+
# Raise an `OperationalError` on the first attempt only.
112+
_test_txn = Mock(
113+
side_effect=[self.db_pool.engine.module.OperationalError, None]
114+
)
115+
after_callback, exception_callback = self._run_interaction(_test_txn)
116+
117+
# Calling both `after_callback`s when the first attempt failed is rather
118+
# surprising (#12184). Let's document the behaviour in a test.
119+
after_callback.assert_has_calls(
120+
[
121+
call(123, 456, extra=789),
122+
call(123, 456, extra=789),
123+
]
124+
)
125+
self.assertEqual(after_callback.call_count, 2) # no additional calls
126+
exception_callback.assert_not_called()

0 commit comments

Comments
 (0)