Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@ def wrapper(interp, *args, **kwargs):
return wrapper


def check_coords(func):
@wraps(func)
def wrapper(interp, *args, **kwargs):
inputs = args + as_tuple(kwargs.get('expr', ()))
# Subfunction of the SparseFunction use to create the interpolator

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can use some blank lines here

sfunc = interp.sfunction
# Subfunctions found in the arguments of the interpolation/injection operation

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SubFunctions

a_sfuncs = {f for f in retrieve_functions(inputs)
if f.is_SparseFunction} - {sfunc}
if a_sfuncs:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid one nesting level, I suggest to return early and de-indent (ultra-nitpick)

# Check that is uses the same coordinates as the interpolator's SparseFunction

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions}
for f in a_sfuncs:
for s in f._sub_functions:
if getattr(f, s, None) not in subfuncs:
raise ValueError(f"Interpolation/injection with {sfunc}"
f"requires {f} "
f"to use the same {s} as {sfunc}")

return func(interp, *args, **kwargs)
return wrapper


def _extract_subdomain(variables):
"""
Check if any of the variables provided are defined on a SubDomain
Expand Down Expand Up @@ -322,6 +345,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None
return idx_subs, temps

@check_radius
@check_coords
def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -342,6 +366,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None)
return Interpolation(expr, increment, implicit_dims, self_subs, self)

@check_radius
@check_coords
def inject(self, field, expr, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.
Expand Down
12 changes: 3 additions & 9 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,11 +603,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
if i.is_Derived and i.parent in nodes]
toposort = DAG(nodes, edges).topological_sort()

futures = {}
for d in reversed(toposort):
if set(d._arg_names).intersection(kwargs):
futures.update(d._arg_values(self._dspace[d], args={}, **kwargs))

# Prepare to process data-carriers
args = kwargs['args'] = ReducerMap()

Expand Down Expand Up @@ -637,9 +632,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
if k not in args:
args[k] = v
elif k in futures:
# An explicit override is later going to set `args[k]`
pass
elif k in kwargs:
# User is in control
# E.g., given a ConditionalDimension `t_sub` with factor `fact`
Expand All @@ -652,8 +644,10 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
f"`{k}={v}`, while `{k}={args[k]}` is expected. Perhaps "
f"you forgot to override `{p}`?"
)
else:
args[k] = args.unique(k, candidate=v)

args = kwargs['args'] = args.reduce_all()
kwargs['args'] = args.reduce_inplace()

for i in discretizations:
args.update(i._arg_values(**kwargs))
Expand Down
6 changes: 4 additions & 2 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def update(self, values):
else:
self.extend(values)

def unique(self, key):
def unique(self, key, candidate=None):
"""
Returns a unique value for a given key, if such a value
exists, and raises a ``ValueError`` if it does not.
Expand All @@ -150,7 +150,7 @@ def unique(self, key):
Key for which to retrieve a unique value.
"""
candidates = self.getall(key)
candidates = [c for c in candidates if c is not None]
candidates = [c for c in candidates + [candidate] if c is not None]
if not candidates:
return None

Expand Down Expand Up @@ -223,6 +223,8 @@ def reduce_inplace(self):
for k, v in self.reduce_all().items():
self[k] = v

return self


class DefaultOrderedDict(OrderedDict):
# Source: http://stackoverflow.com/a/6190500/562769
Expand Down
8 changes: 5 additions & 3 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
# may represent sets of legal values. If that's the case, here we just
# pick one. Note that we sort for determinism
try:
loc_minv = loc_minv.stop
loc_minv = loc_minv.start
except AttributeError:
with suppress(TypeError):
loc_minv = sorted(loc_minv).pop(0)
try:
loc_maxv = loc_maxv.stop
loc_maxv = loc_maxv.stop - 1
except AttributeError:
with suppress(TypeError):
loc_maxv = sorted(loc_maxv).pop(0)
Expand Down Expand Up @@ -1041,7 +1041,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
raise ValueError(f"Incompatible size for ConditionalDimension "
f"{self.name}: {size} < {size0}")
else:
defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1)
# Given a factor the last time index is factor*(size - 1)
# The maximum allowed value is then factor*size - 1
defaults[dim.parent.max_name] = range(d0, d0 + factor*size)

return defaults

Expand Down
15 changes: 15 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode):
assert data1 == None # noqa
assert data2 == None # noqa
assert data3 == None # noqa


def test_wrong_coords():
grid = Grid(shape=(11, 11))
s = SparseFunction(name='src', npoint=1, grid=grid)
s2 = SparseFunction(name='src2', npoint=1, grid=grid)
u = Function(name='u', grid=grid)

with pytest.raises(ValueError) as vinfo:
s.inject(u, expr=s2)
assert "Interpolation/injection with" in str(vinfo.value)

with pytest.raises(ValueError) as vinfo:
s.interpolate(u + s2)
assert "Interpolation/injection with" in str(vinfo.value)
28 changes: 28 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,34 @@ def test_loose_kwargs(self):
# But the following should work perfectly fine
op.arguments(x_size=2, y_size=2)

@pytest.mark.parametrize('vfact', [1, 3, 4])
def test_apply_args_consitency(self, vfact):
nt = 201
grid = Grid(shape=(11, 11, 11))
time = grid.time_dim

u = TimeFunction(name='u', grid=grid, time_order=2, space_order=4)
rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt)

factor = Constant(name='factor', value=vfact, dtype=np.int32)
time_sub = ConditionalDimension(name='t_sub', parent=time, factor=factor)
usave = TimeFunction(name='usave', grid=grid, space_order=4, time_order=0,
save=nt, time_dim=time_sub)

eqns = [
Eq(u.forward, u + 1),
Eq(usave, u),
] + rec.interpolate(expr=u)

op = Operator(eqns, opt='noop')
args0 = op.arguments(time_m=0, time_M=nt-2)
args1 = op.arguments(time_m=0, time_M=nt-2, rec=rec, usave=usave)

for k, v in args0.items():
assert k in args1
if isinstance(v, int):
assert args1[k] == v


@skipif('device')
class TestDeclarator:
Expand Down
Loading