Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,34 @@ def wrapper(interp, *args, **kwargs):
return wrapper


def check_coords(func):
@wraps(func)
def wrapper(interp, *args, **kwargs):
inputs = args + as_tuple(kwargs.get('expr', ()))

# SubFunction of the SparseFunction use to create the interpolator
sfunc = interp.sfunction

# SubFunctions found in the arguments of the interpolation/injection operation
a_sfuncs = {f for f in retrieve_functions(inputs)
if f.is_SparseFunction} - {sfunc}
if not a_sfuncs:
# Only uses the the interpolator's SparseFunction, so no need to check
return func(interp, *args, **kwargs)

# Check that it uses the same coordinates as the interpolator's SparseFunction
subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions}
for f in a_sfuncs:
for s in f._sub_functions:
if getattr(f, s, None) not in subfuncs:
raise ValueError(f"Interpolation/injection with {sfunc}"
f"requires {f} "
f"to use the same {s} as {sfunc}")

return func(interp, *args, **kwargs)
return wrapper


def _extract_subdomain(variables):
"""
Check if any of the variables provided are defined on a SubDomain
Expand Down Expand Up @@ -322,6 +350,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
return idx_subs, temps

@check_radius
@check_coords
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -342,6 +371,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
return Interpolation(expr, increment, implicit_dims, self_subs, self)

@check_radius
@check_coords
def inject(self, field, expr, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.
Expand Down
13 changes: 4 additions & 9 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
if i.is_Derived and i.parent in nodes]
toposort = DAG(nodes, edges).topological_sort()

futures = {}
for d in reversed(toposort):
if set(d._arg_names).intersection(kwargs):
futures.update(d._arg_values(self._dspace[d], args={}, **kwargs))

# Prepare to process data-carriers
args = kwargs['args'] = ReducerMap()

Expand Down Expand Up @@ -637,9 +632,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
if k not in args:
args[k] = v
elif k in futures:
# An explicit override is later going to set `args[k]`
pass
elif k in kwargs:
# User is in control
# E.g., given a ConditionalDimension `t_sub` with factor `fact`
Expand All @@ -652,8 +644,11 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
f"`{k}={v}`, while `{k}={args[k]}` is expected. Perhaps "
f"you forgot to override `{p}`?"
)
else:
args[k] = args.unique(k, candidate=v)

args = kwargs['args'] = args.reduce_all()
args.reduce_inplace()
kwargs['args'] = args

for i in discretizations:
args.update(i._arg_values(**kwargs))
Expand Down
4 changes: 2 additions & 2 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def update(self, values):
else:
self.extend(values)

def unique(self, key):
def unique(self, key, candidate=None):
"""
Returns a unique value for a given key, if such a value
exists, and raises a ``ValueError`` if it does not.
Expand All @@ -150,7 +150,7 @@ def unique(self, key):
Key for which to retrieve a unique value.
"""
candidates = self.getall(key)
candidates = [c for c in candidates if c is not None]
candidates = [c for c in candidates + [candidate] if c is not None]
if not candidates:
return None

Expand Down
8 changes: 5 additions & 3 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
# may represent sets of legal values. If that's the case, here we just
# pick one. Note that we sort for determinism
try:
loc_minv = loc_minv.stop
loc_minv = loc_minv.start
except AttributeError:
with suppress(TypeError):
loc_minv = sorted(loc_minv).pop(0)
try:
loc_maxv = loc_maxv.stop
loc_maxv = loc_maxv.stop - 1
except AttributeError:
with suppress(TypeError):
loc_maxv = sorted(loc_maxv).pop(0)
Expand Down Expand Up @@ -1041,7 +1041,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
raise ValueError(f"Incompatible size for ConditionalDimension "
f"{self.name}: {size} < {size0}")
else:
defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1)
# Given a factor the last time index is factor*(size - 1)
# The maximum allowed value is then factor*size - 1
defaults[dim.parent.max_name] = range(d0, d0 + factor*size)

return defaults

Expand Down
15 changes: 15 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode):
assert data1 == None # noqa
assert data2 == None # noqa
assert data3 == None # noqa


def test_wrong_coords():
grid = Grid(shape=(11, 11))
s = SparseFunction(name='src', npoint=1, grid=grid)
s2 = SparseFunction(name='src2', npoint=1, grid=grid)
u = Function(name='u', grid=grid)

with pytest.raises(ValueError) as vinfo:
s.inject(u, expr=s2)
assert "Interpolation/injection with" in str(vinfo.value)

with pytest.raises(ValueError) as vinfo:
s.interpolate(u + s2)
assert "Interpolation/injection with" in str(vinfo.value)
28 changes: 28 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,34 @@ def test_loose_kwargs(self):
# But the following should work perfectly fine
op.arguments(x_size=2, y_size=2)

@pytest.mark.parametrize('vfact', [1, 3, 4])
def test_apply_args_consitency(self, vfact):
nt = 201
grid = Grid(shape=(11, 11, 11))
time = grid.time_dim

u = TimeFunction(name='u', grid=grid, time_order=2, space_order=4)
rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt)

factor = Constant(name='factor', value=vfact, dtype=np.int32)
time_sub = ConditionalDimension(name='t_sub', parent=time, factor=factor)
usave = TimeFunction(name='usave', grid=grid, space_order=4, time_order=0,
save=nt, time_dim=time_sub)

eqns = [
Eq(u.forward, u + 1),
Eq(usave, u),
] + rec.interpolate(expr=u)

op = Operator(eqns, opt='noop')
args0 = op.arguments(time_m=0, time_M=nt-2)
args1 = op.arguments(time_m=0, time_M=nt-2, rec=rec, usave=usave)

for k, v in args0.items():
assert k in args1
if isinstance(v, int):
assert args1[k] == v


@skipif('device')
class TestDeclarator:
Expand Down
Loading