diff --git a/dpdata/data_type.py b/dpdata/data_type.py index c689e99b..1dd49ea8 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -32,6 +32,13 @@ class DataError(Exception): """Data is not correct.""" +def _dtype_name(dtype) -> str: + """Return a readable name for a dtype that may be a type or a tuple of types.""" + if isinstance(dtype, tuple): + return ", ".join(t.__name__ for t in dtype) + return dtype.__name__ + + class DataType: """DataType represents a type of data, like coordinates, energies, etc. @@ -96,7 +103,7 @@ def __repr__(self) -> str: string representation """ return ( - f"DataType(name='{self.name}', dtype={self.dtype.__name__}, " + f"DataType(name='{self.name}', dtype={_dtype_name(self.dtype)}, " f"shape={self.shape}, required={self.required}, " f"deepmd_name='{self.deepmd_name}')" ) @@ -145,7 +152,7 @@ def check(self, system: System): pass elif not isinstance(data, self.dtype): raise DataError( - f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" + f"Type of {self.name} is {type(data).__name__}, but expected {_dtype_name(self.dtype)}" ) # check shape if self.shape is not None: diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index b26b0f17..72daaa00 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -7,7 +7,7 @@ import numpy as np import dpdata -from dpdata.data_type import Axis, DataType +from dpdata.data_type import Axis, DataError, DataType class TestDataType(unittest.TestCase): @@ -41,6 +41,31 @@ def test_repr(self): ) self.assertEqual(repr(dt), expected) + def test_repr_tuple_dtype(self): + """Regression test for #989: repr must handle a tuple dtype. + + ``check`` accepts ``dtype`` as a ``type`` or a ``tuple[type]``, so repr + (and the type-mismatch error) must not assume ``dtype`` has ``__name__``. + """ + dt = DataType("test", (list, np.ndarray), shape=(Axis.NFRAMES, 3)) + expected = ( + "DataType(name='test', dtype=list, ndarray, " + "shape=(, 3), required=True, " + "deepmd_name='test')" + ) + self.assertEqual(repr(dt), expected) + + def test_check_tuple_dtype_error_message(self): + """Regression test for #989: a type mismatch against a tuple dtype must + raise a readable DataError, not AttributeError on ``tuple.__name__``. + """ + dt = DataType("test", (list, np.ndarray), shape=(Axis.NFRAMES,)) + system = dpdata.System() + system.data["test"] = "not a list or ndarray" + with self.assertRaises(DataError) as cm: + dt.check(system) + self.assertIn("expected list, ndarray", str(cm.exception)) + def test_register_same_data_type_no_warning(self): """Test registering identical DataType instances should not warn.""" dt1 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3))