diff --git a/CHANGELOG.md b/CHANGELOG.md index 4be1a16f7..9df2ee03c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/). Attention: The newest changes should be on top --> ### Added + +- ENH: Add persistent caching for ThrustCurve API (#881) + ENH: Compatibility with MERRA-2 atmosphere reanalysis files [#825](https://github.com/RocketPy-Team/RocketPy/pull/825) - ENH: Enable only radial burning [#815](https://github.com/RocketPy-Team/RocketPy/pull/815) diff --git a/docs/user/motors/genericmotor.rst b/docs/user/motors/genericmotor.rst index 8c5b40703..0706ffcba 100644 --- a/docs/user/motors/genericmotor.rst +++ b/docs/user/motors/genericmotor.rst @@ -109,17 +109,29 @@ note that the user can still provide the parameters manually if needed. The ``load_from_thrustcurve_api`` method ---------------------------------------- -The ``GenericMotor`` class provides a convenience loader that downloads a temporary +The ``GenericMotor`` class provides a convenience loader that downloads an `.eng` file from the ThrustCurve.org public API and builds a ``GenericMotor`` instance from it. This is useful when you know a motor designation (for example -``"M1670"``) but do not want to manually download and -save the `.eng` file. +``"M1670"``) but do not want to manually download and save the `.eng` file. + +The method also includes automatic caching for faster repeated usage. +Downloaded `.eng` files are stored in the user's RocketPy cache folder +(``~/.rocketpy_cache``). When a subsequent request is made for the same motor, +the cached copy is used instead of performing another network request. + +You can bypass the cache by setting ``no_cache=True``: + +- ``no_cache=False`` (default): + Use a cached file if available; otherwise download and store it. + +- ``no_cache=True``: + Always fetch a fresh version from the API and overwrite the cache. .. note:: - This method performs network requests to the ThrustCurve API. Use it only - when you have network access. For automated testing or reproducible runs, - prefer using local `.eng` files. + This method performs network requests to the ThrustCurve API unless a cached + version exists. For automated testing or fully reproducible workflows, prefer + local `.eng` files or set ``no_cache=True`` explicitly. Example ------- @@ -128,8 +140,19 @@ Example from rocketpy.motors import GenericMotor - # Build a motor by name (requires network access) + # Build a motor by name (requires network access unless cached) motor = GenericMotor.load_from_thrustcurve_api("M1670") - # Use the motor as usual + # Print the motor information + motor.info() + +Using the no_cache option +------------------------- + +If you want to force RocketPy to ignore the cache and download a fresh copy +every time, use: + +.. jupyter-execute:: + + motor = GenericMotor.load_from_thrustcurve_api("M1670", no_cache=True) motor.info() diff --git a/rocketpy/motors/motor.py b/rocketpy/motors/motor.py index c81c713d4..733653712 100644 --- a/rocketpy/motors/motor.py +++ b/rocketpy/motors/motor.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from functools import cached_property from os import path, remove +from pathlib import Path import numpy as np import requests @@ -15,8 +16,11 @@ from ..prints.motor_prints import _MotorPrints from ..tools import parallel_axis_theorem_from_com, tuple_handler - # pylint: disable=too-many-public-methods +# ThrustCurve API cache +CACHE_DIR = Path.home() / ".rocketpy_cache" + + class Motor(ABC): """Abstract class to specify characteristics and useful operations for motors. Cannot be instantiated. @@ -1918,7 +1922,7 @@ def load_from_rse_file( ) @staticmethod - def _call_thrustcurve_api(name: str): + def _call_thrustcurve_api(name: str, no_cache: bool = False): # pylint: disable=too-many-statements """ Download a .eng file from the ThrustCurve API based on the given motor name. @@ -1929,6 +1933,8 @@ def _call_thrustcurve_api(name: str): The motor name according to the API (e.g., "Cesaroni_M1670" or "M1670"). Both manufacturer-prefixed and shorthand names are commonly used; if multiple motors match the search, the first result is used. + no_cache : bool, optional + If True, forces a new API fetch even if the motor is cached. Returns ------- @@ -1941,9 +1947,31 @@ def _call_thrustcurve_api(name: str): If no motor is found or if the downloaded .eng data is missing. requests.exceptions.RequestException If a network or HTTP error occurs during the API call. + + Notes + ----- + - The cache prevents multiple network requests for the same motor name across sessions. + - Cached files are stored in `~/.rocketpy_cache` and reused unless `no_cache=True`. + - Filenames are sanitized to avoid invalid characters. """ - base_url = "https://www.thrustcurve.org/api/v1" + try: + CACHE_DIR.mkdir(exist_ok=True) + except OSError as e: + warnings.warn(f"Could not create cache directory: {e}. Caching disabled.") + no_cache = True + # File path in the cache + safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", name) + cache_file = CACHE_DIR / f"{safe_name}.eng.b64" + if not no_cache and cache_file.exists(): + try: + return cache_file.read_text() + except (OSError, UnicodeDecodeError) as e: + warnings.warn( + f"Failed to read cached motor file '{cache_file}': {e}. " + "Fetching fresh data from API." + ) + base_url = "https://www.thrustcurve.org/api/v1" # Step 1. Search motor response = requests.get(f"{base_url}/search.json", params={"commonName": name}) response.raise_for_status() @@ -1979,10 +2007,20 @@ def _call_thrustcurve_api(name: str): raise ValueError( f"Downloaded .eng data for motor '{name}' is empty or invalid." ) + if not no_cache: + try: + cache_file.write_text(data_base64) + except (OSError, PermissionError) as e: + warnings.warn( + f"Could not write to cache file '{cache_file}': {e}. " + "Continuing without caching.", + RuntimeWarning, + ) + return data_base64 @staticmethod - def load_from_thrustcurve_api(name: str, **kwargs): + def load_from_thrustcurve_api(name: str, no_cache: bool = False, **kwargs): """ Creates a Motor instance by downloading a .eng file from the ThrustCurve API based on the given motor name. @@ -2010,7 +2048,7 @@ def load_from_thrustcurve_api(name: str, **kwargs): If a network or HTTP error occurs during the API call. """ - data_base64 = GenericMotor._call_thrustcurve_api(name) + data_base64 = GenericMotor._call_thrustcurve_api(name, no_cache=no_cache) data_bytes = base64.b64decode(data_base64) # Step 3. Create the motor from the .eng file diff --git a/tests/unit/motors/test_genericmotor.py b/tests/unit/motors/test_genericmotor.py index 3d0fbd766..7387189fb 100644 --- a/tests/unit/motors/test_genericmotor.py +++ b/tests/unit/motors/test_genericmotor.py @@ -1,4 +1,5 @@ import base64 +import pathlib import numpy as np import pytest @@ -6,6 +7,7 @@ import scipy.integrate from rocketpy import Function, Motor +from rocketpy.motors.motor import GenericMotor BURN_TIME = (2, 7) @@ -333,3 +335,108 @@ def test_load_from_thrustcurve_api(monkeypatch, generic_motor): ) with pytest.raises(ValueError, match=msg): type(generic_motor).load_from_thrustcurve_api("FakeMotor") + + +def test_thrustcurve_api_cache(monkeypatch, tmp_path): + """Tests that ThrustCurve API caching works correctly.""" + + eng_path = "data/motors/cesaroni/Cesaroni_M1670.eng" + with open(eng_path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("utf-8") + + search_json = {"results": [{"motorId": "12345"}]} + download_json = {"results": [{"data": encoded}]} + + # Patch requests.get to return mocked API responses + monkeypatch.setattr(requests, "get", _mock_get(search_json, download_json)) + + # Patch the module-level CACHE_DIR to use the tmp_path + monkeypatch.setattr("rocketpy.motors.motor.CACHE_DIR", tmp_path) + + # First call writes to cache + motor1 = GenericMotor.load_from_thrustcurve_api("M1670") + cache_file = tmp_path / "M1670.eng.b64" + assert cache_file.exists() + + # Second call reads from cache; API should not be called + monkeypatch.setattr( + requests, + "get", + lambda *args, **kwargs: (_ for _ in ()).throw( + RuntimeError("API should not be called") + ), + ) + motor2 = GenericMotor.load_from_thrustcurve_api("M1670") + assert motor2.thrust.y_array == pytest.approx(motor1.thrust.y_array) + + # Bypass cache with no_cache=True + monkeypatch.setattr(requests, "get", _mock_get(search_json, download_json)) + motor3 = GenericMotor.load_from_thrustcurve_api("M1670", no_cache=True) + assert motor3.thrust.y_array == pytest.approx(motor1.thrust.y_array) + + +def test_thrustcurve_api_cache_robustness(monkeypatch, tmp_path): # pylint: disable=too-many-statements + """ + Tests exception handling for cache operations to ensure 100% coverage. + Simulates OS errors for mkdir, write, and read operations. + """ + + # 1. Setup Mock API to return success + eng_path = "data/motors/cesaroni/Cesaroni_M1670.eng" + with open(eng_path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("utf-8") + + search_json = {"results": [{"motorId": "12345"}]} + download_json = {"results": [{"data": encoded}]} + monkeypatch.setattr(requests, "get", _mock_get(search_json, download_json)) + + # Point cache to tmp_path so we don't mess with real home + monkeypatch.setattr("rocketpy.motors.motor.CACHE_DIR", tmp_path) + + # CASE 1: mkdir fails -> should warn and continue (disable caching) + original_mkdir = pathlib.Path.mkdir + + def mock_mkdir_fail(self, *args, **kwargs): + if self == tmp_path: + raise OSError("Simulated mkdir error") + return original_mkdir(self, *args, **kwargs) + + monkeypatch.setattr(pathlib.Path, "mkdir", mock_mkdir_fail) + + with pytest.warns(UserWarning, match="Could not create cache directory"): + GenericMotor.load_from_thrustcurve_api("M1670") + + # Reset mkdir logic for next test + monkeypatch.setattr(pathlib.Path, "mkdir", original_mkdir) + + # CASE 2: write_text fails -> should warn and continue + original_write = pathlib.Path.write_text + + def mock_write_fail(self, *args, **kwargs): + if "M1670.eng.b64" in str(self): + raise OSError("Simulated write error") + return original_write(self, *args, **kwargs) + + monkeypatch.setattr(pathlib.Path, "write_text", mock_write_fail) + + with pytest.warns(RuntimeWarning, match="Could not write to cache file"): + GenericMotor.load_from_thrustcurve_api("M1670") + + # Reset write logic + monkeypatch.setattr(pathlib.Path, "write_text", original_write) + + # CASE 3: read_text fails (corrupt file) -> should warn and fetch fresh + cache_file = tmp_path / "M1670.eng.b64" + cache_file.write_text("corrupted_data") + + original_read = pathlib.Path.read_text + + def mock_read_fail(self, *args, **kwargs): + if self == cache_file: + raise UnicodeDecodeError("utf-8", b"", 0, 1, "bad") + return original_read(self, *args, **kwargs) + + monkeypatch.setattr(pathlib.Path, "read_text", mock_read_fail) + + with pytest.warns(UserWarning, match="Failed to read cached motor file"): + GenericMotor.load_from_thrustcurve_api("M1670")