Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions dpdata/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class DataError(Exception):
"""Data is not correct."""


def _dtype_name(dtype) -> str:
"""Return a readable name for a dtype that may be a type or a tuple of types."""
if isinstance(dtype, tuple):
return ", ".join(t.__name__ for t in dtype)
return dtype.__name__


class DataType:
"""DataType represents a type of data, like coordinates, energies, etc.

Expand Down Expand Up @@ -96,7 +103,7 @@ def __repr__(self) -> str:
string representation
"""
return (
f"DataType(name='{self.name}', dtype={self.dtype.__name__}, "
f"DataType(name='{self.name}', dtype={_dtype_name(self.dtype)}, "
f"shape={self.shape}, required={self.required}, "
f"deepmd_name='{self.deepmd_name}')"
)
Expand Down Expand Up @@ -145,7 +152,7 @@ def check(self, system: System):
pass
elif not isinstance(data, self.dtype):
raise DataError(
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
f"Type of {self.name} is {type(data).__name__}, but expected {_dtype_name(self.dtype)}"
)
# check shape
if self.shape is not None:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_custom_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

import dpdata
from dpdata.data_type import Axis, DataType
from dpdata.data_type import Axis, DataError, DataType


class TestDataType(unittest.TestCase):
Expand Down Expand Up @@ -41,6 +41,31 @@ def test_repr(self):
)
self.assertEqual(repr(dt), expected)

def test_repr_tuple_dtype(self):
"""Regression test for #989: repr must handle a tuple dtype.

``check`` accepts ``dtype`` as a ``type`` or a ``tuple[type]``, so repr
(and the type-mismatch error) must not assume ``dtype`` has ``__name__``.
"""
dt = DataType("test", (list, np.ndarray), shape=(Axis.NFRAMES, 3))
expected = (
"DataType(name='test', dtype=list, ndarray, "
"shape=(<Axis.NFRAMES: 'nframes'>, 3), required=True, "
"deepmd_name='test')"
)
self.assertEqual(repr(dt), expected)

def test_check_tuple_dtype_error_message(self):
"""Regression test for #989: a type mismatch against a tuple dtype must
raise a readable DataError, not AttributeError on ``tuple.__name__``.
"""
dt = DataType("test", (list, np.ndarray), shape=(Axis.NFRAMES,))
system = dpdata.System()
system.data["test"] = "not a list or ndarray"
with self.assertRaises(DataError) as cm:
dt.check(system)
self.assertIn("expected list, ndarray", str(cm.exception))

def test_register_same_data_type_no_warning(self):
"""Test registering identical DataType instances should not warn."""
dt1 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3))
Expand Down
Loading