diff --git a/modflow_devtools/dfn.py b/modflow_devtools/dfn.py index 4d8c70a0..33bcc14a 100644 --- a/modflow_devtools/dfn.py +++ b/modflow_devtools/dfn.py @@ -25,24 +25,10 @@ from boltons.iterutils import remap from modflow_devtools.download import download_and_unzip +from modflow_devtools.misc import try_literal_eval -# TODO: use dataclasses instead of typed dicts? static -# methods on typed dicts are evidently not allowed -# mypy: ignore-errors - -def _try_literal_eval(value: str) -> Any: - """ - Try to parse a string as a literal. If this fails, - return the value unaltered. - """ - try: - return literal_eval(value) - except (SyntaxError, ValueError): - return value - - -def _try_parse_bool(value: Any) -> Any: +def try_parse_bool(value: Any) -> Any: """ Try to parse a boolean from a string as represented in a DFN file, otherwise return the value unaltered. @@ -54,7 +40,7 @@ def _try_parse_bool(value: Any) -> Any: return value -def _field_attr_sort_key(item) -> int: +def field_attr_sort_key(item) -> int: """ Sort key for input field attributes. The order is: -1. block @@ -90,6 +76,22 @@ def _field_attr_sort_key(item) -> int: return 8 +def block_sort_key(item) -> int: + k, _ = item + if k == "options": + return 0 + elif k == "dimensions": + return 1 + elif k == "griddata": + return 2 + elif k == "packagedata": + return 3 + elif "period" in k: + return 4 + else: + return 5 + + FormatVersion = Literal[1, 2] """DFN format version number.""" @@ -113,11 +115,26 @@ def _field_attr_sort_key(item) -> int: ] -_SCALAR_TYPES = FieldType.__args__[:4] +_SCALAR_TYPES = ("keyword", "integer", "double precision", "string") Dfns = dict[str, "Dfn"] Fields = dict[str, "Field"] +Block = Fields +Blocks = dict[str, Block] + + +def get_blocks(dfn: "Dfn") -> Blocks: + """ + Extract blocks from an input definition. Any entry whose key + is not explicitly defined in `Dfn` is a block. + """ + return dict( + sorted( + {k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(), # type: ignore + key=block_sort_key, + ) + ) class Field(TypedDict): @@ -125,12 +142,12 @@ class Field(TypedDict): name: str type: FieldType - shape: Any | None = None - block: str | None = None - default: Any | None = None - children: Optional["Fields"] = None - description: str | None = None - reader: Reader = "urword" + shape: Any | None + block: str | None + default: Any | None + children: Optional["Fields"] + description: str | None + reader: Reader class Ref(TypedDict): @@ -188,14 +205,14 @@ class Dfn(TypedDict): """ name: str - advanced: bool = False - multi: bool = False - parent: str | None = None - ref: Ref | None = None - sln: Sln | None = None - fkeys: Dfns | None = None - - @staticmethod + advanced: bool + multi: bool + parent: str | None + ref: Ref | None + sln: Sln | None + fkeys: Dfns | None + + @staticmethod # type: ignore[misc] def _load_v1_flat(f, common: dict | None = None) -> tuple[Mapping, list[str]]: field = {} flat = [] @@ -269,7 +286,7 @@ def _load_v1_flat(f, common: dict | None = None) -> tuple[Mapping, list[str]]: # the point of the OMD is to losslessly handle duplicate variable names return OMD(flat), meta - @classmethod + @classmethod # type: ignore[misc] def _load_v1(cls, f, name, **kwargs) -> "Dfn": """ Temporary load routine for the v1 DFN format. @@ -279,6 +296,41 @@ def _load_v1(cls, f, name, **kwargs) -> "Dfn": refs = kwargs.pop("refs", {}) flat, meta = Dfn._load_v1_flat(f, **kwargs) + def _convert_period_block(block: Block) -> Block: + """ + Convert a period block recarray to individual arrays, one per column. + + Extracts recarray fields and creates separate array variables. Gives + each an appropriate grid- or tdis-aligned shape as opposed to sparse + list shape in terms of maxbound as previously. + """ + + fields = list(block.values()) + if fields[0]["type"] == "recarray": + assert len(fields) == 1 + recarray_name = fields[0]["name"] + item = next(iter(fields[0]["children"].values())) + columns = item["children"] + else: + recarray_name = None + columns = block + block.pop(recarray_name, None) + cellid = columns.pop("cellid", None) + for col_name, column in columns.items(): + col_copy = column.copy() + old_dims = col_copy.get("shape") + if old_dims: + old_dims = old_dims[1:-1].split(",") + new_dims = ["nper"] + if cellid: + new_dims.append("nnodes") + if old_dims: + new_dims.extend([dim for dim in old_dims if dim != "maxbound"]) + col_copy["shape"] = f"({', '.join(new_dims)})" + block[col_name] = col_copy + + return block + def _convert_field(var: dict[str, Any]) -> Field: """ Convert an input field specification from its representation @@ -300,7 +352,7 @@ def _load(field) -> Field: # stay a string except default values, which we'll # try to parse as arbitrary literals below, and at # some point types, once we introduce type hinting - field = {k: _try_parse_bool(v) for k, v in field.items()} + field = {k: try_parse_bool(v) for k, v in field.items()} _name = field.pop("name") _type = field.pop("type", None) @@ -308,7 +360,7 @@ def _load(field) -> Field: shape = None if shape == "" else shape block = field.pop("block", None) default = field.pop("default", None) - default = _try_literal_eval(default) if _type != "string" else default + default = try_literal_eval(default) if _type != "string" else default description = field.pop("description", "") reader = field.pop("reader", "urword") ref = refs.get(_name, None) @@ -343,7 +395,7 @@ def _item() -> Field: name=_name, type="record", block=block, - fields=_fields(), + children=_fields(), description=description.replace( "is the list of", "is the record of" ), @@ -368,7 +420,7 @@ def _item() -> Field: name=first["name"] if single else _name, type=item_type, block=block, - fields=first["fields"] if single else fields, + children=first["children"] if single else fields, description=description.replace( "is the list of", f"is the {item_type} of" ), @@ -411,15 +463,16 @@ def _fields() -> Fields: ) if _type.startswith("recarray"): - var_["item"] = _item() + item = _item() + var_["children"] = {item["name"]: item} var_["type"] = "recarray" elif _type.startswith("keystring"): - var_["choices"] = _choices() + var_["children"] = _choices() var_["type"] = "keystring" elif _type.startswith("record"): - var_["fields"] = _fields() + var_["children"] = _fields() var_["type"] = "record" # for now, we can tell a var is an array if its type @@ -453,7 +506,7 @@ def _fields() -> Fields: return var_ - return dict(sorted(_load(var).items(), key=_field_attr_sort_key)) + return dict(sorted(_load(var).items(), key=field_attr_sort_key)) # load top-level fields. any nested # fields will be loaded recursively @@ -469,11 +522,10 @@ def _fields() -> Fields: for block_name, block in groupby(fields.values(), lambda v: v["block"]) } - # mark transient blocks - transient_index_vars = flat.getlist("iper") - for transient_index in transient_index_vars: - transient_block = transient_index["block"] - blocks[transient_block]["transient_block"] = True + # if there's a period block, extract distinct arrays from + # the recarray-style definition + if (period_block := blocks.get("period", None)) is not None: + blocks["period"] = _convert_period_block(period_block) # remove unneeded variable attributes def remove_attrs(path, key, value): @@ -559,14 +611,14 @@ def _rest(): **blocks, ) - @classmethod + @classmethod # type: ignore[misc] def _load_v2(cls, f, name) -> "Dfn": data = tomli.load(f) if name and name != data.get("name", None): raise ValueError(f"Name mismatch, expected {name}") return cls(**data) - @classmethod + @classmethod # type: ignore[misc] def load( cls, f, @@ -585,7 +637,7 @@ def load( else: raise ValueError(f"Unsupported version, expected one of {version.__args__}") - @staticmethod + @staticmethod # type: ignore[misc] def _load_all_v1(dfndir: PathLike) -> Dfns: paths: list[Path] = [ p for p in dfndir.glob("*.dfn") if p.stem not in ["common", "flopy"] @@ -617,7 +669,7 @@ def _load_all_v1(dfndir: PathLike) -> Dfns: return dfns - @staticmethod + @staticmethod # type: ignore[misc] def _load_all_v2(dfndir: PathLike) -> Dfns: paths: list[Path] = [ p for p in dfndir.glob("*.toml") if p.stem not in ["common", "flopy"] @@ -630,7 +682,7 @@ def _load_all_v2(dfndir: PathLike) -> Dfns: return dfns - @staticmethod + @staticmethod # type: ignore[misc] def load_all(dfndir: PathLike, version: FormatVersion = 1) -> Dfns: """Load all component definitions from the given directory.""" if version == 1: @@ -640,7 +692,7 @@ def load_all(dfndir: PathLike, version: FormatVersion = 1) -> Dfns: else: raise ValueError(f"Unsupported version, expected one of {version.__args__}") - @staticmethod + @staticmethod # type: ignore[misc] def load_tree(dfndir: PathLike, version: FormatVersion = 2) -> dict: """Load all definitions and return as hierarchical tree.""" dfns = Dfn.load_all(dfndir, version) @@ -664,8 +716,8 @@ def infer_tree(dfns: dict[str, Dfn]) -> dict: if root_name != "sim": raise ValueError(f"Root component must be named 'sim', found '{root_name}'") - def add_children(node_name: str) -> dict: - node = dfns[node_name].copy() + def add_children(node_name: str) -> dict[str, Any]: + node = dict(dfns[node_name]) children = [ name for name, dfn in dfns.items() if dfn.get("parent") == node_name ] @@ -684,7 +736,7 @@ def get_dfns( if verbose: print(f"Downloading MODFLOW 6 repository from {url}") with tempfile.TemporaryDirectory() as tmp: - dl_path = download_and_unzip(url, tmp, verbose=verbose) + dl_path = download_and_unzip(url, Path(tmp), verbose=verbose) contents = list(dl_path.glob("modflow6-*")) proj_path = next(iter(contents), None) if not proj_path: diff --git a/modflow_devtools/misc.py b/modflow_devtools/misc.py index e9d40929..65bf614b 100644 --- a/modflow_devtools/misc.py +++ b/modflow_devtools/misc.py @@ -565,3 +565,19 @@ def try_get_enum_value(v: Any) -> Any: of an enumeration, otherwise return it unaltered. """ return v.value if isinstance(v, Enum) else v + + +# TODO: use dataclasses instead of typed dicts? static +# methods on typed dicts are evidently not allowed +# mypy: ignore-errors + + +def try_literal_eval(value: str) -> Any: + """ + Try to parse a string as a literal. If this fails, + return the value unaltered. + """ + try: + return literal_eval(value) + except (SyntaxError, ValueError): + return value