|
1 | | -from os import getcwd |
| 1 | +from contextlib import contextmanager |
| 2 | +from importlib.resources import as_file, files |
2 | 3 | from pathlib import Path |
3 | 4 | from tomllib import load |
4 | | -from typing import Any |
| 5 | +from typing import Any, Iterator |
5 | 6 |
|
6 | | -ROOT_DIR = Path(getcwd()).resolve() |
7 | 7 |
|
8 | | - |
9 | | -def load_value_from_toml( |
10 | | - keys: list[str], file_path: Path = ROOT_DIR / "pyproject.toml", default: Any = None |
11 | | -) -> Any: |
| 8 | +@contextmanager |
| 9 | +def find_pyproject_toml() -> Iterator[Path]: |
12 | 10 | """ |
13 | | - Load a nested value from a TOML file. |
| 11 | + Finds the pyproject.toml file and yields its Path. |
| 12 | + Uses a context manager to handle temporary files created by as_file(). |
| 13 | + """ |
| 14 | + try: |
| 15 | + # Get the Traversable object for the pyproject.toml file. |
| 16 | + resource_path = files(__package__).joinpath("pyproject.toml") |
| 17 | + if resource_path.is_file(): |
| 18 | + # Use as_file() to get a Path object to the resource. |
| 19 | + # This is done within a context manager for proper cleanup. |
| 20 | + with as_file(resource_path) as file_path: |
| 21 | + yield file_path |
| 22 | + return |
| 23 | + except (ImportError, TypeError, AttributeError): |
| 24 | + # Fallback for non-packaged or single-script scenarios. |
| 25 | + pass |
14 | 26 |
|
15 | | - Args: |
16 | | - keys: List of nested keys to traverse. |
17 | | - file_path: Path to the TOML file. |
18 | | - default: Value to return if keys not found. |
| 27 | + # Fallback logic |
| 28 | + current_dir = Path(__file__).resolve() |
| 29 | + for parent in current_dir.parents: |
| 30 | + potential_path = parent / "pyproject.toml" |
| 31 | + if potential_path.is_file(): |
| 32 | + yield potential_path |
| 33 | + return |
19 | 34 |
|
20 | | - Returns: |
21 | | - The value from the TOML file. |
| 35 | + raise FileNotFoundError("Could not find pyproject.toml") |
22 | 36 |
|
23 | | - Raises: |
24 | | - FileNotFoundError: If the file doesn't exist and no default is provided. |
25 | | - ValueError: If the keys are missing and no default is provided. |
26 | | - """ |
27 | | - if not file_path.exists(): |
28 | | - if default is not None: |
29 | | - return default |
30 | | - raise FileNotFoundError(f"{file_path} not found") |
31 | 37 |
|
32 | | - try: |
| 38 | +def load_value_from_toml(keys: list[str], default: Any = None) -> Any: |
| 39 | + """ |
| 40 | + Load a nested value from a TOML file. |
| 41 | + """ |
| 42 | + with find_pyproject_toml() as file_path: |
33 | 43 | with file_path.open("rb") as f: |
34 | 44 | data = load(f) |
35 | | - for key in keys: |
36 | | - data = data[key] |
37 | | - return data |
38 | | - except KeyError: |
39 | | - if default is not None: |
40 | | - return default |
41 | | - raise ValueError(f"Keys {'.'.join(keys)} not found in {file_path}") |
42 | | - except Exception as e: |
43 | | - raise ValueError(f"Error reading {file_path}: {e}") from e |
| 45 | + |
| 46 | + try: |
| 47 | + for key in keys: |
| 48 | + data = data[key] |
| 49 | + return data |
| 50 | + except KeyError: |
| 51 | + if default is not None: |
| 52 | + return default |
| 53 | + raise ValueError(f"Keys {'.'.join(keys)} not found in {file_path}") |
| 54 | + except Exception as e: |
| 55 | + raise ValueError(f"Error reading {file_path}: {e}") from e |
0 commit comments