diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index b493f9afa..7e25d7e10 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -173,10 +173,16 @@ def _randomize_attribute(cls, column: Column) -> Any: random_value = RandomBuilder.next_datetime(tz_aware=tz_aware) elif column.value_type == list: length = RandomBuilder.next_int(maximum=10) - base_type = cast(Array, column).base_column.value_type - random_value = [ - cls.__DEFAULT_MAPPER[base_type]() for _ in range(length) - ] + if column._meta.choices: + random_value = [ + RandomBuilder.next_enum(column._meta.choices) + for _ in range(length) + ] + else: + base_type = cast(Array, column).base_column.value_type + random_value = [ + cls.__DEFAULT_MAPPER[base_type]() for _ in range(length) + ] elif column._meta.choices: random_value = RandomBuilder.next_enum(column._meta.choices) else: diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index dba0dc791..b1d07376a 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -1,4 +1,5 @@ import asyncio +import enum import json import unittest @@ -30,9 +31,14 @@ class TableWithArrayField(Table): + class Choices(enum.Enum): + a = "a" + b = "b" + strings = Array(Varchar(30)) integers = Array(Integer()) floats = Array(Real()) + choices = Array(Varchar(), choices=Choices) class TableWithDecimal(Table): @@ -104,6 +110,16 @@ def test_choices(self): ["s", "l", "m"], ) + def test_array_choices(self): + """ + Make sure that ``ModelBuilder`` generates arrays where each array + element is a valid choice. + """ + instance = ModelBuilder.build_sync(TableWithArrayField) + for value in instance.choices: + # Will raise an exception if the enum value isn't found: + TableWithArrayField.Choices[value] + def test_datetime(self): """ Make sure that ``ModelBuilder`` generates timezone aware datetime