Skip to content
150 changes: 150 additions & 0 deletions dpdata/formats/xyz/_unit_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Unit conversion helpers for extended XYZ (extxyz) format.

This module provides a table-driven approach to convert energy, force, and
stress/pressure values from various unit systems commonly found in extxyz
files into dpdata's internal units:

- Energy: eV
- Force: eV/angstrom
- Stress/Pressure: eV/angstrom^3 (before volume multiplication to get virial)
"""

from __future__ import annotations

from dpdata.unit import EnergyConversion, ForceConversion, PressureConversion

# ---------------------------------------------------------------------------
# Unit alias mapping tables
# Keys are LOWERCASE; values are canonical names recognized by dpdata.unit
# ---------------------------------------------------------------------------

_ENERGY_UNIT_MAP: dict[str, str] = {
"ev": "eV",
"hartree": "hartree",
"ha": "hartree",
"ry": "rydberg",
"rydberg": "rydberg",
"kcal/mol": "kcal_mol",
"kcal_mol": "kcal_mol",
"kj/mol": "kJ_mol",
"kj_mol": "kJ_mol",
}

_LENGTH_UNIT_MAP: dict[str, str] = {
"angstrom": "angstrom",
"ang": "angstrom",
"ang.": "angstrom",
"a": "angstrom",
"bohr": "bohr",
"nm": "nm",
}

_PRESSURE_UNIT_MAP: dict[str, str] = {
"gpa": "GPa",
"kbar": "kbar",
"bar": "bar",
"ev/angstrom^3": "eV/angstrom^3",
"ev/ang^3": "eV/angstrom^3",
"ev/a^3": "eV/angstrom^3",
"ha/bohr^3": "hartree/bohr^3",
"hartree/bohr^3": "hartree/bohr^3",
}

# dpdata internal unit strings
_INTERNAL_ENERGY = "eV"
_INTERNAL_FORCE = "eV/angstrom"
_INTERNAL_PRESSURE = "eV/angstrom^3"


def _parse_force_unit(raw: str) -> tuple[str, str]:
"""Split a composite force unit string into (energy_part, length_part).

Examples
--------
>>> _parse_force_unit("kcal/mol/angstrom")
('kcal/mol', 'angstrom')
>>> _parse_force_unit("hartree/bohr")
('hartree', 'bohr')
>>> _parse_force_unit("ev/ang")
('ev', 'ang')
"""
# Try matching known energy prefixes (longest first) to handle
# composite names like "kcal/mol" that themselves contain "/".
for e_key in sorted(_ENERGY_UNIT_MAP.keys(), key=len, reverse=True):
prefix = e_key + "/"
if raw.startswith(prefix):
l_part = raw[len(prefix) :]
if l_part:
return e_key, l_part
# Fallback: split on last "/"
parts = raw.rsplit("/", 1)
if len(parts) == 2 and parts[0] and parts[1]:
return parts[0], parts[1]
raise ValueError(f"Cannot parse force unit string: '{raw}'")


def _get_unit_factor(unit_str: str | None, quantity: str) -> float:
"""Return the multiplicative factor to convert from the given unit to dpdata internals.

Parameters
----------
unit_str : str or None
The unit string read from the extxyz header (e.g. "hartree", "kcal/mol/angstrom").
If None, returns 1.0 (assumes data is already in internal units).
quantity : str
One of "energy", "force", or "stress".

Returns
-------
float
Conversion factor such that ``value_internal = value_file * factor``.

Raises
------
ValueError
If the unit string is not recognized or the quantity type is invalid.
"""
if unit_str is None:
return 1.0

key = unit_str.lower().strip()

if quantity == "energy":
canonical = _ENERGY_UNIT_MAP.get(key)
if canonical is None:
raise ValueError(
f"Unsupported energy unit: '{unit_str}'. "
f"Supported: {list(_ENERGY_UNIT_MAP.keys())}"
)
return EnergyConversion(canonical, _INTERNAL_ENERGY).value()

elif quantity == "force":
e_part, l_part = _parse_force_unit(key)
e_canonical = _ENERGY_UNIT_MAP.get(e_part)
l_canonical = _LENGTH_UNIT_MAP.get(l_part)
if e_canonical is None:
raise ValueError(
f"Unsupported energy part in force unit: '{e_part}' "
f"(from '{unit_str}'). Supported: {list(_ENERGY_UNIT_MAP.keys())}"
)
if l_canonical is None:
raise ValueError(
f"Unsupported length part in force unit: '{l_part}' "
f"(from '{unit_str}'). Supported: {list(_LENGTH_UNIT_MAP.keys())}"
)
src_unit = f"{e_canonical}/{l_canonical}"
return ForceConversion(src_unit, _INTERNAL_FORCE).value()

elif quantity == "stress":
canonical = _PRESSURE_UNIT_MAP.get(key)
if canonical is None:
raise ValueError(
f"Unsupported stress/pressure unit: '{unit_str}'. "
f"Supported: {list(_PRESSURE_UNIT_MAP.keys())}"
)
return PressureConversion(canonical, _INTERNAL_PRESSURE).value()

else:
raise ValueError(
f"Unknown quantity type: '{quantity}'. Must be 'energy', 'force', or 'stress'."
)
28 changes: 23 additions & 5 deletions dpdata/formats/xyz/quip_gap_xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np

from dpdata.formats.xyz._unit_convert import _get_unit_factor
from dpdata.periodic_table import Element

# Possible keys for the energy field in the extxyz comment line,
Expand All @@ -24,7 +25,7 @@
_STRESS_KEYS = ("stress", "stresses")


def _parse_stress_to_virials(stress_str, cell, stress_sign=-1):
def _parse_stress_to_virials(stress_str, cell, stress_sign=-1, stress_factor=1.0):
"""Convert a stress field string to virial tensor.

Parameters
Expand All @@ -38,6 +39,10 @@ def _parse_stress_to_virials(stress_str, cell, stress_sign=-1):
Sign convention for ``virial = stress_sign * volume * stress``.
Default ``-1`` follows the ASE convention where
``virial = -V * stress`` (stress in eV/angstrom^3).
stress_factor : float
Multiplicative factor to convert stress from file units to
eV/angstrom^3 before computing the virial. Default 1.0
(assumes stress is already in eV/angstrom^3).

Returns
-------
Expand All @@ -57,7 +62,7 @@ def _parse_stress_to_virials(stress_str, cell, stress_sign=-1):
f"stress field must have 6 (Voigt) or 9 (3x3) values, got {len(vals)}"
)
volume = abs(np.linalg.det(cell))
virials = stress_sign * volume * stress
virials = stress_sign * volume * stress * stress_factor
return np.array([virials])


Expand Down Expand Up @@ -249,13 +254,26 @@ def handle_single_xyz_frame(lines, stress_sign=-1, **kwargs):
stress_raw = field_dict[skey]
break

# --- unit conversion factors ---
# Read optional unit metadata from extxyz header.
# When absent, factor is 1.0 (assumes dpdata internal units).
e_unit_str = field_dict.get("energy-unit") or field_dict.get("energy_unit")
f_unit_str = field_dict.get("force-unit") or field_dict.get("force_unit")
s_unit_str = field_dict.get("stress-unit") or field_dict.get("stress_unit")

e_factor = _get_unit_factor(e_unit_str, "energy")
f_factor = _get_unit_factor(f_unit_str, "force")
s_factor = _get_unit_factor(s_unit_str, "stress")

if virial_raw is not None:
virials = np.array(
[np.array(list(filter(bool, virial_raw.split(" ")))).reshape(3, 3)]
).astype(np.float64)
# Note: virial values are assumed in eV (no unit header for virial itself;
# if stress-unit is given it only applies to stress fields).
elif stress_raw is not None:
virials = _parse_stress_to_virials(
stress_raw, cells, stress_sign=stress_sign
stress_raw, cells, stress_sign=stress_sign, stress_factor=s_factor
)

# --- energy (try several common keys) ---
Expand All @@ -275,8 +293,8 @@ def handle_single_xyz_frame(lines, stress_sign=-1, **kwargs):
info_dict["atom_numbs"] = list(type_num_array[:, 1].astype(int))
info_dict["atom_types"] = np.array(atom_type_list).astype(int)
info_dict["coords"] = np.array([coords_array]).astype(np.float64)
info_dict["energies"] = np.array([energy_value]).astype(np.float64)
info_dict["forces"] = np.array([force_array]).astype(np.float64)
info_dict["energies"] = np.array([energy_value], dtype=np.float64) * e_factor
info_dict["forces"] = np.array([force_array], dtype=np.float64) * f_factor
if virials is not None:
info_dict["virials"] = virials
info_dict["orig"] = np.zeros(3)
Expand Down
Loading
Loading