diff --git a/src/openlifu/bf/focal_patterns/focal_pattern.py b/src/openlifu/bf/focal_patterns/focal_pattern.py index e6f07ce9..9e4a5854 100644 --- a/src/openlifu/bf/focal_patterns/focal_pattern.py +++ b/src/openlifu/bf/focal_patterns/focal_pattern.py @@ -10,9 +10,10 @@ class FocalPattern(ABC): """ Abstract base class for representing a focal pattern - :ivar target_pressure: Target pressure of the focal pattern in Pa + :ivar target_pressure: Target pressure of the focal pattern in given units """ - target_pressure: float = 1.0 # Pa + target_pressure: float = 1.0 + units: str = "Pa" @abstractmethod def get_targets(self, target: Point): diff --git a/src/openlifu/bf/get_beamwidth.py b/src/openlifu/bf/get_beamwidth.py index 82daaed9..6649a979 100644 --- a/src/openlifu/bf/get_beamwidth.py +++ b/src/openlifu/bf/get_beamwidth.py @@ -98,7 +98,7 @@ def get_beamwidth(vol: DataArray, coords_units: str, focus: Point, cutoff: float inlier_hull = ConvexHull(inlier_points) except QhullError: # If convex hull creation fails (e.g., too few points), add jitter and try again - logging.warning("Invalid inliers, attempting to add jitter to create a valid volume...") #TODO: should be using self.logger + logging.warning("Invalid inliers, attempting to add jitter to create a valid volume...") minmax_coords = np.array([(np.min(coords[i]), np.max(coords[i])) for i in range(len(coords))]) #TODO: min-max should be from coords.extent coords_shape = tuple([len(coords[i]) for i in range(len(coords))]) dx = np.mean(np.diff(minmax_coords) / (np.array(coords_shape) - 1)) diff --git a/src/openlifu/plan/__init__.py b/src/openlifu/plan/__init__.py index 22450b6f..6b91ddd6 100644 --- a/src/openlifu/plan/__init__.py +++ b/src/openlifu/plan/__init__.py @@ -1,9 +1,14 @@ from .protocol import Protocol from .run import Run from .solution import Solution +from .solution_analysis import SolutionAnalysis, SolutionAnalysisOptions +from .target_constraints import TargetConstraints __all__ = [ "Protocol", "Solution", "Run", + "SolutionAnalysis", + "SolutionAnalysisOptions", + "TargetConstraints" ] diff --git a/src/openlifu/plan/protocol.py b/src/openlifu/plan/protocol.py index b9969fab..8b399bcf 100644 --- a/src/openlifu/plan/protocol.py +++ b/src/openlifu/plan/protocol.py @@ -1,11 +1,27 @@ import json +import logging +import math +from copy import deepcopy from dataclasses import asdict, dataclass, field +from datetime import datetime +from enum import Enum from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Optional, Tuple +import numpy as np import xarray as xa from openlifu import bf, geo, seg, sim, xdc +from openlifu.db.session import Session +from openlifu.geo import Point +from openlifu.plan.solution import Solution +from openlifu.plan.solution_analysis import SolutionAnalysis, SolutionAnalysisOptions +from openlifu.plan.target_constraints import TargetConstraints +from openlifu.sim import run_simulation +from openlifu.util.json import PYFUSEncoder +from openlifu.xdc import Transducer + +OnPulseMismatchAction = Enum("OnPulseMismatchAction", ["ERROR", "ROUND", "ROUNDUP", "ROUNDDOWN"]) @dataclass @@ -20,9 +36,12 @@ class Protocol: delay_method: bf.DelayMethod = field(default_factory=bf.delay_methods.Direct) apod_method: bf.ApodizationMethod = field(default_factory=bf.apod_methods.Uniform) seg_method: seg.SegmentationMethod = field(default_factory=seg.seg_methods.Water) - param_constraints: dict = field(default_factory=dict) - target_constraints: dict = field(default_factory=dict) - analysis_options: dict = field(default_factory=dict) + param_constraints: dict = field(default_factory=dict) #TODO: this seems to be used only in `plan.check_analysis` but not called anywhere + target_constraints: List[TargetConstraints] = field(default_factory=list) + analysis_options: SolutionAnalysisOptions = field(default_factory=SolutionAnalysisOptions) + + def __post_init__(self): + self.logger = logging.getLogger(__name__) @staticmethod def from_dict(d : Dict[str,Any]) -> "Protocol": @@ -36,6 +55,11 @@ def from_dict(d : Dict[str,Any]) -> "Protocol": if "materials" in d: seg_method_dict["materials"] = seg.Material.from_dict(d.pop("materials")) d["seg_method"] = seg.SegmentationMethod.from_dict(seg_method_dict) + d['param_constraints'] = d.get("param_constraints", {}) + if "target_constraints" in d: + d['target_constraints'] = [TargetConstraints.from_dict(d_tc) for d_tc in d.get("target_constraints", {})] + if "analysis_options" in d: + d['analysis_options'] = SolutionAnalysisOptions.from_dict(d.get("analysis_options")) return Protocol(**d) def to_dict(self): @@ -80,17 +104,198 @@ def to_json(self, compact:bool) -> str: Returns: A json string representing the complete Protocol object. """ if compact: - return json.dumps(self.to_dict(), separators=(',', ':')) + return json.dumps(self.to_dict(), separators=(',', ':'), cls=PYFUSEncoder) else: - return json.dumps(self.to_dict(), indent=4) + return json.dumps(self.to_dict(), indent=4, cls=PYFUSEncoder) - def to_file(self, filename): + def to_file(self, filename: str): """ Save the protocol to a file - :param filename: Name of the file + Args: + filename: Name of the file """ Path(filename).parent.parent.mkdir(exist_ok=True) Path(filename).parent.mkdir(exist_ok=True) with open(filename, 'w') as file: file.write(self.to_json(compact=False)) + + + def check_target(self, target: Point): + """ + Check if a target is within bounds, raising an exception if it isn't. + + Args: + target: The geo.Point target to check. + """ + if isinstance(target, list): + raise ValueError(f"Input target {target} not supposed to be a list!") + + # check if target position is within target_constraints defined bounds. + for target_constraint in self.target_constraints: + pos = target.get_position( + dim=target_constraint.dim, + units=target_constraint.units + ) + target_constraint.check_bounds(pos) + + def fix_pulse_mismatch(self, on_pulse_mismatch: OnPulseMismatchAction, foci: List[Point]): + """Fix the protocol sequence pulse count in-place given a pulse_mismatch action.""" + if on_pulse_mismatch is OnPulseMismatchAction.ERROR: + raise ValueError(f"Pulse Count {self.sequence.pulse_count} is not a multiple of the number of foci {len(foci)}") + else: + if on_pulse_mismatch is OnPulseMismatchAction.ROUND: + self.sequence.pulse_count = round(self.sequence.pulse_count / len(foci)) * len(foci) + elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDUP: + self.sequence.pulse_count = math.ceil(self.sequence.pulse_count / len(foci)) * len(foci) + elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDDOWN: + self.sequence.pulse_count = math.floor(self.sequence.pulse_count / len(foci)) * len(foci) + self.logger.warning( + f"Pulse Count {self.sequence.pulse_count} is not a multiple of the number of foci {len(foci)}." + f"Rounding to {self.sequence.pulse_count}." + ) + + def calc_solution( + self, + target: Point, + transducer: Transducer, + volume: Optional[xa.DataArray] = None, + session: Optional[Session] = None, + simulate: bool = True, + scale: bool = True, + sim_options: Optional[sim.SimSetup] = None, + analysis_options: Optional[SolutionAnalysisOptions] = None, + on_pulse_mismatch: OnPulseMismatchAction = OnPulseMismatchAction.ERROR + ) -> Tuple[Solution, xa.DataArray, SolutionAnalysis]: + """Calculate the solution and aggregated k-wave simulation outputs. + + Method that computes the delays and apodizations for each focus in the treatment plan, + simulates the resulting pressure field to adjust transmit pressures to reach target pressures, + and then analyzes the resulting pressure field to compute the resulting acoustic parameters. + + Args: + target: The target Point. + Target is expected to be in the simulation grid coordinates (lat, ele, ax). + transducer: A Transducer item. + volume: xa.DataArray + The subject scan (Default: None). + It is expected to be in the simulation grid coordinates (lat, ele, ax). + If None, a default simulation grid will be used. + session: db.Session + A session used to define solution_id (Default: None). + simulate: bool + Enable solution simulation (Default: true). + scale: bool + Triggers solution and simulation scaling to the requested pressure (Default: true). + sim_options : sim.SimSetup + The options for the k-wave simulation (Default: self.sim_setup). + analysis_options: plan.solution.SolutionAnalysisOptions + The options for the solution analysis (Default: self.analysis_options). + on_pulse_mismatch: plan.protocol.OnPulseMismatchAction + An action to take if the number of pulses in the sequence does not match + the number of foci (Default: OnPulseMismatchAction.ERROR). + + Returns: + solution: Solution + simulation_result_aggregated: xa.Dataset + If simulation is enabled, then this is the resulting aggregated + output (max pressure and mean intensity over all foci). + scaled_solution_analysis: SolutionAnalysis + This is the resulting rescaled analysis, if scale is enabled. + """ + if sim_options is None: + sim_options = self.sim_setup + if analysis_options is None: + analysis_options = self.analysis_options + # check before if target is within bounds + self.check_target(target) + params = sim_options.setup_sim_scene(self.seg_method, volume=volume) + + delays_to_stack: List[np.ndarray] = [] + apodizations_to_stack: List[np.ndarray] = [] + simulation_outputs_to_stack: List[xa.Dataset] = [] + simulation_output_stacked: xa.Dataset = xa.Dataset() + simulation_result_aggregated: xa.Dataset = xa.Dataset() + scaled_solution_analysis: SolutionAnalysis = SolutionAnalysis() + foci: List[Point] = self.focal_pattern.get_targets(target) + simulation_cycles = np.max([np.round(self.pulse.duration * self.pulse.frequency), 20]) + + # updating solution sequence if pulse mismatch + if (self.sequence.pulse_count % len(foci)) != 0: + self.fix_pulse_mismatch(on_pulse_mismatch, foci) + # run simulation and aggregate the results + for focus in foci: + self.logger.info(f"Beamform for focus {focus}...") + delays, apodization = self.beamform(arr=transducer, target=focus, params=params) + simulation_output_xarray = None + if simulate: + self.logger.info(f"Simulate for focus {focus}...") + simulation_output_xarray, _ = run_simulation( + arr=transducer, + params=params, + delays=delays, + apod= apodization, + freq = self.pulse.frequency, + cycles = simulation_cycles, + dt=sim_options.dt, + t_end=sim_options.t_end, + amplitude = 1, + gpu = False + ) + delays_to_stack.append(delays) + apodizations_to_stack.append(apodization) + simulation_outputs_to_stack.append(simulation_output_xarray) + if simulate: + simulation_output_stacked = xa.concat( + [ + sim.assign_coords(focal_point_index=i) + for i, sim in enumerate(simulation_outputs_to_stack) + ], + dim='focal_point_index', + ) + # instantiate and return the solution + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + solution_id = timestamp + if session is not None: + solution_id = f"{session.id}_{solution_id}" + solution = Solution( + id=solution_id, + name=f"Solution {timestamp}", + protocol_id=self.id, + transducer_id=transducer.id, + delays=np.stack(delays_to_stack, axis=0), + apodizations=np.stack(apodizations_to_stack, axis=0), + pulse=self.pulse, + sequence=self.sequence, + foci=foci, + target=target, + simulation_result=simulation_output_stacked, + approved=False, + description= ( + f"A solution computed for the {self.name} protocol with transducer {transducer.name}" + f" for target {target.id}." + f" This solution was created for the session {session.id} for subject {session.subject_id}." if session is not None else "" + ) + ) + # optionally scale the solution with simulation result + if scale: + if not simulate: + self.logger.error(msg=f"Cannot scale solution {solution.id} if simulation is not enabled!") + raise ValueError(f"Cannot scale solution {solution.id} if simulation is not enabled!") + self.logger.info(f"Scaling solution {solution.id}...") + #TODO can analysis be an attribute of solution ? + scaled_solution_analysis = solution.scale(transducer, self.focal_pattern, analysis_options=analysis_options) + + if simulate: + # Finally the resulting pressure is max-aggregated and intensity is mean-aggregated, over all focus points . + pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index") + ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index") + # TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times + intensity_aggregated = solution.simulation_result['ita'].mean(dim="focal_point_index") + simulation_result_aggregated = deepcopy(solution.simulation_result) + simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index") + simulation_result_aggregated['p_min'] = pnp_aggregated + simulation_result_aggregated['p_max'] = ppp_aggregated + simulation_result_aggregated['ita'] = intensity_aggregated + + return solution, simulation_result_aggregated, scaled_solution_analysis diff --git a/src/openlifu/plan/solution.py b/src/openlifu/plan/solution.py index 7560fbb2..7aec6f70 100644 --- a/src/openlifu/plan/solution.py +++ b/src/openlifu/plan/solution.py @@ -3,16 +3,18 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np import xarray as xa from openlifu.bf import Pulse, Sequence, mask_focus +from openlifu.bf.focal_patterns import FocalPattern from openlifu.bf.mask_focus import MaskOp from openlifu.geo import Point +from openlifu.plan.solution_analysis import SolutionAnalysis, SolutionAnalysisOptions from openlifu.util.json import PYFUSEncoder -from openlifu.util.units import rescale_data_arr +from openlifu.util.units import getunitconversion, rescale_data_arr from openlifu.xdc import Transducer @@ -23,41 +25,6 @@ def _construct_nc_filepath_from_json_filepath(json_filepath:Path) -> Path: return nc_filepath -@dataclass -class SolutionAnalysis: - mainlobe_pnp_MPa: list[float] = field(default_factory=list) - mainlobe_isppa_Wcm2: list[float] = field(default_factory=list) - mainlobe_ispta_mWcm2: list[float] = field(default_factory=list) - beamwidth_lat_3dB_mm: list[float] = field(default_factory=list) - beamwidth_ax_3dB_mm: list[float] = field(default_factory=list) - beamwidth_lat_6dB_mm: list[float] = field(default_factory=list) - beamwidth_ax_6dB_mm: list[float] = field(default_factory=list) - sidelobe_pnp_MPa: list[float] = field(default_factory=list) - sidelobe_isppa_Wcm2: list[float] = field(default_factory=list) - global_pnp_MPa: list[float] = field(default_factory=list) - global_isppa_Wcm2: list[float] = field(default_factory=list) - p0_Pa: list[float] = field(default_factory=list) - TIC: float = None - power_W: float = None - MI: float = None - global_ispta_mWcm2: float = None - - -@dataclass -class SolutionOptions: - standoff_sound_speed: float = 1500.0 - standoff_density: float = 1000.0 - ref_sound_speed: float = 1500.0 - ref_density: float = 1000.0 - focus_diameter: float = 0.5 - mainlobe_aspect_ratio: tuple[float, float, float] = (1., 1., 5.) - mainlobe_radius: float = 2.5e-3 - beamwidth_radius: float = 5e-3 - sidelobe_radius: float = 3e-3 - sidelobe_zmin: float = 1e-3 - distance_units: str = "m" - - @dataclass class Solution: """ @@ -120,11 +87,11 @@ def num_foci(self) -> int: """Get the number of foci""" return len(self.foci) - def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOptions()) -> SolutionAnalysis: + def analyze(self, transducer: Transducer, options: SolutionAnalysisOptions = SolutionAnalysisOptions()) -> SolutionAnalysis: """Analyzes the treatment solution. Args: - transducer: A Transducer item. #TODO: this should be instantiated at the database level, not here ? + transducer: A Transducer item. options: A struct for solution analysis options. Returns: A struct containing the results of the analysis. @@ -176,7 +143,7 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt # get focus region masks (for mainlobe, sidelobe and beamwidth) mainlobe_mask = mask_focus( - self.simulation_result, #TODO: Original code uses coords, but too complicated to maniplulate a Coordinates class + self.simulation_result, foc, options.mainlobe_radius, mask_op=MaskOp.LESS_EQUAL, @@ -264,6 +231,69 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt return solution_analysis + def compute_scaling_factors( + self, + focal_pattern: FocalPattern, + analysis: SolutionAnalysis + ) -> Tuple[np.ndarray, float, float]: + """ + + Compute the scaling factors used to re-scale the apodizations, simulation results and pulse amplitude. + + Args: + focal_pattern: FocalPattern + analysis: SolutionAnalysis + + Returns: + apod_factors: A np.ndarray apodization factors + v0: A float representing the original pulse amplitude + v1: A float representing the new pulse amplitude + """ + scaling_factors = np.zeros(self.num_foci()) + + for i in range(self.num_foci()): + focal_pattern_pressure_in_MPa = focal_pattern.target_pressure * getunitconversion(focal_pattern.units, "MPa") + scaling_factors[i] = focal_pattern_pressure_in_MPa / analysis.mainlobe_pnp_MPa[i] + max_scaling = np.max(scaling_factors) + v0 = self.pulse.amplitude + v1 = v0 * max_scaling + apod_factors = scaling_factors / max_scaling + + return apod_factors, v0, v1 + + def scale( + self, + transducer: Transducer, + focal_pattern: FocalPattern, + analysis_options: SolutionAnalysisOptions = SolutionAnalysisOptions() + ) -> SolutionAnalysis: + """ + Scale the solution in-place to match the target pressure. + + Args: + transducer: xdc.Transducer + focal_pattern: FocalPattern + analysis_options: plan.solution.SolutionAnalysisOptions + + Returns: + analysis_scaled: the resulting plan.solution.SolutionAnalysis from scaled solution + """ + analysis = self.analyze(transducer, options=analysis_options) + + apod_factors, v0, v1 = self.compute_scaling_factors(focal_pattern, analysis) + + for i in range(self.num_foci()): + scaling = v1/v0*apod_factors[i] + self.simulation_result['p_min'][i].data *= scaling + self.simulation_result['p_max'][i].data *= scaling + self.simulation_result['ita'][i].data *= scaling**2 + self.apodizations[i] = self.apodizations[i]*apod_factors[i] + self.pulse.amplitude = v1 + + analysis_scaled = self.analyze(transducer, options=analysis_options) + + return analysis_scaled + def get_pulsetrain_dutycycle(self) -> float: """ Compute the pulsetrain dutycycle given a sequence. diff --git a/src/openlifu/plan/solution_analysis.py b/src/openlifu/plan/solution_analysis.py new file mode 100644 index 00000000..5cd944d8 --- /dev/null +++ b/src/openlifu/plan/solution_analysis.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from typing import Optional, Tuple + +from openlifu.io.dict_conversion import DictMixin + + +@dataclass +class SolutionAnalysis(DictMixin): + mainlobe_pnp_MPa: list[float] = field(default_factory=list) + mainlobe_isppa_Wcm2: list[float] = field(default_factory=list) + mainlobe_ispta_mWcm2: list[float] = field(default_factory=list) + beamwidth_lat_3dB_mm: list[float] = field(default_factory=list) + beamwidth_ax_3dB_mm: list[float] = field(default_factory=list) + beamwidth_lat_6dB_mm: list[float] = field(default_factory=list) + beamwidth_ax_6dB_mm: list[float] = field(default_factory=list) + sidelobe_pnp_MPa: list[float] = field(default_factory=list) + sidelobe_isppa_Wcm2: list[float] = field(default_factory=list) + global_pnp_MPa: list[float] = field(default_factory=list) + global_isppa_Wcm2: list[float] = field(default_factory=list) + p0_Pa: list[float] = field(default_factory=list) + TIC: Optional[float] = None + power_W: Optional[float] = None + MI: Optional[float] = None + global_ispta_mWcm2: Optional[float] = None + + +@dataclass +class SolutionAnalysisOptions(DictMixin): + standoff_sound_speed: float = 1500.0 + standoff_density: float = 1000.0 + ref_sound_speed: float = 1500.0 + ref_density: float = 1000.0 + focus_diameter: float = 0.5 + mainlobe_aspect_ratio: Tuple[float, float, float] = (1., 1., 5.) + mainlobe_radius: float = 2.5e-3 + beamwidth_radius: float = 5e-3 + sidelobe_radius: float = 3e-3 + sidelobe_zmin: float = 1e-3 + distance_units: str = "m" diff --git a/src/openlifu/plan/target_constraints.py b/src/openlifu/plan/target_constraints.py new file mode 100644 index 00000000..53ba0f5d --- /dev/null +++ b/src/openlifu/plan/target_constraints.py @@ -0,0 +1,37 @@ +import logging +from dataclasses import dataclass + +from openlifu.io.dict_conversion import DictMixin + + +@dataclass +class TargetConstraints(DictMixin): + """ A class for storing target constraints. + + Target constraints are used to define the acceptable range of + positions for a target. For example, a target constraint could + be used to define the acceptable range of values for the x position + of a target. + """ + + dim: str = "x" + """The dimension ID being constrained""" + + name: str = "dim" + """The name of the dimension being constrained""" + + units: str = "m" + """The units of the dimension being constrained""" + + min: float = float("-inf") + """The minimum value of the dimension""" + + max: float = float("inf") + """The maximum value of the dimension""" + + def check_bounds(self, pos: float): + """Check if the given position is within bounds.""" + + if (pos < self.min) or (pos > self.max): + logging.error(msg=f"The position {pos} at dimension {self.name} is not within bounds [{self.min}, {self.max}]!") + raise ValueError(f"The position {pos} at dimension {self.name} is not within bounds [{self.min}, {self.max}]!") diff --git a/src/openlifu/seg/seg_methods/__init__.py b/src/openlifu/seg/seg_methods/__init__.py index 99a1cfa0..6abd1b94 100644 --- a/src/openlifu/seg/seg_methods/__init__.py +++ b/src/openlifu/seg/seg_methods/__init__.py @@ -1,6 +1,5 @@ from . import seg_method from .seg_method import SegmentationMethod, UniformSegmentation -from .segment_mri import SegmentMRI from .tissue import Tissue from .water import Water @@ -10,5 +9,4 @@ "UniformSegmentation", "Water", "Tissue", - "SegmentMRI", ] diff --git a/src/openlifu/seg/seg_methods/segment_mri.py b/src/openlifu/seg/seg_methods/segment_mri.py deleted file mode 100644 index f054fcc7..00000000 --- a/src/openlifu/seg/seg_methods/segment_mri.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - -import xarray as xa - -from openlifu.seg.seg_methods.seg_method import SegmentationMethod - - -@dataclass -class SegmentMRI(SegmentationMethod): - def _segment(self, volume: xa.DataArray): - raise NotImplementedError diff --git a/src/openlifu/sim/sim_setup.py b/src/openlifu/sim/sim_setup.py index 2611f39e..f6ed3e12 100644 --- a/src/openlifu/sim/sim_setup.py +++ b/src/openlifu/sim/sim_setup.py @@ -5,7 +5,9 @@ import numpy as np import xarray as xa +from openlifu.geo import Point from openlifu.io.dict_conversion import DictMixin +from openlifu.seg import SegmentationMethod from openlifu.util.units import getunitconversion from openlifu.xdc import Transducer @@ -100,12 +102,38 @@ def get_max_distance(self, arr: Transducer, units: Optional[str] = None): def get_size(self, dims: Optional[str]=None): dims = self.dims if dims is None else dims - n = [int(np.round(np.diff(ext)/self.spacing))+1 for ext in [self.x_extent, self.y_extent, self.z_extent]] + n = [int((np.round(np.diff(ext)/self.spacing)).item())+1 for ext in [self.x_extent, self.y_extent, self.z_extent]] return np.array([n[self.dims.index(dim)] for dim in dims]).squeeze() def get_spacing(self, units: Optional[str] = None): units = self.units if units is None else units return getunitconversion(self.units, units)*self.spacing - def transform_scene(self, scene, id: Optional[str] = None, name: Optional[str] = None, units: Optional[str] = None): - raise NotImplementedError + def setup_sim_scene( + self, + seg_method: SegmentationMethod, + volume: Optional[xa.DataArray] = None + ) -> Tuple[xa.DataArray, Transducer, Point]: + """ Prepare a simulation scene composed of a simulation grid + + Setup a simulation scene with a simulation grid including physical properties. + A segmentation is performed to detect the medium, so we can assign + physical properties to each voxel, later used by the ultrasound simulation. + This assume that the input volume is resampled to the geo-referenced simulation grid (lon, lat, ele). + + Args: + seg_method: seg.SegmentationMethod + volume: xa.DataArray + Optional volume to be used for simulation grid definition (Default: None). + The volume is assumed to be resampled on sim grid coordinates. + + Returns + params: The xa.DataArray simulation grid with physical properties for each voxel + """ + if volume is None: + sim_coords = self.get_coords() + params = seg_method.ref_params(sim_coords) + else: + params = seg_method.seg_params(volume) + + return params diff --git a/src/openlifu/util/json.py b/src/openlifu/util/json.py index a65cc7a5..78203d29 100644 --- a/src/openlifu/util/json.py +++ b/src/openlifu/util/json.py @@ -6,6 +6,7 @@ from openlifu.db.subject import Subject from openlifu.geo import Point +from openlifu.plan.solution_analysis import SolutionAnalysisOptions from openlifu.seg.material import Material from openlifu.xdc.element import Element from openlifu.xdc.transducer import Transducer @@ -31,6 +32,8 @@ def default(self, obj): return obj.to_dict() if isinstance(obj, Subject): return obj.to_dict() + if isinstance(obj, SolutionAnalysisOptions): + return obj.to_dict() return super().default(obj) def to_json(obj, filename): diff --git a/src/openlifu/util/units.py b/src/openlifu/util/units.py index d03c3b34..a22d89f6 100644 --- a/src/openlifu/util/units.py +++ b/src/openlifu/util/units.py @@ -1,7 +1,6 @@ import numpy as np from xarray import Dataset -#TODO: use Pint (https://github.com/hgrecco/pint) instead to manage physics units in python def getunittype(unit): unit = unit.lower() diff --git a/src/openlifu/xdc/transducer.py b/src/openlifu/xdc/transducer.py index d2dcd957..77c498ac 100644 --- a/src/openlifu/xdc/transducer.py +++ b/src/openlifu/xdc/transducer.py @@ -153,8 +153,6 @@ def rescale(self, units): if self.units != units: for element in self.elements: element.rescale(units) - scl = getunitconversion(self.units, units) - self.matrix[0:3, 3] *= scl self.units = units def to_dict(self): diff --git a/tests/resources/example_db/protocols/example_protocol/example_protocol.json b/tests/resources/example_db/protocols/example_protocol/example_protocol.json index a039be52..7f4362a0 100644 --- a/tests/resources/example_db/protocols/example_protocol/example_protocol.json +++ b/tests/resources/example_db/protocols/example_protocol/example_protocol.json @@ -16,6 +16,7 @@ }, "focal_pattern": { "target_pressure": 1.0e6, + "units": "Pa", "class": "SinglePoint" }, "delay_method": { @@ -87,7 +88,19 @@ "t_end": 0, "options": {} }, - "param_constraints": {}, - "target_constraints": {}, - "analysis_options": {} + "param_constraints": [], + "target_constraints": [], + "analysis_options": { + "standoff_sound_speed": 1500.0, + "standoff_density": 1000.0, + "ref_sound_speed": 1500.0, + "ref_density": 1000.0, + "focus_diameter": 0.5, + "mainlobe_aspect_ratio": [1.0, 1.0, 5.0], + "mainlobe_radius": 2.5e-3, + "beamwidth_radius": 5e-3, + "sidelobe_radius": 3e-3, + "sidelobe_zmin": 1e-3, + "distance_units": "m" + } } diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d8f73058..3233fe82 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,14 +1,31 @@ +import logging from pathlib import Path import pytest -from openlifu import Protocol +from openlifu import Protocol, Transducer +from openlifu.bf.focal_patterns import Wheel +from openlifu.db import Session +from openlifu.plan.protocol import OnPulseMismatchAction +from openlifu.plan.target_constraints import TargetConstraints @pytest.fixture() def example_protocol() -> Protocol: return Protocol.from_file(Path(__file__).parent/'resources/example_db/protocols/example_protocol/example_protocol.json') +@pytest.fixture() +def example_transducer() -> Transducer: + return Transducer.from_file(Path(__file__).parent/"resources/example_db/transducers/example_transducer/example_transducer.json") + +@pytest.fixture() +def example_session() -> Session: + return Session.from_file(Path(__file__).parent/"resources/example_db/subjects/example_subject/sessions/example_session/example_session.json") + +@pytest.fixture() +def example_wheel_pattern() -> Wheel: + return Wheel(num_spokes=6) + @pytest.mark.parametrize("compact_representation", [True, False]) def test_serialize_deserialize_protocol(example_protocol : Protocol, compact_representation: bool): assert example_protocol.from_json(example_protocol.to_json(compact_representation)) == example_protocol @@ -16,3 +33,56 @@ def test_serialize_deserialize_protocol(example_protocol : Protocol, compact_rep def test_default_protocol(): """Ensure it is possible to construct a default protocol""" Protocol() + +@pytest.mark.parametrize( + "target_constraints", + [ + [ + TargetConstraints(dim="P", units="mm", min=0.0, max=float("inf")), + ], + [ + TargetConstraints(dim="P", units="m", min=-0.001, max=0.0), + ], + [ + TargetConstraints(dim="L", units="mm", min=-100.0, max=0.0), + TargetConstraints(dim="P", units="mm", min=-100.0, max=0.0), + TargetConstraints(dim="S", units="mm", min=-100.0, max=-10.0), + ] + ] +) +def test_check_target(example_protocol: Protocol, example_session: Session, target_constraints: TargetConstraints): + """Ensure that the target can be correctly verified.""" + example_protocol.target_constraints = target_constraints + with pytest.raises(ValueError, match="not within bounds"): + example_protocol.check_target(example_session.targets[0]) + +@pytest.mark.parametrize("on_pulse_mismatch", [ + OnPulseMismatchAction.ERROR, + OnPulseMismatchAction.ROUND, + OnPulseMismatchAction.ROUNDUP, + OnPulseMismatchAction.ROUNDDOWN + ] + ) +def test_fix_pulse_mismatch( + example_protocol: Protocol, + example_session: Session, + example_wheel_pattern: Wheel, + on_pulse_mismatch: OnPulseMismatchAction + ): + """Test if sequence is correctly fixed for all pulse mismatch actions.""" + logging.disable(logging.CRITICAL) + + target = example_session.targets[0] + foci = example_wheel_pattern.get_targets(target) + num_foci = len(foci) + if on_pulse_mismatch is OnPulseMismatchAction.ERROR: + with pytest.raises(ValueError, match="not a multiple of the number of foci"): + example_protocol.fix_pulse_mismatch(on_pulse_mismatch, foci) + else: + example_protocol.fix_pulse_mismatch(on_pulse_mismatch, foci) + if on_pulse_mismatch is OnPulseMismatchAction.ROUND: + assert example_protocol.sequence.pulse_count == num_foci + elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDUP: + assert example_protocol.sequence.pulse_count == 2*num_foci + elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDDOWN: + assert example_protocol.sequence.pulse_count == num_foci diff --git a/tests/test_seg_method.py b/tests/test_seg_method.py index c8513b55..18b5b3fa 100644 --- a/tests/test_seg_method.py +++ b/tests/test_seg_method.py @@ -1,3 +1,4 @@ + import pytest from openlifu.seg import Material, SegmentationMethod diff --git a/tests/test_solution.py b/tests/test_solution.py index 88b21ba9..7dc013d7 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -7,6 +7,7 @@ from helpers import dataclasses_are_equal from openlifu import Point, Pulse, Sequence, Solution, Transducer +from openlifu.bf.focal_patterns import SinglePoint from openlifu.xdc.element import Element @@ -26,6 +27,11 @@ def example_transducer() -> Transducer: ) +@pytest.fixture() +def example_focal_pattern_single() -> SinglePoint: + return SinglePoint(target_pressure=1.0e6, units="Pa") + + @pytest.fixture() def example_solution() -> Solution: rng = np.random.default_rng(147) @@ -118,8 +124,3 @@ def test_num_foci(example_solution:Solution): assert len(example_solution.simulation_result['focal_point_index']) == num_foci assert example_solution.delays.shape[0] == num_foci assert example_solution.apodizations.shape[0] == num_foci - - -def test_solution_analysis(example_solution: Solution, example_transducer: Transducer): - """Test that a solution output can be analyzed.""" - example_solution.analyze(example_transducer)