diff --git a/autotest/test_dfn.py b/autotest/test_dfn.py index e5158c8e..0611ee08 100644 --- a/autotest/test_dfn.py +++ b/autotest/test_dfn.py @@ -2,23 +2,25 @@ import pytest -from modflow_devtools.dfn import Dfn, get_dfns +from modflow_devtools.dfn import _load_common, load, load_flat +from modflow_devtools.dfn.fetch import fetch_dfns from modflow_devtools.dfn2toml import convert from modflow_devtools.markers import requires_pkg PROJ_ROOT = Path(__file__).parents[1] DFN_DIR = PROJ_ROOT / "autotest" / "temp" / "dfn" TOML_DIR = DFN_DIR / "toml" -VERSIONS = {1: DFN_DIR, 2: TOML_DIR} +SPEC_DIRS = {1: DFN_DIR, 2: TOML_DIR} MF6_OWNER = "MODFLOW-ORG" MF6_REPO = "modflow6" MF6_REF = "develop" +EMPTY_DFNS = {"exg-gwfgwe", "exg-gwfgwt", "exg-gwfprt", "sln-ems"} def pytest_generate_tests(metafunc): if "dfn_name" in metafunc.fixturenames: if not any(DFN_DIR.glob("*.dfn")): - get_dfns(MF6_OWNER, MF6_REPO, MF6_REF, DFN_DIR, verbose=True) + fetch_dfns(MF6_OWNER, MF6_REPO, MF6_REF, DFN_DIR, verbose=True) dfn_names = [ dfn.stem for dfn in DFN_DIR.glob("*.dfn") @@ -28,12 +30,10 @@ def pytest_generate_tests(metafunc): if "toml_name" in metafunc.fixturenames: convert(DFN_DIR, TOML_DIR) - dfn_paths = list(DFN_DIR.glob("*.dfn")) - assert all( - (TOML_DIR / f"{dfn.stem.replace('-nam', '')}.toml").is_file() - for dfn in dfn_paths - if "common" not in dfn.stem - ) + expected_toml_paths = [ + dfn for dfn in DFN_DIR.glob("*.dfn") if "common" not in dfn.stem + ] + assert all(toml_path.exists() for toml_path in expected_toml_paths) toml_names = [toml.stem for toml in TOML_DIR.glob("*.toml")] metafunc.parametrize("toml_name", toml_names, ids=toml_names) @@ -44,82 +44,73 @@ def test_load_v1(dfn_name): (DFN_DIR / "common.dfn").open() as common_file, (DFN_DIR / f"{dfn_name}.dfn").open() as dfn_file, ): - common, _ = Dfn._load_v1_flat(common_file) - dfn = Dfn.load(dfn_file, name=dfn_name, common=common) - assert any(dfn) + common = _load_common(common_file) + dfn = load(dfn_file, name=dfn_name, format="dfn", common=common) + assert any(dfn.fields) == (dfn.name not in EMPTY_DFNS) @requires_pkg("boltons") def test_load_v2(toml_name): with (TOML_DIR / f"{toml_name}.toml").open(mode="rb") as toml_file: - toml = Dfn.load(toml_file, name=toml_name, version=2) - assert any(toml) + dfn = load(toml_file, name=toml_name, format="toml") + assert any(dfn.fields) == (dfn.name not in EMPTY_DFNS) @requires_pkg("boltons") -@pytest.mark.parametrize("version", list(VERSIONS.keys())) -def test_load_all(version): - dfns = Dfn.load_all(VERSIONS[version], version=version) - assert any(dfns) - +@pytest.mark.parametrize("schema_version", list(SPEC_DIRS.keys())) +def test_load_all(schema_version): + dfns = load_flat(path=SPEC_DIRS[schema_version]) + for dfn in dfns.values(): + assert any(dfn.fields) == (dfn.name not in EMPTY_DFNS) -@requires_pkg("boltons") -def test_load_tree(): - import tempfile +@requires_pkg("boltons", "tomli") +def test_convert(function_tmpdir): import tomli - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) - convert(DFN_DIR, tmp_path) - - # Test file conversion and naming - assert (tmp_path / "sim.toml").exists() - assert (tmp_path / "gwf.toml").exists() - assert not (tmp_path / "sim-nam.toml").exists() - - # Test parent relationships in files - with (tmp_path / "sim.toml").open("rb") as f: - sim_data = tomli.load(f) - assert sim_data["name"] == "sim" - assert "parent" not in sim_data - - with (tmp_path / "gwf.toml").open("rb") as f: - gwf_data = tomli.load(f) - assert gwf_data["name"] == "gwf" - assert gwf_data["parent"] == "sim" - - # Test hierarchy enforcement and completeness - dfns = Dfn.load_all(tmp_path, version=2) - roots = [name for name, dfn in dfns.items() if not dfn.get("parent")] - assert len(roots) == 1 - assert roots[0] == "sim" - - for dfn in dfns.values(): - parent = dfn.get("parent") - if parent: - assert parent in dfns - - # Test tree building and navigation - tree = Dfn.load_tree(tmp_path, version=2) - assert "sim" in tree - assert tree["sim"]["name"] == "sim" - - for model_type in ["gwf", "gwt", "gwe"]: - if model_type in tree["sim"]: - assert tree["sim"][model_type]["name"] == model_type - assert tree["sim"][model_type]["parent"] == "sim" - - if "gwf" in tree["sim"]: - gwf_packages = [ - k - for k in tree["sim"]["gwf"].keys() - if k.startswith("gwf-") and isinstance(tree["sim"]["gwf"][k], dict) - ] - assert len(gwf_packages) > 0 - - if "gwf-dis" in tree["sim"]["gwf"]: - dis = tree["sim"]["gwf"]["gwf-dis"] - assert dis["name"] == "gwf-dis" - assert dis["parent"] == "gwf" - assert "options" in dis or "dimensions" in dis + convert(DFN_DIR, function_tmpdir) + + assert (function_tmpdir / "sim-nam.toml").exists() + assert (function_tmpdir / "gwf-nam.toml").exists() + + with (function_tmpdir / "sim-nam.toml").open("rb") as f: + sim_data = tomli.load(f) + assert sim_data["name"] == "sim-nam" + assert sim_data["schema_version"] == "2" + assert "parent" not in sim_data + + with (function_tmpdir / "gwf-nam.toml").open("rb") as f: + gwf_data = tomli.load(f) + assert gwf_data["name"] == "gwf-nam" + assert gwf_data["parent"] == "sim-nam" + assert gwf_data["schema_version"] == "2" + + dfns = load_flat(function_tmpdir) + roots = [] + for dfn in dfns.values(): + if dfn.parent: + assert dfn.parent in dfns + else: + roots.append(dfn.name) + assert len(roots) == 1 + root = dfns[roots[0]] + assert root.name == "sim-nam" + + models = root.children or {} + for mdl in models: + assert models[mdl].name == mdl + assert models[mdl].parent == "sim-nam" + + if gwf := models.get("gwf-nam", None): + pkgs = gwf.children or {} + pkgs = { + k: v + for k, v in pkgs.items() + if k.startswith("gwf-") and isinstance(v, dict) + } + assert len(pkgs) > 0 + if dis := pkgs.get("gwf-dis", None): + assert dis.name == "gwf-dis" + assert dis.parent == "gwf" + assert "options" in (dis.blocks or {}) + assert "dimensions" in (dis.blocks or {}) diff --git a/modflow_devtools/dfn.py b/modflow_devtools/dfn.py deleted file mode 100644 index 9a543ec0..00000000 --- a/modflow_devtools/dfn.py +++ /dev/null @@ -1,794 +0,0 @@ -""" -MODFLOW 6 definition file tools. Includes types for field -and component specification, a parser for the original -DFN format as well as for TOML definition files, and -a function to fetch DFNs from the MF6 repository. -""" - -import shutil -import tempfile -from ast import literal_eval -from collections.abc import Mapping -from itertools import groupby -from os import PathLike -from pathlib import Path -from typing import ( - Any, - Literal, - Optional, - TypedDict, -) -from warnings import warn - -import tomli -from boltons.dictutils import OMD -from boltons.iterutils import remap - -from modflow_devtools.download import download_and_unzip -from modflow_devtools.misc import try_literal_eval - - -def try_parse_bool(value: Any) -> Any: - """ - Try to parse a boolean from a string as represented - in a DFN file, otherwise return the value unaltered. - """ - if isinstance(value, str): - value = value.lower() - if value in ["true", "false"]: - return value == "true" - return value - - -def field_attr_sort_key(item) -> int: - """ - Sort key for input field attributes. The order is: - -1. block - 0. name - 1. type - 2. shape - 3. default - 4. reader - 5. optional - 6. longname - 7. description - """ - - k, _ = item - if k == "block": - return -1 - if k == "name": - return 0 - if k == "type": - return 1 - if k == "shape": - return 2 - if k == "default": - return 3 - if k == "reader": - return 4 - if k == "optional": - return 5 - if k == "longname": - return 6 - if k == "description": - return 7 - return 8 - - -def block_sort_key(item) -> int: - k, _ = item - if k == "options": - return 0 - elif k == "dimensions": - return 1 - elif k == "griddata": - return 2 - elif k == "packagedata": - return 3 - elif "period" in k: - return 4 - else: - return 5 - - -FormatVersion = Literal[1, 2] -"""DFN format version number.""" - - -FieldType = Literal[ - "keyword", - "integer", - "double precision", - "string", - "record", - "recarray", - "keystring", -] - - -Reader = Literal[ - "urword", - "u1ddbl", - "u2ddbl", - "readarray", -] - - -SCALAR_TYPES = ("keyword", "integer", "double precision", "string") -_SCALAR_TYPES = SCALAR_TYPES # allow backwards compat; imported by flopy - - -Dfns = dict[str, "Dfn"] -Fields = dict[str, "Field"] -Block = Fields -Blocks = dict[str, Block] - - -def get_blocks(dfn: "Dfn") -> Blocks: - """ - Extract blocks from an input definition. - """ - - def _is_block(item: tuple[str, Any]) -> bool: - k, _v = item - return k not in Dfn.__annotations__ - - return dict( - sorted( - {k: v for k, v in dfn.items() if _is_block((k, v))}.items(), # type: ignore - key=block_sort_key, - ) - ) - - -def get_fields(dfn: "Dfn") -> Fields: - """ - Extract a flat dictionary of fields from an input definition. - Only top-level fields are included, i.e. subfields of records - or recarrays are not included. - """ - fields = {} - for block in get_blocks(dfn).values(): - for field in block.values(): - if field["name"] in fields: - warn(f"Duplicate field name {field['name']} in {dfn['name']}") - fields[field["name"]] = field - return fields - - -class Field(TypedDict): - """A field specification.""" - - name: str - type: FieldType - shape: Any | None - block: str | None - default: Any | None - children: Optional["Fields"] - description: str | None - reader: Reader - - -class Ref(TypedDict): - """ - A foreign-key-like reference between a file input variable - in a referring input component and another input component - referenced by it. Previously known as a "subpackage". - - A `Dfn` with a nonempty `ref` can be referred to by other - component definitions, via a filepath variable which acts - as a foreign key. If such a variable is detected when any - component is loaded, the component's `__init__` method is - modified, such that the variable named `val`, residing in - the referenced component, replaces the variable with name - `key` in the referencing component, i.e., the foreign key - filepath variable, This forces a referencing component to - accept a subcomponent's data directly, as if it were just - a variable, rather than indirectly, with the subcomponent - loaded up from a file identified by the filepath variable. - """ - - key: str - val: str - abbr: str - param: str - parent: str - description: str | None - - -class Sln(TypedDict): - """ - A solution package specification. - """ - - abbr: str - pattern: str - - -class Dfn(TypedDict): - """ - MODFLOW 6 input definition. An input definition - specifies a component in an MF6 simulation, e.g. - a model or package. A component contains input - variables, and may contain other metadata such - as foreign key references to other components - (i.e. subpackages), package-specific metadata - (e.g. for solutions), advanced package status, - and whether the component is a multi-package. - - An input definition must have a name. Other top- - level keys are blocks, which must be mappings of - `str` to `Field`, and metadata, of which only a - limited set of keys are allowed. Block names and - metadata keys may not overlap. - """ - - name: str - advanced: bool - multi: bool - parent: str | None - ref: Ref | None - sln: Sln | None - - @staticmethod # type: ignore[misc] - def _load_v1_flat(f, common: dict | None = None) -> tuple[Mapping, list[str]]: - field = {} - flat = [] - meta = [] - common = common or {} - - for line in f: - # remove whitespace/etc from the line - line = line.strip() - - # record context name and flopy metadata - # attributes, skip all other comment lines - if line.startswith("#"): - _, sep, tail = line.partition("flopy") - if sep == "flopy": - if ( - "multi-package" in tail - or "solution_package" in tail - or "subpackage" in tail - or "parent" in tail - ): - meta.append(tail.strip()) - _, sep, tail = line.partition("package-type") - if sep == "package-type": - meta.append(f"package-type {tail.strip()}") - continue - - # if we hit a newline and the parameter dict - # is nonempty, we've reached the end of its - # block of attributes - if not any(line): - if any(field): - flat.append((field["name"], field)) - field = {} - continue - - # split the attribute's key and value and - # store it in the parameter dictionary - key, _, value = line.partition(" ") - if key == "default_value": - key = "default" - field[key] = value - - # make substitutions from common variable definitions, - # remove backslashes, TODO: generate/insert citations. - descr = field.get("description", None) - if descr: - descr = descr.replace("\\", "").replace("``", "'").replace("''", "'") - _, replace, tail = descr.strip().partition("REPLACE") - if replace: - key, _, subs = tail.strip().partition(" ") - subs = literal_eval(subs) - cmmn = common.get(key, None) - if cmmn is None: - warn( - "Can't substitute description text, " - f"common variable not found: {key}" - ) - else: - descr = cmmn.get("description", "") - if any(subs): - descr = descr.replace("\\", "").replace( - "{#1}", subs["{#1}"] - ) - field["description"] = descr - - # add the final parameter - if any(field): - flat.append((field["name"], field)) - - # the point of the OMD is to losslessly handle duplicate variable names - return OMD(flat), meta - - @classmethod # type: ignore[misc] - def _load_v1(cls, f, name, **kwargs) -> "Dfn": - """ - Temporary load routine for the v1 DFN format. - """ - - flat, meta = Dfn._load_v1_flat(f, **kwargs) - - def _convert_recarray_block(block: Block, block_name: str) -> Block: - """ - Convert a recarray block to individual arrays, one per column. - - Extract recarray fields and create separate array variables. For period - blocks, give each an appropriate grid- or time-aligned shape (nper, nnodes). - For other blocks, uses the declared dimensions directly. - """ - - fields = list(block.values()) - if fields[0]["type"] == "recarray": - assert len(fields) == 1 - recarray_field = fields[0] - recarray_name = recarray_field["name"] - item = next(iter(recarray_field["children"].values())) - columns = item["children"] - - # Get the original recarray shape to determine base dimensions - recarray_shape = recarray_field.get("shape") - if recarray_shape: - # Parse shape like "(nexg)" or "(maxbound)" - base_dims = recarray_shape[1:-1].split(",") - base_dims = [dim.strip() for dim in base_dims if dim.strip()] - else: - base_dims = [] - else: - recarray_name = None - columns = block - base_dims = [] - - # Remove the original recarray field - block.pop(recarray_name, None) - - # Handle cellid specially - it indicates spatial indexing - cellid = columns.pop("cellid", None) - - for col_name, column in columns.items(): - col_copy = column.copy() - old_dims = col_copy.get("shape") - if old_dims: - old_dims = old_dims[1:-1].split(",") - old_dims = [dim.strip() for dim in old_dims if dim.strip()] - else: - old_dims = [] - - # Determine new dimensions based on block type - if block_name == "period": - # Period blocks get time + spatial dimensions - new_dims = ["nper"] - if cellid: - new_dims.append("nnodes") - # Add any additional dimensions, excluding maxbound - if old_dims: - new_dims.extend([dim for dim in old_dims if dim != "maxbound"]) - else: - # Non-period blocks use declared dimensions - new_dims = [] - if base_dims: - # Use the dimensions from the recarray shape - # Only drop maxbound if there are other meaningful dimensions - filtered_base_dims = [ - dim for dim in base_dims if dim != "maxbound" - ] - if filtered_base_dims: - new_dims.extend(filtered_base_dims) - else: - # Keep maxbound if no other dimensions are available - new_dims.extend(base_dims) - # Add any column-specific dimensions - if old_dims: - filtered_old_dims = [ - dim for dim in old_dims if dim != "maxbound" - ] - if filtered_old_dims: - new_dims.extend(filtered_old_dims) - else: - # Keep maxbound if no other dimensions are available - new_dims.extend(old_dims) - - if new_dims: - col_copy["shape"] = f"({', '.join(new_dims)})" - else: - # Scalar field - col_copy["shape"] = None - - block[col_name] = col_copy - - return block - - def _convert_field(var: dict[str, Any]) -> Field: - """ - Convert an input field specification from its representation - in a v1 format definition file to the v2 (structured) format. - - Notes - ----- - If the field does not have a `default` attribute, it will - default to `False` if it is a keyword, otherwise to `None`. - - A filepath field whose name functions as a foreign key - for a separate context will be given a reference to it. - """ - - def _load(field) -> Field: - field = field.copy() - - # parse booleans from strings. everything else can - # stay a string except default values, which we'll - # try to parse as arbitrary literals below, and at - # some point types, once we introduce type hinting - field = {k: try_parse_bool(v) for k, v in field.items()} - - _name = field.pop("name") - _type = field.pop("type", None) - shape = field.pop("shape", None) - shape = None if shape == "" else shape - block = field.pop("block", None) - default = field.pop("default", None) - default = try_literal_eval(default) if _type != "string" else default - description = field.pop("description", "") - reader = field.pop("reader", "urword") - - def _item() -> Field: - """Load list item.""" - - item_names = _type.split()[1:] - item_types = [ - v["type"] - for v in flat.values(multi=True) - if v["name"] in item_names and v.get("in_record", False) - ] - n_item_names = len(item_names) - if n_item_names < 1: - raise ValueError(f"Missing list definition: {_type}") - - # explicit record - if n_item_names == 1 and ( - item_types[0].startswith("record") - or item_types[0].startswith("keystring") - ): - return _convert_field(next(iter(flat.getlist(item_names[0])))) - - # implicit simple record (no children) - if all(t in SCALAR_TYPES for t in item_types): - return Field( - name=_name, - type="record", - block=block, - children=_fields(), - description=description.replace( - "is the list of", "is the record of" - ), - reader=reader, - **field, - ) - - # implicit complex record (has children) - fields = { - v["name"]: _convert_field(v) - for v in flat.values(multi=True) - if v["name"] in item_names and v.get("in_record", False) - } - first = next(iter(fields.values())) - single = len(fields) == 1 - item_type = ( - "keystring" - if single and "keystring" in first["type"] - else "record" - ) - return Field( - name=first["name"] if single else _name, - type=item_type, - block=block, - children=first["children"] if single else fields, - description=description.replace( - "is the list of", f"is the {item_type} of" - ), - reader=reader, - **field, - ) - - def _choices() -> Fields: - """Load keystring (union) choices.""" - names = _type.split()[1:] - return { - v["name"]: _convert_field(v) - for v in flat.values(multi=True) - if v["name"] in names and v.get("in_record", False) - } - - def _fields() -> Fields: - """Load record fields.""" - names = _type.split()[1:] - fields = {} - for name in names: - v = flat.get(name, None) - if ( - not v - or not v.get("in_record", False) - or v["type"].startswith("record") - ): - continue - fields[name] = _convert_field(v) - return fields - - var_ = Field( - name=_name, - shape=shape, - block=block, - description=description, - default=default, - reader=reader, - **field, - ) - - if _type.startswith("recarray"): - item = _item() - var_["children"] = {item["name"]: item} - var_["type"] = "recarray" - - elif _type.startswith("keystring"): - var_["children"] = _choices() - var_["type"] = "keystring" - - elif _type.startswith("record"): - var_["children"] = _fields() - var_["type"] = "record" - - # for now, we can tell a var is an array if its type - # is scalar and it has a shape. once we have proper - # typing, this can be read off the type itself. - elif shape is not None and _type not in SCALAR_TYPES: - raise TypeError(f"Unsupported array type: {_type}") - - else: - var_["type"] = _type - - return var_ - - return dict(sorted(_load(var).items(), key=field_attr_sort_key)) - - # load top-level fields. any nested - # fields will be loaded recursively - fields = { - field["name"]: _convert_field(field) - for field in flat.values(multi=True) - if not field.get("in_record", False) - } - - # group variables by block - blocks = { - block_name: {v["name"]: v for v in block} - for block_name, block in groupby(fields.values(), lambda v: v["block"]) - } - - # extract distinct arrays from recarray-style definitions in all blocks - for block_name, block in blocks.items(): - # Check if this block contains any recarray fields - has_recarray = any(field["type"] == "recarray" for field in block.values()) - if has_recarray: - blocks[block_name] = _convert_recarray_block(block, block_name) - - # remove unneeded variable attributes - def remove_attrs(path, key, value): - if key in ["in_record", "tagged", "preserve_case"]: - return False - return True - - blocks = remap(blocks, visit=remove_attrs) - - def _advanced() -> bool | None: - return any("package-type advanced" in m for m in meta) - - def _multi() -> bool: - return any("multi-package" in m for m in meta) - - def _sln() -> Sln | None: - sln = next( - iter( - m - for m in meta - if isinstance(m, str) and m.startswith("solution_package") - ), - None, - ) - if sln: - abbr, pattern = sln.split()[1:] - return Sln(abbr=abbr, pattern=pattern) - return None - - def _sub() -> Ref | None: - def _parent(): - line = next( - iter( - m for m in meta if isinstance(m, str) and m.startswith("parent") - ), - None, - ) - if not line: - return None - split = line.split() - return split[1] - - def _rest(): - line = next( - iter( - m for m in meta if isinstance(m, str) and m.startswith("subpac") - ), - None, - ) - if not line: - return None - _, key, abbr, param, val = line.split() - matches = [v for v in fields.values() if v["name"] == val] - if not any(matches): - descr = None - else: - if len(matches) > 1: - warn(f"Multiple matches for referenced variable {val}") - match = matches[0] - descr = match["description"] - - return { - "key": key, - "val": val, - "abbr": abbr, - "param": param, - "description": descr, - } - - parent = _parent() - rest = _rest() - if parent and rest: - return Ref(parent=parent, **rest) - return None - - sln = _sln() - multi = ( - _multi() - or sln is not None - or ("nam" in name and "sim" not in name) - or name.startswith("exg-") - ) - - return cls( - name=name, - advanced=_advanced(), - multi=multi, - sln=sln, - ref=_sub(), - **blocks, - ) - - @classmethod # type: ignore[misc] - def _load_v2(cls, f, name) -> "Dfn": - data = tomli.load(f) - if name and name != data.get("name", None): - raise ValueError(f"Name mismatch, expected {name}") - return cls(**data) - - @classmethod # type: ignore[misc] - def load( - cls, - f, - name: str | None = None, - version: FormatVersion = 1, - **kwargs, - ) -> "Dfn": - """ - Load a component definition from a definition file. - """ - - if version == 1: - return cls._load_v1(f, name, **kwargs) - elif version == 2: - return cls._load_v2(f, name) - else: - raise ValueError(f"Unsupported version, expected one of {version.__args__}") - - @staticmethod # type: ignore[misc] - def _load_all_v1(dfndir: PathLike) -> Dfns: - paths: list[Path] = [ - p for p in dfndir.glob("*.dfn") if p.stem not in ["common", "flopy"] - ] - - # load common variables - common_path: Path | None = dfndir / "common.dfn" - if not common_path.is_file: - common = None - else: - with common_path.open() as f: - common, _ = Dfn._load_v1_flat(f) - - # load definitions - dfns: Dfns = {} - for path in paths: - with path.open() as f: - dfn = Dfn.load(f, name=path.stem, common=common) - dfns[path.stem] = dfn - - return dfns - - @staticmethod # type: ignore[misc] - def _load_all_v2(dfndir: PathLike) -> Dfns: - paths: list[Path] = [ - p for p in dfndir.glob("*.toml") if p.stem not in ["common", "flopy"] - ] - dfns: Dfns = {} - for path in paths: - with path.open(mode="rb") as f: - dfn = Dfn.load(f, name=path.stem, version=2) - dfns[path.stem] = dfn - - return dfns - - @staticmethod # type: ignore[misc] - def load_all(dfndir: PathLike, version: FormatVersion = 1) -> Dfns: - """Load all component definitions from the given directory.""" - if version == 1: - return Dfn._load_all_v1(dfndir) - elif version == 2: - return Dfn._load_all_v2(dfndir) - else: - raise ValueError(f"Unsupported version, expected one of {version.__args__}") - - @staticmethod # type: ignore[misc] - def load_tree(dfndir: PathLike, version: FormatVersion = 2) -> dict: - """Load all definitions and return as hierarchical tree.""" - dfns = Dfn.load_all(dfndir, version) - return infer_tree(dfns) - - -def infer_tree(dfns: dict[str, Dfn]) -> dict: - """Infer the component hierarchy from definitions. - - Enforces single root requirement - must be exactly one component - with no parent, and it must be named 'sim'. - """ - roots = [name for name, dfn in dfns.items() if not dfn.get("parent")] - - if len(roots) != 1: - raise ValueError( - f"Expected exactly one root component, found {len(roots)}: {roots}" - ) - - root_name = roots[0] - if root_name != "sim": - raise ValueError(f"Root component must be named 'sim', found '{root_name}'") - - def add_children(node_name: str) -> dict[str, Any]: - node = dict(dfns[node_name]) - children = [ - name for name, dfn in dfns.items() if dfn.get("parent") == node_name - ] - for child in children: - node[child] = add_children(child) - return node - - return {root_name: add_children(root_name)} - - -def get_dfns( - owner: str, repo: str, ref: str, outdir: str | PathLike, verbose: bool = False -): - """Fetch definition files from the MODFLOW 6 repository.""" - url = f"https://github.com/{owner}/{repo}/archive/{ref}.zip" - if verbose: - print(f"Downloading MODFLOW 6 repository from {url}") - with tempfile.TemporaryDirectory() as tmp: - dl_path = download_and_unzip(url, Path(tmp), verbose=verbose) - contents = list(dl_path.glob("modflow6-*")) - proj_path = next(iter(contents), None) - if not proj_path: - raise ValueError(f"Missing proj dir in {dl_path}, found {contents}") - if verbose: - print("Copying dfns from download dir to output dir") - shutil.copytree( - proj_path / "doc" / "mf6io" / "mf6ivar" / "dfn", outdir, dirs_exist_ok=True - ) diff --git a/modflow_devtools/dfn/__init__.py b/modflow_devtools/dfn/__init__.py new file mode 100644 index 00000000..c532d099 --- /dev/null +++ b/modflow_devtools/dfn/__init__.py @@ -0,0 +1,523 @@ +""" +MODFLOW 6 definition file tools. +""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass, replace +from itertools import groupby +from os import PathLike +from pathlib import Path +from typing import ( + Literal, + cast, +) + +import tomli +from boltons.dictutils import OMD +from boltons.iterutils import remap +from packaging.version import Version + +from modflow_devtools.dfn.parse import ( + is_advanced_package, + is_multi_package, + parse_dfn, + try_parse_bool, + try_parse_parent, +) +from modflow_devtools.dfn.schema.block import Block, Blocks +from modflow_devtools.dfn.schema.field import SCALAR_TYPES, Field, Fields +from modflow_devtools.dfn.schema.ref import Ref +from modflow_devtools.dfn.schema.v1 import FieldV1 +from modflow_devtools.dfn.schema.v2 import FieldV2 +from modflow_devtools.misc import drop_none_or_empty, try_literal_eval + +__all__ = [ + "SCALAR_TYPES", + "Block", + "Blocks", + "Dfn", + "Dfns", + "Field", + "FieldV1", + "FieldV2", + "Fields", + "Ref", + "load", + "load_flat", + "load_tree", + "map", + "to_flat", + "to_tree", +] + + +Format = Literal["dfn", "toml"] +"""DFN serialization format.""" + + +Dfns = dict[str, "Dfn"] + + +@dataclass +class Dfn: + """ + MODFLOW 6 input component definition. + """ + + schema_version: Version + name: str + parent: str | None = None + advanced: bool = False + multi: bool = False + ref: Ref | None = None + blocks: Blocks | None = None + children: Dfns | None = None + + @property + def fields(self) -> Fields: + """ + A combined map of fields from all blocks. + + Only top-level fields are included, no subfields of composites + such as records or recarrays. + """ + fields = [] + for block in (self.blocks or {}).values(): + for field in block.values(): + fields.append((field.name, field)) + + # for now return a multidict to support duplicate field names. + # TODO: change to normal dict after deprecating v1 schema + return OMD(fields) + + +class SchemaMap(ABC): + @abstractmethod + def map(self, dfn: Dfn) -> Dfn: ... + + +class MapV1To2(SchemaMap): + @staticmethod + def map_period_block(dfn: Dfn, block: Block) -> Block: + """ + Convert a period block recarray to individual arrays, one per column. + + Extracts recarray fields and creates separate array variables. Gives + each an appropriate grid- or tdis-aligned shape as opposed to sparse + list shape in terms of maxbound as previously. + """ + + block = dict(block) + fields = list(block.values()) + if fields[0].type == "recarray": + assert len(fields) == 1 + recarray_name = fields[0].name + block.pop(recarray_name, None) + item = next(iter((fields[0].children or {}).values())) + columns = dict(item.children or {}) + else: + recarray_name = None + columns = block + + cellid = columns.pop("cellid", None) + for col_name, column in columns.items(): + old_dims = column.shape + if old_dims: + old_dims = old_dims[1:-1].split(",") # type: ignore + new_dims = ["nper"] + if cellid: + new_dims.append("nnodes") + if old_dims: + new_dims.extend([dim for dim in old_dims if dim != "maxbound"]) + block[col_name] = replace(column, shape=f"({', '.join(new_dims)})") + + return block + + @staticmethod + def map_field(dfn: Dfn, field: Field) -> Field: + """ + Convert an input field specification from its representation + in a v1 format definition file to the v2 (structured) format. + + Notes + ----- + If the field does not have a `default` attribute, it will + default to `False` if it is a keyword, otherwise to `None`. + + A filepath field whose name functions as a foreign key + for a separate context will be given a reference to it. + """ + + fields = cast(OMD, dfn.fields) + + def _map_field(_field) -> Field: + field_dict = asdict(_field) + # parse booleans from strings. everything else can + # stay a string except default values, which we'll + # try to parse as arbitrary literals below, and at + # some point types, once we introduce type hinting + field_dict = {k: try_parse_bool(v) for k, v in field_dict.items()} + _name = field_dict.pop("name") + _type = field_dict.pop("type", None) + shape = field_dict.pop("shape", None) + shape = None if shape == "" else shape + block = field_dict.pop("block", None) + default = field_dict.pop("default", None) + default = try_literal_eval(default) if _type != "string" else default + description = field_dict.pop("description", "") + + def _row_field() -> Field: + """Parse a table's record (row) field""" + item_names = _type.split()[1:] + item_types = [ + f.type + for f in fields.values(multi=True) + if f.name in item_names and f.in_record + ] + n_item_names = len(item_names) + if n_item_names < 1: + raise ValueError(f"Missing list definition: {_type}") + + # explicit record or keystring + if n_item_names == 1 and ( + item_types[0].startswith("record") + or item_types[0].startswith("keystring") + ): + return MapV1To2.map_field( + dfn, next(iter(fields.getlist(item_names[0]))) + ) + + # implicit record with all scalar fields + if all(t in SCALAR_TYPES for t in item_types): + children = _record_fields() + return FieldV2.from_dict( + { + **field_dict, + "name": _name, + "type": "record", + "block": block, + "children": children, + "description": description.replace( + "is the list of", "is the record of" + ), + } + ) + + # implicit record with composite fields + children = { + f.name: MapV1To2.map_field(dfn, f) + for f in fields.values(multi=True) + if f.name in item_names and f.in_record + } + first = next(iter(children.values())) + if not first.type: + raise ValueError(f"Missing type for field: {first.name}") + single = len(children) == 1 + item_type = ( + "keystring" if single and "keystring" in first.type else "record" + ) + return FieldV2.from_dict( + { + "name": first.name if single else _name, + "type": item_type, + "block": block, + "children": first.children if single else children, + "description": description.replace( + "is the list of", f"is the {item_type} of" + ), + **field_dict, + } + ) + + def _union_fields() -> Fields: + """Parse a union's fields""" + names = _type.split()[1:] + return { + f.name: MapV1To2.map_field(dfn, f) + for f in fields.values(multi=True) + if f.name in names and f.in_record + } + + def _record_fields() -> Fields: + """Parse a record's fields""" + names = _type.split()[1:] + return { + f.name: _map_field(f) + for f in fields.values(multi=True) + if f.name in names + and f.in_record + and not f.type.startswith("record") + } + + _field = FieldV2.from_dict( + { + "name": _name, + "shape": shape, + "block": block, + "description": description, + "default": default, + **field_dict, + } + ) + + if _type.startswith("recarray"): + child = _row_field() + _field.children = {child.name: child} + _field.type = "recarray" + + elif _type.startswith("keystring"): + _field.children = _union_fields() + _field.type = "keystring" + + elif _type.startswith("record"): + _field.children = _record_fields() + _field.type = "record" + + # for now, we can tell a var is an array if its type + # is scalar and it has a shape. once we have proper + # typing, this can be read off the type itself. + elif shape is not None and _type not in SCALAR_TYPES: + raise TypeError(f"Unsupported array type: {_type}") + + else: + _field.type = _type + + return _field + + return _map_field(field) + + @staticmethod + def map_blocks(dfn: Dfn) -> Blocks: + fields = { + field.name: MapV1To2.map_field(dfn, field) + for field in cast(OMD, dfn.fields).values(multi=True) + if not field.in_record # type: ignore + } + block_dicts = { + block_name: {f.name: f for f in block} + for block_name, block in groupby(fields.values(), lambda f: f.block) + } + blocks = {} + + # Handle period blocks specially + if (period_block := block_dicts.get("period", None)) is not None: + blocks["period"] = MapV1To2.map_period_block(dfn, period_block) + + for block_name, block_data in block_dicts.items(): + if block_name != "period": + blocks[block_name] = block_data + + def remove_attrs(path, key, value): + # remove unneeded variable attributes + if key in ["in_record", "tagged", "preserve_case"]: + return False + return True + + return remap(blocks, visit=remove_attrs) + + def map(self, dfn: Dfn) -> Dfn: + if dfn.schema_version == (v2 := Version("2")): + return dfn + + return Dfn( + name=dfn.name, + advanced=dfn.advanced, + multi=dfn.multi, + ref=dfn.ref, + blocks=MapV1To2.map_blocks(dfn), + schema_version=v2, + parent=dfn.parent, + ) + + +def map( + dfn: Dfn, + schema_version: str | Version = "2", +) -> Dfn: + """Map a MODFLOW 6 specification to another schema version.""" + if dfn.schema_version == schema_version: + return dfn + elif Version(str(schema_version)) == Version("1"): + raise NotImplementedError("Mapping to schema version 1 is not implemented yet.") + elif Version(str(schema_version)) == Version("2"): + return MapV1To2().map(dfn) + raise ValueError(f"Unsupported schema version: {schema_version}. Expected 1 or 2.") + + +def load(f, format: str = "dfn", **kwargs) -> Dfn: + """Load a MODFLOW 6 definition file.""" + if format == "dfn": + name = kwargs.pop("name") + fields, meta = parse_dfn(f, **kwargs) + blocks = { + block_name: {field["name"]: FieldV1.from_dict(field) for field in block} + for block_name, block in groupby( + fields.values(), lambda field: field["block"] + ) + } + return Dfn( + name=name, + schema_version=Version("1"), + parent=try_parse_parent(meta), + advanced=is_advanced_package(meta), + multi=is_multi_package(meta), + blocks=blocks, + ) + + elif format == "toml": + data = tomli.load(f) + + dfn_fields = { + "name": data.pop("name", kwargs.pop("name", None)), + "schema_version": Version(str(data.pop("schema_version", "2"))), + "parent": data.pop("parent", None), + "advanced": data.pop("advanced", False), + "multi": data.pop("multi", False), + "ref": data.pop("ref", None), + } + + if (expected_name := kwargs.pop("name", None)) is not None: + if dfn_fields["name"] != expected_name: + raise ValueError( + f"DFN name mismatch: {expected_name} != {dfn_fields['name']}" + ) + + blocks = {} + for section_name, section_data in data.items(): + if isinstance(section_data, dict): + block_fields = {} + for field_name, field_data in section_data.items(): + if isinstance(field_data, dict): + block_fields[field_name] = FieldV2.from_dict(field_data) + else: + block_fields[field_name] = field_data + blocks[section_name] = block_fields # type: ignore + + dfn_fields["blocks"] = blocks if blocks else None + + return Dfn(**dfn_fields) + + raise ValueError(f"Unsupported format: {format}. Expected 'dfn' or 'toml'.") + + +def _load_common(f) -> Fields: + common, _ = parse_dfn(f) + return common + + +def load_flat(path: str | PathLike) -> Dfns: + """ + Load a flat MODFLOW 6 specification from definition files in a directory. + + Returns a dictionary of unlinked DFNs, i.e. without `children` populated. + Components will have `parent` populated if the schema is v2 but not if v1. + """ + exclude = ["common", "flopy"] + path = Path(path).expanduser().resolve() + dfn_paths = {p.stem: p for p in path.glob("*.dfn") if p.stem not in exclude} + toml_paths = {p.stem: p for p in path.glob("*.toml") if p.stem not in exclude} + dfns = {} + if dfn_paths: + with (path / "common.dfn").open() as f: + common = _load_common(f) + for dfn_name, dfn_path in dfn_paths.items(): + with dfn_path.open() as f: + dfns[dfn_name] = load(f, name=dfn_name, common=common, format="dfn") + if toml_paths: + for toml_name, toml_path in toml_paths.items(): + with toml_path.open("rb") as f: + dfns[toml_name] = load(f, name=toml_name, format="toml") + return dfns + + +def load_tree(path: str | PathLike) -> Dfn: + """ + Load a structured MODFLOW 6 specification from definition files in a directory. + + A single root component definition (the simulation) is returned. This contains + child (and grandchild) components for the relevant models and packages. + """ + return to_tree(load_flat(path)) + + +def to_tree(dfns: Dfns) -> Dfn: + """ + Infer the MODFLOW 6 input component hierarchy from a flat spec: + unlinked DFNs, i.e. without `children` populated, only `parent`. + + Returns the root component. There must be exactly one root, i.e. + component with no `parent`. Composite components have `children` + populated. + + Assumes DFNs are already in v2 schema, just lacking parent-child + links; before calling this function, map them first with `map()`. + """ + + def set_parent(dfn): + dfn = asdict(dfn) + if (dfn_name := dfn["name"]) == "sim-nam": + pass + elif dfn_name.endswith("-nam"): + dfn["parent"] = "sim-nam" + elif ( + dfn_name.startswith("exg-") + or dfn_name.startswith("sln-") + or dfn_name.startswith("utl-") + ): + dfn["parent"] = "sim-nam" + elif "-" in dfn_name: + mdl = dfn_name.split("-")[0] + dfn["parent"] = f"{mdl}-nam" + + return Dfn(**remap(dfn, visit=drop_none_or_empty)) + + dfns = {name: set_parent(dfn) for name, dfn in dfns.items()} + first_dfn = next(iter(dfns.values()), None) + match schema_version := str( + first_dfn.schema_version if first_dfn else Version("1") + ): + case "1": + raise NotImplementedError("Tree inference from v1 schema not implemented") + case "2": + if ( + nroots := len( + roots := { + name: dfn for name, dfn in dfns.items() if dfn.parent is None + } + ) + ) != 1: + raise ValueError(f"Expected one root component, found {nroots}") + + def _build_tree(node_name: str) -> Dfn: + node = dfns[node_name] + children = { + name: dfn for name, dfn in dfns.items() if dfn.parent == node_name + } + if any(children): + node.children = { + name: _build_tree(name) for name in children.keys() + } + return node + + return _build_tree(next(iter(roots.keys()))) + case _: + raise ValueError( + f"Unsupported schema version: {schema_version}. Expected 1 or 2." + ) + + +def to_flat(dfn: Dfn) -> Dfns: + """ + Flatten a MODFLOW 6 input component hierarchy to a flat spec: + unlinked DFNs, i.e. without `children` populated, only `parent`. + + Returns a dictionary of all components in the specification. + """ + + def _flatten(dfn: Dfn) -> Dfns: + dfns = {dfn.name: replace(dfn, children=None)} + for child in (dfn.children or {}).values(): + dfns.update(_flatten(child)) + return dfns + + return _flatten(dfn) diff --git a/modflow_devtools/dfn/fetch.py b/modflow_devtools/dfn/fetch.py new file mode 100644 index 00000000..34cdfa76 --- /dev/null +++ b/modflow_devtools/dfn/fetch.py @@ -0,0 +1,29 @@ +from os import PathLike +from pathlib import Path +from shutil import copytree +from tempfile import TemporaryDirectory + +from modflow_devtools.download import download_and_unzip + + +def fetch_dfns( + owner: str, repo: str, ref: str, outdir: str | PathLike, verbose: bool = False +): + """Fetch definition files from the MODFLOW 6 repository.""" + url = f"https://github.com/{owner}/{repo}/archive/{ref}.zip" + if verbose: + print(f"Downloading MODFLOW 6 repository archive from {url}") + with TemporaryDirectory() as tmp: + dl_path = download_and_unzip(url, Path(tmp), verbose=verbose) + contents = list(dl_path.glob("modflow6-*")) + proj_path = next(iter(contents), None) + if not proj_path: + raise ValueError(f"Missing proj dir in {dl_path}, found {contents}") + if verbose: + print("Copying dfns from download dir to output dir") + copytree( + proj_path / "doc" / "mf6io" / "mf6ivar" / "dfn", outdir, dirs_exist_ok=True + ) + + +get_dfns = fetch_dfns # alias for backward compatibility diff --git a/modflow_devtools/dfn/parse.py b/modflow_devtools/dfn/parse.py new file mode 100644 index 00000000..5716d9cf --- /dev/null +++ b/modflow_devtools/dfn/parse.py @@ -0,0 +1,177 @@ +from ast import literal_eval +from typing import Any +from warnings import warn + +from boltons.dictutils import OMD + + +def field_attr_sort_key(item) -> int: + """ + Sort key for input field attributes. The order is: + -1. block + 0. name + 1. type + 2. shape + 3. default + 4. reader + 5. optional + 6. longname + 7. description + """ + + k, _ = item + if k == "block": + return -1 + if k == "name": + return 0 + if k == "type": + return 1 + if k == "shape": + return 2 + if k == "default": + return 3 + if k == "reader": + return 4 + if k == "optional": + return 5 + if k == "longname": + return 6 + if k == "description": + return 7 + return 8 + + +def try_parse_bool(value: Any) -> Any: + """ + Try to parse a boolean from a string as represented + in a DFN file, otherwise return the value unaltered. + 1. `"true"` -> `True` + 2. `"false"` -> `False` + 3. anything else -> `value` + """ + if isinstance(value, str): + value = value.lower() + if value in ["true", "false"]: + return value == "true" + return value + + +def try_parse_parent(meta: list[str]) -> str | None: + """ + Try to parse a component's parent component name from its metadata. + Return `None` if it has no parent specified. + """ + line = next( + iter(m for m in meta if isinstance(m, str) and m.startswith("parent")), + None, + ) + if not line: + return None + split = line.split() + return split[1] + + +def is_advanced_package(meta: list[str]) -> bool: + """Determine if the component is an advanced package from its metadata.""" + return any("package-type advanced" in m for m in meta) + + +def is_multi_package(meta: list[str]) -> bool: + """Determine if the component is a multi-package from its metadata.""" + return any("multi-package" in m for m in meta) + + +def parse_dfn(f, common: dict | None = None) -> tuple[OMD, list[str]]: + """ + Parse a DFN file into an ordered dict of fields and a list of metadata. + + Parameters + ---------- + f : readable file-like + A file-like object to read the DFN file from. + common : dict, optional + A dictionary of common variable definitions to use for + description substitutions, by default None. + + Returns + ------- + tuple[OMD, list[str]] + A tuple containing an ordered multi-dict of fields and a list of metadata. + + Notes + ----- + A DFN file consists of field definitions (each as a set of attributes) and a + number of comment lines either a) containing metadata about the component or + b) delimiting variables into blocks. This parser reads the file line-by-line + and saves component metadata and field attributes, ignoring block delimiters; + There is a `block` attribute on each field anyway so delimiters are unneeded. + + The returned ordered multi-dict (OMD) maps names to dicts of their attributes, + with duplicate field names allowed. This is important because some DFN files + have fields with the same name defined multiple times for different purposes + (e.g., an `auxiliary` options block keyword, and column in the period block). + + """ + + common = common or {} + field: dict = {} + fields: list = [] + metadata: list = [] + + for line in f: + # parse metadata line + if (line := line.strip()).startswith("#"): + _, sep, tail = line.partition("flopy") + if sep == "flopy": + if ( + "multi-package" in tail + or "solution_package" in tail + or "subpackage" in tail + or "parent" in tail + ): + metadata.append(tail.strip()) + _, sep, tail = line.partition("package-type") + if sep == "package-type": + metadata.append(f"package-type {tail.strip()}") + continue + + # if we hit a newline and the field has attributes, + # we've reached the end of the field. Save it. + if not any(line): + if any(field): + fields.append((field["name"], field)) + field = {} + continue + + # parse field attribute + key, _, value = line.partition(" ") + if key == "default_value": + key = "default" + field[key] = value + + # if this is the description attribute, substitute + # from common variable definitions if needed. drop + # backslashes too, TODO: generate/insert citations. + if key == "description": + descr = value.replace("\\", "").replace("``", "'").replace("''", "'") + _, replace, tail = descr.strip().partition("REPLACE") + if replace: + key, _, subs = tail.strip().partition(" ") + subs = literal_eval(subs) + cmmn = common.get(key, None) + if cmmn is None: + warn( + "Can't substitute description text, " + f"common variable not found: {key}" + ) + else: + descr = cmmn["description"] + if any(subs): + descr = descr.replace("\\", "").replace("{#1}", subs["{#1}"]) # type: ignore + field["description"] = descr + + # Save the last field if needed. + if any(field): + fields.append((field["name"], field)) + + return OMD(fields), metadata diff --git a/modflow_devtools/dfn/schema/block.py b/modflow_devtools/dfn/schema/block.py new file mode 100644 index 00000000..ed0f32af --- /dev/null +++ b/modflow_devtools/dfn/schema/block.py @@ -0,0 +1,22 @@ +from collections.abc import Mapping + +from modflow_devtools.dfn.schema.field import Fields + +Block = Fields +Blocks = Mapping[str, Block] + + +def block_sort_key(item) -> int: + k, _ = item + if k == "options": + return 0 + elif k == "dimensions": + return 1 + elif k == "griddata": + return 2 + elif k == "packagedata": + return 3 + elif "period" in k: + return 4 + else: + return 5 diff --git a/modflow_devtools/dfn/schema/field.py b/modflow_devtools/dfn/schema/field.py new file mode 100644 index 00000000..d48bde68 --- /dev/null +++ b/modflow_devtools/dfn/schema/field.py @@ -0,0 +1,39 @@ +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Literal + +SCALAR_TYPES = ("keyword", "integer", "double precision", "string") + +Fields = Mapping[str, "Field"] + +FieldType = Literal[ + "keyword", + "integer", + "double precision", + "string", + "record", + "recarray", + "keystring", +] + + +Reader = Literal[ + "urword", + "u1ddbl", + "u2ddbl", + "readarray", +] + + +@dataclass(kw_only=True) +class Field: + name: str + type: str | None = None + block: str | None = None + default: Any | None = None + description: str | None = None + children: Fields | None = None + optional: bool | None = None + reader: Reader = "urword" + shape: str | None = None + valid: tuple[str, ...] | None = None diff --git a/modflow_devtools/dfn/schema/ref.py b/modflow_devtools/dfn/schema/ref.py new file mode 100644 index 00000000..e9306a2f --- /dev/null +++ b/modflow_devtools/dfn/schema/ref.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + + +@dataclass +class Ref: + """ + A foreign-key-like reference between a file input variable + in a referring input component and another input component + referenced by it. + """ + + key: str # name of file path field in referring component + tgt: str # name of target component diff --git a/modflow_devtools/dfn/schema/v1.py b/modflow_devtools/dfn/schema/v1.py new file mode 100644 index 00000000..5919881e --- /dev/null +++ b/modflow_devtools/dfn/schema/v1.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +from modflow_devtools.dfn.schema.field import Field + + +@dataclass(kw_only=True) +class FieldV1(Field): + valid: tuple[str, ...] | None = None + tagged: bool | None = None + in_record: bool | None = None + layered: bool | None = None + longname: str | None = None + preserve_case: bool | None = None + numeric_index: bool | None = None + deprecated: bool = False + removed: bool = False + mf6internal: str | None = None + + @classmethod + def from_dict(cls, d: dict) -> "FieldV1": + """Create a FieldV1 instance from a dictionary.""" + keys = list(cls.__annotations__.keys()) + list(Field.__annotations__.keys()) + return cls(**{k: v for k, v in d.items() if k in keys}) diff --git a/modflow_devtools/dfn/schema/v2.py b/modflow_devtools/dfn/schema/v2.py new file mode 100644 index 00000000..e13846cb --- /dev/null +++ b/modflow_devtools/dfn/schema/v2.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Literal + +from modflow_devtools.dfn.schema.field import Field + +FieldType = Literal[ + "keyword", "integer", "double", "string", "array", "record", "union" +] + + +@dataclass(kw_only=True) +class FieldV2(Field): + pass + + @classmethod + def from_dict(cls, d: dict) -> "FieldV2": + """Create a FieldV2 instance from a dictionary.""" + keys = list(cls.__annotations__.keys()) + list(Field.__annotations__.keys()) + return cls(**{k: v for k, v in d.items() if k in keys}) diff --git a/modflow_devtools/dfn2toml.py b/modflow_devtools/dfn2toml.py index 7d346e6d..db7d5eaf 100644 --- a/modflow_devtools/dfn2toml.py +++ b/modflow_devtools/dfn2toml.py @@ -1,74 +1,61 @@ """Convert DFNs to TOML.""" import argparse +from dataclasses import asdict from os import PathLike from pathlib import Path import tomli_w as tomli from boltons.iterutils import remap -from modflow_devtools.dfn import Dfn +from modflow_devtools.dfn import load_flat, map, to_flat, to_tree +from modflow_devtools.dfn.schema.block import block_sort_key +from modflow_devtools.misc import drop_none_or_empty # mypy: ignore-errors -def convert(indir: PathLike, outdir: PathLike): +def convert(indir: PathLike, outdir: PathLike, schema_version: str = "2") -> None: indir = Path(indir).expanduser().absolute() outdir = Path(outdir).expanduser().absolute() outdir.mkdir(exist_ok=True, parents=True) - for dfn in Dfn.load_all(indir).values(): - dfn_name = dfn["name"] - # Determine new filename and parent relationship - if dfn_name == "sim-nam": - filename = "sim.toml" - dfn = dfn.copy() - dfn["name"] = "sim" - # No parent - this is root - elif dfn_name.endswith("-nam"): - # Model name files: gwf-nam -> gwf.toml, parent = "sim" - model_type = dfn_name[:-4] # Remove "-nam" - filename = f"{model_type}.toml" - dfn = dfn.copy() - dfn["name"] = model_type - dfn["parent"] = "sim" - elif dfn_name.startswith("exg-"): - # Exchanges: parent = "sim" - filename = f"{dfn_name}.toml" - dfn = dfn.copy() - dfn["parent"] = "sim" - elif dfn_name.startswith("sln-"): - # Solutions: parent = "sim" - filename = f"{dfn_name}.toml" - dfn = dfn.copy() - dfn["parent"] = "sim" - elif dfn_name.startswith("utl-"): - # Utilities: parent = "sim" - filename = f"{dfn_name}.toml" - dfn = dfn.copy() - dfn["parent"] = "sim" - elif "-" in dfn_name: - # Packages: gwf-dis -> parent = "gwf" - model_type = dfn_name.split("-")[0] - filename = f"{dfn_name}.toml" - dfn = dfn.copy() - dfn["parent"] = model_type - else: - # Default case - filename = f"{dfn_name}.toml" + dfns = { + name: map(dfn, schema_version=schema_version) + for name, dfn in load_flat(indir).items() + } + tree = to_tree(dfns) + flat = to_flat(tree) + for dfn_name, dfn in flat.items(): + with Path.open(outdir / f"{dfn_name}.toml", "wb") as f: + # TODO if we start using c/attrs, swap out + # all this for a custom unstructuring hook + dfn_dict = asdict(dfn) + dfn_dict["schema_version"] = str(dfn_dict["schema_version"]) + if dfn_dict.get("blocks"): + blocks = dfn_dict.pop("blocks") + for block_name, block_fields in blocks.items(): + if block_name not in dfn_dict: + dfn_dict[block_name] = {} + for field_name, field_data in block_fields.items(): + dfn_dict[block_name][field_name] = field_data - with Path.open(outdir / filename, "wb") as f: - - def drop_none_or_empty(path, key, value): - if value is None or value == "" or value == [] or value == {}: - return False - return True - - tomli.dump(remap(dfn, visit=drop_none_or_empty), f) + tomli.dump( + dict( + sorted( + remap(dfn_dict, visit=drop_none_or_empty).items(), + key=block_sort_key, + ) + ), + f, + ) if __name__ == "__main__": - """Convert DFN files to TOML.""" + """ + Convert DFN files in the original format and schema version (1) + to TOML files with a new schema version. + """ parser = argparse.ArgumentParser(description="Convert DFN files to TOML.") parser.add_argument( @@ -82,5 +69,12 @@ def drop_none_or_empty(path, key, value): "-o", help="Output directory.", ) + parser.add_argument( + "--schema-version", + "-s", + type=str, + default="2", + help="Schema version to convert to.", + ) args = parser.parse_args() - convert(args.indir, args.outdir) + convert(args.indir, args.outdir, args.schema_version) diff --git a/modflow_devtools/misc.py b/modflow_devtools/misc.py index 65bf614b..38d2e25a 100644 --- a/modflow_devtools/misc.py +++ b/modflow_devtools/misc.py @@ -4,6 +4,7 @@ import traceback from _warnings import warn from ast import literal_eval +from collections.abc import Iterable from contextlib import contextmanager from enum import Enum from functools import wraps @@ -581,3 +582,13 @@ def try_literal_eval(value: str) -> Any: return literal_eval(value) except (SyntaxError, ValueError): return value + + +def drop_none_or_empty(path, key, value): + """ + Drop dictionary items with None or empty values. + For use with `boltons.iterutils.remap`. + """ + if value is None or (isinstance(value, Iterable) and not any(value)): + return False + return True diff --git a/modflow_devtools/models.py b/modflow_devtools/models.py index c2fc9b6e..1cafc9d7 100644 --- a/modflow_devtools/models.py +++ b/modflow_devtools/models.py @@ -21,13 +21,7 @@ from pooch import Pooch import modflow_devtools -from modflow_devtools.misc import get_model_paths - - -def _drop_none_or_empty(path, key, value): - if value is None or value == "": - return False - return True +from modflow_devtools.misc import drop_none_or_empty, get_model_paths def _model_sort_key(k) -> int: @@ -421,7 +415,7 @@ def index( with self._registry_file_path.open("ab+") as registry_file: tomli_w.dump( - remap(dict(sorted(files.items())), visit=_drop_none_or_empty), + remap(dict(sorted(files.items())), visit=drop_none_or_empty), registry_file, )