Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 107 additions & 55 deletions modflow_devtools/dfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,10 @@
from boltons.iterutils import remap

from modflow_devtools.download import download_and_unzip
from modflow_devtools.misc import try_literal_eval

# TODO: use dataclasses instead of typed dicts? static
# methods on typed dicts are evidently not allowed
# mypy: ignore-errors


def _try_literal_eval(value: str) -> Any:
"""
Try to parse a string as a literal. If this fails,
return the value unaltered.
"""
try:
return literal_eval(value)
except (SyntaxError, ValueError):
return value


def _try_parse_bool(value: Any) -> Any:
def try_parse_bool(value: Any) -> Any:
"""
Try to parse a boolean from a string as represented
in a DFN file, otherwise return the value unaltered.
Expand All @@ -54,7 +40,7 @@ def _try_parse_bool(value: Any) -> Any:
return value


def _field_attr_sort_key(item) -> int:
def field_attr_sort_key(item) -> int:
"""
Sort key for input field attributes. The order is:
-1. block
Expand Down Expand Up @@ -90,6 +76,22 @@ def _field_attr_sort_key(item) -> int:
return 8


def block_sort_key(item) -> int:
k, _ = item
if k == "options":
return 0
elif k == "dimensions":
return 1
elif k == "griddata":
return 2
elif k == "packagedata":
return 3
elif "period" in k:
return 4
else:
return 5


FormatVersion = Literal[1, 2]
"""DFN format version number."""

Expand All @@ -113,24 +115,39 @@ def _field_attr_sort_key(item) -> int:
]


_SCALAR_TYPES = FieldType.__args__[:4]
_SCALAR_TYPES = ("keyword", "integer", "double precision", "string")


Dfns = dict[str, "Dfn"]
Fields = dict[str, "Field"]
Block = Fields
Blocks = dict[str, Block]


def get_blocks(dfn: "Dfn") -> Blocks:
"""
Extract blocks from an input definition. Any entry whose key
is not explicitly defined in `Dfn` is a block.
"""
return dict(
sorted(
{k: v for k, v in dfn.items() if k not in Dfn.__annotations__}.items(), # type: ignore
key=block_sort_key,
)
)


class Field(TypedDict):
"""A field specification."""

name: str
type: FieldType
shape: Any | None = None
block: str | None = None
default: Any | None = None
children: Optional["Fields"] = None
description: str | None = None
reader: Reader = "urword"
shape: Any | None
block: str | None
default: Any | None
children: Optional["Fields"]
description: str | None
reader: Reader


class Ref(TypedDict):
Expand Down Expand Up @@ -188,14 +205,14 @@ class Dfn(TypedDict):
"""

name: str
advanced: bool = False
multi: bool = False
parent: str | None = None
ref: Ref | None = None
sln: Sln | None = None
fkeys: Dfns | None = None

@staticmethod
advanced: bool
multi: bool
parent: str | None
ref: Ref | None
sln: Sln | None
fkeys: Dfns | None

@staticmethod # type: ignore[misc]
def _load_v1_flat(f, common: dict | None = None) -> tuple[Mapping, list[str]]:
field = {}
flat = []
Expand Down Expand Up @@ -269,7 +286,7 @@ def _load_v1_flat(f, common: dict | None = None) -> tuple[Mapping, list[str]]:
# the point of the OMD is to losslessly handle duplicate variable names
return OMD(flat), meta

@classmethod
@classmethod # type: ignore[misc]
def _load_v1(cls, f, name, **kwargs) -> "Dfn":
"""
Temporary load routine for the v1 DFN format.
Expand All @@ -279,6 +296,41 @@ def _load_v1(cls, f, name, **kwargs) -> "Dfn":
refs = kwargs.pop("refs", {})
flat, meta = Dfn._load_v1_flat(f, **kwargs)

def _convert_period_block(block: Block) -> Block:
"""
Convert a period block recarray to individual arrays, one per column.

Extracts recarray fields and creates separate array variables. Gives
each an appropriate grid- or tdis-aligned shape as opposed to sparse
list shape in terms of maxbound as previously.
"""

fields = list(block.values())
if fields[0]["type"] == "recarray":
assert len(fields) == 1
recarray_name = fields[0]["name"]
item = next(iter(fields[0]["children"].values()))
columns = item["children"]
else:
recarray_name = None
columns = block
block.pop(recarray_name, None)
cellid = columns.pop("cellid", None)
for col_name, column in columns.items():
col_copy = column.copy()
old_dims = col_copy.get("shape")
if old_dims:
old_dims = old_dims[1:-1].split(",")
new_dims = ["nper"]
if cellid:
new_dims.append("nnodes")
if old_dims:
new_dims.extend([dim for dim in old_dims if dim != "maxbound"])
col_copy["shape"] = f"({', '.join(new_dims)})"
block[col_name] = col_copy

return block

def _convert_field(var: dict[str, Any]) -> Field:
"""
Convert an input field specification from its representation
Expand All @@ -300,15 +352,15 @@ def _load(field) -> Field:
# stay a string except default values, which we'll
# try to parse as arbitrary literals below, and at
# some point types, once we introduce type hinting
field = {k: _try_parse_bool(v) for k, v in field.items()}
field = {k: try_parse_bool(v) for k, v in field.items()}

_name = field.pop("name")
_type = field.pop("type", None)
shape = field.pop("shape", None)
shape = None if shape == "" else shape
block = field.pop("block", None)
default = field.pop("default", None)
default = _try_literal_eval(default) if _type != "string" else default
default = try_literal_eval(default) if _type != "string" else default
description = field.pop("description", "")
reader = field.pop("reader", "urword")
ref = refs.get(_name, None)
Expand Down Expand Up @@ -343,7 +395,7 @@ def _item() -> Field:
name=_name,
type="record",
block=block,
fields=_fields(),
children=_fields(),
description=description.replace(
"is the list of", "is the record of"
),
Expand All @@ -368,7 +420,7 @@ def _item() -> Field:
name=first["name"] if single else _name,
type=item_type,
block=block,
fields=first["fields"] if single else fields,
children=first["children"] if single else fields,
description=description.replace(
"is the list of", f"is the {item_type} of"
),
Expand Down Expand Up @@ -411,15 +463,16 @@ def _fields() -> Fields:
)

if _type.startswith("recarray"):
var_["item"] = _item()
item = _item()
var_["children"] = {item["name"]: item}
var_["type"] = "recarray"

elif _type.startswith("keystring"):
var_["choices"] = _choices()
var_["children"] = _choices()
var_["type"] = "keystring"

elif _type.startswith("record"):
var_["fields"] = _fields()
var_["children"] = _fields()
var_["type"] = "record"

# for now, we can tell a var is an array if its type
Expand Down Expand Up @@ -453,7 +506,7 @@ def _fields() -> Fields:

return var_

return dict(sorted(_load(var).items(), key=_field_attr_sort_key))
return dict(sorted(_load(var).items(), key=field_attr_sort_key))

# load top-level fields. any nested
# fields will be loaded recursively
Expand All @@ -469,11 +522,10 @@ def _fields() -> Fields:
for block_name, block in groupby(fields.values(), lambda v: v["block"])
}

# mark transient blocks
transient_index_vars = flat.getlist("iper")
for transient_index in transient_index_vars:
transient_block = transient_index["block"]
blocks[transient_block]["transient_block"] = True
# if there's a period block, extract distinct arrays from
# the recarray-style definition
if (period_block := blocks.get("period", None)) is not None:
blocks["period"] = _convert_period_block(period_block)

# remove unneeded variable attributes
def remove_attrs(path, key, value):
Expand Down Expand Up @@ -559,14 +611,14 @@ def _rest():
**blocks,
)

@classmethod
@classmethod # type: ignore[misc]
def _load_v2(cls, f, name) -> "Dfn":
data = tomli.load(f)
if name and name != data.get("name", None):
raise ValueError(f"Name mismatch, expected {name}")
return cls(**data)

@classmethod
@classmethod # type: ignore[misc]
def load(
cls,
f,
Expand All @@ -585,7 +637,7 @@ def load(
else:
raise ValueError(f"Unsupported version, expected one of {version.__args__}")

@staticmethod
@staticmethod # type: ignore[misc]
def _load_all_v1(dfndir: PathLike) -> Dfns:
paths: list[Path] = [
p for p in dfndir.glob("*.dfn") if p.stem not in ["common", "flopy"]
Expand Down Expand Up @@ -617,7 +669,7 @@ def _load_all_v1(dfndir: PathLike) -> Dfns:

return dfns

@staticmethod
@staticmethod # type: ignore[misc]
def _load_all_v2(dfndir: PathLike) -> Dfns:
paths: list[Path] = [
p for p in dfndir.glob("*.toml") if p.stem not in ["common", "flopy"]
Expand All @@ -630,7 +682,7 @@ def _load_all_v2(dfndir: PathLike) -> Dfns:

return dfns

@staticmethod
@staticmethod # type: ignore[misc]
def load_all(dfndir: PathLike, version: FormatVersion = 1) -> Dfns:
"""Load all component definitions from the given directory."""
if version == 1:
Expand All @@ -640,7 +692,7 @@ def load_all(dfndir: PathLike, version: FormatVersion = 1) -> Dfns:
else:
raise ValueError(f"Unsupported version, expected one of {version.__args__}")

@staticmethod
@staticmethod # type: ignore[misc]
def load_tree(dfndir: PathLike, version: FormatVersion = 2) -> dict:
"""Load all definitions and return as hierarchical tree."""
dfns = Dfn.load_all(dfndir, version)
Expand All @@ -664,8 +716,8 @@ def infer_tree(dfns: dict[str, Dfn]) -> dict:
if root_name != "sim":
raise ValueError(f"Root component must be named 'sim', found '{root_name}'")

def add_children(node_name: str) -> dict:
node = dfns[node_name].copy()
def add_children(node_name: str) -> dict[str, Any]:
node = dict(dfns[node_name])
children = [
name for name, dfn in dfns.items() if dfn.get("parent") == node_name
]
Expand All @@ -684,7 +736,7 @@ def get_dfns(
if verbose:
print(f"Downloading MODFLOW 6 repository from {url}")
with tempfile.TemporaryDirectory() as tmp:
dl_path = download_and_unzip(url, tmp, verbose=verbose)
dl_path = download_and_unzip(url, Path(tmp), verbose=verbose)
contents = list(dl_path.glob("modflow6-*"))
proj_path = next(iter(contents), None)
if not proj_path:
Expand Down
16 changes: 16 additions & 0 deletions modflow_devtools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,19 @@ def try_get_enum_value(v: Any) -> Any:
of an enumeration, otherwise return it unaltered.
"""
return v.value if isinstance(v, Enum) else v


# TODO: use dataclasses instead of typed dicts? static
# methods on typed dicts are evidently not allowed
# mypy: ignore-errors


def try_literal_eval(value: str) -> Any:
"""
Try to parse a string as a literal. If this fails,
return the value unaltered.
"""
try:
return literal_eval(value)
except (SyntaxError, ValueError):
return value