From f1af25b8d38d525ea00cf9f5f5164f93ede7f4e9 Mon Sep 17 00:00:00 2001 From: Vincenzo Pandolfo Date: Mon, 1 Aug 2016 12:53:51 +0100 Subject: [PATCH 1/3] Optimized time_substitution method --- devito/propagator.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/devito/propagator.py b/devito/propagator.py index 3c325608cd..65d42e2c3a 100644 --- a/devito/propagator.py +++ b/devito/propagator.py @@ -553,13 +553,7 @@ def get_time_stepping(self): return body - def time_substitutions(self, sympy_expr): - """This method checks through the sympy_expr to replace the time index with a cyclic index - but only for variables which are not being saved in the time domain - - :param sympy_expr: The Sympy expression to process - :returns: The expression after the substitutions - """ + def _time_substitutions(self, sympy_expr, subs_dict): if isinstance(sympy_expr, Indexed): array_term = sympy_expr @@ -570,12 +564,24 @@ def time_substitutions(self, sympy_expr): if not self.save_vars[str(array_term.base.label)]: array_term = array_term.xreplace(self.t_replace) - return array_term + return (subs_dict, array_term) else: for arg in sympy_expr.args: - sympy_expr = sympy_expr.subs(arg, self.time_substitutions(arg)) + subs_dict, value = self._time_substitutions(arg, subs_dict) + subs_dict[arg] = value + + return (subs_dict, sympy_expr) + + def time_substitutions(self, sympy_expr): + """This method checks through the sympy_expr to replace the time index with a cyclic index + but only for variables which are not being saved in the time domain + + :param sympy_expr: The Sympy expression to process + :returns: The expression after the substitutions + """ + subs_dict, sympy_expr = self._time_substitutions(sympy_expr, {}) - return sympy_expr + return sympy_expr.subs(subs_dict, simultaneous=True) def add_time_loop_stencil(self, stencil, before=False): """Add a statement either before or after the main spatial loop, but still inside the time loop. From abcaeaa4d300c8ba6753c65fcc71b91187830928 Mon Sep 17 00:00:00 2001 From: Vincenzo Pandolfo Date: Mon, 1 Aug 2016 13:36:57 +0100 Subject: [PATCH 2/3] Changed time_substitution to use postorder_traversal --- devito/propagator.py | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/devito/propagator.py b/devito/propagator.py index 65d42e2c3a..876238e870 100644 --- a/devito/propagator.py +++ b/devito/propagator.py @@ -6,6 +6,7 @@ import numpy as np from sympy import Indexed, IndexedBase, symbols from sympy.abc import t, x, y, z +from sympy.utilities.iterables import postorder_traversal import cgen_wrapper as cgen from codeprinter import ccode @@ -553,25 +554,6 @@ def get_time_stepping(self): return body - def _time_substitutions(self, sympy_expr, subs_dict): - if isinstance(sympy_expr, Indexed): - array_term = sympy_expr - - if not str(array_term.base.label) in self.save_vars: - raise ValueError("Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" % - str(array_term.base.label)) - - if not self.save_vars[str(array_term.base.label)]: - array_term = array_term.xreplace(self.t_replace) - - return (subs_dict, array_term) - else: - for arg in sympy_expr.args: - subs_dict, value = self._time_substitutions(arg, subs_dict) - subs_dict[arg] = value - - return (subs_dict, sympy_expr) - def time_substitutions(self, sympy_expr): """This method checks through the sympy_expr to replace the time index with a cyclic index but only for variables which are not being saved in the time domain @@ -579,7 +561,19 @@ def time_substitutions(self, sympy_expr): :param sympy_expr: The Sympy expression to process :returns: The expression after the substitutions """ - subs_dict, sympy_expr = self._time_substitutions(sympy_expr, {}) + subs_dict = {} + + for arg in postorder_traversal(sympy_expr): + if isinstance(arg, Indexed): + array_term = arg + + if not str(array_term.base.label) in self.save_vars: + raise ValueError("Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" % str(array_term.base.label)) + + if not self.save_vars[str(array_term.base.label)]: + array_term = array_term.xreplace(self.t_replace) + + subs_dict[arg] = array_term return sympy_expr.subs(subs_dict, simultaneous=True) From ca2612ada25458f63e935e6d84615a124cef2e32 Mon Sep 17 00:00:00 2001 From: Vincenzo Pandolfo Date: Mon, 1 Aug 2016 15:33:58 +0100 Subject: [PATCH 3/3] Using xreplace instead of subs --- devito/propagator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/devito/propagator.py b/devito/propagator.py index 876238e870..d7d7acdda9 100644 --- a/devito/propagator.py +++ b/devito/propagator.py @@ -568,14 +568,15 @@ def time_substitutions(self, sympy_expr): array_term = arg if not str(array_term.base.label) in self.save_vars: - raise ValueError("Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" % str(array_term.base.label)) + raise ValueError( + "Invalid variable '%s' in sympy expression. Did you add it to the operator's params?" + % str(array_term.base.label) + ) if not self.save_vars[str(array_term.base.label)]: - array_term = array_term.xreplace(self.t_replace) + subs_dict[arg] = array_term.xreplace(self.t_replace) - subs_dict[arg] = array_term - - return sympy_expr.subs(subs_dict, simultaneous=True) + return sympy_expr.xreplace(subs_dict) def add_time_loop_stencil(self, stencil, before=False): """Add a statement either before or after the main spatial loop, but still inside the time loop.