diff --git a/examples/advection/surface.py b/examples/advection/surface.py index bc06efc5d..0e986f31b 100644 --- a/examples/advection/surface.py +++ b/examples/advection/surface.py @@ -31,7 +31,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from meshmode.dof_array import flatten @@ -62,7 +61,7 @@ def __init__(self, actx, dcoll, order, visualize=True): import matplotlib.pyplot as pt self.fig = pt.figure(figsize=(8, 8), dpi=300) - x = thaw(dcoll.discr_from_dd(dof_desc.DD_VOLUME).nodes(), actx) + x = actx.thaw(dcoll.discr_from_dd(dof_desc.DD_VOLUME).nodes()) self.x = actx.to_numpy(flatten(actx.np.arctan2(x[1], x[0]))) elif self.ambient_dim == 3: from grudge.shortcuts import make_visualizer @@ -174,7 +173,7 @@ def main(ctx_factory, dim=2, order=4, use_quad=False, visualize=False): # {{{ Surface advection operator # velocity field - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) c = make_obj_array([-x[1], x[0], 0.0])[:dim] def f_initial_condition(x): @@ -238,7 +237,7 @@ def rhs(t, u): df = dof_desc.DOFDesc(FACE_RESTR_INTERIOR) face_discr = dcoll.discr_from_dd(df) - face_normal = thaw(dcoll.normal(dd=df), actx) + face_normal = actx.thaw(dcoll.normal(dd=df)) from meshmode.discretization.visualization import make_visualizer vis = make_visualizer(actx, face_discr) diff --git a/examples/advection/var-velocity.py b/examples/advection/var-velocity.py index 289f1274d..de1b45354 100644 --- a/examples/advection/var-velocity.py +++ b/examples/advection/var-velocity.py @@ -29,7 +29,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from meshmode.dof_array import flatten @@ -61,7 +60,7 @@ def __init__(self, actx, dcoll, order, visualize=True, ylim=None): self.ylim = ylim volume_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - self.x = actx.to_numpy(flatten(thaw(volume_discr.nodes()[0], actx))) + self.x = actx.to_numpy(flatten(actx.thaw(volume_discr.nodes()[0]))) else: from grudge.shortcuts import make_visualizer self.vis = make_visualizer(dcoll) @@ -168,7 +167,7 @@ def zero_inflow_bc(dtag, t=0): from grudge.models.advection import VariableCoefficientAdvectionOperator - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) # velocity field if dim == 1: diff --git a/examples/advection/weak.py b/examples/advection/weak.py index 433f9092e..3470fdd60 100644 --- a/examples/advection/weak.py +++ b/examples/advection/weak.py @@ -30,7 +30,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from meshmode.dof_array import flatten @@ -60,7 +59,7 @@ def __init__(self, actx, dcoll, order, visualize=True, ylim=None): self.ylim = ylim volume_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - self.x = actx.to_numpy(flatten(thaw(volume_discr.nodes()[0], actx))) + self.x = actx.to_numpy(flatten(actx.thaw(volume_discr.nodes()[0]))) else: from grudge.shortcuts import make_visualizer self.vis = make_visualizer(dcoll) @@ -152,13 +151,13 @@ def u_analytic(x, t=0): dcoll, c, inflow_u=lambda t: u_analytic( - thaw(dcoll.nodes(dd=BTAG_ALL), actx), + actx.thaw(dcoll.nodes(dd=BTAG_ALL)), t=t ), flux_type=flux_type ) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) u = u_analytic(nodes, t=0) def rhs(t, u): diff --git a/examples/euler/acoustic_pulse.py b/examples/euler/acoustic_pulse.py index 2672dda12..779062910 100644 --- a/examples/euler/acoustic_pulse.py +++ b/examples/euler/acoustic_pulse.py @@ -28,7 +28,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw, freeze from grudge.array_context import ( PyOpenCLArrayContext, PytatoPyOpenCLArrayContext @@ -175,7 +174,7 @@ def rhs(t, q): cn = 0.5*(order + 1)**2 dt = cfl * actx.to_numpy(h_min_from_volume(dcoll)) / cn - fields = acoustic_pulse_condition(thaw(dcoll.nodes(), actx)) + fields = acoustic_pulse_condition(actx.thaw(dcoll.nodes())) logger.info("Timestep size: %g", dt) @@ -204,7 +203,7 @@ def rhs(t, q): ) assert norm_q < 5 - fields = thaw(freeze(fields, actx), actx) + fields = actx.thaw(actx.freeze(fields)) fields = rk4_step(fields, t, dt, compiled_rhs) t += dt step += 1 diff --git a/examples/euler/vortex.py b/examples/euler/vortex.py index 0e4cf7d76..9f00743e5 100644 --- a/examples/euler/vortex.py +++ b/examples/euler/vortex.py @@ -26,8 +26,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw, freeze - from grudge.array_context import PytatoPyOpenCLArrayContext, PyOpenCLArrayContext from grudge.models.euler import ( vortex_initial_condition, @@ -111,7 +109,7 @@ def rhs(t, q): compiled_rhs = actx.compile(rhs) - fields = vortex_initial_condition(thaw(dcoll.nodes(), actx)) + fields = vortex_initial_condition(actx.thaw(dcoll.nodes())) from grudge.dt_utils import h_min_from_volume @@ -147,7 +145,7 @@ def rhs(t, q): ) assert norm_q < 200 - fields = thaw(freeze(fields, actx), actx) + fields = actx.thaw(actx.freeze(fields)) fields = rk4_step(fields, t, dt, compiled_rhs) t += dt step += 1 diff --git a/examples/geometry.py b/examples/geometry.py index f79cd6bb9..442bbcfff 100644 --- a/examples/geometry.py +++ b/examples/geometry.py @@ -30,7 +30,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from grudge import DiscretizationCollection, shortcuts @@ -51,9 +50,9 @@ def main(write_output=True): dcoll = DiscretizationCollection(actx, mesh, order=4) - nodes = thaw(dcoll.nodes(), actx) - bdry_nodes = thaw(dcoll.nodes(dd=BTAG_ALL), actx) - bdry_normals = thaw(dcoll.normal(dd=BTAG_ALL), actx) + nodes = actx.thaw(dcoll.nodes()) + bdry_nodes = actx.thaw(dcoll.nodes(dd=BTAG_ALL)) + bdry_normals = actx.thaw(dcoll.normal(dd=BTAG_ALL)) if write_output: vis = shortcuts.make_visualizer(dcoll) diff --git a/examples/hello-grudge.py b/examples/hello-grudge.py index bf93fd08b..cfb724115 100644 --- a/examples/hello-grudge.py +++ b/examples/hello-grudge.py @@ -14,7 +14,6 @@ import grudge.op as op from meshmode.mesh.generation import generate_box_mesh from meshmode.array_context import PyOpenCLArrayContext -from arraycontext import thaw from grudge.dof_desc import DTAG_BOUNDARY, FACE_RESTR_INTERIOR @@ -43,7 +42,7 @@ def left_boundary_condition(x, t): def flux(dcoll, u_tpair): dd = u_tpair.dd velocity = np.array([2 * np.pi]) - normal = thaw(dcoll.normal(dd), actx) + normal = actx.thaw(dcoll.normal(dd)) v_dot_n = np.dot(velocity, normal) u_upwind = actx.np.where(v_dot_n > 0, @@ -55,8 +54,8 @@ def flux(dcoll, u_tpair): left_bndry = DTAG_BOUNDARY("left") right_bndry = DTAG_BOUNDARY("right") -x_vol = thaw(dcoll.nodes(), actx) -x_bndry = thaw(dcoll.discr_from_dd(left_bndry).nodes(), actx) +x_vol = actx.thaw(dcoll.nodes()) +x_bndry = actx.thaw(dcoll.discr_from_dd(left_bndry).nodes()) uh = initial_condition(x_vol) diff --git a/examples/maxwell/cavities.py b/examples/maxwell/cavities.py index 90a515bdf..3d581c18a 100644 --- a/examples/maxwell/cavities.py +++ b/examples/maxwell/cavities.py @@ -28,7 +28,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from grudge.shortcuts import set_up_rk4 @@ -84,7 +83,7 @@ def cavity_mode(x, t=0): else: return get_rectangular_cavity_mode(actx, x, t, 1, (2, 3)) - fields = cavity_mode(thaw(dcoll.nodes(), actx), t=0) + fields = cavity_mode(actx.thaw(dcoll.nodes()), t=0) maxwell_operator.check_bc_coverage(mesh) diff --git a/examples/wave/var-propagation-speed.py b/examples/wave/var-propagation-speed.py index 407dd06b2..9929f6dbf 100644 --- a/examples/wave/var-propagation-speed.py +++ b/examples/wave/var-propagation-speed.py @@ -28,7 +28,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from grudge.shortcuts import set_up_rk4 @@ -63,7 +62,7 @@ def source_f(actx, dcoll, t=0): source_center = np.array([0.1, 0.22, 0.33])[:dcoll.dim] source_width = 0.05 source_omega = 3 - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) source_center_dist = flat_obj_array( [nodes[i] - source_center[i] for i in range(dcoll.dim)] ) @@ -75,7 +74,7 @@ def source_f(actx, dcoll, t=0): ) ) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) ones = dcoll.zeros(actx) + 1 c = actx.np.where(np.dot(x, x) < 0.15, 0.1 * ones, 0.2 * ones) diff --git a/examples/wave/wave-min-mpi.py b/examples/wave/wave-min-mpi.py index c58ed841d..6c56353bd 100644 --- a/examples/wave/wave-min-mpi.py +++ b/examples/wave/wave-min-mpi.py @@ -28,7 +28,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import MPIPyOpenCLArrayContext from grudge.shortcuts import set_up_rk4 @@ -88,7 +87,7 @@ def source_f(actx, dcoll, t=0): source_center = np.array([0.1, 0.22, 0.33])[:dcoll.dim] source_width = 0.05 source_omega = 3 - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) source_center_dist = flat_obj_array( [nodes[i] - source_center[i] for i in range(dcoll.dim)] ) diff --git a/examples/wave/wave-op-mpi.py b/examples/wave/wave-op-mpi.py index e33240443..8c23336d0 100644 --- a/examples/wave/wave-op-mpi.py +++ b/examples/wave/wave-op-mpi.py @@ -30,7 +30,6 @@ import pyopencl.tools as cl_tools from arraycontext import ( - thaw, with_container_arithmetic, dataclass_array_container ) @@ -73,12 +72,12 @@ def array_context(self): return self.u.array_context -def wave_flux(dcoll, c, w_tpair): +def wave_flux(actx, dcoll, c, w_tpair): u = w_tpair.u v = w_tpair.v dd = w_tpair.dd - normal = thaw(dcoll.normal(dd), u.int.array_context) + normal = actx.thaw(dcoll.normal(dd)) flux_weak = WaveState( u=v.avg @ normal, @@ -99,7 +98,7 @@ class _WaveStateTag: pass -def wave_operator(dcoll, c, w, quad_tag=None): +def wave_operator(actx, dcoll, c, w, quad_tag=None): dd_base = as_dofdesc("vol") dd_vol = DOFDesc("vol", quad_tag) dd_faces = DOFDesc("all_faces", quad_tag) @@ -135,13 +134,14 @@ def interp_to_surf_quad(utpair): dcoll, dd_faces, wave_flux( + actx, dcoll, c=c, w_tpair=op.bdry_trace_pair(dcoll, dd_btag, interior=dir_bval, exterior=dir_bc) ) + sum( - wave_flux(dcoll, c=c, w_tpair=interp_to_surf_quad(tpair)) + wave_flux(actx, dcoll, c=c, w_tpair=interp_to_surf_quad(tpair)) for tpair in op.interior_trace_pairs(dcoll, w, comm_tag=_WaveStateTag) ) @@ -165,7 +165,7 @@ def bump(actx, dcoll, t=0): source_width = 0.05 source_omega = 3 - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) center_dist = flat_obj_array([ nodes[i] - source_center[i] for i in range(dcoll.dim) @@ -258,7 +258,7 @@ def main(ctx_factory, dim=2, order=3, vis = make_visualizer(dcoll) def rhs(t, w): - return wave_operator(dcoll, c=c, w=w, quad_tag=quad_tag) + return wave_operator(actx, dcoll, c=c, w=w, quad_tag=quad_tag) compiled_rhs = actx.compile(rhs) diff --git a/examples/wave/wave-op-var-velocity.py b/examples/wave/wave-op-var-velocity.py index 9dace53c5..43c72eff9 100644 --- a/examples/wave/wave-op-var-velocity.py +++ b/examples/wave/wave-op-var-velocity.py @@ -29,7 +29,6 @@ import pyopencl as cl import pyopencl.tools as cl_tools -from arraycontext import thaw from grudge.array_context import PyOpenCLArrayContext from pytools.obj_array import flat_obj_array @@ -48,14 +47,14 @@ # {{{ wave equation bits -def wave_flux(dcoll, c, w_tpair): +def wave_flux(actx, dcoll, c, w_tpair): dd = w_tpair.dd dd_quad = dd.with_discr_tag(DISCR_TAG_QUAD) u = w_tpair[0] v = w_tpair[1:] - normal = thaw(dcoll.normal(dd), u.int.array_context) + normal = actx.thaw(dcoll.normal(dd)) flux_weak = flat_obj_array( np.dot(v.avg, normal), @@ -76,7 +75,7 @@ def wave_flux(dcoll, c, w_tpair): return op.project(dcoll, dd_quad, dd_allfaces_quad, c_quad*flux_quad) -def wave_operator(dcoll, c, w): +def wave_operator(actx, dcoll, c, w): u = w[0] v = w[1:] @@ -104,13 +103,14 @@ def wave_operator(dcoll, c, w): dcoll, dd_allfaces_quad, wave_flux( + actx, dcoll, c=c, w_tpair=op.bdry_trace_pair(dcoll, BTAG_ALL, interior=dir_bval, exterior=dir_bc) ) + sum( - wave_flux(dcoll, c=c, w_tpair=tpair) + wave_flux(actx, dcoll, c=c, w_tpair=tpair) for tpair in op.interior_trace_pairs(dcoll, w) ) ) @@ -135,7 +135,7 @@ def bump(actx, dcoll, t=0, width=0.05, center=None): center = center[:dcoll.dim] source_omega = 3 - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) center_dist = flat_obj_array([ nodes[i] - center[i] for i in range(dcoll.dim) @@ -189,7 +189,7 @@ def main(ctx_factory, dim=2, order=3, visualize=False): vis = make_visualizer(dcoll) def rhs(t, w): - return wave_operator(dcoll, c=c, w=w) + return wave_operator(actx, dcoll, c=c, w=w) logger.info("dt = %g", dt) diff --git a/grudge/array_context.py b/grudge/array_context.py index 171016bfe..7178a51ba 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -40,6 +40,7 @@ PyOpenCLArrayContext as _PyOpenCLArrayContextBase, PytatoPyOpenCLArrayContext as _PytatoPyOpenCLArrayContextBase) from pyrsistent import pmap +from warnings import warn import logging logger = logging.getLogger(__name__) @@ -56,17 +57,38 @@ # (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) from loopy.codegen.result import get_idis_for_kernel # noqa except ImportError: - from warnings import warn + # warn("Your loopy and meshmode branches are mismatched. " + # "Please make sure that you have the " + # "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms " # noqa + # "branch of loopy.") + _HAVE_SINGLE_GRID_WORK_BALANCING = False + else: + _HAVE_SINGLE_GRID_WORK_BALANCING = True + +except ImportError: + _HAVE_SINGLE_GRID_WORK_BALANCING = False + +try: + # FIXME: temporary workaround while FusionContractorArrayContext + # is not available in meshmode's main branch + from meshmode.array_context import FusionContractorArrayContext + + try: + # Crude check if we have the correct loopy branch + # (https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) + from loopy.transform.loop_fusion import get_kennedy_unweighted_fusion_candidates # noqa + except ImportError: warn("Your loopy and meshmode branches are mismatched. " "Please make sure that you have the " "https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms " # noqa "branch of loopy.") - _HAVE_SINGLE_GRID_WORK_BALANCING = False + _HAVE_FUSION_ACTX = False else: - _HAVE_SINGLE_GRID_WORK_BALANCING = True + _HAVE_FUSION_ACTX = True except ImportError: - _HAVE_SINGLE_GRID_WORK_BALANCING = False + _HAVE_FUSION_ACTX = False + from arraycontext.pytest import ( _PytestPyOpenCLArrayContextFactoryWithClass, @@ -74,7 +96,7 @@ register_pytest_array_context_factory) from arraycontext import ArrayContext from arraycontext.container import ArrayContainer -from arraycontext.impl.pytato.compile import LazilyCompilingFunctionCaller +from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller if TYPE_CHECKING: import pytato as pt @@ -96,7 +118,6 @@ def __init__(self, queue: "pyopencl.CommandQueue", force_device_scalars: bool = False) -> None: if allocator is None: - from warnings import warn warn("No memory allocator specified, please pass one. " "(Preferably a pyopencl.tools.MemoryPool in order " "to reduce device allocations)") @@ -116,7 +137,6 @@ class PytatoPyOpenCLArrayContext(_PytatoPyOpenCLArrayContextBase): """ def __init__(self, queue, allocator=None): if allocator is None: - from warnings import warn warn("No memory allocator specified, please pass one. " "(Preferably a pyopencl.tools.MemoryPool in order " "to reduce device allocations)") @@ -131,16 +151,41 @@ class MPIBasedArrayContext: # {{{ distributed + pytato -class _DistributedLazilyCompilingFunctionCaller(LazilyCompilingFunctionCaller): +class _DistributedLazilyPyOpenCLCompilingFunctionCaller( + LazilyPyOpenCLCompilingFunctionCaller): def _dag_to_compiled_func(self, dict_of_named_arrays, input_id_to_name_in_program, output_id_to_name_in_program, output_template): - from pytato.transform import deduplicate_data_wrappers - dict_of_named_arrays = deduplicate_data_wrappers(dict_of_named_arrays) + import pytato as pt + + from pytools import ProcessLogger - from pytato import find_distributed_partition - distributed_partition = find_distributed_partition(dict_of_named_arrays) + with ProcessLogger(logger, "deduplicate_data_wrappers[pre-partition]"): + dict_of_named_arrays = pt.transform.deduplicate_data_wrappers( + dict_of_named_arrays) + + with ProcessLogger(logger, "materialize_with_mpms[pre-partition]"): + dict_of_named_arrays = pt.transform.materialize_with_mpms( + dict_of_named_arrays) + + # FIXME: Remove the import failure handling once this is in upstream grudge + try: + # pytest: disable=no-name-in-module,import-error + from meshmode.pytato_utils import unify_discretization_entity_tags + except ImportError: + from warnings import warn + warn("'unify_discretization_entity_tags' is unavailable in meshmode, " + "skipping. Certain array contexts may require this " + "transformation for acceptable results.") + + else: + with ProcessLogger(logger, + "transform_dag.infer_axes_tags[pre-partition]"): + dict_of_named_arrays = unify_discretization_entity_tags( + dict_of_named_arrays) + + distributed_partition = pt.find_distributed_partition(dict_of_named_arrays) # {{{ turn symbolic tags into globally agreed-upon integers @@ -175,7 +220,7 @@ def _dag_to_compiled_func(self, dict_of_named_arrays, part_id_to_prg[part.pid], part_prg_name_to_tags, part_prg_name_to_axes - ) = self._dag_to_transformed_loopy_prg(d) + ) = self._dag_to_transformed_pytato_prg(d) assert not (set(name_in_program_to_tags.keys()) & set(part_prg_name_to_tags.keys())) @@ -201,8 +246,8 @@ def _dag_to_compiled_func(self, dict_of_named_arrays, class _DistributedCompiledFunction: """ A callable which captures the :class:`pytato.target.BoundProgram` resulting - from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of - input types, and generating :mod:`loopy` IR from it. + from calling :attr:`~LazilyPyOpenCLCompilingFunctionCaller.f` with a given + set of input types, and generating :mod:`loopy` IR from it. .. attribute:: pytato_program @@ -210,8 +255,9 @@ class _DistributedCompiledFunction: A mapping from input id to the placeholder name in :attr:`CompiledFunction.pytato_program`. Input id is represented as the - position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented - with the leaf array's key if the argument is an array container. + position of :attr:`~LazilyPyOpenCLCompilingFunctionCaller.f`'s argument + augmented with the leaf array's key if the argument is an array + container. .. attribute:: output_id_to_name_in_program @@ -243,10 +289,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: representation. """ - from arraycontext.impl.pytato.compile import _args_to_cl_buffers + from arraycontext.impl.pytato.compile import _args_to_device_buffers from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from arraycontext.impl.pytato.utils import get_cl_axes_from_pt_axes - input_args_for_prg = _args_to_cl_buffers( + input_args_for_prg = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) from pytato.distributed import execute_distributed_partition @@ -274,7 +320,6 @@ def __init__( self, mpi_communicator, queue, *, mpi_base_tag, allocator=None ) -> None: if allocator is None: - from warnings import warn warn("No memory allocator specified, please pass one. " "(Preferably a pyopencl.tools.MemoryPool in order " "to reduce device allocations)") @@ -287,7 +332,7 @@ def __init__( # FIXME: implement distributed-aware freeze def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: - return _DistributedLazilyCompilingFunctionCaller(self, f) + return _DistributedLazilyPyOpenCLCompilingFunctionCaller(self, f) def clone(self): # type-ignore-reason: 'DistributedLazyArrayContext' has no 'queue' member @@ -356,6 +401,18 @@ class MPISingleGridWorkBalancingPytatoArrayContext( else: MPIPytatoArrayContext = MPIBasePytatoPyOpenCLArrayContext + +if _HAVE_FUSION_ACTX: + class MPIFusionContractorArrayContext( + MPIPytatoArrayContextBase, FusionContractorArrayContext): + """ + .. autofunction:: __init__ + """ + + MPIPytatoArrayContext = MPIFusionContractorArrayContext +else: + MPIPytatoArrayContext = MPIBasePytatoPyOpenCLArrayContext + # }}} @@ -388,31 +445,60 @@ class PytestPyOpenCLArrayContextFactoryWithHostScalars( # {{{ actx selection + +def _get_single_grid_pytato_actx_class(distributed: bool) -> Type[ArrayContext]: + if not _HAVE_SINGLE_GRID_WORK_BALANCING: + warn("No device-parallel actx available, execution will be slow. " + "Please make sure you have the right branches for loopy " + "(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa + "and meshmode " + "(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).") # noqa + # lazy, non-distributed + if not distributed: + if _HAVE_SINGLE_GRID_WORK_BALANCING: + actx_class = SingleGridWorkBalancingPytatoArrayContext + else: + actx_class = PytatoPyOpenCLArrayContext + # distributed+lazy: + if _HAVE_SINGLE_GRID_WORK_BALANCING: + actx_class = MPISingleGridWorkBalancingPytatoArrayContext + else: + actx_class = MPIBasePytatoPyOpenCLArrayContext + + return actx_class + + def get_reasonable_array_context_class( - lazy: bool = True, distributed: bool = True + lazy: bool = True, distributed: bool = True, + fusion: Optional[bool] = None, ) -> Type[ArrayContext]: """Returns a reasonable :class:`PyOpenCLArrayContext` currently supported given the constraints of *lazy* and *distributed*.""" + if fusion is None: + fusion = lazy + if lazy: - if not _HAVE_SINGLE_GRID_WORK_BALANCING: - from warnings import warn - warn("No device-parallel actx available, execution will be slow. " - "Please make sure you have the right branches for loopy " - "(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa - "and meshmode " - "(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).") # noqa - # lazy, non-distributed - if not distributed: - if _HAVE_SINGLE_GRID_WORK_BALANCING: - actx_class = SingleGridWorkBalancingPytatoArrayContext + if fusion: + if not _HAVE_FUSION_ACTX: + warn("No device-parallel actx available, execution will be slow. " + "Please make sure you have the right branches for loopy " + "(https://github.com/kaushikcfd/loopy/tree/pytato-array-context-transforms) " # noqa + "and meshmode " + "(https://github.com/kaushikcfd/meshmode/tree/pytato-array-context-transforms).") # noqa + # lazy+fusion, non-distributed + + if _HAVE_FUSION_ACTX: + if distributed: + actx_class = MPIFusionContractorArrayContext + else: + actx_class = FusionContractorArrayContext else: - actx_class = PytatoPyOpenCLArrayContext - # distributed+lazy: - if _HAVE_SINGLE_GRID_WORK_BALANCING: - actx_class = MPISingleGridWorkBalancingPytatoArrayContext + actx_class = _get_single_grid_pytato_actx_class(distributed) else: - actx_class = MPIBasePytatoPyOpenCLArrayContext + actx_class = _get_single_grid_pytato_actx_class(distributed) else: + if fusion: + raise ValueError("No eager actx's support op-fusion.") if distributed: actx_class = MPIPyOpenCLArrayContext else: @@ -422,7 +508,7 @@ def get_reasonable_array_context_class( "device-parallel=%r", actx_class.__name__, lazy, distributed, # eager is always device-parallel: - (_HAVE_SINGLE_GRID_WORK_BALANCING or not lazy)) + (_HAVE_SINGLE_GRID_WORK_BALANCING or _HAVE_FUSION_ACTX or not lazy)) return actx_class # }}} diff --git a/grudge/discretization.py b/grudge/discretization.py index 3a86b7060..43bd24226 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -734,10 +734,9 @@ def normal(self, dd): :arg dd: a :class:`~grudge.dof_desc.DOFDesc` as the surface discretization. :returns: an object array of frozen :class:`~meshmode.dof_array.DOFArray`\ s. """ - from arraycontext import freeze from grudge.geometry import normal - return freeze(normal(self._setup_actx, self, dd)) + return self._setup_actx.freeze(normal(self._setup_actx, self, dd)) # }}} diff --git a/grudge/dt_utils.py b/grudge/dt_utils.py index e245ad259..afda657f0 100644 --- a/grudge/dt_utils.py +++ b/grudge/dt_utils.py @@ -45,8 +45,11 @@ import numpy as np -from arraycontext import ArrayContext, thaw, freeze, Scalar -from meshmode.transform_metadata import FirstAxisIsElementsTag +from arraycontext import ArrayContext, Scalar, tag_axes +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationElementAxisTag) from grudge.dof_desc import DD_VOLUME, DOFDesc, as_dofdesc from grudge.discretization import DiscretizationCollection @@ -90,7 +93,7 @@ def characteristic_lengthscales( @memoize_in(dcoll, (characteristic_lengthscales, "compute_characteristic_lengthscales")) def _compute_characteristic_lengthscales(): - return freeze( + return actx.freeze( DOFArray( actx, data=tuple( @@ -99,12 +102,12 @@ def _compute_characteristic_lengthscales(): cng * geo_facts for cng, geo_facts in zip( dt_non_geometric_factors(dcoll), - thaw(dt_geometric_factors(dcoll), actx) + actx.thaw(dt_geometric_factors(dcoll)) ) ) ) ) - return thaw(_compute_characteristic_lengthscales(), actx) + return actx.thaw(_compute_characteristic_lengthscales()) @memoize_on_first_arg @@ -267,7 +270,7 @@ def dt_geometric_factors( if dcoll.dim == 1: # Inscribed "circle" radius is half the cell size - return freeze(cell_vols/2) + return actx.freeze(cell_vols/2) dd_face = DOFDesc("all_faces", dd.discretization_tag) face_discr = dcoll.discr_from_dd(dd_face) @@ -287,15 +290,18 @@ def dt_geometric_factors( data=tuple( actx.einsum( "fej->e", - face_ae_i.reshape( - vgrp.mesh_el_group.nfaces, - vgrp.nelements, - face_ae_i.shape[-1]), + tag_axes(actx, { + 0: DiscretizationFaceAxisTag(), + 1: DiscretizationElementAxisTag(), + 2: DiscretizationDOFAxisTag() + }, + face_ae_i.reshape( + vgrp.mesh_el_group.nfaces, + vgrp.nelements, + face_ae_i.shape[-1])), tagged=(FirstAxisIsElementsTag(),)) - for vgrp, face_ae_i in zip(volm_discr.groups, face_areas) - ) - ) + for vgrp, face_ae_i in zip(volm_discr.groups, face_areas))) else: surface_areas = DOFArray( actx, @@ -316,20 +322,17 @@ def dt_geometric_factors( tagged=(FirstAxisIsElementsTag(),)) for vgrp, afgrp, face_ae_i in zip(volm_discr.groups, - face_discr.groups, - face_areas) + face_discr.groups, + face_areas) ) ) - return freeze(DOFArray( + return actx.freeze(DOFArray( actx, data=tuple( - actx.einsum("e,ei->ei", - 1/sae_i, - cv_i, - tagged=(FirstAxisIsElementsTag(),)) * dcoll.dim - - for cv_i, sae_i in zip(cell_vols, surface_areas) + actx.einsum("e,ei->ei", 1/sae_i, cv_i, + tagged=(FirstAxisIsElementsTag(),)) * dcoll.dim + for cv_i, sae_i, vgrp in zip(cell_vols, surface_areas, volm_discr.groups) ) )) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 492d8b7f7..9110201be 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -60,7 +60,7 @@ import numpy as np -from arraycontext import thaw, freeze, ArrayContext +from arraycontext import ArrayContext, tag_axes from meshmode.dof_array import DOFArray from grudge import DiscretizationCollection @@ -70,6 +70,10 @@ DD_VOLUME, DOFDesc, DISCR_TAG_BASE ) +from meshmode.transform_metadata import (DiscretizationAmbientDimAxisTag, + DiscretizationTopologicalDimAxisTag) + + from pymbolic.geometric_algebra import MultiVector from pytools.obj_array import make_obj_array @@ -169,7 +173,7 @@ def forward_metric_nth_derivative( vec = num_reference_derivative( dcoll.discr_from_dd(inner_dd), flat_ref_axes, - thaw(dcoll.discr_from_dd(inner_dd).nodes(), actx)[xyz_axis] + actx.thaw(dcoll.discr_from_dd(inner_dd).nodes())[xyz_axis] ) return _geometry_to_quad_if_requested( @@ -513,17 +517,21 @@ def _inv_surf_metric_deriv(): multiplier = 1 mat = actx.np.stack([ - actx.np.stack([ - multiplier - * inverse_surface_metric_derivative(actx, dcoll, - rst_axis, xyz_axis, dd=dd, - _use_geoderiv_connection=_use_geoderiv_connection) - for rst_axis in range(dcoll.dim)]) - for xyz_axis in range(dcoll.ambient_dim)]) + actx.np.stack([ + multiplier + * inverse_surface_metric_derivative( + actx, dcoll, + rst_axis, xyz_axis, dd=dd, + _use_geoderiv_connection=_use_geoderiv_connection) + for rst_axis in range(dcoll.dim)]) + for xyz_axis in range(dcoll.ambient_dim)]) - return freeze(mat, actx) + return actx.freeze(tag_axes(actx, { + 0: DiscretizationAmbientDimAxisTag(), + 1: DiscretizationTopologicalDimAxisTag()}, + mat)) - return thaw(_inv_surf_metric_deriv(), actx) + return actx.thaw(_inv_surf_metric_deriv()) def _signed_face_ones( @@ -541,13 +549,13 @@ def _signed_face_ones( actx, dtype=dcoll.real_dtype ) + 1 - from arraycontext import to_numpy, from_numpy, thaw + from arraycontext import to_numpy, from_numpy _signed_face_ones_numpy = to_numpy(signed_ones, actx) for igrp, grp in enumerate(all_faces_conn.groups): for batch in grp.batches: - i = to_numpy(thaw(batch.to_element_indices, actx), actx) + i = to_numpy(actx.thaw(batch.to_element_indices), actx) grp_field = _signed_face_ones_numpy[igrp].reshape(-1) grp_field[i] = \ (2.0 * (batch.to_element_face % 2) - 1.0) * grp_field[i] @@ -636,9 +644,9 @@ def _area_elements(): actx, dcoll, dd=dd, _use_geoderiv_connection=_use_geoderiv_connection).norm_squared()) - return freeze(result, actx) + return actx.freeze(result) - return thaw(_area_elements(), actx) + return actx.thaw(_area_elements()) # }}} @@ -743,10 +751,9 @@ def _normal(): result = mv / actx.np.sqrt(mv.norm_squared()) - return freeze(result, actx) + return actx.freeze(result) - n = _normal() - return thaw(n, actx) + return actx.thaw(_normal()) def normal(actx: ArrayContext, dcoll: DiscretizationCollection, dd, diff --git a/grudge/models/advection.py b/grudge/models/advection.py index aacd2fcfd..cfe1a4920 100644 --- a/grudge/models/advection.py +++ b/grudge/models/advection.py @@ -30,8 +30,6 @@ import grudge.op as op import types -from arraycontext.container.traversal import thaw - from grudge.models import HyperbolicOperator @@ -43,7 +41,7 @@ def advection_weak_flux(dcoll, flux_type, u_tpair, velocity): """ actx = u_tpair.int.array_context dd = u_tpair.dd - normal = thaw(dcoll.normal(dd), actx) + normal = actx.thaw(dcoll.normal(dd)) v_dot_n = np.dot(velocity, normal) flux_type = flux_type.lower() @@ -92,7 +90,7 @@ class StrongAdvectionOperator(AdvectionOperatorBase): def flux(self, u_tpair): actx = u_tpair.int.array_context dd = u_tpair.dd - normal = thaw(self.dcoll.normal(dd), actx) + normal = actx.thaw(self.dcoll.normal(dd)) v_dot_normal = np.dot(self.v, normal) return u_tpair.int * v_dot_normal - self.weak_flux(u_tpair) @@ -285,7 +283,7 @@ def v_dot_n_tpair(actx, dcoll, velocity, trace_dd): from grudge.trace_pair import TracePair from meshmode.discretization.connection import FACE_RESTR_INTERIOR - normal = thaw(dcoll.normal(trace_dd.with_discr_tag(None)), actx) + normal = actx.thaw(dcoll.normal(trace_dd.with_discr_tag(None))) v_dot_n = velocity.dot(normal) i = op.project(dcoll, trace_dd.with_discr_tag(None), trace_dd, v_dot_n) diff --git a/grudge/models/em.py b/grudge/models/em.py index e14341cfa..7bc952437 100644 --- a/grudge/models/em.py +++ b/grudge/models/em.py @@ -28,7 +28,7 @@ """ -from arraycontext import thaw, get_container_context_recursively +from arraycontext import get_container_context_recursively from grudge.models import HyperbolicOperator @@ -121,7 +121,7 @@ def flux(self, wtpair): """ actx = get_container_context_recursively(wtpair) - normal = thaw(self.dcoll.normal(wtpair.dd), actx) + normal = actx.thaw(self.dcoll.normal(wtpair.dd)) if self.fixed_material: e, h = self.split_eh(wtpair) @@ -222,7 +222,7 @@ def absorbing_bc(self, w): """ actx = get_container_context_recursively(w) - absorb_normal = thaw(self.dcoll.normal(dd=self.absorb_tag), actx) + absorb_normal = actx.thaw(self.dcoll.normal(dd=self.absorb_tag)) e, h = self.split_eh(w) diff --git a/grudge/models/euler.py b/grudge/models/euler.py index 98acbc1ef..f4d6f8f4c 100644 --- a/grudge/models/euler.py +++ b/grudge/models/euler.py @@ -50,7 +50,6 @@ from dataclasses import dataclass from arraycontext import ( - thaw, dataclass_array_container, with_container_arithmetic ) @@ -191,7 +190,7 @@ def boundary_tpair( return TracePair( dd_bc, interior=op.project(dcoll, dd_base, dd_bc, state), - exterior=self.prescribed_state(thaw(dcoll.nodes(dd_bc), actx), t=t) + exterior=self.prescribed_state(actx.thaw(dcoll.nodes(dd_bc)), t=t) ) @@ -204,7 +203,7 @@ def boundary_tpair( state: ConservedEulerField, t=0): actx = state.array_context dd_base = as_dofdesc("vol").with_discr_tag(DISCR_TAG_BASE) - nhat = thaw(dcoll.normal(dd_bc), actx) + nhat = actx.thaw(dcoll.normal(dd_bc)) interior = op.project(dcoll, dd_base, dd_bc, state) return TracePair( @@ -271,7 +270,7 @@ def euler_numerical_flux( exterior=euler_volume_flux(dcoll, q_rr, gamma=gamma) ) num_flux = flux_tpair.avg - normal = thaw(dcoll.normal(dd_intfaces), actx) + normal = actx.thaw(dcoll.normal(dd_intfaces)) if lf_stabilization: from arraycontext import outer diff --git a/grudge/models/wave.py b/grudge/models/wave.py index 209575ba3..ff8ee57e8 100644 --- a/grudge/models/wave.py +++ b/grudge/models/wave.py @@ -28,8 +28,6 @@ import numpy as np -from arraycontext import thaw, freeze - from grudge.models import HyperbolicOperator from meshmode.mesh import BTAG_ALL, BTAG_NONE @@ -93,7 +91,7 @@ def flux(self, wtpair): u = wtpair[0] v = wtpair[1:] actx = u.int.array_context - normal = thaw(self.dcoll.normal(wtpair.dd), actx) + normal = actx.thaw(self.dcoll.normal(wtpair.dd)) central_flux_weak = -self.c*flat_obj_array( np.dot(v.avg, normal), @@ -136,7 +134,7 @@ def operator(self, t, w): neu_bc = flat_obj_array(neu_u, -neu_v) # radiation BCs ------------------------------------------------------- - rad_normal = thaw(dcoll.normal(dd=self.radiation_tag), actx) + rad_normal = actx.thaw(dcoll.normal(dd=self.radiation_tag)) rad_u = op.project(dcoll, "vol", self.radiation_tag, u) rad_v = op.project(dcoll, "vol", self.radiation_tag, v) @@ -214,7 +212,7 @@ def __init__(self, dcoll, c, source_f=None, neumann_tag=BTAG_NONE, radiation_tag=BTAG_NONE): """ - :arg c: a thawed (with *actx*) :class:`~meshmode.dof_array.DOFArray` + :arg c: a frozen :class:`~meshmode.dof_array.DOFArray` representing the propogation speed of the wave. """ @@ -223,11 +221,13 @@ def __init__(self, dcoll, c, source_f=None, actx = c.array_context self.dcoll = dcoll - self.c = freeze(c) + self.c = c self.source_f = source_f ones = dcoll.zeros(actx) + 1 - self.sign = freeze(actx.np.where(actx.np.greater(c, 0), ones, -ones)) + thawed_c = dcoll._setup_actx.thaw(c) + self.sign = dcoll._setup_actx.freeze( + actx.np.where(actx.np.greater(thawed_c, 0), ones, -ones)) self.dirichlet_tag = dirichlet_tag self.neumann_tag = neumann_tag @@ -242,7 +242,7 @@ def flux(self, wtpair): u = wtpair[1] v = wtpair[2:] actx = u.int.array_context - normal = thaw(self.dcoll.normal(wtpair.dd), actx) + normal = actx.thaw(self.dcoll.normal(wtpair.dd)) flux_central_weak = -0.5 * flat_obj_array( np.dot(v.int*c.int + v.ext*c.ext, normal), @@ -266,7 +266,7 @@ def operator(self, t, w): v = w[1:] actx = u.array_context - c = thaw(self.c, actx) + c = actx.thaw(self.c) flux_w = flat_obj_array(c, w) @@ -294,13 +294,12 @@ def operator(self, t, w): neu_bc = flat_obj_array(neu_c, neu_u, -neu_v) # radiation BCs ------------------------------------------------------- - rad_normal = thaw(dcoll.normal(dd=self.radiation_tag), actx) + rad_normal = actx.thaw(dcoll.normal(dd=self.radiation_tag)) rad_c = op.project(dcoll, "vol", self.radiation_tag, c) rad_u = op.project(dcoll, "vol", self.radiation_tag, u) rad_v = op.project(dcoll, "vol", self.radiation_tag, v) - rad_sign = op.project(dcoll, "vol", self.radiation_tag, - thaw(self.sign, actx)) + rad_sign = op.project(dcoll, "vol", self.radiation_tag, actx.thaw(self.sign)) rad_bc = flat_obj_array( rad_c, @@ -345,7 +344,7 @@ def check_bc_coverage(self, mesh): self.radiation_tag]) def max_characteristic_velocity(self, actx, **kwargs): - return actx.np.abs(thaw(self.c, actx)) + return actx.np.abs(actx.thaw(self.c)) # }}} diff --git a/grudge/op.py b/grudge/op.py index d544d7ca6..0f4d63052 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -57,13 +57,16 @@ """ -from arraycontext import ArrayContext, map_array_container +from arraycontext import ArrayContext, map_array_container, tag_axes from arraycontext.container import ArrayOrContainerT from functools import partial from meshmode.dof_array import DOFArray -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, + DiscretizationElementAxisTag, + DiscretizationFaceAxisTag) from grudge.discretization import DiscretizationCollection @@ -168,8 +171,7 @@ def _single_axis_derivative_kernel( get_diff_mat( actx, out_element_group=out_grp, - in_element_group=in_grp - ), + in_element_group=in_grp), vec_i, arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), tagged=(FirstAxisIsElementsTag(),)) @@ -222,12 +224,10 @@ def _reference_derivative_matrices(actx: ArrayContext, def get_ref_derivative_mats(grp): from meshmode.discretization.poly_element import diff_matrices return actx.freeze( - actx.from_numpy( - np.asarray( - [dfmat for dfmat in diff_matrices(grp)] - ) - ) - ) + actx.tag_axis( + 1, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray([dfmat for dfmat in diff_matrices(grp)])))) return get_ref_derivative_mats(out_element_group) @@ -346,12 +346,10 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): mmat = mass_matrix(out_grp) return actx.freeze( - actx.from_numpy( - np.asarray( - [dmat.T @ mmat.T for dmat in diff_matrices(out_grp)] - ) - ) - ) + actx.tag_axis(1, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + [dmat.T @ mmat.T for dmat in diff_matrices(out_grp)])))) from modepy import vandermonde basis = out_grp.basis_obj() @@ -568,13 +566,11 @@ def get_ref_mass_mat(out_grp, in_grp): weights = in_grp.quadrature_rule().weights return actx.freeze( - actx.from_numpy( - np.asarray( - np.einsum("j,ik,jk->ij", weights, vand_inv_t, o_vand), - order="C" - ) - ) - ) + actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + np.einsum("j,ik,jk->ij", weights, vand_inv_t, o_vand), + order="C")))) return get_ref_mass_mat(out_element_group, in_element_group) @@ -598,15 +594,15 @@ def _apply_mass_operator( actx, data=tuple( actx.einsum("ij,ej,ej->ei", - reference_mass_matrix( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), - ae_i, - vec_i, - arg_names=("mass_mat", "jac", "vec"), - tagged=(FirstAxisIsElementsTag(),)) + reference_mass_matrix( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ), + ae_i, + vec_i, + arg_names=("mass_mat", "jac", "vec"), + tagged=(FirstAxisIsElementsTag(),)) for in_grp, out_grp, ae_i, vec_i in zip( in_discr.groups, out_discr.groups, area_elements, vec) @@ -664,13 +660,11 @@ def get_ref_inv_mass_mat(grp): basis = grp.basis_obj() return actx.freeze( - actx.from_numpy( - np.asarray( - inverse_mass_matrix(basis.functions, grp.unit_nodes), - order="C" - ) - ) - ) + actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy( + np.asarray( + inverse_mass_matrix(basis.functions, grp.unit_nodes), + order="C")))) return get_ref_inv_mass_mat(element_group) @@ -695,21 +689,15 @@ def _apply_inverse_mass_operator( discr = dcoll.discr_from_dd(dd_in) inv_area_elements = 1./area_element(actx, dcoll, dd=dd_in, _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) - group_data = [] - for grp, jac_inv, vec_i in zip(discr.groups, inv_area_elements, vec): - - ref_mass_inverse = reference_inverse_mass_matrix(actx, - element_group=grp) - - group_data.append( + group_data = [ # Based on https://arxiv.org/pdf/1608.03836.pdf # true_Minv ~ ref_Minv * ref_M * (1/jac_det) * ref_Minv actx.einsum("ei,ij,ej->ei", jac_inv, - ref_mass_inverse, + reference_inverse_mass_matrix(actx, element_group=grp), vec_i, tagged=(FirstAxisIsElementsTag(),)) - ) + for grp, jac_inv, vec_i in zip(discr.groups, inv_area_elements, vec)] return DOFArray(actx, data=tuple(group_data)) @@ -840,7 +828,12 @@ def get_ref_face_mass_mat(face_grp, vol_grp): vol_grp.unit_nodes, ) - return actx.freeze(actx.from_numpy(matrix)) + return actx.freeze( + tag_axes(actx, { + 0: DiscretizationDOFAxisTag(), + 2: DiscretizationDOFAxisTag() + }, + actx.from_numpy(matrix))) return get_ref_face_mass_mat(face_element_group, vol_element_group) @@ -867,27 +860,27 @@ def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec): data=tuple( actx.einsum("ifj,fej,fej->ei", reference_face_mass_matrix( - actx, - face_element_group=afgrp, - vol_element_group=vgrp, - dtype=dtype), - surf_ae_i.reshape( + actx, + face_element_group=afgrp, + vol_element_group=vgrp, + dtype=dtype), + actx.tag_axis(1, DiscretizationElementAxisTag(), + surf_ae_i.reshape( vgrp.mesh_el_group.nfaces, vgrp.nelements, - surf_ae_i.shape[-1]), - vec_i.reshape( + surf_ae_i.shape[-1])), + actx.tag_axis(0, DiscretizationFaceAxisTag(), + vec_i.reshape( vgrp.mesh_el_group.nfaces, vgrp.nelements, - afgrp.nunit_dofs), + afgrp.nunit_dofs)), arg_names=("ref_face_mass_mat", "jac_surf", "vec"), tagged=(FirstAxisIsElementsTag(),)) for vgrp, afgrp, vec_i, surf_ae_i in zip(volm_discr.groups, face_discr.groups, vec, - surf_area_elements) - ) - ) + surf_area_elements))) def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainerT: diff --git a/grudge/reductions.py b/grudge/reductions.py index 1efeff4e9..936467222 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -72,6 +72,7 @@ from pytools import memoize_in from meshmode.dof_array import DOFArray +from meshmode.transform_metadata import DiscretizationDOFAxisTag import numpy as np import grudge.dof_desc as dof_desc @@ -364,13 +365,12 @@ def elementwise_prg(): "iel": ConcurrentElementInameTag(), "idof": ConcurrentDOFInameTag()}) - return DOFArray( - actx, - data=tuple( - actx.call_loopy(elementwise_prg(), operand=vec_i)["result"] - for vec_i in vec - ) - ) + return actx.tag_axis(1, DiscretizationDOFAxisTag(), + DOFArray( + actx, + data=tuple( + actx.call_loopy(elementwise_prg(), operand=vec_i)["result"] + for vec_i in vec))) def elementwise_sum( diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 1fd5404bd..a7e5bdc60 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -419,7 +419,8 @@ def communicate_single_array(key, local_bdry_ary): local_bdry_ary, dest_rank=remote_rank, comm_tag=ary_tag, stapled_to=make_distributed_recv( src_rank=remote_rank, comm_tag=ary_tag, - shape=local_bdry_ary.shape, dtype=local_bdry_ary.dtype)) + shape=local_bdry_ary.shape, dtype=local_bdry_ary.dtype, + axes=local_bdry_ary.axes)) from arraycontext.container.traversal import rec_keyed_map_array_container self.remote_data = rec_keyed_map_array_container( diff --git a/test/test_dt_utils.py b/test/test_dt_utils.py index 63eb4b4c3..95d7e4e35 100644 --- a/test/test_dt_utils.py +++ b/test/test_dt_utils.py @@ -24,8 +24,6 @@ import numpy as np -from arraycontext import thaw - from grudge.array_context import ( PytestPyOpenCLArrayContextFactory, PytestPytatoPyOpenCLArrayContextFactory @@ -76,7 +74,7 @@ def test_geometric_factors_regular_refinement(actx_factory, name): dcoll = DiscretizationCollection(actx, mesh, order=builder.order) min_factors.append( actx.to_numpy( - op.nodal_min(dcoll, "vol", thaw(dt_geometric_factors(dcoll), actx))) + op.nodal_min(dcoll, "vol", actx.thaw(dt_geometric_factors(dcoll)))) ) # Resolution is doubled each refinement, so the ratio of consecutive @@ -88,7 +86,7 @@ def test_geometric_factors_regular_refinement(actx_factory, name): # Make sure it works with empty meshes mesh = builder.get_mesh(0, builder.mesh_order) dcoll = DiscretizationCollection(actx, mesh, order=builder.order) - factors = thaw(dt_geometric_factors(dcoll), actx) # noqa: F841 + factors = actx.thaw(dt_geometric_factors(dcoll)) # noqa: F841 @pytest.mark.parametrize("name", ["interval", "box2d", "box3d"]) diff --git a/test/test_euler_model.py b/test/test_euler_model.py index 3f5174e26..13a479d48 100644 --- a/test/test_euler_model.py +++ b/test/test_euler_model.py @@ -27,7 +27,6 @@ from grudge.array_context import PytestPyOpenCLArrayContextFactory from arraycontext import ( pytest_generate_tests_for_array_contexts, - thaw, freeze ) pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) @@ -82,7 +81,7 @@ def test_euler_vortex_convergence(actx_factory, order): discr_tag_to_group_factory=discr_tag_to_group_factory ) h_max = actx.to_numpy(h_max_from_volume(dcoll, dim=dcoll.ambient_dim)) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) # }}} @@ -115,7 +114,7 @@ def rhs(t, q): t = 0.0 last_q = None while t < final_time: - fields = thaw(freeze(fields, actx), actx) + fields = actx.thaw(actx.freeze(fields)) fields = rk4_step(fields, t, dt, compiled_rhs) t += dt logger.info("[%04d] t = %.5f", step, t) diff --git a/test/test_grudge.py b/test/test_grudge.py index 6a42167ae..ce0b199b8 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -31,8 +31,6 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) -from arraycontext.container.traversal import thaw - from meshmode import _acf # noqa: F401 from meshmode.dof_array import flat_norm @@ -92,12 +90,12 @@ def f(x): return actx.np.sin(x[0])**2 volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - x_volm = thaw(volm_disc.nodes(), actx) + x_volm = actx.thaw(volm_disc.nodes()) f_volm = f(x_volm) ones_volm = volm_disc.zeros(actx) + 1 quad_disc = dcoll.discr_from_dd(dd_quad) - x_quad = thaw(quad_disc.nodes(), actx) + x_quad = actx.thaw(quad_disc.nodes()) f_quad = f(x_quad) ones_quad = quad_disc.zeros(actx) + 1 @@ -271,7 +269,7 @@ def f(x): return actx.np.cos(4.0 * x[0]) dd = dof_desc.DD_VOLUME - x_volm = thaw(volume_discr.nodes(), actx) + x_volm = actx.thaw(volume_discr.nodes()) f_volm = f(x_volm) f_inv = op.inverse_mass( dcoll, op.mass(dcoll, dd, f_volm) @@ -342,7 +340,7 @@ def test_face_normal_surface(actx_factory, mesh_name): ) surf_normal = surf_normal / actx.np.sqrt(sum(surf_normal**2)) - face_normal_i = thaw(dcoll.normal(df), actx) + face_normal_i = actx.thaw(dcoll.normal(df)) face_normal_e = dcoll.opposite_face_connection()(face_normal_i) if mesh.ambient_dim == 3: @@ -408,7 +406,7 @@ def df(x, axis): dcoll = DiscretizationCollection(actx, mesh, order=4) volume_discr = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - x = thaw(volume_discr.nodes(), actx) + x = actx.thaw(volume_discr.nodes()) for axis in range(dim): df_num = op.local_grad(dcoll, f(x, axis))[axis] @@ -450,7 +448,7 @@ def test_2d_gauss_theorem(actx_factory): dcoll = DiscretizationCollection(actx, mesh, order=2) volm_disc = dcoll.discr_from_dd(dof_desc.DD_VOLUME) - x_volm = thaw(volm_disc.nodes(), actx) + x_volm = actx.thaw(volm_disc.nodes()) def f(x): return flat_obj_array( @@ -462,7 +460,7 @@ def f(x): int_1 = op.integral(dcoll, "vol", op.local_div(dcoll, f_volm)) prj_f = op.project(dcoll, "vol", BTAG_ALL, f_volm) - normal = thaw(dcoll.normal(BTAG_ALL), actx) + normal = actx.thaw(dcoll.normal(BTAG_ALL)) int_2 = op.integral(dcoll, BTAG_ALL, prj_f.dot(normal)) assert abs(int_1 - int_2) < 1e-13 @@ -564,14 +562,14 @@ def f(x): ambient_dim = dcoll.ambient_dim # variables - f_num = f(thaw(dcoll.nodes(dd=dd), actx)) - f_quad_num = f(thaw(dcoll.nodes(dd=dq), actx)) + f_num = f(actx.thaw(dcoll.nodes(dd=dd))) + f_quad_num = f(actx.thaw(dcoll.nodes(dd=dq))) from grudge.geometry import normal, summed_curvature kappa = summed_curvature(actx, dcoll, dd=dq) normal = normal(actx, dcoll, dd=dq) - face_normal = thaw(dcoll.normal(df), actx) + face_normal = actx.thaw(dcoll.normal(df)) face_f = op.project(dcoll, dd, df, f_num) # operators @@ -713,12 +711,12 @@ def u_analytic(x, t=0): "weak": WeakAdvectionOperator}[op_type] adv_operator = op_class(dcoll, v, inflow_u=lambda t: u_analytic( - thaw(dcoll.nodes(dd=BTAG_ALL), actx), + actx.thaw(dcoll.nodes(dd=BTAG_ALL)), t=t ), flux_type=flux_type) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) u = u_analytic(nodes, t=0) def rhs(t, u): @@ -812,7 +810,7 @@ def test_convergence_maxwell(actx_factory, order): def analytic_sol(x, t=0): return get_rectangular_cavity_mode(actx, x, t, 1, (1, 2, 2)) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) fields = analytic_sol(nodes, t=0) from grudge.models.em import MaxwellOperator @@ -909,7 +907,7 @@ def conv_test(descr, use_quad): discr_tag_to_group_factory=discr_tag_to_group_factory ) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) def zero_inflow(dtag, t=0): dd = dof_desc.DOFDesc(dtag, qtag) @@ -959,7 +957,7 @@ def test_bessel(actx_factory): dcoll = DiscretizationCollection(actx, mesh, order=3) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) r = actx.np.sqrt(nodes[0]**2 + nodes[1]**2) # FIXME: Bessel functions need to brought out of the symbolic @@ -990,7 +988,7 @@ def test_norm_real(actx_factory, p): a=(0,)*dim, b=(1,)*dim, nelements_per_axis=(8,)*dim, order=1) dcoll = DiscretizationCollection(actx, mesh, order=4) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) norm = op.norm(dcoll, nodes[0], p) if p == 2: @@ -1011,7 +1009,7 @@ def test_norm_complex(actx_factory, p): a=(0,)*dim, b=(1,)*dim, nelements_per_axis=(8,)*dim, order=1) dcoll = DiscretizationCollection(actx, mesh, order=4) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) norm = op.norm(dcoll, (1 + 1j)*nodes[0], p) if p == 2: @@ -1032,7 +1030,7 @@ def test_norm_obj_array(actx_factory, p): a=(0,)*dim, b=(1,)*dim, nelements_per_axis=(8,)*dim, order=1) dcoll = DiscretizationCollection(actx, mesh, order=4) - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) norm = op.norm(dcoll, nodes, p) diff --git a/test/test_modal_connections.py b/test/test_modal_connections.py index 4f4354b26..b633036f3 100644 --- a/test/test_modal_connections.py +++ b/test/test_modal_connections.py @@ -25,7 +25,6 @@ from arraycontext import pytest_generate_tests_for_array_contexts pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) -from arraycontext import thaw from meshmode.discretization.poly_element import ( # Simplex group factories @@ -76,7 +75,7 @@ def f(x): dd_modal = dof_desc.DD_VOLUME_MODAL dd_volume = dof_desc.DD_VOLUME - x_nodal = thaw(dcoll.discr_from_dd(dd_volume).nodes()[0], actx) + x_nodal = actx.thaw(dcoll.discr_from_dd(dd_volume).nodes()[0]) nodal_f = f(x_nodal) # Map nodal coefficients of f to modal coefficients @@ -120,7 +119,7 @@ def f(x): dd_quad = dof_desc.DOFDesc(dof_desc.DTAG_VOLUME_ALL, dof_desc.DISCR_TAG_QUAD) - x_quad = thaw(dcoll.discr_from_dd(dd_quad).nodes()[0], actx) + x_quad = actx.thaw(dcoll.discr_from_dd(dd_quad).nodes()[0]) quad_f = f(x_quad) # Map nodal coefficients of f to modal coefficients diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 8e053f6fd..47100b415 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -32,7 +32,6 @@ import sys from grudge.array_context import MPIPyOpenCLArrayContext, MPIPytatoArrayContext -from arraycontext.container.traversal import thaw, freeze logger = logging.getLogger(__name__) logging.basicConfig() @@ -136,7 +135,7 @@ def _test_func_comparison_mpi_communication_entrypoint(actx): dcoll = DiscretizationCollection(actx, local_mesh, order=5) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) myfunc = actx.np.sin(np.dot(x, [2, 3])) from grudge.dof_desc import as_dofdesc @@ -214,7 +213,7 @@ def source_f(actx, dcoll, t=0): source_center = np.array([0.1, 0.22, 0.33])[:dcoll.dim] source_width = 0.05 source_omega = 3 - nodes = thaw(dcoll.nodes(), actx) + nodes = actx.thaw(dcoll.nodes()) source_center_dist = flat_obj_array( [nodes[i] - source_center[i] for i in range(dcoll.dim)] ) @@ -282,7 +281,7 @@ def rhs(t, w): for step in range(nsteps): t = step*dt fields = rk4_step(fields, t=t, h=dt, f=compiled_rhs) - fields = thaw(freeze(fields, actx), actx) + fields = actx.thaw(actx.freeze(fields)) norm = actx.to_numpy(op.norm(dcoll, fields, 2)) logger.info("[%04d] t = %.5e |u| = %.5e elapsed %.5e", diff --git a/test/test_op.py b/test/test_op.py index 364a988f5..0eaac4f0e 100644 --- a/test/test_op.py +++ b/test/test_op.py @@ -37,8 +37,6 @@ pytest_generate_tests = pytest_generate_tests_for_array_contexts( [PytestPyOpenCLArrayContextFactory]) -from arraycontext.container.traversal import thaw - import logging logger = logging.getLogger(__name__) @@ -89,7 +87,7 @@ def grad_f(x): result[dim-1] = result[dim-1] * (-np.pi/2*actx.np.sin(np.pi/2*x[dim-1])) return result - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) if vectorize: u = make_obj_array([(i+1)*f(x) for i in range(dim)]) @@ -99,7 +97,7 @@ def grad_f(x): def get_flux(u_tpair): dd = u_tpair.dd dd_allfaces = dd.with_dtag("all_faces") - normal = thaw(dcoll.normal(dd), actx) + normal = actx.thaw(dcoll.normal(dd)) u_avg = u_tpair.avg if vectorize: if nested: @@ -213,7 +211,7 @@ def div_f(x): result = result + deriv return result - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) if vectorize: u = make_obj_array([(i+1)*f(x) for i in range(dim)]) @@ -225,7 +223,7 @@ def div_f(x): def get_flux(u_tpair): dd = u_tpair.dd dd_allfaces = dd.with_dtag("all_faces") - normal = thaw(dcoll.normal(dd), actx) + normal = actx.thaw(dcoll.normal(dd)) flux = u_tpair.avg @ normal return op.project(dcoll, dd, dd_allfaces, flux) diff --git a/test/test_reductions.py b/test/test_reductions.py index d7f2f0b8c..9e7387bad 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -27,7 +27,6 @@ from dataclasses import dataclass from arraycontext import ( - thaw, with_container_arithmetic, dataclass_array_container, pytest_generate_tests_for_array_contexts @@ -70,7 +69,7 @@ def test_nodal_reductions(actx_factory, mesh_size, with_initial): mesh = builder.get_mesh(mesh_size, builder.mesh_order) dcoll = DiscretizationCollection(actx, mesh, order=builder.order) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) def f(x): return -actx.np.sin(10*x[0]) @@ -137,7 +136,7 @@ def test_elementwise_reductions(actx_factory): nelements = 4 mesh = builder.get_mesh(nelements, builder.mesh_order) dcoll = DiscretizationCollection(actx, mesh, order=builder.order) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) def f(x): return actx.np.sin(x[0]) @@ -197,7 +196,7 @@ def test_nodal_reductions_with_container(actx_factory): mesh = builder.get_mesh(4, builder.mesh_order) dcoll = DiscretizationCollection(actx, mesh, order=builder.order) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) def f(x): return -actx.np.sin(10*x[0]) * actx.np.cos(2*x[1]) @@ -245,7 +244,7 @@ def test_elementwise_reductions_with_container(actx_factory): nelements = 4 mesh = builder.get_mesh(nelements, builder.mesh_order) dcoll = DiscretizationCollection(actx, mesh, order=builder.order) - x = thaw(dcoll.nodes(), actx) + x = actx.thaw(dcoll.nodes()) def f(x): return actx.np.sin(x[0]) * actx.np.sin(x[1])