diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index b0c65108e5..dce5a0b434 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -120,6 +120,11 @@ class DescrptSeAtten(DescrptSeA): When using stripped type embedding, whether to dot smooth factor on the network output of type embedding to keep the network smooth, instead of setting `set_davg_zero` to be True. Default value will be True in `se_atten_v2` descriptor. + + Raises + ------ + ValueError + if ntypes is 0. """ def __init__( @@ -178,6 +183,8 @@ def __init__( assert Version(TF_VERSION) > Version( "2" ), "se_atten only support tensorflow version 2.0 or higher." + if ntypes == 0: + raise ValueError("`model/type_map` is not set or empty!") self.stripped_type_embedding = stripped_type_embedding self.smooth = smooth_type_embdding self.ntypes = ntypes diff --git a/deepmd/model/model.py b/deepmd/model/model.py index 9ae5eacf4f..dddf6e2702 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -531,7 +531,7 @@ def __init__( self.descrpt = descriptor else: self.descrpt = Descriptor( - **descriptor, ntypes=len(type_map), spin=self.spin + **descriptor, ntypes=len(self.get_type_map()), spin=self.spin ) if isinstance(fitting_net, Fitting): diff --git a/deepmd/model/multi.py b/deepmd/model/multi.py index b0aa11a109..e111508bc4 100644 --- a/deepmd/model/multi.py +++ b/deepmd/model/multi.py @@ -122,7 +122,7 @@ def __init__( else: self.descrpt = Descriptor( **descriptor, - ntypes=len(type_map), + ntypes=len(self.get_type_map()), multi_task=True, spin=self.spin, )