diff --git a/simpeg/directives/_save_geoh5.py b/simpeg/directives/_save_geoh5.py index 0096bc6b7d..4cd3cb8aa2 100644 --- a/simpeg/directives/_save_geoh5.py +++ b/simpeg/directives/_save_geoh5.py @@ -77,7 +77,7 @@ def get_names( base_name += f"_{component}" channel_name = base_name - if channel: + if len(channel) > 0: channel_name += f"_{channel}" if self.label is not None: @@ -86,6 +86,17 @@ def get_names( return channel_name, base_name + @staticmethod + def _channel_label(channel: int, label: str | float | None) -> str: + """ + Format the channel label. + """ + if isinstance(label, str) and len(label) > 1: + return label + elif isinstance(label, float): + return f"[{channel}]" + return "" + @abstractmethod def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noqa """ @@ -263,8 +274,9 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq if self.sorting is not None: values = values[self.sorting] + label = self._channel_label(ii, channel) channel_name, base_name = self.get_names( - component, channel, iteration + component, label, iteration ) data = h5_object.add_data( @@ -278,9 +290,7 @@ def write(self, iteration: int, values: list[np.ndarray] = None): # flake8: noq # Re-assign the data type if channel not in self.data_type[component].keys(): self.data_type[component][channel] = data.entity_type - type_name = f"{self._attribute_type}_{component}" - if channel: - type_name += f"_{channel}" + type_name = f"{self._attribute_type}_{component}" + f"_{label}" data.entity_type.name = type_name else: data.entity_type = w_s.find_type( @@ -443,10 +453,10 @@ def write(self, iteration: int, **_): for component in self.components: properties = [] - for channel in self.channels: - + for ii, channel in enumerate(self.channels): + label = self._channel_label(ii, channel) channel_name, base_name = self.get_names( - component, channel, iteration + component, label, iteration ) children = [ child @@ -490,7 +500,7 @@ def __init__( super().__init__(h5_object, **kwargs) def get_names( - self, component: str, channel: str, iteration: int + self, component: str, channel: int | None, iteration: int ) -> tuple[str, str]: """ Format the data and property_group name. @@ -545,7 +555,7 @@ def write(self, iteration: int, values: list[np.ndarray] | None = None): """ petro_model = self.get_values(values) petro_model = self.apply_transformations(petro_model).flatten() - channel_name, base_name = self.get_names("petrophysics", "", iteration) + channel_name, _ = self.get_names("petrophysics", "", iteration) with fetch_active_workspace(self._geoh5, mode="r+") as w_s: h5_object = w_s.get_entity(self.h5_object)[0] data = h5_object.add_data(