From 6ded0a0e3d6e8dbabce057943ca000e13b6ba59c Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 14 Apr 2026 07:40:26 -0400 Subject: [PATCH 1/4] compiler: fix operator arg processing and subsampling size --- devito/operator/operator.py | 10 +--------- devito/tools/data_structures.py | 4 ++-- devito/types/dimension.py | 3 ++- tests/test_operator.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 5ee6e5db35..4e9d643ea1 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -603,11 +603,6 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs): if i.is_Derived and i.parent in nodes] toposort = DAG(nodes, edges).topological_sort() - futures = {} - for d in reversed(toposort): - if set(d._arg_names).intersection(kwargs): - futures.update(d._arg_values(self._dspace[d], args={}, **kwargs)) - # Prepare to process data-carriers args = kwargs['args'] = ReducerMap() @@ -637,15 +632,12 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs): for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items(): if k not in args: args[k] = v - elif k in futures: - # An explicit override is later going to set `args[k]` - pass elif k in kwargs: # User is in control # E.g., given a ConditionalDimension `t_sub` with factor `fact` # and a TimeFunction `usave(t_sub, x, y)`, an override for # `fact` is supplied w/o overriding `usave`; that's legal - pass + args[k] = args.unique(k, candidate=v) elif is_integer(args[k]) and not contains_val(args[k], v): raise InvalidArgument( f"Default `{p}` is incompatible with other args as " diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index e1007f8205..d875878d02 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -139,7 +139,7 @@ def update(self, values): else: self.extend(values) - def unique(self, key): + def unique(self, key, candidate=None): """ Returns a unique value for a given key, if such a value exists, and raises a ``ValueError`` if it does not. @@ -150,7 +150,7 @@ def unique(self, key): Key for which to retrieve a unique value. """ candidates = self.getall(key) - candidates = [c for c in candidates if c is not None] + candidates = [c for c in candidates + [candidate] if c is not None] if not candidates: return None diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 4010472bef..aec0a90e8c 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -1041,7 +1041,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None): raise ValueError(f"Incompatible size for ConditionalDimension " f"{self.name}: {size} < {size0}") else: - defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1) + # Given a factor the last time index is factor*(size - 1) + defaults[dim.parent.max_name] = range(d0, d0 + factor*(size - 1) + 1) return defaults diff --git a/tests/test_operator.py b/tests/test_operator.py index 2e98a7603d..e8d8b55bed 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1315,6 +1315,34 @@ def test_loose_kwargs(self): # But the following should work perfectly fine op.arguments(x_size=2, y_size=2) + @pytest.mark.parametrize('vfact', [1, 3, 4]) + def test_apply_args_consitency(self, vfact): + nt = 201 + grid = Grid(shape=(11, 11, 11)) + time = grid.time_dim + + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=4) + rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt) + + factor = Constant(name='factor', value=vfact, dtype=np.int32) + time_sub = ConditionalDimension(name='t_sub', parent=time, factor=factor) + usave = TimeFunction(name='usave', grid=grid, space_order=4, time_order=0, + save=nt, time_dim=time_sub) + + eqns = [ + Eq(u.forward, u + 1), + Eq(usave, u), + ] + rec.interpolate(expr=u) + + op = Operator(eqns, opt='noop') + args0 = op.arguments(time_m=0, time_M=nt-2) + args1 = op.arguments(time_m=0, time_M=nt-2, rec=rec, usave=usave) + + for k, v in args0.items(): + assert k in args1 + if isinstance(v, int): + assert args1[k] == v + @skipif('device') class TestDeclarator: From 56d462614cfa8780be3bc13a08c88ce30e22bcf9 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 14 Apr 2026 09:30:13 -0400 Subject: [PATCH 2/4] compiler: prevent premature reducermap->dict conversion --- devito/operator/operator.py | 7 ++++--- devito/tools/data_structures.py | 2 ++ devito/types/dimension.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 4e9d643ea1..fb726cce5a 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -637,15 +637,16 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs): # E.g., given a ConditionalDimension `t_sub` with factor `fact` # and a TimeFunction `usave(t_sub, x, y)`, an override for # `fact` is supplied w/o overriding `usave`; that's legal - args[k] = args.unique(k, candidate=v) - elif is_integer(args[k]) and not contains_val(args[k], v): + continue + if is_integer(args[k]) and not contains_val(args[k], v): raise InvalidArgument( f"Default `{p}` is incompatible with other args as " f"`{k}={v}`, while `{k}={args[k]}` is expected. Perhaps " f"you forgot to override `{p}`?" ) + args[k] = args.unique(k, candidate=v) - args = kwargs['args'] = args.reduce_all() + kwargs['args'] = args.reduce_inplace() for i in discretizations: args.update(i._arg_values(**kwargs)) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index d875878d02..755d6de7cc 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -223,6 +223,8 @@ def reduce_inplace(self): for k, v in self.reduce_all().items(): self[k] = v + return self + class DefaultOrderedDict(OrderedDict): # Source: http://stackoverflow.com/a/6190500/562769 diff --git a/devito/types/dimension.py b/devito/types/dimension.py index aec0a90e8c..2f22765146 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -1042,7 +1042,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None): f"{self.name}: {size} < {size0}") else: # Given a factor the last time index is factor*(size - 1) - defaults[dim.parent.max_name] = range(d0, d0 + factor*(size - 1) + 1) + # The maximum allowed value is then factor*size - 1 + defaults[dim.parent.max_name] = range(d0, d0 + factor*size) return defaults From 17458063fa125fb3b7dc348fbaf6bf55b1c5dc58 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 14 Apr 2026 10:41:58 -0400 Subject: [PATCH 3/4] api: fix bounds from range --- devito/operator/operator.py | 7 ++++--- devito/types/dimension.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index fb726cce5a..5890d24f49 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -637,14 +637,15 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs): # E.g., given a ConditionalDimension `t_sub` with factor `fact` # and a TimeFunction `usave(t_sub, x, y)`, an override for # `fact` is supplied w/o overriding `usave`; that's legal - continue - if is_integer(args[k]) and not contains_val(args[k], v): + pass + elif is_integer(args[k]) and not contains_val(args[k], v): raise InvalidArgument( f"Default `{p}` is incompatible with other args as " f"`{k}={v}`, while `{k}={args[k]}` is expected. Perhaps " f"you forgot to override `{p}`?" ) - args[k] = args.unique(k, candidate=v) + else: + args[k] = args.unique(k, candidate=v) kwargs['args'] = args.reduce_inplace() diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 2f22765146..41c1f90b88 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -318,12 +318,12 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs): # may represent sets of legal values. If that's the case, here we just # pick one. Note that we sort for determinism try: - loc_minv = loc_minv.stop + loc_minv = loc_minv.start except AttributeError: with suppress(TypeError): loc_minv = sorted(loc_minv).pop(0) try: - loc_maxv = loc_maxv.stop + loc_maxv = loc_maxv.stop - 1 except AttributeError: with suppress(TypeError): loc_maxv = sorted(loc_maxv).pop(0) From ea5f763049f0fff421fa734b51372edd84ffbfc3 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 14 Apr 2026 11:41:21 -0400 Subject: [PATCH 4/4] api: enforce valid coordinates for inject/interp --- devito/operations/interpolators.py | 30 ++++++++++++++++++++++++++++++ devito/operator/operator.py | 3 ++- devito/tools/data_structures.py | 2 -- tests/test_interpolation.py | 15 +++++++++++++++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 1ac838c050..160b33fdee 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -34,6 +34,34 @@ def wrapper(interp, *args, **kwargs): return wrapper +def check_coords(func): + @wraps(func) + def wrapper(interp, *args, **kwargs): + inputs = args + as_tuple(kwargs.get('expr', ())) + + # SubFunction of the SparseFunction use to create the interpolator + sfunc = interp.sfunction + + # SubFunctions found in the arguments of the interpolation/injection operation + a_sfuncs = {f for f in retrieve_functions(inputs) + if f.is_SparseFunction} - {sfunc} + if not a_sfuncs: + # Only uses the the interpolator's SparseFunction, so no need to check + return func(interp, *args, **kwargs) + + # Check that it uses the same coordinates as the interpolator's SparseFunction + subfuncs = {getattr(sfunc, s, None) for s in sfunc._sub_functions} + for f in a_sfuncs: + for s in f._sub_functions: + if getattr(f, s, None) not in subfuncs: + raise ValueError(f"Interpolation/injection with {sfunc}" + f"requires {f} " + f"to use the same {s} as {sfunc}") + + return func(interp, *args, **kwargs) + return wrapper + + def _extract_subdomain(variables): """ Check if any of the variables provided are defined on a SubDomain @@ -322,6 +350,7 @@ def _interp_idx(self, variables, implicit_dims=None, pos_only=(), subdomain=None return idx_subs, temps @check_radius + @check_coords def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None): """ Generate equations interpolating an arbitrary expression into ``self``. @@ -342,6 +371,7 @@ def interpolate(self, expr, increment=False, self_subs=None, implicit_dims=None) return Interpolation(expr, increment, implicit_dims, self_subs, self) @check_radius + @check_coords def inject(self, field, expr, implicit_dims=None): """ Generate equations injecting an arbitrary expression into a field. diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 5890d24f49..58fab20495 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -647,7 +647,8 @@ def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs): else: args[k] = args.unique(k, candidate=v) - kwargs['args'] = args.reduce_inplace() + args.reduce_inplace() + kwargs['args'] = args for i in discretizations: args.update(i._arg_values(**kwargs)) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index 755d6de7cc..d875878d02 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -223,8 +223,6 @@ def reduce_inplace(self): for k, v in self.reduce_all().items(): self[k] = v - return self - class DefaultOrderedDict(OrderedDict): # Source: http://stackoverflow.com/a/6190500/562769 diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 8f59ca0076..7c605a599b 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -1255,3 +1255,18 @@ def test_inject_subdomain_mpi(self, mode): assert data1 == None # noqa assert data2 == None # noqa assert data3 == None # noqa + + +def test_wrong_coords(): + grid = Grid(shape=(11, 11)) + s = SparseFunction(name='src', npoint=1, grid=grid) + s2 = SparseFunction(name='src2', npoint=1, grid=grid) + u = Function(name='u', grid=grid) + + with pytest.raises(ValueError) as vinfo: + s.inject(u, expr=s2) + assert "Interpolation/injection with" in str(vinfo.value) + + with pytest.raises(ValueError) as vinfo: + s.interpolate(u + s2) + assert "Interpolation/injection with" in str(vinfo.value)