Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions geoapps-assets/FlinFlon_dcip.geoh5
Git LFS file not shown
64 changes: 19 additions & 45 deletions geoapps/inversion/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def __init__(self, workspace: Workspace, params: InversionBaseParams):
self.radar: np.ndarray | None = None
self.locations: np.ndarray | None = None
self.mask: np.ndarray | None = None
self.global_map: np.ndarray | None = None
self.indices: np.ndarray | None = None
self.vector: bool | None = None
self.n_blocks: int | None = None
Expand All @@ -113,12 +112,16 @@ def _initialize(self) -> None:
self.offset, self.radar = self.params.offset()
self.locations = super().get_locations(self.params.data_object)

if self.angle is not None and self.angle != 0:
raise ValueError("Mesh is rotated.")
self.mask = np.ones(len(self.locations), dtype=bool)
if self.radar is not None:
if any(np.isnan(self.radar)):
self.mask[np.isnan(self.radar)] = False
if (
getattr(self.params, "line_id", None) is not None
and getattr(self.params, "line_object", None) is not None
):
self.mask = self.params.line_object.values == self.params.line_id
else:
self.mask = np.ones(len(self.locations), dtype=bool)

if self.radar is not None and any(np.isnan(self.radar)):
self.mask[np.isnan(self.radar)] = False

self.observed = self.filter(self.observed)
self.radar = self.filter(self.radar)
Expand All @@ -127,8 +130,8 @@ def _initialize(self) -> None:
self.normalizations = self.get_normalizations()
self.observed = self.normalize(self.observed)
self.uncertainties = self.normalize(self.uncertainties, absolute=True)
self.locations = self.apply_transformations(self.locations)
self.entity = self.write_entity()
self.params.data_object = self.entity
self.locations = super().get_locations(self.entity)
self.survey, self.local_index, _ = self.create_survey()

Expand Down Expand Up @@ -163,25 +166,8 @@ def drape_locations(self, locations: np.ndarray) -> np.ndarray:

def filter(self, a):
"""Remove vertices based on mask property."""
if (
self.params.inversion_type
in [
"direct current pseudo 3d",
"direct current 3d",
"direct current 2d",
"induced polarization 3d",
"induced polarization 2d",
"induced polarization pseudo 3d",
]
and self.indices is None
):
ab_ind = np.where(np.any(self.mask[self.params.data_object.cells], axis=1))[
0
]
self.indices = ab_ind

if self.indices is None:
self.indices = np.where(self.mask)
self.indices = np.where(self.mask)[0]

a = super().filter(a, mask=self.indices)

Expand Down Expand Up @@ -251,11 +237,10 @@ def save_data(self, entity):
else:
for component in data:
dnorm = data[component] / self.normalizations[None][component]
if "2d" in self.params.inversion_type:
dnorm = self._embed_2d(dnorm)
data_dict[component] = entity.add_data(
{f"{basename}_{component}": {"values": dnorm}}
)

if not self.params.forward_only:
self._observed_data_types[component] = data_dict[
component
Expand All @@ -265,20 +250,14 @@ def save_data(self, entity):
/ self.normalizations[None][component]
)
uncerts[np.isinf(uncerts)] = np.nan
if "2d" in self.params.inversion_type:
uncerts = self._embed_2d(uncerts)

uncert_dict[component] = entity.add_data(
{f"Uncertainties_{component}": {"values": uncerts}}
)

if "direct current" in self.params.inversion_type:
apparent_property = data[component].copy()
apparent_property[self.global_map] *= self.transformations[
"apparent resistivity"
]

if "2d" in self.params.inversion_type:
apparent_property = self._embed_2d(apparent_property)
apparent_property *= self.transformations["apparent resistivity"]

data_dict["apparent_resistivity"] = entity.add_data(
{
Expand All @@ -299,8 +278,7 @@ def apply_transformations(self, locations: np.ndarray):
locations = self.displace(locations, self.offset)
if self.radar is not None:
locations = self.drape(locations, self.radar)
if self.is_rotated:
locations = super().rotate(locations)

return locations

def displace(self, locs: np.ndarray, offset: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -503,12 +481,6 @@ def observed_data_types(self):
"""
return self._observed_data_types

def _embed_2d(self, data):
ind = np.ones_like(data, dtype=bool)
ind[self.global_map] = False
data[ind] = np.nan
return data

@staticmethod
def check_tensor(channels):
tensor_components = ["xx", "xy", "xz", "yx", "zx", "yy", "zz", "zy", "yz"]
Expand All @@ -531,5 +503,7 @@ def update_params(self, data_dict, uncert_dict):
setattr(self.params, f"{comp}_uncertainty", uncert_dict[comp])

if getattr(self.params, "line_object", None) is not None:
new_line = self.params.line_object.copy(parent=self.entity)
new_line = self.params.line_object.copy(
parent=self.entity, values=self.params.line_object.values[self.mask]
)
self.params.line_object = new_line
11 changes: 1 addition & 10 deletions geoapps/inversion/components/factories/directives_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,18 +435,9 @@ def assemble_data_keywords_dcip(
"association": "CELL",
}

if sorting is not None and "2d" not in self.factory_type:
if sorting is not None:
kwargs["sorting"] = np.hstack(sorting)

if "2d" in self.factory_type:

def transform_2d(x):
expanded_data = np.array([np.nan] * len(inversion_object.indices))
expanded_data[inversion_object.global_map] = x[sorting]
return expanded_data

kwargs["transforms"].insert(0, transform_2d)

if is_dc and name == "Apparent Resistivity":
kwargs["transforms"].insert(
0,
Expand Down
86 changes: 28 additions & 58 deletions geoapps/inversion/components/factories/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from geoapps.inversion.components.data import InversionData

import numpy as np
from geoh5py.objects import Curve, Grid2D
from geoh5py.objects import CurrentElectrode, Curve, Grid2D, Points, PotentialElectrode

from geoapps.inversion.components.factories.abstract_factory import AbstractFactory

Expand All @@ -35,13 +35,9 @@ def factory_type(self):
def concrete_object(self):
"""Returns a geoh5py object to be constructed by the build method."""
if "current" in self.factory_type or "polarization" in self.factory_type:
from geoh5py.objects import CurrentElectrode, PotentialElectrode

return (PotentialElectrode, CurrentElectrode)
return PotentialElectrode, CurrentElectrode

elif isinstance(self.params.data_object, Grid2D):
from geoh5py.objects import Points

return Points

else:
Expand All @@ -50,48 +46,7 @@ def concrete_object(self):
def build(self, inversion_data: InversionData):
"""Constructs geoh5py object for provided inversion type."""

if "current" in self.factory_type or "polarization" in self.factory_type:
entity = self._build_dcip(inversion_data)
else:
entity = self._build(inversion_data)

return entity

def _build_dcip(self, inversion_data: InversionData):
PotentialElectrode, CurrentElectrode = self.concrete_object
workspace = inversion_data.workspace

# Trim down receivers
rx_obj = self.params.data_object
rcv_ind = np.where(np.any(inversion_data.mask[rx_obj.cells], axis=1))[0]
rcv_locations, rcv_cells = EntityFactory._prune_from_indices(rx_obj, rcv_ind)
uni_src_ids, src_ids = np.unique(
rx_obj.ab_cell_id.values[rcv_ind], return_inverse=True
)
ab_cell_id = np.arange(1, uni_src_ids.shape[0] + 1)[src_ids]
entity = PotentialElectrode.create(
workspace,
name="Data",
parent=self.params.out_group,
vertices=inversion_data.apply_transformations(rcv_locations),
cells=rcv_cells,
)
entity.ab_cell_id = ab_cell_id
# Trim down sources
tx_obj = rx_obj.current_electrodes
src_ind = np.hstack(
[np.where(tx_obj.ab_cell_id.values == ind)[0] for ind in uni_src_ids]
)
src_locations, src_cells = EntityFactory._prune_from_indices(tx_obj, src_ind)
new_currents = CurrentElectrode.create(
workspace,
name="Data (currents)",
parent=self.params.out_group,
vertices=inversion_data.apply_transformations(src_locations),
cells=src_cells,
)
new_currents.add_default_ab_cell_id()
entity.current_electrodes = new_currents
entity = self._build(inversion_data)

return entity

Expand All @@ -102,11 +57,27 @@ def _build(self, inversion_data: InversionData):
)

else:
entity = self.params.data_object.copy(
parent=self.params.out_group,
copy_children=False,
vertices=inversion_data.locations,
)
kwargs = {
"parent": self.params.out_group,
"copy_children": False,
}

if np.any(~inversion_data.mask):
if isinstance(self.params.data_object, PotentialElectrode):
active_poles = np.zeros(
self.params.data_object.n_vertices, dtype=bool
)
active_poles[
self.params.data_object.cells[inversion_data.mask, :].ravel()
] = True
kwargs.update(
{"mask": active_poles, "cell_mask": inversion_data.mask}
)
else:
kwargs.update({"mask": inversion_data.mask})

entity = self.params.data_object.copy(**kwargs)
entity.vertices = inversion_data.apply_transformations(entity.vertices)

if getattr(entity, "transmitters", None) is not None:
entity.transmitters.vertices = inversion_data.apply_transformations(
Expand All @@ -116,11 +87,10 @@ def _build(self, inversion_data: InversionData):
if tx_freq:
tx_freq[0].copy(parent=entity.transmitters)

if np.any(~inversion_data.mask):
entity.remove_vertices(~inversion_data.mask)

if getattr(entity, "transmitters", None) is not None:
entity.transmitters.remove_vertices(~inversion_data.mask)
if getattr(entity, "current_electrodes", None) is not None:
entity.current_electrodes.vertices = inversion_data.apply_transformations(
entity.current_electrodes.vertices
)

return entity

Expand Down
7 changes: 0 additions & 7 deletions geoapps/inversion/components/factories/receiver_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,6 @@ def assemble_arguments(

args = []

if getattr(self.params.mesh, "rotation", None):
locations = rotate_xyz(
locations,
self.params.mesh.origin.tolist(),
-1 * self.params.mesh.rotation[0],
)

if (
"direct current" in self.factory_type
or "induced polarization" in self.factory_type
Expand Down
33 changes: 4 additions & 29 deletions geoapps/inversion/components/factories/survey_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import SimPEG.electromagnetics.time_domain as tdem
from scipy.interpolate import interp1d

from geoapps.utils.surveys import extract_dcip_survey

from .receiver_factory import ReceiversFactory
from .simpeg_factory import SimPEGFactory
from .source_factory import SourcesFactory
Expand Down Expand Up @@ -227,12 +225,7 @@ def _add_data(self, survey, data, local_index, channel):
survey.std = uncertainty_vec

else:
index_map = (
data.global_map[local_index]
if data.global_map is not None
else local_index
)
local_data = {k: v[index_map] for k, v in data.observed.items()}
local_data = {k: v[local_index] for k, v in data.observed.items()}
local_uncertainties = {
k: v[local_index] for k, v in data.uncertainties.items()
}
Expand Down Expand Up @@ -274,31 +267,18 @@ def _dcip_arguments(self, data=None, local_index=None):

receiver_entity = data.entity
if "2d" in self.factory_type:
receiver_entity = extract_dcip_survey(
self.params.geoh5,
receiver_entity,
self.params.line_object.values,
self.params.line_id,
)
self.local_index = np.arange(receiver_entity.n_cells)
data.global_map = [
k for k in receiver_entity.children if k.name == "Global Map"
][0].values

source_ids, order = np.unique(
receiver_entity.ab_cell_id.values[self.local_index], return_index=True
)
currents = receiver_entity.current_electrodes

if "2d" in self.params.inversion_type:
receiver_locations = receiver_entity.vertices
source_locations = currents.vertices
if local_index is not None:
receiver_locations = data.drape_locations(receiver_locations)
source_locations = data.drape_locations(source_locations)

receiver_locations = data.drape_locations(receiver_entity.vertices)
source_locations = data.drape_locations(currents.vertices)
else:
receiver_locations = data.locations
receiver_locations = receiver_entity.vertices
source_locations = currents.vertices

# TODO hook up tile_spatial to handle local_index handling
Expand Down Expand Up @@ -332,11 +312,6 @@ def _dcip_arguments(self, data=None, local_index=None):

self.local_index = np.hstack(self.local_index)

if "2d" in self.factory_type:
current_entity = receiver_entity.current_electrodes
self.params.geoh5.remove_entity(receiver_entity)
self.params.geoh5.remove_entity(current_entity)

return [sources]

def _tdem_arguments(self, data=None, local_index=None, mesh=None):
Expand Down
Loading