Skip to content

Commit 2a418a1

Browse files
committed
uuid and boolean converter
1 parent cb942da commit 2a418a1

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

piccolo/engine/mysql.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
import contextvars
44
import json
5+
import uuid
56
from collections.abc import Sequence
67
from dataclasses import dataclass
78
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
89

10+
from pymysql.constants import FIELD_TYPE
11+
from pymysql.converters import conversions
912
from typing_extensions import Self
1013

1114
from piccolo.engine.base import (
@@ -32,16 +35,45 @@
3235
from piccolo.table import Table
3336

3437

38+
# converters and formaters
3539
def backticks_format(querysting: str) -> str:
3640
return querysting.replace('"', "`")
3741

3842

39-
def converter(value: list) -> str:
43+
def convert_list(value: list) -> str:
4044
if isinstance(value, list):
4145
return json.dumps(value)
4246
return value
4347

4448

49+
def convert_bool(value: int) -> bool:
50+
return bool(int(value)) if value is not None else None
51+
52+
53+
def convert_uuid(value: Any) -> Union[str, uuid.UUID]:
54+
if isinstance(value, (bytes, bytearray)):
55+
value = value.decode()
56+
value = value.strip()
57+
# check if string is uuid string
58+
if len(value) == 36 and value.count("-") == 4:
59+
try:
60+
return uuid.UUID(value)
61+
except ValueError:
62+
return value
63+
return value
64+
65+
66+
converters = conversions.copy()
67+
custom_decoders: dict[str, Any] = {
68+
FIELD_TYPE.STRING: convert_uuid,
69+
FIELD_TYPE.VAR_STRING: convert_uuid,
70+
FIELD_TYPE.VARCHAR: convert_uuid,
71+
FIELD_TYPE.CHAR: convert_uuid,
72+
FIELD_TYPE.TINY: convert_bool,
73+
}
74+
converters.update(custom_decoders)
75+
76+
4577
@dataclass
4678
class AsyncBatch(BaseBatch):
4779
connection: Connection
@@ -282,6 +314,8 @@ def __init__(
282314
self.current_transaction = contextvars.ContextVar(
283315
f"mysql_current_transaction_{db_name}", default=None
284316
)
317+
# converters
318+
config["conv"] = converters
285319

286320
super().__init__(
287321
engine_type="mysql",
@@ -392,7 +426,7 @@ async def _run_in_new_connection(
392426
args = []
393427
connection = await self.get_new_connection()
394428
# convert lists
395-
params = tuple(converter(arg) for arg in args)
429+
params = tuple(convert_list(arg) for arg in args)
396430
try:
397431
async with connection.cursor() as cursor:
398432
await cursor.execute(query, params)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ line_length = 79
1111
module = [
1212
"asyncpg.*",
1313
"aiomysql.*",
14+
"pymysql.*",
1415
"colorama",
1516
"dateutil",
1617
"IPython",

tests/apps/fixtures/commands/test_load_dump_mysql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _run_comparison(self, table_class_names: list[str]):
116116
{
117117
"id": 1,
118118
"bigint_col": 1,
119-
"boolean_col": 1,
119+
"boolean_col": True,
120120
"bytea_col": b"hello",
121121
"date_col": datetime.date(2021, 1, 1),
122122
"foreignkey_col": 1,
@@ -131,7 +131,7 @@ def _run_comparison(self, table_class_names: list[str]):
131131
"text_col": "hello",
132132
"timestamp_col": datetime.datetime(2021, 1, 1, 0, 0),
133133
"timestamptz_col": datetime.datetime(2021, 1, 1, 0, 0),
134-
"uuid_col": "12783854-c012-4c15-8183-8eecb46f2c4e",
134+
"uuid_col": uuid.UUID("12783854-c012-4c15-8183-8eecb46f2c4e"),
135135
"varchar_col": "hello",
136136
"unique_col": "hello",
137137
"null_col": None,

0 commit comments

Comments
 (0)