Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ENH: provide from_dict classmethods for decoding basic classes.
  • Loading branch information
phmbressan committed Sep 22, 2024
commit 40eaf3c46136d586b0090749514faca2f95701e7
103 changes: 100 additions & 3 deletions rocketpy/_encoders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Defines a custom JSON encoder for RocketPy objects."""

import base64
import json
from datetime import datetime
from importlib import import_module

import dill
import numpy as np


Expand Down Expand Up @@ -33,17 +36,111 @@ def default(self, o):
elif isinstance(o, np.ndarray):
return o.tolist()
elif isinstance(o, datetime):
return o.isoformat()
return [o.year, o.month, o.day, o.hour]
elif hasattr(o, "__iter__") and not isinstance(o, str):
return list(o)
elif hasattr(o, "to_dict"):
return o.to_dict()
encoding = o.to_dict()

encoding["signature"] = get_class_signature(o)

return encoding

elif hasattr(o, "__dict__"):
exception_set = {"prints", "plots"}
return {
encoding = {
key: value
for key, value in o.__dict__.items()
if key not in exception_set
}

if "rocketpy" in o.__class__.__module__ and not any(
subclass in o.__class__.__name__
for subclass in ["FlightPhase", "TimeNode"]
):
encoding["signature"] = get_class_signature(o)

return encoding
else:
return super().default(o)


class RocketPyDecoder(json.JSONDecoder):
"""Custom JSON decoder for RocketPy objects. It defines how to decode
different types of objects from a JSON supported format."""

def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, obj):
if "signature" in obj:
signature = obj.pop("signature")

try:
class_ = get_class_from_signature(signature)

if hasattr(class_, "from_dict"):
return class_.from_dict(obj)
else:
# Filter keyword arguments
kwargs = {
key: value
for key, value in obj.items()
if key in class_.__init__.__code__.co_varnames
}

return class_(**kwargs)
except ImportError: # AttributeException
return obj
else:
return obj


def get_class_signature(obj):
class_ = obj.__class__

return f"{class_.__module__}.{class_.__name__}"


def get_class_from_signature(signature):
module_name, class_name = signature.rsplit(".", 1)

module = import_module(module_name)

return getattr(module, class_name)


def to_hex_encode(obj, encoder=base64.b85encode):
"""Converts an object to hex representation using dill.

Parameters
----------
obj : object
Object to be converted to hex.
encoder : callable, optional
Function to encode the bytes. Default is base64.b85encode.

Returns
-------
bytes
Object converted to bytes.
"""
return encoder(dill.dumps(obj)).hex()


def from_hex_decode(obj_bytes, decoder=base64.b85decode):
"""Converts an object from hex representation using dill.

Parameters
----------
obj_bytes : str
Hex string to be converted to object.
decoder : callable, optional
Function to decode the bytes. Default is base64.b85decode.

Returns
-------
object
Object converted from bytes.
"""
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
64 changes: 64 additions & 0 deletions rocketpy/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,6 +2853,70 @@ def decimal_degrees_to_arc_seconds(angle):
arc_seconds = (remainder * 60 - arc_minutes) * 60
return degrees, arc_minutes, arc_seconds

def to_dict(self):
return {
"gravity": self.gravity,
"date": self.date,
"latitude": self.latitude,
"longitude": self.longitude,
"elevation": self.elevation,
"datum": self.datum,
"timezone": self.timezone,
"_max_expected_height": self.max_expected_height,
"atmospheric_model_type": self.atmospheric_model_type,
"pressure": self.pressure,
"temperature": self.temperature,
"wind_velocity_x": self.wind_velocity_x,
"wind_velocity_y": self.wind_velocity_y,
"wind_heading": self.wind_heading,
"wind_direction": self.wind_direction,
"wind_speed": self.wind_speed,
}

@classmethod
def from_dict(cls, data):
environment = cls(
gravity=data["gravity"],
date=data["date"],
latitude=data["latitude"],
longitude=data["longitude"],
elevation=data["elevation"],
datum=data["datum"],
timezone=data["timezone"],
max_expected_height=data["_max_expected_height"],
)

atmospheric_model = data["atmospheric_model_type"]

if atmospheric_model == "standard_atmosphere":
environment.set_atmospheric_model("standard_atmosphere")
elif atmospheric_model == "custom_atmosphere":
environment.set_atmospheric_model(
type="custom_atmosphere",
pressure=data["pressure"],
temperature=data["temperature"],
wind_u=data["wind_velocity_x"],
wind_v=data["wind_velocity_y"],
)
else:
environment.__set_pressure_function(data["pressure"])
environment.__set_barometric_height_function(data["temperature"])
environment.__set_temperature_function(data["temperature"])
environment.__set_wind_velocity_x_function(data["wind_velocity_x"])
environment.__set_wind_velocity_y_function(data["wind_velocity_y"])
environment.__set_wind_heading_function(data["wind_heading"])
environment.__set_wind_direction_function(data["wind_direction"])
environment.__set_wind_speed_function(data["wind_speed"])
environment.elevation = data["elevation"]
environment.max_expected_height = data["_max_expected_height"]

if atmospheric_model != "ensemble":
environment.calculate_density_profile()
environment.calculate_speed_of_sound_profile()
environment.calculate_dynamic_viscosity()

return environment


if __name__ == "__main__":
import doctest
Expand Down
10 changes: 4 additions & 6 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
carefully as it may impact all the rest of the project.
"""

import base64
import warnings
import zlib
from bisect import bisect_left
from collections.abc import Iterable
from copy import deepcopy
Expand All @@ -25,6 +23,8 @@
RBFInterpolator,
)

from rocketpy._encoders import from_hex_decode, to_hex_encode

# Numpy 1.x compatibility,
# TODO: remove these lines when all dependencies support numpy>=2.0.0
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
Expand Down Expand Up @@ -3401,7 +3401,7 @@ def to_dict(self):
source = self.source

if callable(source):
source = zlib.compress(base64.b85encode(dill.dumps(source))).hex()
source = to_hex_encode(source)

return {
"source": source,
Expand All @@ -3423,9 +3423,7 @@ def from_dict(cls, func_dict):
"""
source = func_dict["source"]
if func_dict["interpolation"] is None and func_dict["extrapolation"] is None:
source = dill.loads(
base64.b85decode(zlib.decompress(bytes.fromhex(source)))
)
source = from_hex_decode(source)

return cls(
source=source,
Expand Down
24 changes: 24 additions & 0 deletions rocketpy/motors/motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,30 @@ def all_info(self):
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
thrust_source=data["thrust_source"],
burn_time=data["_burn_time"],
chamber_radius=data["chamber_radius"],
chamber_height=data["chamber_height"],
chamber_position=data["chamber_position"],
propellant_initial_mass=data["propellant_initial_mass"],
nozzle_radius=data["nozzle_radius"],
dry_mass=data["dry_mass"],
center_of_dry_mass_position=data["center_of_dry_mass_position"],
dry_inertia=(
data["dry_I_11"],
data["dry_I_22"],
data["dry_I_33"],
data["dry_I_12"],
data["dry_I_13"],
data["dry_I_23"],
),
nozzle_position=data["nozzle_position"],
interpolation_method=data["interpolate"],
)


class EmptyMotor:
"""Class that represents an empty motor with no mass and no thrust."""
Expand Down
29 changes: 29 additions & 0 deletions rocketpy/motors/solid_motor.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,3 +738,32 @@ def all_info(self):
"""Prints out all data and graphs available about the SolidMotor."""
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
thrust_source=data["thrust_source"],
dry_mass=data["dry_mass"],
dry_inertia=(
data["dry_I_11"],
data["dry_I_22"],
data["dry_I_33"],
data["dry_I_12"],
data["dry_I_13"],
data["dry_I_23"],
),
nozzle_radius=data["nozzle_radius"],
grain_number=data["grain_number"],
grain_density=data["grain_density"],
grain_outer_radius=data["grain_outer_radius"],
grain_initial_inner_radius=data["grain_initial_inner_radius"],
grain_initial_height=data["grain_initial_height"],
grain_separation=data["grain_separation"],
grains_center_of_mass_position=data["grains_center_of_mass_position"],
center_of_dry_mass_position=data["center_of_dry_mass_position"],
nozzle_position=data["nozzle_position"],
burn_time=data["_burn_time"],
throat_radius=data["throat_radius"],
interpolation_method=data["interpolate"],
coordinate_system_orientation=data["coordinate_system_orientation"],
)
12 changes: 12 additions & 0 deletions rocketpy/rocket/aero_surface/fins/elliptical_fins.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,15 @@ def info(self):
def all_info(self):
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
n=data["_n"],
root_chord=data["_root_chord"],
span=data["_span"],
rocket_radius=data["_rocket_radius"],
cant_angle=data["_cant_angle"],
airfoil=data["_airfoil"],
name=data["name"],
)
13 changes: 13 additions & 0 deletions rocketpy/rocket/aero_surface/fins/trapezoidal_fins.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,16 @@ def info(self):
def all_info(self):
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
n=data["_n"],
root_chord=data["_root_chord"],
tip_chord=data["_tip_chord"],
span=data["_span"],
rocket_radius=data["_rocket_radius"],
cant_angle=data["_cant_angle"],
airfoil=data["_airfoil"],
name=data["name"],
)
12 changes: 12 additions & 0 deletions rocketpy/rocket/aero_surface/nose_cone.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,15 @@ def all_info(self):
"""
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
length=data["_length"],
kind=data["_kind"],
base_radius=data["_base_radius"],
bluffness=data["_bluffness"],
rocket_radius=data["_rocket_radius"],
power=data["_power"],
name=data["name"],
)
10 changes: 10 additions & 0 deletions rocketpy/rocket/aero_surface/tail.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,13 @@ def info(self):
def all_info(self):
self.prints.all()
self.plots.all()

@classmethod
def from_dict(cls, data):
return cls(
top_radius=data["_top_radius"],
bottom_radius=data["_bottom_radius"],
length=data["_length"],
rocket_radius=data["_rocket_radius"],
name=data["name"],
)
Loading