|
2 | 2 |
|
3 | 3 | import contextvars |
4 | 4 | import json |
| 5 | +import uuid |
5 | 6 | from collections.abc import Sequence |
6 | 7 | from dataclasses import dataclass |
7 | 8 | from typing import TYPE_CHECKING, Any, Mapping, Optional, Union |
8 | 9 |
|
| 10 | +from pymysql.constants import FIELD_TYPE |
| 11 | +from pymysql.converters import conversions |
9 | 12 | from typing_extensions import Self |
10 | 13 |
|
11 | 14 | from piccolo.engine.base import ( |
|
32 | 35 | from piccolo.table import Table |
33 | 36 |
|
34 | 37 |
|
| 38 | +# converters and formaters |
35 | 39 | def backticks_format(querysting: str) -> str: |
36 | 40 | return querysting.replace('"', "`") |
37 | 41 |
|
38 | 42 |
|
39 | | -def converter(value: list) -> str: |
| 43 | +def convert_list(value: list) -> str: |
40 | 44 | if isinstance(value, list): |
41 | 45 | return json.dumps(value) |
42 | 46 | return value |
43 | 47 |
|
44 | 48 |
|
| 49 | +def convert_bool(value: int) -> bool: |
| 50 | + return bool(int(value)) if value is not None else None |
| 51 | + |
| 52 | + |
| 53 | +def convert_uuid(value: Any) -> Union[str, uuid.UUID]: |
| 54 | + if isinstance(value, (bytes, bytearray)): |
| 55 | + value = value.decode() |
| 56 | + value = value.strip() |
| 57 | + # check if string is uuid string |
| 58 | + if len(value) == 36 and value.count("-") == 4: |
| 59 | + try: |
| 60 | + return uuid.UUID(value) |
| 61 | + except ValueError: |
| 62 | + return value |
| 63 | + return value |
| 64 | + |
| 65 | + |
| 66 | +converters = conversions.copy() |
| 67 | +custom_decoders: dict[str, Any] = { |
| 68 | + FIELD_TYPE.STRING: convert_uuid, |
| 69 | + FIELD_TYPE.VAR_STRING: convert_uuid, |
| 70 | + FIELD_TYPE.VARCHAR: convert_uuid, |
| 71 | + FIELD_TYPE.CHAR: convert_uuid, |
| 72 | + FIELD_TYPE.TINY: convert_bool, |
| 73 | +} |
| 74 | +converters.update(custom_decoders) |
| 75 | + |
| 76 | + |
45 | 77 | @dataclass |
46 | 78 | class AsyncBatch(BaseBatch): |
47 | 79 | connection: Connection |
@@ -282,6 +314,8 @@ def __init__( |
282 | 314 | self.current_transaction = contextvars.ContextVar( |
283 | 315 | f"mysql_current_transaction_{db_name}", default=None |
284 | 316 | ) |
| 317 | + # converters |
| 318 | + config["conv"] = converters |
285 | 319 |
|
286 | 320 | super().__init__( |
287 | 321 | engine_type="mysql", |
@@ -392,7 +426,7 @@ async def _run_in_new_connection( |
392 | 426 | args = [] |
393 | 427 | connection = await self.get_new_connection() |
394 | 428 | # convert lists |
395 | | - params = tuple(converter(arg) for arg in args) |
| 429 | + params = tuple(convert_list(arg) for arg in args) |
396 | 430 | try: |
397 | 431 | async with connection.cursor() as cursor: |
398 | 432 | await cursor.execute(query, params) |
|
0 commit comments