Skip to content
Merged
4 changes: 2 additions & 2 deletions src/openlifu/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import shutil
from enum import Enum
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List

import h5py

from openlifu.nav.photoscan import Photoscan, load_data_from_photoscan
from openlifu.plan import Protocol, Run, Solution
from openlifu.util.json import PYFUSEncoder
from openlifu.util.types import PathLike
from openlifu.xdc import Transducer, TransducerArray
from openlifu.xdc.util import load_transducer_from_file

Expand All @@ -22,7 +23,6 @@
from .user import User

OnConflictOpts = Enum('OnConflictOpts', ['ERROR', 'OVERWRITE', 'SKIP'])
PathLike = Union[str, os.PathLike]

class Database:
def __init__(self, path: str | None = None):
Expand Down
54 changes: 19 additions & 35 deletions src/openlifu/nav/photoscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import numpy as np
import onnxruntime as ort
import OpenEXR
import requests
import trimesh
import vtk
from PIL import Image
from vtk.util.numpy_support import numpy_to_vtk

from openlifu.util.annotations import OpenLIFUFieldData
from openlifu.util.assets import download_and_install_modnet, get_modnet_path

logger_meshrecon = logging.getLogger("MeshRecon")
logger_meshroom = logging.getLogger("Meshroom")
Expand Down Expand Up @@ -300,6 +300,7 @@ def run_reconstruction(
locations: List[Tuple[float, float, float]] | None = None,
return_durations: bool = False,
progress_callback : Callable[[int,str],None] | None = None,
download_masking_model: bool = True,
) -> Tuple[Photoscan, Path] | Tuple[Photoscan, Path, Dict[str, float]]:
"""Run Meshroom with the given images and pipeline.
Args:
Expand All @@ -322,6 +323,8 @@ def run_reconstruction(
return_durations (bool): If True, also return a dictionary mapping node names to durations in seconds.
progress_callback: An optional function that will be called to report progress. The function should accept two arguments:
an integer progress value from 0 to 100 followed by a string message describing the step currently being worked on.
download_masking_model: Whether to auto-download the masking model weights if they are not present;
only relevant if use_masks is enabled.

Returns:
Union[Tuple[Photoscan, Path], Tuple[Photoscan, Path, Dict[str, float]]]:
Expand Down Expand Up @@ -432,7 +435,7 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
start_time = time.perf_counter()
masks_dir = temp_dir / "masks"
masks_dir.mkdir(parents=True, exist_ok=True)
make_masks(new_paths, masks_dir)
make_masks(new_paths, masks_dir, download_model=download_masking_model)
command.append( f"PrepareDenseScene_1.masksFolders=['{masks_dir.as_posix()}']" )
durations["MaskCreation"] = time.perf_counter() - start_time

Expand Down Expand Up @@ -597,38 +600,6 @@ def inverse_transform(img):

return inverse_transform(image) if inverse else transform(image)


def get_modnet_path() -> Path:
"""Get the MODNet checkpoint path. Download it if not present.
"""
package = "openlifu.nav.modnet_checkpoints"
filename = "modnet_photographic_portrait_matting.onnx"
url = "https://data.kitware.com/api/v1/file/67feb2cb31a330568827ab32/download"
try:
# Try to find the checkpoint in the package
resource_path = importlib.resources.files(package) / filename
if resource_path.is_file():
logger_meshrecon.info(f"Found existing MODNet checkpoint at {resource_path}")
return resource_path
except (FileNotFoundError, ModuleNotFoundError):
pass

# Fallback: Download the checkpoint
base_dir = Path(importlib.resources.files(package))
full_path = base_dir / filename
logger_meshrecon.info(f"MODNet checkpoint not found. Downloading from {url}...")
response = requests.get(url, stream=True, timeout=(10, 300))
if response.status_code == 200:
with open(full_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger_meshrecon.info(f"Downloaded MODNet checkpoint to {full_path}")
else:
raise RuntimeError(f"Failed to download MODNet checkpoint: {response.status_code} - {response.text}")

return full_path

def preprocess_image_modnet(image: np.ndarray, ref_size: int = 512) -> np.ndarray:
"""
Preprocess an input image for MODNet inference.
Expand Down Expand Up @@ -675,7 +646,7 @@ def preprocess_image_modnet(image: np.ndarray, ref_size: int = 512) -> np.ndarra
return image


def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.01) -> None:
def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.01, download_model=True) -> None:
"""
Runs MODNet on a list of image paths and saves the output masks.

Expand All @@ -687,10 +658,23 @@ def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.0
image_paths (List[str]): List of input image file paths.
output_dir (str): Directory where the output masks will be saved.
threshold (float): Threshold to binarize the soft segmentation output.
download_model (bool): Whether to auto-download the model weights if they are not present.
"""

# Load the ONNX model

ckpt_path = get_modnet_path()
if not ckpt_path.exists():
if download_model:
logger_meshrecon.info(f"Downloading MODNet checkpoint to {ckpt_path}")
download_and_install_modnet()
else:
raise FileNotFoundError(f"MODNet checkpoint not found at {ckpt_path}. Install it using an appropirate utility in openlifu.util.assets.")
else:
logger_meshrecon.info(f"Found existing MODNet checkpoint at {ckpt_path}")

session = ort.InferenceSession(ckpt_path, providers=["CPUExecutionProvider"]) # or CUDAExecutionProvider

for image_path in image_paths:
image = Image.open(image_path)
exif = image.getexif()
Expand Down
20 changes: 10 additions & 10 deletions src/openlifu/sim/kwave_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@
import logging
from typing import List

import kwave
import kwave.data
import numpy as np
import xarray as xa
from kwave.kgrid import kWaveGrid
from kwave.kmedium import kWaveMedium
from kwave.ksensor import kSensor
from kwave.ksource import kSource
from kwave.kspaceFirstOrder3D import kspaceFirstOrder3D
from kwave.options.simulation_execution_options import SimulationExecutionOptions
from kwave.options.simulation_options import SimulationOptions
from kwave.utils.kwave_array import kWaveArray

from openlifu import xdc
from openlifu.util.units import getunitconversion


def get_kgrid(coords: xa.Coordinates, t_end = 0, dt = 0, sound_speed_ref=1500, cfl=0.5):
from kwave.kgrid import kWaveGrid
units = [coords[dim].attrs['units'] for dim in coords.dims]
if not all(unit == units[0] for unit in units):
raise ValueError("All coordinates must have the same units")
Expand All @@ -40,6 +31,9 @@ def get_karray(arr: xdc.Transducer,
upsampling_rate: int = 5,
translation: List[float] = [0.,0.,0.],
rotation: List[float] = [0.,0.,0.]):
import kwave
import kwave.data
from kwave.utils.kwave_array import kWaveArray
karray = kWaveArray(bli_tolerance=bli_tolerance, upsampling_rate=upsampling_rate,
single_precision=True)
for el in arr.elements:
Expand All @@ -53,6 +47,7 @@ def get_karray(arr: xdc.Transducer,
return karray

def get_medium(params: xa.Dataset, ref_values_only: bool = False):
from kwave.kmedium import kWaveMedium
if ref_values_only:
medium = kWaveMedium(sound_speed=params['sound_speed'].attrs['ref_value'],
density=params['density'].attrs['ref_value'],
Expand All @@ -68,11 +63,13 @@ def get_medium(params: xa.Dataset, ref_values_only: bool = False):
return medium

def get_sensor(kgrid, record=['p_max','p_min']):
from kwave.ksensor import kSensor
sensor_mask = np.ones([kgrid.Nx, kgrid.Ny, kgrid.Nz])
sensor = kSensor(sensor_mask, record=record)
return sensor

def get_source(kgrid, karray, source_sig):
from kwave.ksource import kSource
source = kSource()
logging.info("Getting binary mask")
source.p_mask = karray.get_array_binary_mask(kgrid)
Expand All @@ -95,6 +92,9 @@ def run_simulation(arr: xdc.Transducer,
gpu: bool = True,
ref_values_only: bool = False
):
from kwave.kspaceFirstOrder3D import kspaceFirstOrder3D
from kwave.options.simulation_execution_options import SimulationExecutionOptions
from kwave.options.simulation_options import SimulationOptions
delays = delays if delays is not None else np.zeros(arr.numelements())
apod = apod if apod is not None else np.ones(arr.numelements())
kgrid = get_kgrid(params.coords, dt=dt, t_end=t_end, cfl=cfl)
Expand Down
170 changes: 170 additions & 0 deletions src/openlifu/util/assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Utilities for downloading and installing assets that openlifu needs."""

from __future__ import annotations

import ast
import importlib
import shutil
import sys
import tempfile
from pathlib import Path
from types import ModuleType

import requests

from openlifu.util.types import PathLike


def install_asset(destination:PathLike, path_to_asset:PathLike|None=None, url_to_asset:str|None=None) -> None:
"""Install a file to a location if it isn't already there.

Downloads if a `url_to_asset` is provided, and copies if a local `path_to_asset` is provided.
Does nothing if the `destination` already exists.

Args:
destination: The path where the asset should end up. If this already exists then the function will do nothing.
path_to_asset: Local filepath; if provided then the asset will be copied from here to `destination`.
Required if url_to_asset is not provided.
url_to_asset: Web URL to the asset; if provided then the asset will be downloaded and saved to `destination`.
Required if path_to_asset is not provided.
"""
destination = Path(destination)

if destination.exists():
return

destination.parent.mkdir(parents=True, exist_ok=True)

if path_to_asset is not None:
path_to_asset = Path(path_to_asset)
shutil.copy2(path_to_asset, destination)
elif url_to_asset is not None:
temp_file_path = None
try:
response = requests.get(url_to_asset, stream=True, timeout=(10, 300))
response.raise_for_status()
with tempfile.NamedTemporaryFile(mode='wb', dir=destination.parent, delete=False) as f:
temp_file_path = f.name
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
shutil.move(temp_file_path, destination)
finally:
if temp_file_path is not None and Path(temp_file_path).exists():
Path(temp_file_path).unlink()
else:
raise ValueError("Either path_to_asset or url_to_asset must be provided.")


def get_modnet_path() -> Path:
"""Get the MODNet checkpoint path.
It may or may not exist; see `download_and_install_modnet` and `install_modnet_from_file`.
If `get_modnet_path().exists()` is False, then use one of those two options to install.
"""
package = "openlifu.nav.modnet_checkpoints"
filename = "modnet_photographic_portrait_matting.onnx"
base_dir = Path(importlib.resources.files(package))
return base_dir / filename

def download_and_install_modnet() -> Path:
"""Download and install the MODNet checkpoint. Returns path to installed MODNet checkpoint."""
url = "https://data.kitware.com/api/v1/file/67feb2cb31a330568827ab32/download"
modnet_path = get_modnet_path()
install_asset(modnet_path, url_to_asset=url)
return modnet_path

def install_modnet_from_file(path_to_modnet_file:PathLike) -> Path:
"""Copy MODNet checkpoint to the appropriate place for openlifu to use it.
Returns path to installed MODNet checkpoint."""
modnet_path = get_modnet_path()
install_asset(modnet_path, path_to_asset=path_to_modnet_file)
return modnet_path

def _import_without_calls(pkg: str, banned_calls:list[str], register=False) -> ModuleType:
"""Import `pkg` but strip any top-level statements that call a banned function.

It is simplistic: it is looking at the syntax tree and stripping out any node that
has a banned function call in any of its descendent nodes. There are lots of ways to break
this if there is enough misdirection in a banned function call. The point of this is just
to help handle a specific issue we have with kwave's binary download.

Args:
pkg: The name of the package to import
banned_calls: A list of functions to import
register: Whether to add the module in global import registry.
Doing so makes any future imports of the module via the usual `import`
statement end up referring to the version imported here.

Returns the module.
"""
spec = importlib.util.find_spec(pkg)
if not spec or not spec.submodule_search_locations:
raise ImportError(f"Can't find package {pkg!r}")

init_path = Path(spec.submodule_search_locations[0]) / "__init__.py"
src = init_path.read_text(encoding="utf-8")
tree = ast.parse(src, filename=str(init_path))

# this function tells whether a top level statement tries to call a banned function anywhere inside it
def stmt_calls_banned(stmt: ast.stmt) -> bool:
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
f = node.func
if isinstance(f, ast.Name) and f.id in banned_calls:
return True
if isinstance(f, ast.Attribute) and f.attr in banned_calls:
return True
return False

tree.body = [s for s in tree.body if not stmt_calls_banned(s)] # strip out offending top level statements
code = compile(tree, str(init_path), "exec")

module = ModuleType(pkg) # create a blank module object
module.__file__ = str(init_path)
module.__package__ = pkg
g = module.__dict__ # build up the context in which we will execute the module code
g["__name__"] = pkg
g["__file__"] = str(init_path)
exec(code, g, g)

if register:
sys.modules[pkg] = module
return module

def _import_kwave_inertly() -> ModuleType:
"""Import kwave without allowing it to install binaries"""
return _import_without_calls("kwave", banned_calls=["install_binaries"])

def get_kwave_paths() -> list[tuple[Path, str]]:
"""Get a list of paths and urls to kwave binaries for this platform.

Each item in the list is a pair consisting of the install path of a needed binary, followed by a download url for that binary.
"""
kwave = _import_kwave_inertly()
paths : list[tuple[str, str]] = []
for url_list in kwave.URL_DICT[kwave.PLATFORM].values():
for url in url_list:
_, filename = url.split("/")[-2:]
paths.append((Path(kwave.BINARY_PATH) / filename, url))
return paths

def download_and_install_kwave_assets() -> None:
"""Download and install the binaries needed by kwave for this platform"""
for install_path, url in get_kwave_paths():
install_asset(destination=install_path, url_to_asset=url)

def install_kwave_asset_from_file(path_to_kwave_binary:PathLike) -> Path:
"""Copy kwave binary file to the appropriate place for kwave to use it.
The filename is used to identify which binary it is.
Returns the path to the installed (i.e. copied) binary.
"""
path_to_kwave_binary = Path(path_to_kwave_binary)
kwave_paths = get_kwave_paths()
for install_path, _ in kwave_paths:
if path_to_kwave_binary.name == install_path.name:
install_asset(destination=install_path, path_to_asset=path_to_kwave_binary)
return install_path
raise ValueError(
f"The filename {path_to_kwave_binary.name} was not recognized as one of the binaries kwave is looking for: "
+ ", ".join([str(install_path) for install_path, _ in kwave_paths])
)
7 changes: 7 additions & 0 deletions src/openlifu/util/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Custom types defined for openlifu"""
from __future__ import annotations

import os
from typing import Union

PathLike = Union[str,os.PathLike]
4 changes: 1 addition & 3 deletions src/openlifu/xdc/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import json
import os
from typing import Union

from openlifu.util.types import PathLike
from openlifu.xdc.transducer import Transducer
from openlifu.xdc.transducerarray import TransducerArray

PathLike = Union[str,os.PathLike]

def load_transducer_from_file(transducer_filepath : PathLike, convert_array:bool = True) -> Transducer|TransducerArray:
"""Load a Transducer or TransducerArray from file, depending on the "type" field in the file.
Expand Down
Loading
Loading