Skip to content

Commit d78d8cf

Browse files
authored
Merge pull request #205 from boxine/fix_connection_confusion
Fix issue where only the last database's connection was used in QueryAsserter
2 parents 7a7feba + e6dc0c0 commit d78d8cf

File tree

10 files changed

+137
-23
lines changed

10 files changed

+137
-23
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ ModelField, FormField and validators for GTIN/UPC/EAN numbers
152152

153153
#### bx_django_utils.dbperf.query_recorder
154154

155-
* [`SQLQueryRecorder()`](https://github.com/boxine/bx_django_utils/blob/master/bx_django_utils/dbperf/query_recorder.py#L95-L176) - A context manager that allows recording SQL queries executed during its lifetime.
155+
* [`SQLQueryRecorder()`](https://github.com/boxine/bx_django_utils/blob/master/bx_django_utils/dbperf/query_recorder.py#L106-L185) - A context manager that allows recording SQL queries executed during its lifetime.
156156

157157
### bx_django_utils.feature_flags
158158

@@ -261,7 +261,7 @@ Utilities / helper for writing tests.
261261

262262
#### bx_django_utils.test_utils.assert_queries
263263

264-
* [`AssertQueries()`](https://github.com/boxine/bx_django_utils/blob/master/bx_django_utils/test_utils/assert_queries.py#L34-L287) - Assert executed database queries: Check table names, duplicate/similar Queries.
264+
* [`AssertQueries()`](https://github.com/boxine/bx_django_utils/blob/master/bx_django_utils/test_utils/assert_queries.py#L34-L295) - Assert executed database queries: Check table names, duplicate/similar Queries.
265265

266266
#### bx_django_utils.test_utils.cache
267267

@@ -383,7 +383,15 @@ apt-get install pipx
383383
pipx install uv
384384
```
385385

386-
Clone the project and just use our `Makefile` e.g.:
386+
you should be able to then do
387+
```bash
388+
make install
389+
make playwright-install
390+
```
391+
392+
and validate everything works with `make test`
393+
394+
For other options, you can check out our makefile:
387395

388396
```bash
389397
~$ git clone https://github.com/boxine/bx_django_utils.git

bx_django_utils/admin_utils/tests/test_log_entry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
class LogEntryTestCase(TestCase):
1717
"""""" # noqa - don't add in README
1818

19+
databases = ['default', 'second']
1920
maxDiff = None
2021

2122
def test_basic(self):

bx_django_utils/dbperf/query_recorder.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
from collections.abc import Callable, Iterable
3+
from functools import partial
34
from pprint import saferepr
45

56
from django.db import connections
@@ -92,6 +93,16 @@ def dump(self, aggregate_queries=True):
9293
return results
9394

9495

96+
def _get_cursor_wrapper(*, cursor, connection, logger, collect_stacktrace, query_explain):
97+
return RecordingCursorWrapper(
98+
cursor(),
99+
connection,
100+
logger,
101+
collect_stacktrace=collect_stacktrace,
102+
query_explain=query_explain,
103+
)
104+
105+
95106
class SQLQueryRecorder:
96107
"""
97108
A context manager that allows recording SQL queries executed during its lifetime.
@@ -135,26 +146,23 @@ def __enter__(self):
135146
connection._recording_cursor = connection.cursor
136147
connection._recording_chunked_cursor = connection.chunked_cursor
137148

138-
def cursor():
139-
return RecordingCursorWrapper(
140-
connection._recording_cursor(), # noqa:B023
141-
connection, # noqa:B023
142-
self.logger,
143-
collect_stacktrace=self.collect_stacktrace,
144-
query_explain=self.query_explain,
145-
)
146-
147-
def chunked_cursor():
148-
return RecordingCursorWrapper(
149-
connection._recording_chunked_cursor(), # noqa:B023
150-
connection, # noqa:B023
151-
self.logger,
152-
collect_stacktrace=self.collect_stacktrace,
153-
query_explain=self.query_explain,
154-
)
155-
156-
connection.cursor = cursor
157-
connection.chunked_cursor = chunked_cursor
149+
common_kwargs = {
150+
'connection': connection,
151+
'logger': self.logger,
152+
'collect_stacktrace': self.collect_stacktrace,
153+
'query_explain': self.query_explain
154+
}
155+
156+
connection.cursor = partial(
157+
_get_cursor_wrapper,
158+
cursor=connection._recording_cursor,
159+
**common_kwargs
160+
)
161+
connection.chunked_cursor = partial(
162+
_get_cursor_wrapper,
163+
cursor=connection._recording_chunked_cursor,
164+
**common_kwargs
165+
)
158166

159167
self.running = True
160168
return self
@@ -166,6 +174,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
166174

167175
# undo the cursor wrapping so the connection is 'clean' again
168176
del connection._recording_cursor
177+
del connection._recording_chunked_cursor
169178
del connection.cursor
170179
del connection.chunked_cursor
171180

bx_django_utils/test_utils/assert_queries.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ def assert_table_names(self, *expected_table_names):
175175
))
176176
raise AssertionError(self.build_error_message(f'Table names does not match:\n{diff}'))
177177

178+
def assert_databases_touched(self, *expected_database_aliases):
179+
touched_keys = self.logger._databases.keys()
180+
if set(expected_database_aliases) != touched_keys:
181+
raise AssertionError(
182+
self.build_error_message(f'Not all expected tables were touched, '
183+
f'expected: {expected_database_aliases}, touched: {touched_keys}')
184+
)
185+
178186
def assert_table_counts(self, table_counts: Counter | dict, exclude: tuple[str, ...] | None = None):
179187
if not isinstance(table_counts, Counter):
180188
table_counts = Counter(table_counts)

bx_django_utils_tests/test_app/migrations/0001_initial.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ class Migration(migrations.Migration):
4040
),
4141
],
4242
),
43+
migrations.CreateModel(
44+
name='ColorFieldTestModelSecondary',
45+
fields=[
46+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
47+
('required_color', bx_django_utils.models.color_field.ColorModelField(max_length=7)),
48+
(
49+
'optional_color',
50+
bx_django_utils.models.color_field.ColorModelField(blank=True, max_length=7, null=True),
51+
),
52+
],
53+
),
4354
migrations.CreateModel(
4455
name='ConnectedUniqueSlugModel1',
4556
fields=[

bx_django_utils_tests/test_app/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ class ColorFieldTestModel(models.Model):
4646
optional_color = ColorModelField(blank=True, null=True)
4747

4848

49+
class ColorFieldTestModelSecondary(models.Model):
50+
required_color = ColorModelField()
51+
optional_color = ColorModelField(blank=True, null=True)
52+
53+
4954
class StoreSaveModel(models.Model):
5055
name = models.CharField(max_length=64)
5156
_save_calls = threading.local()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
2+
3+
class MultiDBRouter:
4+
5+
def _model_name_to_db(self, model_name: str):
6+
if 'secondary' in model_name.lower():
7+
return 'second'
8+
else:
9+
return 'default'
10+
11+
def db_for_read(self, model, **hints):
12+
return self._model_name_to_db(model.__name__)
13+
14+
def db_for_write(self, model, **hints):
15+
return self._model_name_to_db(model.__name__)
16+
17+
def allow_migrate(self, db, app_label, model_name=None, **hints):
18+
if model_name:
19+
return self._model_name_to_db(model_name) == db
20+
else:
21+
return 'default'

bx_django_utils_tests/test_project/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,13 @@
9898
'default': {
9999
'ENGINE': 'django.db.backends.sqlite3',
100100
'NAME': str(BASE_DIR / 'db.sqlite3'),
101+
},
102+
'second': {
103+
'ENGINE': 'django.db.backends.sqlite3',
104+
'NAME': str(BASE_DIR / 'db_second.sqlite3'),
101105
}
102106
}
107+
DATABASE_ROUTERS = ['bx_django_utils_tests.test_project.routers.MultiDBRouter']
103108
DEFAULT_AUTO_FIELD = 'django.db.models.AutoField'
104109

105110
# Password validation

bx_django_utils_tests/tests/test_assert_queries.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from django.test import TestCase
1111

1212
from bx_django_utils.test_utils.assert_queries import AssertQueries
13+
from bx_django_utils_tests.test_app.models import ColorFieldTestModelSecondary
1314

1415

1516
def make_database_queries(count=1):
@@ -22,6 +23,7 @@ def make_database_queries2(count=1):
2223

2324

2425
class AssertQueriesTestCase(TestCase):
26+
databases = ['default', 'second']
2527

2628
def get_instance(self):
2729
with AssertQueries() as queries:
@@ -89,6 +91,7 @@ def test_assert_table_counts(self):
8991
queries = self.get_instance()
9092
queries.assert_table_counts(Counter(auth_permission=1))
9193
queries.assert_table_counts({'auth_permission': 1})
94+
queries.assert_databases_touched('default')
9295

9396
with self.assertRaises(AssertionError) as err:
9497
queries.assert_table_counts(Counter(auth_permission=2, foo=1, bar=3))
@@ -101,6 +104,18 @@ def test_assert_table_counts(self):
101104
assert 'Captured queries were:\n' in msg
102105
assert '1. SELECT "auth_permission"' in msg
103106

107+
def test_assert_two_touched_dbs(self):
108+
with AssertQueries(databases=['default', 'second']) as queries:
109+
Permission.objects.all().first()
110+
with self.assertRaises(AssertionError):
111+
# only 1 touched.
112+
queries.assert_databases_touched('default', 'second')
113+
114+
with AssertQueries(databases=['default', 'second']) as queries:
115+
Permission.objects.all().first()
116+
ColorFieldTestModelSecondary.objects.count()
117+
queries.assert_databases_touched('default', 'second')
118+
104119
def test_assert_table_counts_exclude(self):
105120
with AssertQueries() as queries:
106121
Permission.objects.all().first()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from django.db import OperationalError
2+
from django.test import TestCase
3+
from model_bakery import baker
4+
5+
from bx_django_utils_tests.test_app.models import ColorFieldTestModel, ColorFieldTestModelSecondary
6+
7+
8+
class DatabaseRoutersTestCase(TestCase):
9+
databases = ['default', 'second']
10+
11+
def test_database_router(self):
12+
baker.make(
13+
ColorFieldTestModel,
14+
pk=1,
15+
required_color='#000002',
16+
optional_color='#000003',
17+
)
18+
with self.assertRaises(OperationalError):
19+
ColorFieldTestModel.objects.using("second").count()
20+
self.assertEqual(ColorFieldTestModel.objects.using("default").count(), 1)
21+
22+
def test_database_router_secondary(self):
23+
baker.make(
24+
ColorFieldTestModelSecondary,
25+
pk=1,
26+
required_color='#000002',
27+
optional_color='#000003',
28+
)
29+
with self.assertRaises(OperationalError):
30+
ColorFieldTestModelSecondary.objects.using("default").count()
31+
self.assertEqual(ColorFieldTestModelSecondary.objects.using("second").count(), 1)

0 commit comments

Comments
 (0)