From 42ab3180c591f3b015882f1d337fed7845fa5d72 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 17 Nov 2021 19:08:33 -0600 Subject: [PATCH] add discretization entity tags to array axes --- grudge/dt_utils.py | 35 ++++++++++++++++----- grudge/geometry/metrics.py | 28 +++++++++++------ grudge/op.py | 64 ++++++++++++++++++++++++-------------- 3 files changed, 87 insertions(+), 40 deletions(-) diff --git a/grudge/dt_utils.py b/grudge/dt_utils.py index b41c683cd..7342d7be6 100644 --- a/grudge/dt_utils.py +++ b/grudge/dt_utils.py @@ -46,7 +46,10 @@ import numpy as np from arraycontext import ArrayContext, thaw, freeze, DeviceScalar -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, + DiscretizationFaceAxisTag, + DiscretizationElementAxisTag) from grudge.dof_desc import DD_VOLUME, DOFDesc, as_dofdesc from grudge.discretization import DiscretizationCollection @@ -287,11 +290,25 @@ def dt_geometric_factors( data=tuple( actx.einsum( "fej->e", - face_ae_i.reshape( - vgrp.mesh_el_group.nfaces, vgrp.nelements, -1), + actx.tag_axis( + 0, + DiscretizationFaceAxisTag.from_group(vgrp), + actx.tag_axis( + 1, + DiscretizationElementAxisTag.from_group(vgrp), + actx.tag_axis( + 2, + DiscretizationDOFAxisTag.from_group(fgrp), + face_ae_i.reshape( + vgrp.mesh_el_group.nfaces, + vgrp.nelements, + -1 + )))), tagged=(FirstAxisIsElementsTag(),)) - for vgrp, face_ae_i in zip(volm_discr.groups, face_areas) + for vgrp, fgrp, face_ae_i in zip(volm_discr.groups, + face_discr.groups, + face_areas) ) ) else: @@ -312,8 +329,8 @@ def dt_geometric_factors( tagged=(FirstAxisIsElementsTag(),)) for vgrp, afgrp, face_ae_i in zip(volm_discr.groups, - face_discr.groups, - face_areas) + face_discr.groups, + face_areas) ) ) @@ -322,10 +339,12 @@ def dt_geometric_factors( data=tuple( actx.einsum("e,ei->ei", 1/sae_i, - cv_i, + actx.tag_axis(1, + DiscretizationDOFAxisTag.from_group(vgrp), + cv_i), tagged=(FirstAxisIsElementsTag(),)) * dcoll.dim - for cv_i, sae_i in zip(cell_vols, surface_areas) + for cv_i, sae_i, vgrp in zip(cell_vols, surface_areas, volm_discr.groups) ) )) diff --git a/grudge/geometry/metrics.py b/grudge/geometry/metrics.py index 492d8b7f7..dbcbc6945 100644 --- a/grudge/geometry/metrics.py +++ b/grudge/geometry/metrics.py @@ -70,6 +70,10 @@ DD_VOLUME, DOFDesc, DISCR_TAG_BASE ) +from meshmode.transform_metadata import (DiscretizationPhysicalDimAxisTag, + DiscretizationRefDimAxisTag) + + from pymbolic.geometric_algebra import MultiVector from pytools.obj_array import make_obj_array @@ -513,15 +517,21 @@ def _inv_surf_metric_deriv(): multiplier = 1 mat = actx.np.stack([ - actx.np.stack([ - multiplier - * inverse_surface_metric_derivative(actx, dcoll, - rst_axis, xyz_axis, dd=dd, - _use_geoderiv_connection=_use_geoderiv_connection) - for rst_axis in range(dcoll.dim)]) - for xyz_axis in range(dcoll.ambient_dim)]) - - return freeze(mat, actx) + actx.tag_axis(0, + DiscretizationRefDimAxisTag((dcoll.dim,)), + actx.np.stack([ + multiplier + * inverse_surface_metric_derivative( + actx, dcoll, + rst_axis, xyz_axis, dd=dd, + _use_geoderiv_connection=_use_geoderiv_connection) + for rst_axis in range(dcoll.dim)])) + for xyz_axis in range(dcoll.ambient_dim)]) + + return freeze(actx.tag_axis(0, + DiscretizationPhysicalDimAxisTag((dcoll.dim,)), + mat), + actx) return thaw(_inv_surf_metric_deriv(), actx) diff --git a/grudge/op.py b/grudge/op.py index 70bd1caf7..4f3958b20 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -56,7 +56,10 @@ from functools import partial from meshmode.dof_array import DOFArray -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag, + DiscretizationElementAxisTag, + DiscretizationFaceAxisTag) from grudge.discretization import DiscretizationCollection @@ -114,11 +117,13 @@ def _single_axis_derivative_kernel( # r for rst axis actx.einsum("rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei", ijm_i[xyz_axis], - get_diff_mat( - actx, - out_element_group=out_grp, - in_element_group=in_grp - ), + actx.tag_axis(1, + DiscretizationDOFAxisTag.from_group(out_grp), + get_diff_mat( + actx, + out_element_group=out_grp, + in_element_group=in_grp + )), vec_i, arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ), tagged=(FirstAxisIsElementsTag(),)) @@ -602,10 +607,14 @@ def _apply_mass_operator( actx, data=tuple( actx.einsum("ij,ej,ej->ei", - reference_mass_matrix( - actx, - out_element_group=out_grp, - in_element_group=in_grp + actx.tag_axis( + 0, + DiscretizationDOFAxisTag.from_group(out_grp), + reference_mass_matrix( + actx, + out_element_group=out_grp, + in_element_group=in_grp + ) ), ae_i, vec_i, @@ -870,19 +879,28 @@ def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec): actx, data=tuple( actx.einsum("ifj,fej,fej->ei", - reference_face_mass_matrix( - actx, - face_element_group=afgrp, - vol_element_group=vgrp, - dtype=dtype), - surf_ae_i.reshape( - vgrp.mesh_el_group.nfaces, - vgrp.nelements, - -1), - vec_i.reshape( - vgrp.mesh_el_group.nfaces, - vgrp.nelements, - afgrp.nunit_dofs), + actx.tag_axis(0, + DiscretizationDOFAxisTag.from_group(vgrp), + actx.tag_axis( + 2, + DiscretizationDOFAxisTag.from_group(afgrp), + reference_face_mass_matrix( + actx, + face_element_group=afgrp, + vol_element_group=vgrp, + dtype=dtype))), + actx.tag_axis(1, + DiscretizationElementAxisTag.from_group(vgrp), + surf_ae_i.reshape( + vgrp.mesh_el_group.nfaces, + vgrp.nelements, + -1)), + actx.tag_axis(0, + DiscretizationFaceAxisTag.from_group(vgrp), + vec_i.reshape( + vgrp.mesh_el_group.nfaces, + vgrp.nelements, + afgrp.nunit_dofs)), arg_names=("ref_face_mass_mat", "jac_surf", "vec"), tagged=(FirstAxisIsElementsTag(),))