Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions simpeg/maps/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,23 +1140,28 @@ def __init__(self, *args):
and isinstance(arg[0], str)
and
# TODO: this should be extended to a slice.
isinstance(arg[1], (int, np.integer))
isinstance(arg[1], (int, np.integer, Projection))
), (
"Each wire needs to be a tuple: (name, length). "
"Each wire needs to be a tuple: (name, length) or (name, Projection). "
"You provided: {}".format(arg)
)

self._nP = int(np.sum([w[1] for w in args]))
start = 0
maps = []
for arg in args:
wire = Projection(self.nP, slice(start, start + arg[1]))

if isinstance(arg[1], (int, np.integer)):
wire = Projection(self.nP, slice(start, start + arg[1]))
start += arg[1]
else:
wire = arg[1]

setattr(self, arg[0], wire)
maps += [(arg[0], wire)]
start += arg[1]
self.maps = maps

self._tuple = namedtuple("Model", [w[0] for w in args])
self._nP = maps[0][1].nP
self._tuple = namedtuple("Model", [name for name, _ in args])
self._projection = sp.vstack([wire.P for _, wire in self.maps])

def __mul__(self, val):
assert isinstance(val, np.ndarray)
Expand All @@ -1176,6 +1181,22 @@ def nP(self):
"""
return self._nP

def deriv(self, m):
"""
Derivative of the mapping with respect to the input parameters

Parameters
----------
m : (n_param, ) numpy.ndarray
The model for which the gradient is evaluated.

Returns
-------
(n_param, ) numpy.ndarray
The Gradient of the mapping function evaluated for the model provided.
"""
return self._projection


class TileMap(IdentityMap):
"""
Expand Down