diff --git a/db_dvc.dvc b/db_dvc.dvc index 609a445f..10ee7c1d 100644 --- a/db_dvc.dvc +++ b/db_dvc.dvc @@ -1,6 +1,6 @@ outs: -- md5: 02aa9b3bcbf45464ed81942af6dd21c7.dir - size: 21806971 +- md5: 272d36e8069e0b15bd61da82a8a4c5b9.dir + size: 21808282 nfiles: 36 hash: md5 path: db_dvc diff --git a/src/openlifu/plan/solution.py b/src/openlifu/plan/solution.py index 5b9209c0..7560fbb2 100644 --- a/src/openlifu/plan/solution.py +++ b/src/openlifu/plan/solution.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import xarray as xa @@ -82,16 +82,22 @@ class Solution: """Description of this solution""" delays: Optional[np.ndarray] = None - """Vector of time delays to steer the beam""" + """Vectors of time delays to steer the beam. Shape is (number of foci, number of transducer elements).""" apodizations: Optional[np.ndarray] = None - """Vector of apodizations to steer the beam""" + """Vectors of apodizations to steer the beam. Shape is (number of foci, number of transducer elements).""" + pulse: Pulse = field(default_factory=Pulse) """Pulse to send to the transducer when running sonication""" + sequence: Sequence = field(default_factory=Sequence) """Pulse sequence to use when running sonication""" - focus: Optional[Point] = None - """Point that is being focused on in this Solution; part of the focal pattern of the target""" + + foci: List[Point] = field(default_factory=list) + """Points that are focused on in this Solution due to the focal pattern around the target. + Each item in this list is a unique point from the focal pattern, and the pulse sequence is + what determines how many times each point will be used. + """ # there was "target_id" in the matlab software, but here we do not have the concept of a target ID. # I believe this was only needed in the matlab software because solutions were organized by target rather @@ -110,16 +116,9 @@ class Solution: """Approval state of this solution as a sonication plan. `True` means the user has provided some kind of confirmation that the solution is safe and acceptable to be executed.""" - def num_foci(self): + def num_foci(self) -> int: """Get the number of foci""" - if isinstance(self.focus, list): - nfoc = len(self.focus) - elif isinstance(self.focus, Point): - nfoc = 1 - else: - raise ValueError("Cannot get number of foci for types other than Point.") - - return nfoc + return len(self.foci) def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOptions()) -> SolutionAnalysis: """Analyzes the treatment solution. @@ -166,10 +165,7 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt # power_W = np.zeros(self.num_foci()) # TIC = np.zeros(self.num_foci()) for focus_index in range(self.num_foci()): - if isinstance(self.focus, list): - foc = self.focus[focus_index] - elif isinstance(self.focus, Point): - foc = self.focus + foc = self.foci[focus_index] # output_signal = [] # output_signal = np.zeros((transducer.numelements(), len(input_signal))) # for i in range(transducer.numelements()): @@ -203,8 +199,8 @@ def analyze(self, transducer: Transducer, options: SolutionOptions = SolutionOpt # distance=options.beamwidth_radius, # options=mask_options) - pk = np.max(pnp_MPa.data * mainlobe_mask) #TODO: pnp_MPa supposed to be a list for each focus: pnp_MPa(focus_index) - solution_analysis.mainlobe_pnp_MPa = pk + pk = np.max(pnp_MPa.data[focus_index] * mainlobe_mask) #TODO: pnp_MPa supposed to be a list for each focus: pnp_MPa(focus_index) + solution_analysis.mainlobe_pnp_MPa += [pk] # thresh_m3dB = pk*10**(-3 / 20) # thresh_m6dB = pk*10**(-6 / 20) @@ -357,10 +353,13 @@ def from_json(json_string : str, simulation_result: Optional[xa.Dataset]=None) - solution_dict["delays"] = np.array(solution_dict["delays"]) if solution_dict["apodizations"] is not None: solution_dict["apodizations"] = np.array(solution_dict["apodizations"], ndmin=2) + solution_dict["apodizations"] = np.array(solution_dict["apodizations"], ndmin=2) solution_dict["pulse"] = Pulse.from_dict(solution_dict["pulse"]) solution_dict["sequence"] = Sequence.from_dict(solution_dict["sequence"]) - if solution_dict["focus"] is not None: - solution_dict["focus"] = Point.from_dict(solution_dict["focus"]) #TODO: Solution analysis needs a list, to interface with FocalPattern ? + solution_dict["foci"] = [ + Point.from_dict(focus_dict) + for focus_dict in solution_dict["foci"] + ] if solution_dict["target"] is not None: solution_dict["target"] = Point.from_dict(solution_dict["target"]) diff --git a/src/openlifu/util/units.py b/src/openlifu/util/units.py index 88c9e0eb..d03c3b34 100644 --- a/src/openlifu/util/units.py +++ b/src/openlifu/util/units.py @@ -211,11 +211,12 @@ def rescale_coords(data_arr: Dataset, units: str) -> Dataset: rescaled = data_arr.copy(deep=True) for coord_key in data_arr.coords: curr_coord_attrs = rescaled[coord_key].attrs - curr_coord_units = curr_coord_attrs['units'] - scale = getunitconversion(curr_coord_units, units) - curr_coord_rescaled = scale*rescaled[coord_key].data - rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)}) - rescaled[coord_key].attrs['units'] = units + if 'units' in curr_coord_attrs: + curr_coord_units = curr_coord_attrs['units'] + scale = getunitconversion(curr_coord_units, units) + curr_coord_rescaled = scale*rescaled[coord_key].data + rescaled = rescaled.assign_coords({coord_key: (coord_key, curr_coord_rescaled, curr_coord_attrs)}) + rescaled[coord_key].attrs['units'] = units return rescaled @@ -235,7 +236,8 @@ def get_ndgrid_from_arr(data_arr: Dataset) -> np.ndarray: ordered_key = data_arr[first_data_key].dims all_coord = [] for coord_key in ordered_key: - all_coord += [data_arr.coords[coord_key].data] + if 'units' in data_arr[coord_key].attrs: + all_coord += [data_arr.coords[coord_key].data] ndgrid = np.stack(np.meshgrid(*all_coord, indexing="ij"), axis=-1) return ndgrid diff --git a/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json b/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json index e5ffcabe..66a0b39b 100644 --- a/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json +++ b/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.json @@ -6,32 +6,36 @@ "created_on": "2024-01-30T09:18:11", "description": "Example plan created 30-Jan-2024 09:16:02", "delays": [ - 7.139258974920841e-7, 1.164095583074107e-6, 1.4321977043056202e-6, - 1.5143736925332023e-6, 1.4094191657896087e-6, 1.1188705305994071e-6, - 6.46894718323139e-7, 0.0, 1.2683760653622907e-6, 1.7251583088871662e-6, - 1.9972746364594766e-6, 2.0806926330138847e-6, 1.974152793990993e-6, - 1.6792617729461087e-6, 1.2003736682835394e-6, 5.442792226777689e-7, - 1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6, - 2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6, - 1.5738505533232074e-6, 9.114003043615498e-7, 1.8309933020916417e-6, - 2.29468834620898e-6, 2.571004767271362e-6, 2.655722845121427e-6, - 2.5475236155907678e-6, 2.248089500472459e-6, 1.7619762095269915e-6, - 1.0962779929139644e-6, 1.8309933020916417e-6, 2.29468834620898e-6, - 2.571004767271362e-6, 2.655722845121427e-6, 2.5475236155907678e-6, - 2.248089500472459e-6, 1.7619762095269915e-6, 1.0962779929139644e-6, - 1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6, - 2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6, - 1.5738505533232074e-6, 9.114003043615498e-7, 1.2683760653622907e-6, - 1.7251583088871662e-6, 1.9972746364594766e-6, 2.0806926330138847e-6, - 1.974152793990993e-6, 1.6792617729461087e-6, 1.2003736682835394e-6, - 5.442792226777689e-7, 7.139258974920841e-7, 1.164095583074107e-6, - 1.4321977043056202e-6, 1.5143736925332023e-6, 1.4094191657896087e-6, - 1.1188705305994071e-6, 6.46894718323139e-7, 0.0 + [ + 7.139258974920841e-7, 1.164095583074107e-6, 1.4321977043056202e-6, + 1.5143736925332023e-6, 1.4094191657896087e-6, 1.1188705305994071e-6, + 6.46894718323139e-7, 0.0, 1.2683760653622907e-6, 1.7251583088871662e-6, + 1.9972746364594766e-6, 2.0806926330138847e-6, 1.974152793990993e-6, + 1.6792617729461087e-6, 1.2003736682835394e-6, 5.442792226777689e-7, + 1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6, + 2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6, + 1.5738505533232074e-6, 9.114003043615498e-7, 1.8309933020916417e-6, + 2.29468834620898e-6, 2.571004767271362e-6, 2.655722845121427e-6, + 2.5475236155907678e-6, 2.248089500472459e-6, 1.7619762095269915e-6, + 1.0962779929139644e-6, 1.8309933020916417e-6, 2.29468834620898e-6, + 2.571004767271362e-6, 2.655722845121427e-6, 2.5475236155907678e-6, + 2.248089500472459e-6, 1.7619762095269915e-6, 1.0962779929139644e-6, + 1.6425243839760022e-6, 2.1038804054306437e-6, 2.378775260158848e-6, + 2.4630532471222807e-6, 2.3554157329476133e-6, 2.0575192329568598e-6, + 1.5738505533232074e-6, 9.114003043615498e-7, 1.2683760653622907e-6, + 1.7251583088871662e-6, 1.9972746364594766e-6, 2.0806926330138847e-6, + 1.974152793990993e-6, 1.6792617729461087e-6, 1.2003736682835394e-6, + 5.442792226777689e-7, 7.139258974920841e-7, 1.164095583074107e-6, + 1.4321977043056202e-6, 1.5143736925332023e-6, 1.4094191657896087e-6, + 1.1188705305994071e-6, 6.46894718323139e-7, 0.0 + ] ], "apodizations": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ] ], "pulse": { "frequency": 500000, @@ -44,15 +48,17 @@ "pulse_train_interval": 1, "pulse_train_count": 1 }, - "focus": { - "id": "example_target", - "name": "Example Target", - "color": [1.0, 0.0, 0.0], - "radius": 0.001, - "position": [0.0, -0.0022437460888595447, 0.05518120697745499], - "dims": ["lat", "ele", "ax"], - "units": "m" - }, + "foci": [ + { + "id": "example_target", + "name": "Example Target", + "color": [1.0, 0.0, 0.0], + "radius": 0.001, + "position": [0.0, -0.0022437460888595447, 0.05518120697745499], + "dims": ["lat", "ele", "ax"], + "units": "m" + } + ], "target": { "id": "example_target", "name": "Example Target", diff --git a/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.nc b/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.nc index 70853f8a..1dd26100 100644 Binary files a/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.nc and b/tests/resources/example_db/subjects/example_subject/sessions/example_session/solutions/example_solution/example_solution.nc differ diff --git a/tests/test_database.py b/tests/test_database.py index dda23b73..380819b3 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -283,6 +283,11 @@ def test_load_solution(example_database:Database, example_session:Session): assert example_solution.name == "Example Solution" assert "p_min" in example_solution.simulation_result.data_vars # ensure the xarray dataset got loaded too + # ensure the simulation and beamform data was loaded for all foci + assert len(example_solution.simulation_result['focal_point_index']) == len(example_solution.foci) + assert example_solution.delays.shape[0] == len(example_solution.foci) + assert example_solution.apodizations.shape[0] == len(example_solution.foci) + def test_write_solution(example_database:Database, example_session:Session): solution = Solution(name="bleh", id='new_solution') diff --git a/tests/test_solution.py b/tests/test_solution.py index 275c7b8b..88b21ba9 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -36,34 +36,35 @@ def example_solution() -> Solution: transducer_id="trans_456", created_on=datetime(2024, 1, 1, 12, 0), description="This is a test solution for a unit test.", - delays=np.array([0.0, 1.0, 2.0, 3.0]), - apodizations=np.array([0.5, 0.75, 1.0, 0.85]), + delays=np.array([[0.0, 1.0, 2.0, 3.0]]), + apodizations=np.array([[0.5, 0.75, 1.0, 0.85]]), pulse=Pulse(frequency=42), sequence=Sequence(pulse_count=27), - focus=Point(id="test_focus_point"), + foci=[Point(id="test_focus_point")], target=Point(id="test_target_point"), simulation_result=xa.Dataset( { 'p_min': xa.DataArray( - data=rng.random((3, 2, 3)), - dims=["x", "y", "z"], + data=rng.random((1, 3, 2, 3)), + dims=["focal_point_index", "x", "y", "z"], attrs={'units': "Pa"} ), 'p_max': xa.DataArray( - data=rng.random((3, 2, 3)), - dims=["x", "y", "z"], + data=rng.random((1, 3, 2, 3)), + dims=["focal_point_index", "x", "y", "z"], attrs={'units': "Pa"} ), 'ita': xa.DataArray( - data=rng.random((3, 2, 3)), - dims=["x", "y", "z"], + data=rng.random((1, 3, 2, 3)), + dims=["focal_point_index", "x", "y", "z"], attrs={'units': "W/cm^2"} ) }, coords={ 'x': xa.DataArray(dims=["x"], data=np.linspace(0, 1, 3), attrs={'units': "m"}), 'y': xa.DataArray(dims=["y"], data=np.linspace(0, 1, 2), attrs={'units': "m"}), - 'z': xa.DataArray(dims=["z"], data=np.linspace(0, 1, 3), attrs={'units': "m"}) + 'z': xa.DataArray(dims=["z"], data=np.linspace(0, 1, 3), attrs={'units': "m"}), + 'focal_point_index': [0] } ), ) @@ -110,6 +111,15 @@ def test_save_load_solution_custom_dataset_filepath(example_solution: Solution, assert dataclasses_are_equal(Solution.from_files(json_filepath, nc_filepath), example_solution) +def test_num_foci(example_solution:Solution): + """Ensure that the number of foci in the test solution matches the number of foci provided in the simuluation and beamform data.""" + num_foci = example_solution.num_foci() + assert len(example_solution.foci) == num_foci + assert len(example_solution.simulation_result['focal_point_index']) == num_foci + assert example_solution.delays.shape[0] == num_foci + assert example_solution.apodizations.shape[0] == num_foci + + def test_solution_analysis(example_solution: Solution, example_transducer: Transducer): """Test that a solution output can be analyzed.""" example_solution.analyze(example_transducer)