Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f8901da
Port over solution creation and scaling code (#126)
ltetrel Nov 4, 2024
57bb9d1
Protocol objects are now correctly json serializable.
ltetrel Nov 5, 2024
0b327a1
Fixed numpy deprecation warning.
ltetrel Nov 5, 2024
f0be40e
Refactoring `Protocol` so `scale_solution` is now a member of the `So…
ltetrel Nov 5, 2024
dbef2cf
Merge branch 'main' into issue_126
ltetrel Nov 5, 2024
edb303c
Fixed simulation output variable handling if simulation not enabled.
ltetrel Nov 5, 2024
150d310
\'setup_sim_scene\' is now the minimal implementation
ltetrel Nov 6, 2024
f787c42
Removing unused references related to segmentation and resampling fro…
ltetrel Nov 6, 2024
580ad0c
transform argument not needed currently since we use a default water …
ltetrel Nov 6, 2024
06942c2
Tests for check_targets
ltetrel Nov 6, 2024
df1ca69
` example_volume` is left commented for reference
ltetrel Nov 6, 2024
5b25419
If we need to convert a transform, we want to use the method form `Tr…
ltetrel Nov 6, 2024
911d535
Checking for pulse mismatch and updating accordingly the `Protocol` s…
ltetrel Nov 6, 2024
5ef5737
Refactored `Solution.scale` and added associated tests.
ltetrel Nov 6, 2024
2cdc725
rationalizing constants in `Solution.scale` using the new `FocalPatte…
ltetrel Nov 6, 2024
1826c05
Clarified docstring for `Protocol.calc_solution`
ltetrel Nov 6, 2024
b706923
Improved docstring for `Protocol.check_target`
ltetrel Nov 6, 2024
5dcf6ae
Fixed typing if None default
ltetrel Nov 6, 2024
fb3a0ac
now segmenting volume as well if volume is defined
ltetrel Nov 6, 2024
ea524b6
Refactored `setup_sim_scene` so it can also segment a volume.
ltetrel Nov 6, 2024
0670968
fix sim setup call
ltetrel Nov 6, 2024
0e2b74a
logging now through an attribute of `Protocol`
ltetrel Nov 6, 2024
2b519b0
better testing coverage for target verification
ltetrel Nov 6, 2024
8f55507
pytest raise instead of fail
ltetrel Nov 6, 2024
c26600d
test formating and better condition test
ltetrel Nov 7, 2024
9ef9f02
Re-worked fix_pulse_mismatch test for better rationalization
ltetrel Nov 7, 2024
3311c86
Some tests will be handled in the future in #152 since they require s…
ltetrel Nov 7, 2024
a1b39a6
Cleaning old nifti to xarray tests
ltetrel Nov 7, 2024
912a733
Adding `Solution.compute_scaling_factors\' to further break down and …
ltetrel Nov 7, 2024
15a6d05
Cleaning some TODOs
ltetrel Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/openlifu/bf/focal_patterns/focal_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ class FocalPattern(ABC):
"""
Abstract base class for representing a focal pattern

:ivar target_pressure: Target pressure of the focal pattern in Pa
:ivar target_pressure: Target pressure of the focal pattern in given units
"""
target_pressure: float = 1.0 # Pa
target_pressure: float = 1.0
units: str = "Pa"

@abstractmethod
def get_targets(self, target: Point):
Expand Down
2 changes: 1 addition & 1 deletion src/openlifu/bf/get_beamwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_beamwidth(vol: DataArray, coords_units: str, focus: Point, cutoff: float
inlier_hull = ConvexHull(inlier_points)
except QhullError:
# If convex hull creation fails (e.g., too few points), add jitter and try again
logging.warning("Invalid inliers, attempting to add jitter to create a valid volume...") #TODO: should be using self.logger
logging.warning("Invalid inliers, attempting to add jitter to create a valid volume...")
minmax_coords = np.array([(np.min(coords[i]), np.max(coords[i])) for i in range(len(coords))]) #TODO: min-max should be from coords.extent
coords_shape = tuple([len(coords[i]) for i in range(len(coords))])
dx = np.mean(np.diff(minmax_coords) / (np.array(coords_shape) - 1))
Expand Down
5 changes: 5 additions & 0 deletions src/openlifu/plan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from .protocol import Protocol
from .run import Run
from .solution import Solution
from .solution_analysis import SolutionAnalysis, SolutionAnalysisOptions
from .target_constraints import TargetConstraints

__all__ = [
"Protocol",
"Solution",
"Run",
"SolutionAnalysis",
"SolutionAnalysisOptions",
"TargetConstraints"
]
221 changes: 213 additions & 8 deletions src/openlifu/plan/protocol.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
import json
import logging
import math
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import xarray as xa

from openlifu import bf, geo, seg, sim, xdc
from openlifu.db.session import Session
from openlifu.geo import Point
from openlifu.plan.solution import Solution
from openlifu.plan.solution_analysis import SolutionAnalysis, SolutionAnalysisOptions
from openlifu.plan.target_constraints import TargetConstraints
from openlifu.sim import run_simulation
from openlifu.util.json import PYFUSEncoder
from openlifu.xdc import Transducer

OnPulseMismatchAction = Enum("OnPulseMismatchAction", ["ERROR", "ROUND", "ROUNDUP", "ROUNDDOWN"])


@dataclass
Expand All @@ -20,9 +36,12 @@ class Protocol:
delay_method: bf.DelayMethod = field(default_factory=bf.delay_methods.Direct)
apod_method: bf.ApodizationMethod = field(default_factory=bf.apod_methods.Uniform)
seg_method: seg.SegmentationMethod = field(default_factory=seg.seg_methods.Water)
param_constraints: dict = field(default_factory=dict)
target_constraints: dict = field(default_factory=dict)
analysis_options: dict = field(default_factory=dict)
param_constraints: dict = field(default_factory=dict) #TODO: this seems to be used only in `plan.check_analysis` but not called anywhere
target_constraints: List[TargetConstraints] = field(default_factory=list)
analysis_options: SolutionAnalysisOptions = field(default_factory=SolutionAnalysisOptions)

def __post_init__(self):
self.logger = logging.getLogger(__name__)

@staticmethod
def from_dict(d : Dict[str,Any]) -> "Protocol":
Expand All @@ -36,6 +55,11 @@ def from_dict(d : Dict[str,Any]) -> "Protocol":
if "materials" in d:
seg_method_dict["materials"] = seg.Material.from_dict(d.pop("materials"))
d["seg_method"] = seg.SegmentationMethod.from_dict(seg_method_dict)
d['param_constraints'] = d.get("param_constraints", {})
if "target_constraints" in d:
d['target_constraints'] = [TargetConstraints.from_dict(d_tc) for d_tc in d.get("target_constraints", {})]
if "analysis_options" in d:
d['analysis_options'] = SolutionAnalysisOptions.from_dict(d.get("analysis_options"))
return Protocol(**d)

def to_dict(self):
Expand Down Expand Up @@ -80,17 +104,198 @@ def to_json(self, compact:bool) -> str:
Returns: A json string representing the complete Protocol object.
"""
if compact:
return json.dumps(self.to_dict(), separators=(',', ':'))
return json.dumps(self.to_dict(), separators=(',', ':'), cls=PYFUSEncoder)
else:
return json.dumps(self.to_dict(), indent=4)
return json.dumps(self.to_dict(), indent=4, cls=PYFUSEncoder)

def to_file(self, filename):
def to_file(self, filename: str):
"""
Save the protocol to a file

:param filename: Name of the file
Args:
filename: Name of the file
"""
Path(filename).parent.parent.mkdir(exist_ok=True)
Path(filename).parent.mkdir(exist_ok=True)
with open(filename, 'w') as file:
file.write(self.to_json(compact=False))


def check_target(self, target: Point):
"""
Check if a target is within bounds, raising an exception if it isn't.

Args:
target: The geo.Point target to check.
"""
if isinstance(target, list):
raise ValueError(f"Input target {target} not supposed to be a list!")

# check if target position is within target_constraints defined bounds.
for target_constraint in self.target_constraints:
pos = target.get_position(
dim=target_constraint.dim,
units=target_constraint.units
)
target_constraint.check_bounds(pos)

def fix_pulse_mismatch(self, on_pulse_mismatch: OnPulseMismatchAction, foci: List[Point]):
"""Fix the protocol sequence pulse count in-place given a pulse_mismatch action."""
if on_pulse_mismatch is OnPulseMismatchAction.ERROR:
raise ValueError(f"Pulse Count {self.sequence.pulse_count} is not a multiple of the number of foci {len(foci)}")
else:
if on_pulse_mismatch is OnPulseMismatchAction.ROUND:
self.sequence.pulse_count = round(self.sequence.pulse_count / len(foci)) * len(foci)
elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDUP:
self.sequence.pulse_count = math.ceil(self.sequence.pulse_count / len(foci)) * len(foci)
elif on_pulse_mismatch is OnPulseMismatchAction.ROUNDDOWN:
self.sequence.pulse_count = math.floor(self.sequence.pulse_count / len(foci)) * len(foci)
self.logger.warning(
f"Pulse Count {self.sequence.pulse_count} is not a multiple of the number of foci {len(foci)}."
f"Rounding to {self.sequence.pulse_count}."
)

def calc_solution(
self,
target: Point,
transducer: Transducer,
volume: Optional[xa.DataArray] = None,
session: Optional[Session] = None,
simulate: bool = True,
Comment thread
ebrahimebrahim marked this conversation as resolved.
scale: bool = True,
sim_options: Optional[sim.SimSetup] = None,
analysis_options: Optional[SolutionAnalysisOptions] = None,
on_pulse_mismatch: OnPulseMismatchAction = OnPulseMismatchAction.ERROR
) -> Tuple[Solution, xa.DataArray, SolutionAnalysis]:
"""Calculate the solution and aggregated k-wave simulation outputs.

Method that computes the delays and apodizations for each focus in the treatment plan,
simulates the resulting pressure field to adjust transmit pressures to reach target pressures,
and then analyzes the resulting pressure field to compute the resulting acoustic parameters.

Args:
target: The target Point.
Target is expected to be in the simulation grid coordinates (lat, ele, ax).
transducer: A Transducer item.
volume: xa.DataArray
The subject scan (Default: None).
It is expected to be in the simulation grid coordinates (lat, ele, ax).
If None, a default simulation grid will be used.
session: db.Session
A session used to define solution_id (Default: None).
simulate: bool
Enable solution simulation (Default: true).
scale: bool
Triggers solution and simulation scaling to the requested pressure (Default: true).
sim_options : sim.SimSetup
The options for the k-wave simulation (Default: self.sim_setup).
analysis_options: plan.solution.SolutionAnalysisOptions
The options for the solution analysis (Default: self.analysis_options).
on_pulse_mismatch: plan.protocol.OnPulseMismatchAction
An action to take if the number of pulses in the sequence does not match
the number of foci (Default: OnPulseMismatchAction.ERROR).

Returns:
solution: Solution
simulation_result_aggregated: xa.Dataset
If simulation is enabled, then this is the resulting aggregated
output (max pressure and mean intensity over all foci).
scaled_solution_analysis: SolutionAnalysis
This is the resulting rescaled analysis, if scale is enabled.
"""
if sim_options is None:
sim_options = self.sim_setup
if analysis_options is None:
analysis_options = self.analysis_options
# check before if target is within bounds
self.check_target(target)
params = sim_options.setup_sim_scene(self.seg_method, volume=volume)

delays_to_stack: List[np.ndarray] = []
apodizations_to_stack: List[np.ndarray] = []
simulation_outputs_to_stack: List[xa.Dataset] = []
simulation_output_stacked: xa.Dataset = xa.Dataset()
simulation_result_aggregated: xa.Dataset = xa.Dataset()
scaled_solution_analysis: SolutionAnalysis = SolutionAnalysis()
foci: List[Point] = self.focal_pattern.get_targets(target)
simulation_cycles = np.max([np.round(self.pulse.duration * self.pulse.frequency), 20])

# updating solution sequence if pulse mismatch
if (self.sequence.pulse_count % len(foci)) != 0:
self.fix_pulse_mismatch(on_pulse_mismatch, foci)
# run simulation and aggregate the results
for focus in foci:
self.logger.info(f"Beamform for focus {focus}...")
delays, apodization = self.beamform(arr=transducer, target=focus, params=params)
simulation_output_xarray = None
if simulate:
self.logger.info(f"Simulate for focus {focus}...")
simulation_output_xarray, _ = run_simulation(
arr=transducer,
params=params,
delays=delays,
apod= apodization,
freq = self.pulse.frequency,
cycles = simulation_cycles,
dt=sim_options.dt,
t_end=sim_options.t_end,
amplitude = 1,
gpu = False
)
delays_to_stack.append(delays)
apodizations_to_stack.append(apodization)
simulation_outputs_to_stack.append(simulation_output_xarray)
if simulate:
simulation_output_stacked = xa.concat(
[
sim.assign_coords(focal_point_index=i)
for i, sim in enumerate(simulation_outputs_to_stack)
],
dim='focal_point_index',
)
# instantiate and return the solution
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
solution_id = timestamp
if session is not None:
solution_id = f"{session.id}_{solution_id}"
solution = Solution(
id=solution_id,
name=f"Solution {timestamp}",
protocol_id=self.id,
transducer_id=transducer.id,
delays=np.stack(delays_to_stack, axis=0),
apodizations=np.stack(apodizations_to_stack, axis=0),
pulse=self.pulse,
sequence=self.sequence,
foci=foci,
target=target,
simulation_result=simulation_output_stacked,
approved=False,
description= (
f"A solution computed for the {self.name} protocol with transducer {transducer.name}"
f" for target {target.id}."
f" This solution was created for the session {session.id} for subject {session.subject_id}." if session is not None else ""
)
)
# optionally scale the solution with simulation result
if scale:
if not simulate:
self.logger.error(msg=f"Cannot scale solution {solution.id} if simulation is not enabled!")
raise ValueError(f"Cannot scale solution {solution.id} if simulation is not enabled!")
self.logger.info(f"Scaling solution {solution.id}...")
#TODO can analysis be an attribute of solution ?
scaled_solution_analysis = solution.scale(transducer, self.focal_pattern, analysis_options=analysis_options)
Comment thread
ebrahimebrahim marked this conversation as resolved.

if simulate:
# Finally the resulting pressure is max-aggregated and intensity is mean-aggregated, over all focus points .
pnp_aggregated = solution.simulation_result['p_min'].max(dim="focal_point_index")
ppp_aggregated = solution.simulation_result['p_max'].max(dim="focal_point_index")
# TODO: Ensure this mean is weighted by the number of times each point is focused on, once openlifu supports hitting points different numbers of times
intensity_aggregated = solution.simulation_result['ita'].mean(dim="focal_point_index")
simulation_result_aggregated = deepcopy(solution.simulation_result)
simulation_result_aggregated = simulation_result_aggregated.drop_dims("focal_point_index")
simulation_result_aggregated['p_min'] = pnp_aggregated
simulation_result_aggregated['p_max'] = ppp_aggregated
simulation_result_aggregated['ita'] = intensity_aggregated

return solution, simulation_result_aggregated, scaled_solution_analysis
Loading