From d36342b61f0fbe1836dda6d47ef59defdd033ab1 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 6 Jun 2025 14:46:03 +0100 Subject: [PATCH 01/14] compiler: Pull some aspects of DAG construction out and cache --- devito/ir/support/basic.py | 4 ++++ devito/passes/clusters/misc.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 1b680b13a9..52e883e743 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -834,6 +834,10 @@ def __init__(self, exprs, rules=None): self.rules = as_tuple(rules) assert all(callable(i) for i in self.rules) + @cached_property + def thingy(self): + return any(i.cause for i in self.d_anti_gen()) + @memoized_generator def writes_gen(self): """ diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index d81f2d93bf..9a88114c2b 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -362,7 +362,9 @@ def is_cross(source, sink): # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - if any(i.cause & prefix for i in scope.d_anti_gen()): + # FIXME: Attach to the scope + # if any(i.cause & prefix for i in scope.d_anti_gen()): + if prefix and scope.thingy: for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) for cg2 in cgroups[cgroups.index(cg1)+1:]: From 90ba30ea64146b398c5166e2d659ba5213a5094e Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 6 Jun 2025 16:43:33 +0100 Subject: [PATCH 02/14] compiler: Tweaks to CTemp purging in CSE --- devito/ir/clusters/algorithms.py | 4 ++++ devito/ir/support/basic.py | 2 ++ devito/operator/operator.py | 6 ++++-- devito/passes/clusters/cse.py | 32 ++++++++++++++++++++++++++++++-- devito/passes/clusters/misc.py | 3 +-- devito/types/basic.py | 1 + 6 files changed, 42 insertions(+), 6 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 5323539dd4..6bbecaaf0d 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -122,6 +122,7 @@ class Schedule(QueueStateful): @timed_pass(name='schedule') def process(self, clusters): + # from IPython import embed; embed() return self._process_fatd(clusters, 1) def callback(self, clusters, prefix, backlog=None, known_break=None): @@ -156,6 +157,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # Schedule Clusters over different IterationSpaces if this increases # parallelism for i in range(1, len(clusters)): + # FIXME: This eats a lot of time (four seconds each time) if self._break_for_parallelism(scope, candidates, i): return self.callback(clusters[:i], prefix, clusters[i:] + backlog, candidates | known_break) @@ -191,6 +193,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): def _break_for_parallelism(self, scope, candidates, i): # `test` will be True if there's at least one data-dependence that would # break parallelism + + # TODO: Can this loop be made to short-circuit? test = False for d in scope.d_from_access_gen(scope.a_query(i)): if d.is_local or d.is_storage_related(candidates): diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 52e883e743..a82b0120f4 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -363,6 +363,7 @@ def distance(self, other): # Case 1: `sit` is an IterationInterval with statically known # trip count. E.g. it ranges from 0 to 3; `other` performs a # constant access at 4 + # TODO: This case represents the majority of time constructing a DAG for v in (self[n], other[n]): try: if bool(v < sit.symbolic_min or v > sit.symbolic_max): @@ -834,6 +835,7 @@ def __init__(self, exprs, rules=None): self.rules = as_tuple(rules) assert all(callable(i) for i in self.rules) + # FIXME: Should be put somewhere sensible @cached_property def thingy(self): return any(i.cause for i in self.d_anti_gen()) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 21a216dace..acc4133078 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -967,8 +967,10 @@ def _emit_build_profiling(self): tot = timings.pop('op-compile') perf(f"Operator `{self.name}` generated in {fround(tot):.2f} s") - max_hotspots = 3 - threshold = 20. + # max_hotspots = 3 + # threshold = 20. + max_hotspots = 300 + threshold = 0.5 def _emit_timings(timings, indent=''): timings.pop('total', None) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 705bd5423f..55f7b923de 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -25,6 +25,8 @@ class CTemp(Temp): """ A cluster-level Temp, similar to Temp, ensured to have different priority """ + is_CTemp = True + ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp') @@ -220,13 +222,39 @@ def _compact(exprs, exclude): `for (i = ...) { a = b; for (j = a ...) ... }`. Hence, this routine only targets CTemps. """ + # FIXME: Can use is_CTemp rather than isinstance candidates = [e for e in exprs if isinstance(e.lhs, CTemp) and e.lhs not in exclude] mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)} - mapper.update({e.lhs: e.rhs for e in candidates - if sum([i.rhs.count(e.lhs) for i in exprs]) == 1}) + # FIXME: Move this to searches as retrieve_ctemps + from devito.symbolics.search import search + + def q_ctemp(expr): + try: + return expr.is_CTemp + except AttributeError: + return False + + # Find all the CTemps in expressions without removing duplicates + # ctemps = search(exprs, q_ctemp, 'all', 'dfs') + # I think it was more like + ctemps = search([e.rhs for e in exprs], q_ctemp, 'all', 'dfs') + + # print(ctemps, len(ctemps), len(set(ctemps)), len(candidates)) + + # FIXME: This line is kinda slow. I should find some way to replace it. + # FIXME: Specifically sum([i.rhs.count(e.lhs) for i in exprs]) == 1 is slow as hell + # mapper.update({e.lhs: e.rhs for e in candidates + # if sum([i.rhs.count(e.lhs) for i in exprs]) == 1}) + + # If there are ctemps in the expressions, then add any to the mapper which only + # appear once + # TODO: Double check this is exactly the prior behaviour? + if ctemps: + mapper.update({e.lhs: e.rhs for e in candidates + if ctemps.count(e.lhs) == 1}) processed = [] for e in exprs: diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 9a88114c2b..336c5a6161 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -362,8 +362,7 @@ def is_cross(source, sink): # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - # FIXME: Attach to the scope - # if any(i.cause & prefix for i in scope.d_anti_gen()): + # FIXME: This is a terrible variable name if prefix and scope.thingy: for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) diff --git a/devito/types/basic.py b/devito/types/basic.py index 8c7e960fb2..76067fc287 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -298,6 +298,7 @@ class Basic(CodeSymbol): is_Object = False is_LocalObject = False is_LocalType = False + is_CTemp = False # Created by the user is_Input = False From d165087b4406a44d594621397041cd58d3e0ef5e Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 9 Jun 2025 13:59:10 +0100 Subject: [PATCH 03/14] compiler: Tidy up new searches and functionality tweaks --- devito/ir/clusters/algorithms.py | 1 - devito/ir/support/basic.py | 9 ++++----- devito/passes/clusters/cse.py | 26 ++++---------------------- devito/passes/clusters/misc.py | 4 ++-- devito/symbolics/queries.py | 7 +++++++ devito/symbolics/search.py | 9 +++++++-- 6 files changed, 24 insertions(+), 32 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 6bbecaaf0d..53d126accc 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -122,7 +122,6 @@ class Schedule(QueueStateful): @timed_pass(name='schedule') def process(self, clusters): - # from IPython import embed; embed() return self._process_fatd(clusters, 1) def callback(self, clusters, prefix, backlog=None, known_break=None): diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index a82b0120f4..6801e1d0e4 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -835,11 +835,6 @@ def __init__(self, exprs, rules=None): self.rules = as_tuple(rules) assert all(callable(i) for i in self.rules) - # FIXME: Should be put somewhere sensible - @cached_property - def thingy(self): - return any(i.cause for i in self.d_anti_gen()) - @memoized_generator def writes_gen(self): """ @@ -1132,6 +1127,10 @@ def d_anti(self): """Anti (or "write-after-read") dependences.""" return DependenceGroup(self.d_anti_gen()) + @cached_property + def has_antidependencies(self): + return any(i.cause for i in self.d_anti_gen()) + @memoized_generator def d_output_gen(self): """Generate the output (or "write-after-write") dependences.""" diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 55f7b923de..046585654d 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -13,6 +13,7 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass from devito.symbolics import estimate_cost, q_leaf, q_terminal +from devito.symbolics.search import retrieve_ctemps from devito.symbolics.manipulation import _uxreplace from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype from devito.types import Eq, Symbol, Temp @@ -222,36 +223,17 @@ def _compact(exprs, exclude): `for (i = ...) { a = b; for (j = a ...) ... }`. Hence, this routine only targets CTemps. """ - # FIXME: Can use is_CTemp rather than isinstance candidates = [e for e in exprs if isinstance(e.lhs, CTemp) and e.lhs not in exclude] mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)} - # FIXME: Move this to searches as retrieve_ctemps - from devito.symbolics.search import search - - def q_ctemp(expr): - try: - return expr.is_CTemp - except AttributeError: - return False - - # Find all the CTemps in expressions without removing duplicates - # ctemps = search(exprs, q_ctemp, 'all', 'dfs') - # I think it was more like - ctemps = search([e.rhs for e in exprs], q_ctemp, 'all', 'dfs') - - # print(ctemps, len(ctemps), len(set(ctemps)), len(candidates)) - - # FIXME: This line is kinda slow. I should find some way to replace it. - # FIXME: Specifically sum([i.rhs.count(e.lhs) for i in exprs]) == 1 is slow as hell - # mapper.update({e.lhs: e.rhs for e in candidates - # if sum([i.rhs.count(e.lhs) for i in exprs]) == 1}) + # Find all the CTemps in expression right-hand-sides without removing duplicates + ctemps = retrieve_ctemps([e.rhs for e in exprs]) # If there are ctemps in the expressions, then add any to the mapper which only # appear once - # TODO: Double check this is exactly the prior behaviour? + # TODO: Double check this is exactly the prior behaviour if ctemps: mapper.update({e.lhs: e.rhs for e in candidates if ctemps.count(e.lhs) == 1}) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 336c5a6161..ad6855a71d 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -362,8 +362,8 @@ def is_cross(source, sink): # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - # FIXME: This is a terrible variable name - if prefix and scope.thingy: + # TODO: Check that this is indeed what the attribute does + if prefix and scope.has_antidependencies: for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) for cg2 in cgroups[cgroups.index(cg1)+1:]: diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index 4b7f04253b..c39c0afafc 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -32,6 +32,13 @@ def q_symbol(expr): return False +def q_ctemp(expr): + try: + return expr.is_CTemp + except AttributeError: + return False + + def q_comp_acc(expr): return isinstance(expr, ComponentAccess) diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 9d57cf8135..b26246b8d7 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -1,7 +1,7 @@ import sympy from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, - q_symbol, q_dimension, q_derivative) + q_symbol, q_ctemp, q_dimension, q_derivative) from devito.tools import as_tuple __all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers', @@ -155,10 +155,15 @@ def retrieve_functions(exprs, mode='all', deep=False): def retrieve_symbols(exprs, mode='all'): - """Shorthand to retrieve the Scalar in ``exprs``.""" + """Shorthand to retrieve the Scalar in `exprs`.""" return search(exprs, q_symbol, mode, 'dfs') +def retrieve_ctemps(exprs, mode='all'): + """Shorthand to retrieve the CTemps in `exprs`""" + return search(exprs, q_ctemp, mode, 'dfs') + + def retrieve_function_carriers(exprs, mode='all'): """ Shorthand to retrieve the DiscreteFunction carriers in ``exprs``. An From 1c96965c1bbf9d6e65ce43b5c7d476440002aafc Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 9 Jun 2025 14:32:59 +0100 Subject: [PATCH 04/14] compiler: Use counter --- devito/passes/clusters/cse.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 046585654d..c184cba1e4 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -1,4 +1,4 @@ -from collections import defaultdict +from collections import defaultdict, Counter from functools import cached_property, singledispatch import numpy as np @@ -230,13 +230,14 @@ def _compact(exprs, exclude): # Find all the CTemps in expression right-hand-sides without removing duplicates ctemps = retrieve_ctemps([e.rhs for e in exprs]) + ctemp_count = Counter(ctemps) # If there are ctemps in the expressions, then add any to the mapper which only # appear once # TODO: Double check this is exactly the prior behaviour if ctemps: mapper.update({e.lhs: e.rhs for e in candidates - if ctemps.count(e.lhs) == 1}) + if ctemp_count[e.lhs] == 1}) processed = [] for e in exprs: From 201511120d89fe19a21de7edc0680c75416e5c84 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 9 Jun 2025 14:35:58 +0100 Subject: [PATCH 05/14] compiler: Revert incorrect modifications to DAG construction --- devito/ir/support/basic.py | 5 ----- devito/passes/clusters/misc.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 6801e1d0e4..1b680b13a9 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -363,7 +363,6 @@ def distance(self, other): # Case 1: `sit` is an IterationInterval with statically known # trip count. E.g. it ranges from 0 to 3; `other` performs a # constant access at 4 - # TODO: This case represents the majority of time constructing a DAG for v in (self[n], other[n]): try: if bool(v < sit.symbolic_min or v > sit.symbolic_max): @@ -1127,10 +1126,6 @@ def d_anti(self): """Anti (or "write-after-read") dependences.""" return DependenceGroup(self.d_anti_gen()) - @cached_property - def has_antidependencies(self): - return any(i.cause for i in self.d_anti_gen()) - @memoized_generator def d_output_gen(self): """Generate the output (or "write-after-write") dependences.""" diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index ad6855a71d..dd1615ca4a 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -362,8 +362,9 @@ def is_cross(source, sink): # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - # TODO: Check that this is indeed what the attribute does - if prefix and scope.has_antidependencies: + + # FIXME: Slow + if any(i.cause & prefix for i in scope.d_anti_gen()): for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) for cg2 in cgroups[cgroups.index(cg1)+1:]: From ca79bee44401fd1fc4d146bf93ea81c3e62dc97b Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Mon, 9 Jun 2025 17:20:39 +0100 Subject: [PATCH 06/14] compiler: Further misc improvements (minor) --- devito/ir/clusters/algorithms.py | 2 ++ devito/ir/support/basic.py | 7 ++++++- devito/passes/clusters/misc.py | 3 +-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 53d126accc..0ac12a58df 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -157,6 +157,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # parallelism for i in range(1, len(clusters)): # FIXME: This eats a lot of time (four seconds each time) + # FIXME: Pull scope out of this if self._break_for_parallelism(scope, candidates, i): return self.callback(clusters[:i], prefix, clusters[i:] + backlog, candidates | known_break) @@ -194,6 +195,7 @@ def _break_for_parallelism(self, scope, candidates, i): # break parallelism # TODO: Can this loop be made to short-circuit? + # TODO: Most of the time is burned in d_from_access_gen test = False for d in scope.d_from_access_gen(scope.a_query(i)): if d.is_local or d.is_storage_related(candidates): diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 1b680b13a9..09e015df86 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -317,6 +317,8 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp + # NOTE: This is called a lot with the same arguments - memoize yields mild speedup + @memoized_meth def distance(self, other): """ Compute the distance from ``self`` to ``other``. @@ -365,7 +367,9 @@ def distance(self, other): # constant access at 4 for v in (self[n], other[n]): try: - if bool(v < sit.symbolic_min or v > sit.symbolic_max): + # NOTE: Split the boolean to make the conditional short circuit + # more frequently for mild speedup + if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max): return Vector(S.ImaginaryUnit) except TypeError: pass @@ -1170,6 +1174,7 @@ def d_from_access_gen(self, accesses): Generate all flow, anti, and output dependences involving any of the given TimedAccess objects. """ + # FIXME: This seems to be a hotspot accesses = as_tuple(accesses) for d in self.d_all_gen(): for i in accesses: diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index dd1615ca4a..1acc36ad00 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -352,8 +352,7 @@ def is_cross(source, sink): v = len(cg0.exprs) return t0 < v <= t1 or t1 < v <= t0 - for cg1 in cgroups[n+1:]: - n1 = cgroups.index(cg1) + for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): # A Scope to compute all cross-ClusterGroup anti-dependences scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross) From ad69ef75bf7f468b9501a30840baa54dc608ca6a Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Tue, 10 Jun 2025 11:08:21 +0100 Subject: [PATCH 07/14] compiler: Add a check to pre-empt expensive symbolic comparisons before try-except --- devito/ir/clusters/algorithms.py | 5 ----- devito/ir/support/basic.py | 26 ++++++++++++++++---------- devito/operator/operator.py | 6 ++---- devito/passes/clusters/cse.py | 1 - devito/passes/clusters/misc.py | 1 - 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 0ac12a58df..5323539dd4 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -156,8 +156,6 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): # Schedule Clusters over different IterationSpaces if this increases # parallelism for i in range(1, len(clusters)): - # FIXME: This eats a lot of time (four seconds each time) - # FIXME: Pull scope out of this if self._break_for_parallelism(scope, candidates, i): return self.callback(clusters[:i], prefix, clusters[i:] + backlog, candidates | known_break) @@ -193,9 +191,6 @@ def callback(self, clusters, prefix, backlog=None, known_break=None): def _break_for_parallelism(self, scope, candidates, i): # `test` will be True if there's at least one data-dependence that would # break parallelism - - # TODO: Can this loop be made to short-circuit? - # TODO: Most of the time is burned in d_from_access_gen test = False for d in scope.d_from_access_gen(scope.a_query(i)): if d.is_local or d.is_storage_related(candidates): diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 09e015df86..693ae06f05 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -317,7 +317,7 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - # NOTE: This is called a lot with the same arguments - memoize yields mild speedup + # Note: memoization yields mild compiler speedup @memoized_meth def distance(self, other): """ @@ -365,14 +365,21 @@ def distance(self, other): # Case 1: `sit` is an IterationInterval with statically known # trip count. E.g. it ranges from 0 to 3; `other` performs a # constant access at 4 - for v in (self[n], other[n]): - try: - # NOTE: Split the boolean to make the conditional short circuit - # more frequently for mild speedup - if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max): - return Vector(S.ImaginaryUnit) - except TypeError: - pass + + # To avoid evaluating expensive symbolic Lt or Gt operations, + # we pre-empt such operations by checking if the values to be compared + # to are symbolic, and skip this case if not. + if not any(isinstance(i, sympy.core.Basic) + for i in (sit.symbolic_min, sit.symbolic_max)): + + for v in (self[n], other[n]): + try: + # Note: Boolean is split to make the conditional short + # circuit more frequently for mild speedup + if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max): + return Vector(S.ImaginaryUnit) + except TypeError: + pass # Case 2: `sit` is an IterationInterval over a local SubDimension # and `other` performs a constant access @@ -1174,7 +1181,6 @@ def d_from_access_gen(self, accesses): Generate all flow, anti, and output dependences involving any of the given TimedAccess objects. """ - # FIXME: This seems to be a hotspot accesses = as_tuple(accesses) for d in self.d_all_gen(): for i in accesses: diff --git a/devito/operator/operator.py b/devito/operator/operator.py index acc4133078..21a216dace 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -967,10 +967,8 @@ def _emit_build_profiling(self): tot = timings.pop('op-compile') perf(f"Operator `{self.name}` generated in {fround(tot):.2f} s") - # max_hotspots = 3 - # threshold = 20. - max_hotspots = 300 - threshold = 0.5 + max_hotspots = 3 + threshold = 20. def _emit_timings(timings, indent=''): timings.pop('total', None) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index c184cba1e4..119667b659 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -234,7 +234,6 @@ def _compact(exprs, exclude): # If there are ctemps in the expressions, then add any to the mapper which only # appear once - # TODO: Double check this is exactly the prior behaviour if ctemps: mapper.update({e.lhs: e.rhs for e in candidates if ctemp_count[e.lhs] == 1}) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 1acc36ad00..ae95ecb4b6 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -362,7 +362,6 @@ def is_cross(source, sink): # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - # FIXME: Slow if any(i.cause & prefix for i in scope.d_anti_gen()): for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) From c725b2f579654b1d0aae35898db9a27131affebd Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 11 Jun 2025 14:19:23 +0100 Subject: [PATCH 08/14] misc: Refactoring and misc code style improvments --- devito/ir/support/basic.py | 5 +++-- devito/passes/clusters/cse.py | 12 ++++++++---- devito/passes/clusters/misc.py | 1 - devito/symbolics/queries.py | 7 ------- devito/symbolics/search.py | 7 +------ devito/types/basic.py | 1 - 6 files changed, 12 insertions(+), 21 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 693ae06f05..54512ecb87 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -317,7 +317,8 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - # Note: memoization yields mild compiler speedup + # Note: memoization yields mild compiler speedup. Will need to be made + # thread-safe for multithreading the compiler. @memoized_meth def distance(self, other): """ @@ -369,7 +370,7 @@ def distance(self, other): # To avoid evaluating expensive symbolic Lt or Gt operations, # we pre-empt such operations by checking if the values to be compared # to are symbolic, and skip this case if not. - if not any(isinstance(i, sympy.core.Basic) + if not any(isinstance(i, sympy.Basic) for i in (sit.symbolic_min, sit.symbolic_max)): for v in (self[n], other[n]): diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 119667b659..d7f5b92b5e 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -13,7 +13,7 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass from devito.symbolics import estimate_cost, q_leaf, q_terminal -from devito.symbolics.search import retrieve_ctemps +from devito.symbolics.search import search from devito.symbolics.manipulation import _uxreplace from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype from devito.types import Eq, Symbol, Temp @@ -26,11 +26,15 @@ class CTemp(Temp): """ A cluster-level Temp, similar to Temp, ensured to have different priority """ - is_CTemp = True ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp') +def retrieve_ctemps(exprs, mode='all'): + """Shorthand to retrieve the CTemps in `exprs`""" + return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs') + + @cluster_pass def cse(cluster, sregistry=None, options=None, **kwargs): """ @@ -229,12 +233,12 @@ def _compact(exprs, exclude): mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)} # Find all the CTemps in expression right-hand-sides without removing duplicates - ctemps = retrieve_ctemps([e.rhs for e in exprs]) - ctemp_count = Counter(ctemps) + ctemps = retrieve_ctemps(e.rhs for e in exprs) # If there are ctemps in the expressions, then add any to the mapper which only # appear once if ctemps: + ctemp_count = Counter(ctemps) mapper.update({e.lhs: e.rhs for e in candidates if ctemp_count[e.lhs] == 1}) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index ae95ecb4b6..52c66eaf43 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -361,7 +361,6 @@ def is_cross(source, sink): # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - if any(i.cause & prefix for i in scope.d_anti_gen()): for cg2 in cgroups[n:cgroups.index(cg1)]: dag.add_edge(cg2, cg1) diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index c39c0afafc..4b7f04253b 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -32,13 +32,6 @@ def q_symbol(expr): return False -def q_ctemp(expr): - try: - return expr.is_CTemp - except AttributeError: - return False - - def q_comp_acc(expr): return isinstance(expr, ComponentAccess) diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index b26246b8d7..fceb8a1e18 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -1,7 +1,7 @@ import sympy from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, - q_symbol, q_ctemp, q_dimension, q_derivative) + q_symbol, q_dimension, q_derivative) from devito.tools import as_tuple __all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers', @@ -159,11 +159,6 @@ def retrieve_symbols(exprs, mode='all'): return search(exprs, q_symbol, mode, 'dfs') -def retrieve_ctemps(exprs, mode='all'): - """Shorthand to retrieve the CTemps in `exprs`""" - return search(exprs, q_ctemp, mode, 'dfs') - - def retrieve_function_carriers(exprs, mode='all'): """ Shorthand to retrieve the DiscreteFunction carriers in ``exprs``. An diff --git a/devito/types/basic.py b/devito/types/basic.py index 76067fc287..8c7e960fb2 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -298,7 +298,6 @@ class Basic(CodeSymbol): is_Object = False is_LocalObject = False is_LocalType = False - is_CTemp = False # Created by the user is_Input = False From 89cde25a2fd69dbcbf1776aa7275f1a051874be0 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 11 Jun 2025 16:45:08 +0100 Subject: [PATCH 09/14] misc: Use less esoteric comment phrasing --- devito/passes/clusters/cse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index d7f5b92b5e..f658ca3b9f 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -235,8 +235,8 @@ def _compact(exprs, exclude): # Find all the CTemps in expression right-hand-sides without removing duplicates ctemps = retrieve_ctemps(e.rhs for e in exprs) - # If there are ctemps in the expressions, then add any to the mapper which only - # appear once + # If there are ctemps in the expressions, then add any that only appear once to + # the mapper if ctemps: ctemp_count = Counter(ctemps) mapper.update({e.lhs: e.rhs for e in candidates From 3d6eff0c2b6310e3bcebe70e454f51334618e7ba Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 11 Jun 2025 17:58:17 +0100 Subject: [PATCH 10/14] misc: Add comparisons that give up upon encountering symbolic arguments --- devito/ir/support/basic.py | 29 ++++++++++------------------- devito/tools/utils.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 54512ecb87..b3ec1bae4b 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -11,7 +11,8 @@ q_constant, q_comp_acc, q_affine, q_routine, search, uxreplace) from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted, - flatten, memoized_meth, memoized_generator) + flatten, memoized_meth, memoized_generator, smart_gt, + smart_lt) from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence, CriticalRegion, Function, Symbol, Temp, TempArray, TBArray) @@ -317,9 +318,6 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - # Note: memoization yields mild compiler speedup. Will need to be made - # thread-safe for multithreading the compiler. - @memoized_meth def distance(self, other): """ Compute the distance from ``self`` to ``other``. @@ -366,21 +364,14 @@ def distance(self, other): # Case 1: `sit` is an IterationInterval with statically known # trip count. E.g. it ranges from 0 to 3; `other` performs a # constant access at 4 - - # To avoid evaluating expensive symbolic Lt or Gt operations, - # we pre-empt such operations by checking if the values to be compared - # to are symbolic, and skip this case if not. - if not any(isinstance(i, sympy.Basic) - for i in (sit.symbolic_min, sit.symbolic_max)): - - for v in (self[n], other[n]): - try: - # Note: Boolean is split to make the conditional short - # circuit more frequently for mild speedup - if bool(v < sit.symbolic_min) or bool(v > sit.symbolic_max): - return Vector(S.ImaginaryUnit) - except TypeError: - pass + for v in (self[n], other[n]): + # Note: To avoid evaluating expensive symbolic Lt or Gt operations, + # we pre-empt such operations by checking if the values to be compared + # to are symbolic, and skip this case if not. + # Note: Boolean is split to make the conditional short + # circuit more frequently for mild speedup. + if smart_lt(v, sit.symbolic_min) or smart_gt(v, sit.symbolic_max): + return Vector(S.ImaginaryUnit) # Case 2: `sit` is an IterationInterval over a local SubDimension # and `other` performs a constant access diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 0a28de16a8..4d6367e812 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -12,7 +12,8 @@ 'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered', 'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list', 'indices_to_slices', 'indices_to_sections', 'transitive_closure', - 'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number'] + 'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number', + 'smart_lt', 'smart_gt'] def prod(iterable, initial=1): @@ -346,3 +347,32 @@ def key(i): return (v, str(type(i))) return sorted(items, key=key, reverse=True) + + +def avoid_symbolic_relations(func): + """ + Decorator to avoid calculating a relation symbolically if doing so may be slow. + In the case that one of the values being compared is symbolic, just give up + and return False. + """ + def wrapper(a, b): + if any(isinstance(expr, sympy.Basic) for expr in (a, b)): + # An argument is symbolic, so give up and assume False + return False + + try: + return func(a, b) + except TypeError: + return False + + return wrapper + + +@avoid_symbolic_relations +def smart_lt(a, b): + return bool(a < b) + + +@avoid_symbolic_relations +def smart_gt(a, b): + return bool(a > b) From 8022ca380845a8cda55d98753d0e2417e341ec93 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Wed, 11 Jun 2025 18:12:08 +0100 Subject: [PATCH 11/14] misc: Make avoid_symbolic decorator more generic --- devito/tools/utils.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 4d6367e812..9f7974407e 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -1,6 +1,6 @@ from collections import OrderedDict from collections.abc import Iterable -from functools import reduce +from functools import reduce, wraps from itertools import chain, combinations, groupby, product, zip_longest from operator import attrgetter, mul import types @@ -349,30 +349,36 @@ def key(i): return sorted(items, key=key, reverse=True) -def avoid_symbolic_relations(func): +def avoid_symbolic(default_val): """ - Decorator to avoid calculating a relation symbolically if doing so may be slow. - In the case that one of the values being compared is symbolic, just give up - and return False. + Decorator to avoid calling a function where doing so will result in symbolic + computation being performed. For use if symbolic computation may be slow. In + the case that an arg is symbolic, just give up and return a default value. """ - def wrapper(a, b): - if any(isinstance(expr, sympy.Basic) for expr in (a, b)): - # An argument is symbolic, so give up and assume False - return False + def _avoid_symbolic(func): + @wraps(func) + def wrapper(*args): + if any(isinstance(expr, sympy.Basic) for expr in args): + # An argument is symbolic, so give up and assume default + return default_val - try: - return func(a, b) - except TypeError: - return False + try: + return func(*args) + except TypeError: + return default_val + + return wrapper - return wrapper + return _avoid_symbolic -@avoid_symbolic_relations +@avoid_symbolic(False) def smart_lt(a, b): + """An Lt that gives up and returns False if supplied a symbolic argument""" return bool(a < b) -@avoid_symbolic_relations +@avoid_symbolic(False) def smart_gt(a, b): + """A Gt that gives up and returns False if supplied a symbolic argument""" return bool(a > b) From d5e39cbe1814afeead5f998dfe1eb7171fcf6a9a Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 12 Jun 2025 09:34:59 +0100 Subject: [PATCH 12/14] misc: Tweak avoid_symbolic decorator --- devito/tools/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 9f7974407e..99f8464de7 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -349,7 +349,7 @@ def key(i): return sorted(items, key=key, reverse=True) -def avoid_symbolic(default_val): +def avoid_symbolic(default=None): """ Decorator to avoid calling a function where doing so will result in symbolic computation being performed. For use if symbolic computation may be slow. In @@ -360,25 +360,25 @@ def _avoid_symbolic(func): def wrapper(*args): if any(isinstance(expr, sympy.Basic) for expr in args): # An argument is symbolic, so give up and assume default - return default_val + return default try: return func(*args) except TypeError: - return default_val + return default return wrapper return _avoid_symbolic -@avoid_symbolic(False) +@avoid_symbolic(default=False) def smart_lt(a, b): """An Lt that gives up and returns False if supplied a symbolic argument""" return bool(a < b) -@avoid_symbolic(False) +@avoid_symbolic(default=False) def smart_gt(a, b): """A Gt that gives up and returns False if supplied a symbolic argument""" return bool(a > b) From 373990a5e529319a2255b69db48fbbbe7f6e9c45 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 12 Jun 2025 09:42:45 +0100 Subject: [PATCH 13/14] misc: Update comment --- devito/ir/support/basic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index b3ec1bae4b..2848fb4d64 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -365,9 +365,8 @@ def distance(self, other): # trip count. E.g. it ranges from 0 to 3; `other` performs a # constant access at 4 for v in (self[n], other[n]): - # Note: To avoid evaluating expensive symbolic Lt or Gt operations, - # we pre-empt such operations by checking if the values to be compared - # to are symbolic, and skip this case if not. + # Note: Uses smart_ comparisons avoid evaluating expensive + # symbolic Lt or Gt operations, # Note: Boolean is split to make the conditional short # circuit more frequently for mild speedup. if smart_lt(v, sit.symbolic_min) or smart_gt(v, sit.symbolic_max): From 613f1fe29e2c0b3660f3d545a779cb91e601d8de Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 12 Jun 2025 09:50:28 +0100 Subject: [PATCH 14/14] misc: Refactor and tidy decorator --- devito/tools/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 99f8464de7..546c5cd49f 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -362,10 +362,7 @@ def wrapper(*args): # An argument is symbolic, so give up and assume default return default - try: - return func(*args) - except TypeError: - return default + return func(*args) return wrapper @@ -375,10 +372,10 @@ def wrapper(*args): @avoid_symbolic(default=False) def smart_lt(a, b): """An Lt that gives up and returns False if supplied a symbolic argument""" - return bool(a < b) + return a < b @avoid_symbolic(default=False) def smart_gt(a, b): """A Gt that gives up and returns False if supplied a symbolic argument""" - return bool(a > b) + return a > b