diff --git a/src/openlifu/plan/protocol.py b/src/openlifu/plan/protocol.py index b9969fab..f3ee315b 100644 --- a/src/openlifu/plan/protocol.py +++ b/src/openlifu/plan/protocol.py @@ -1,11 +1,20 @@ import json from dataclasses import asdict, dataclass, field +from datetime import datetime from pathlib import Path -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Tuple +import numpy as np import xarray as xa from openlifu import bf, geo, seg, sim, xdc +from openlifu.geo import Point +from openlifu.plan.solution import Solution +from openlifu.sim import run_simulation +from openlifu.xdc import Transducer + +if TYPE_CHECKING: + from openlifu.db import Session @dataclass @@ -94,3 +103,76 @@ def to_file(self, filename): Path(filename).parent.mkdir(exist_ok=True) with open(filename, 'w') as file: file.write(self.to_json(compact=False)) + + def calc_solution( + self, + transducer:Transducer, + volume:xa.DataArray, + target: Point, + session:"Optional[Session]"=None, # useful in solution id + ) -> Tuple[Solution, xa.DataArray, xa.DataArray]: + params = self.seg_method.seg_params(volume) + delays_to_stack : List[np.ndarray] = [] + apodizations_to_stack : List[np.ndarray] = [] + simulation_outputs_to_stack : List[xa.Dataset] = [] + target_pattern_points : List[Point] = self.focal_pattern.get_targets(target) + for focus_point in target_pattern_points: + delays, apodization = self.beamform(arr=transducer, target=focus_point, params=params) + + simulation_output_xarray, simulation_output_kwave = run_simulation( + arr=transducer, + params=params, + delays=delays, + apod= apodization, + freq = self.pulse.frequency, + cycles = np.max([np.round(self.pulse.duration * self.pulse.frequency), 20]), + dt=self.sim_setup.dt, + t_end=self.sim_setup.t_end, + amplitude = 1, + gpu = False + ) + + delays_to_stack.append(delays) + apodizations_to_stack.append(apodization) + simulation_outputs_to_stack.append(simulation_output_xarray) + + simulation_output_stacked = xa.concat( + [ + sim.assign_coords(focal_point_index=i) + for i,sim in enumerate(simulation_outputs_to_stack) + ], + dim='focal_point_index', + ) + + # Peak negative pressure volume, a simulation output. This is max-aggregated over all focus points. + pnp_aggregated = simulation_output_stacked['p_min'].max(dim="focal_point_index") + + # Mean-aggregate the intensity over the focus points + # TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times + intensity_aggregated = simulation_output_stacked['ita'].mean(dim="focal_point_index") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + solution_id = timestamp + if session is not None: + solution_id = f"{session.id}_{solution_id}" + solution = Solution( + id=solution_id, + name=f"Solution {timestamp}", + protocol_id=self.id, + transducer_id=transducer.id, + delays=np.stack(delays_to_stack, axis=0), + apodizations=np.stack(apodizations_to_stack, axis=0), + pulse=self.pulse, # TODO This pulse needs to be scaled via a port of scale_solution from matlab!! + sequence=self.sequence, # TODO is it correct to set the sequence the same as the protocol's here? + foci=target_pattern_points, + target=target, + simulation_result=simulation_output_stacked, + approved=False, + description= ( + f"A solution computed for the {self.name} protocol with transducer {transducer.name}" + f" for subject volume [TODO]" # TODO put volume ID here if it is not None, once Sadhana's PR #123 is merged + f" for target {target.id}." + f" This solution was created for the session {session.id} for subject {session.subject_id}." if session is not None else "" + ) + ) + return solution, pnp_aggregated, intensity_aggregated