From 6b904d98214a0d14afa165b4e5832a3806abe8a8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 2 Apr 2025 10:47:26 +0100 Subject: [PATCH 1/3] compiler: Fix Weights reconstruction --- devito/finite_differences/differentiable.py | 4 ++-- devito/passes/clusters/unevaluate.py | 24 +-------------------- devito/symbolics/manipulation.py | 23 +++++++++++++++++--- tests/test_pickle.py | 22 ++++++++++++++++++- 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 39458424d4..a6533c0705 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -746,8 +746,8 @@ def __init_finalize__(self, *args, **kwargs): assert isinstance(weights, (list, tuple, np.ndarray)) # Normalize `weights` - from devito.symbolics import pow_to_mul # noqa, sigh - weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights) + from devito.symbolics import pow_to_mul, unevaluate # noqa, sigh + weights = tuple(unevaluate(pow_to_mul(sympy.sympify(i))) for i in weights) kwargs['scope'] = kwargs.get('scope', 'stack') kwargs['initvalue'] = weights diff --git a/devito/passes/clusters/unevaluate.py b/devito/passes/clusters/unevaluate.py index cec8e9e770..b895e1eafa 100644 --- a/devito/passes/clusters/unevaluate.py +++ b/devito/passes/clusters/unevaluate.py @@ -1,8 +1,5 @@ -import sympy - from devito.ir import cluster_pass -from devito.symbolics import reuse_if_untouched, q_leaf -from devito.symbolics.unevaluation import Add, Mul, Pow +from devito.symbolics import unevaluate as _unevaluate __all__ = ['unevaluate'] @@ -12,22 +9,3 @@ def unevaluate(cluster): exprs = [_unevaluate(e) for e in cluster.exprs] return cluster.rebuild(exprs=exprs) - - -mapper = { - sympy.Add: Add, - sympy.Mul: Mul, - sympy.Pow: Pow -} - - -def _unevaluate(expr): - if q_leaf(expr): - return expr - - args = [_unevaluate(a) for a in expr.args] - - try: - return mapper[expr.func](*args) - except KeyError: - return reuse_if_untouched(expr, args) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index bf795cb86d..6ac0e221b0 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -13,7 +13,9 @@ from devito.symbolics.extended_sympy import DefFunction, rfunc from devito.symbolics.queries import q_leaf from devito.symbolics.search import retrieve_indexed, retrieve_functions -from devito.symbolics.unevaluation import Mul as UMul +from devito.symbolics.unevaluation import ( + Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow +) from devito.tools import as_list, as_tuple, flatten, split, transitive_closure from devito.types.basic import Basic, Indexed from devito.types.array import ComponentAccess @@ -22,7 +24,7 @@ __all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args', 'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite', - 'reuse_if_untouched', 'evalrel', 'flatten_args'] + 'reuse_if_untouched', 'evalrel', 'flatten_args', 'unevaluate'] def uxreplace(expr, rule): @@ -338,7 +340,7 @@ def pow_to_mul(expr): # but at least we traverse the base looking for other Pows return expr.func(pow_to_mul(base), exp, evaluate=False) elif exp > 0: - return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False) + return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False) elif exp < 0: # Reciprocal powers become inverse of the negative power # for example Pow(expr, -2) becomes Pow(expr * expr, -1) @@ -502,3 +504,18 @@ def evalrel(func=min, input=None, assumptions=None): except TypeError: pass return rfunc(func, *input) + + +uneval_mapper = {Add: UnevalAdd, Mul: UnevalMul, Pow: UnevalPow} + + +def unevaluate(expr): + if q_leaf(expr): + return expr + + args = [unevaluate(a) for a in expr.args] + + try: + return uneval_mapper[expr.func](*args) + except KeyError: + return reuse_if_untouched(expr, args) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index df8c3a79b2..0ccab826ad 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -12,6 +12,7 @@ PrecomputedSparseTimeFunction, SubDomain) from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext from devito.data import LEFT, OWNED +from devito.finite_differences.differentiable import Weights from devito.finite_differences.tools import direct, transpose, left, right, centered from devito.mpi.halo_scheme import Halo from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject, @@ -19,7 +20,7 @@ from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, PointerArray, Lock, PThreadArray, SharedData, Timer, DeviceID, NPThreads, ThreadID, TempFunction, Indirection, - FIndexed) + FIndexed, StencilDimension) from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, @@ -416,6 +417,25 @@ def test_findexed(self, pickle): assert new_fi.indices == (x+1, y, z-2) assert new_fi.strides_map == fi.strides_map + def test_weights_to_array(self, pickle): + grid = Grid(shape=(3, 3, 3)) + x, y, z = grid.dimensions + h_x = x.spacing + + i = StencilDimension('i0', 0, 2) + w = Weights(name='w0', dimensions=i, + initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)]) + a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue, + scope='stack') + + pkl_a = pickle.dumps(a) + new_a = pickle.loads(pkl_a) + + # Weights optimizes `initvalue` by turning pows into muls. This test checks + # that the optimization is correctly carried over to the pickled object + # (in practice, the optimized expressions must have been frozen) + assert a.initvalue == new_a.initvalue + def test_symbolics(self, pickle): a = Symbol('a') From 58f1c70654766e38dda594297f886ca2fe638fd6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 2 Apr 2025 10:47:43 +0100 Subject: [PATCH 2/3] compiler: Fix ComponentAccess pickling --- devito/types/array.py | 11 ++++++++--- tests/test_pickle.py | 16 +++++++++++++++- tests/test_symbolics.py | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 4 deletions(-) diff --git a/devito/types/array.py b/devito/types/array.py index 105cdd21da..6a1bc9eb90 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -4,7 +4,7 @@ import numpy as np from sympy import Expr, cacheit -from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p, +from devito.tools import (Pickable, as_tuple, c_restrict_void_p, dtype_to_ctype, dtypes_vector_mapper, is_integer) from devito.types.basic import AbstractFunction, LocalType from devito.types.utils import CtypesFactory, DimensionTuple @@ -518,10 +518,11 @@ def handles(self): return self.components -class ComponentAccess(Expr, Reconstructable): +class ComponentAccess(Expr, Pickable): _component_names = ('x', 'y', 'z', 'w') + __rargs__ = ('arg',) __rkwargs__ = ('index',) def __new__(cls, arg, index=0, **kwargs): @@ -543,7 +544,7 @@ def __str__(self): __repr__ = __str__ - func = Reconstructable._rebuild + func = Pickable._rebuild def _sympystr(self, printer): return str(self) @@ -552,6 +553,10 @@ def _sympystr(self, printer): def base(self): return self.args[0] + @property + def arg(self): + return self.base + @property def index(self): return self._index diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 0ccab826ad..8502a7ca09 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -20,7 +20,7 @@ from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, PointerArray, Lock, PThreadArray, SharedData, Timer, DeviceID, NPThreads, ThreadID, TempFunction, Indirection, - FIndexed, StencilDimension) + FIndexed, ComponentAccess, StencilDimension) from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, @@ -417,6 +417,20 @@ def test_findexed(self, pickle): assert new_fi.indices == (x+1, y, z-2) assert new_fi.strides_map == fi.strides_map + def test_component_access(self, pickle): + grid = Grid(shape=(3, 3, 3)) + x, y, z = grid.dimensions + + f = Function(name='f', grid=grid) + + ca = ComponentAccess(f.indexify(), 1) + + pkl_ca = pickle.dumps(ca) + new_ca = pickle.loads(pkl_ca) + + assert new_ca.index == 1 + assert new_ca.function.name == f.name + def test_weights_to_array(self, pickle): grid = Grid(shape=(3, 3, 3)) x, y, z = grid.dimensions diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 977de57f2b..7754975b25 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -483,6 +483,24 @@ def test_findexed(): assert new_fi.strides_map == strides_map +def test_component_access(): + grid = Grid(shape=(3, 3, 3)) + x, y, z = grid.dimensions + + f = Function(name='f', grid=grid) + + cf0 = ComponentAccess(f.indexify(), 0) + cf1 = ComponentAccess(f.indexify(), 1) + + assert ccode(cf0) == 'f[x][y][z].x' + assert ccode(cf1) == 'f[x][y][z].y' + + # Reconstruction + cf2 = cf1.func(*cf1.args) + assert cf2.index == cf1.index + assert cf2 == cf1 + + def test_canonical_ordering_of_weights(): grid = Grid(shape=(3, 3, 3)) x, y, z = grid.dimensions From cb43942ea4706e754af11d10f10a95a3ffa2846d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 2 Apr 2025 12:18:19 +0100 Subject: [PATCH 3/3] compiler: Postpone weights unevaluation --- devito/finite_differences/differentiable.py | 4 ++-- devito/passes/iet/definitions.py | 4 ++-- tests/test_pickle.py | 24 ++------------------- 3 files changed, 6 insertions(+), 26 deletions(-) diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index a6533c0705..39458424d4 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -746,8 +746,8 @@ def __init_finalize__(self, *args, **kwargs): assert isinstance(weights, (list, tuple, np.ndarray)) # Normalize `weights` - from devito.symbolics import pow_to_mul, unevaluate # noqa, sigh - weights = tuple(unevaluate(pow_to_mul(sympy.sympify(i))) for i in weights) + from devito.symbolics import pow_to_mul # noqa, sigh + weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights) kwargs['scope'] = kwargs.get('scope', 'stack') kwargs['initvalue'] = weights diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 3cc5446d57..57f7fca54f 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -16,7 +16,7 @@ from devito.passes.iet.engine import iet_pass from devito.passes.iet.langbase import LangBB from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer, - SizeOf, VOID, pow_to_mul) + SizeOf, VOID, pow_to_mul, unevaluate) from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap, DeviceRM, Eq, Symbol) @@ -119,7 +119,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage): # Create input array name = '%s_init' % obj.name - initvalue = np.array([pow_to_mul(i) for i in obj.initvalue]) + initvalue = np.array([unevaluate(pow_to_mul(i)) for i in obj.initvalue]) src = Array(name=name, dtype=obj.dtype, dimensions=obj.dimensions, space='host', scope='stack', initvalue=initvalue) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 8502a7ca09..833141d09e 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -1,7 +1,7 @@ import ctypes import pickle as pickle0 -import cloudpickle as pickle1 +import cloudpickle as pickle1 import pytest import numpy as np from sympy import Symbol @@ -12,7 +12,6 @@ PrecomputedSparseTimeFunction, SubDomain) from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext from devito.data import LEFT, OWNED -from devito.finite_differences.differentiable import Weights from devito.finite_differences.tools import direct, transpose, left, right, centered from devito.mpi.halo_scheme import Halo from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject, @@ -20,7 +19,7 @@ from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, PointerArray, Lock, PThreadArray, SharedData, Timer, DeviceID, NPThreads, ThreadID, TempFunction, Indirection, - FIndexed, ComponentAccess, StencilDimension) + FIndexed, ComponentAccess) from devito.types.basic import BoundSymbol, AbstractSymbol from devito.tools import EnrichedTuple from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, @@ -431,25 +430,6 @@ def test_component_access(self, pickle): assert new_ca.index == 1 assert new_ca.function.name == f.name - def test_weights_to_array(self, pickle): - grid = Grid(shape=(3, 3, 3)) - x, y, z = grid.dimensions - h_x = x.spacing - - i = StencilDimension('i0', 0, 2) - w = Weights(name='w0', dimensions=i, - initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)]) - a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue, - scope='stack') - - pkl_a = pickle.dumps(a) - new_a = pickle.loads(pkl_a) - - # Weights optimizes `initvalue` by turning pows into muls. This test checks - # that the optimization is correctly carried over to the pickled object - # (in practice, the optimized expressions must have been frozen) - assert a.initvalue == new_a.initvalue - def test_symbolics(self, pickle): a = Symbol('a')