Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions grudge/dt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
import numpy as np

from arraycontext import ArrayContext, thaw, freeze, DeviceScalar
from meshmode.transform_metadata import FirstAxisIsElementsTag
from meshmode.transform_metadata import (FirstAxisIsElementsTag,
DiscretizationDOFAxisTag,
DiscretizationFaceAxisTag,
DiscretizationElementAxisTag)

from grudge.dof_desc import DD_VOLUME, DOFDesc, as_dofdesc
from grudge.discretization import DiscretizationCollection
Expand Down Expand Up @@ -287,11 +290,25 @@ def dt_geometric_factors(
data=tuple(
actx.einsum(
"fej->e",
face_ae_i.reshape(
vgrp.mesh_el_group.nfaces, vgrp.nelements, -1),
actx.tag_axis(
0,
DiscretizationFaceAxisTag.from_group(vgrp),
actx.tag_axis(
1,
DiscretizationElementAxisTag.from_group(vgrp),
actx.tag_axis(
2,
DiscretizationDOFAxisTag.from_group(fgrp),
face_ae_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
-1
)))),
tagged=(FirstAxisIsElementsTag(),))

for vgrp, face_ae_i in zip(volm_discr.groups, face_areas)
for vgrp, fgrp, face_ae_i in zip(volm_discr.groups,
face_discr.groups,
face_areas)
)
)
else:
Expand All @@ -312,8 +329,8 @@ def dt_geometric_factors(
tagged=(FirstAxisIsElementsTag(),))

for vgrp, afgrp, face_ae_i in zip(volm_discr.groups,
face_discr.groups,
face_areas)
face_discr.groups,
face_areas)
)
)

Expand All @@ -322,10 +339,12 @@ def dt_geometric_factors(
data=tuple(
actx.einsum("e,ei->ei",
1/sae_i,
cv_i,
actx.tag_axis(1,
DiscretizationDOFAxisTag.from_group(vgrp),
cv_i),
tagged=(FirstAxisIsElementsTag(),)) * dcoll.dim

for cv_i, sae_i in zip(cell_vols, surface_areas)
for cv_i, sae_i, vgrp in zip(cell_vols, surface_areas, volm_discr.groups)
)
))

Expand Down
28 changes: 19 additions & 9 deletions grudge/geometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@
DD_VOLUME, DOFDesc, DISCR_TAG_BASE
)

from meshmode.transform_metadata import (DiscretizationPhysicalDimAxisTag,
DiscretizationRefDimAxisTag)


from pymbolic.geometric_algebra import MultiVector

from pytools.obj_array import make_obj_array
Expand Down Expand Up @@ -513,15 +517,21 @@ def _inv_surf_metric_deriv():
multiplier = 1

mat = actx.np.stack([
actx.np.stack([
multiplier
* inverse_surface_metric_derivative(actx, dcoll,
rst_axis, xyz_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
for rst_axis in range(dcoll.dim)])
for xyz_axis in range(dcoll.ambient_dim)])

return freeze(mat, actx)
actx.tag_axis(0,
DiscretizationRefDimAxisTag((dcoll.dim,)),
actx.np.stack([
multiplier
* inverse_surface_metric_derivative(
actx, dcoll,
rst_axis, xyz_axis, dd=dd,
_use_geoderiv_connection=_use_geoderiv_connection)
for rst_axis in range(dcoll.dim)]))
for xyz_axis in range(dcoll.ambient_dim)])

return freeze(actx.tag_axis(0,
DiscretizationPhysicalDimAxisTag((dcoll.dim,)),
mat),
actx)

return thaw(_inv_surf_metric_deriv(), actx)

Expand Down
64 changes: 41 additions & 23 deletions grudge/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@
from functools import partial

from meshmode.dof_array import DOFArray
from meshmode.transform_metadata import FirstAxisIsElementsTag
from meshmode.transform_metadata import (FirstAxisIsElementsTag,
DiscretizationDOFAxisTag,
DiscretizationElementAxisTag,
DiscretizationFaceAxisTag)

from grudge.discretization import DiscretizationCollection

Expand Down Expand Up @@ -114,11 +117,13 @@ def _single_axis_derivative_kernel(
# r for rst axis
actx.einsum("rej,rij,ej->ei" if metric_in_matvec else "rei,rij,ej->ei",
ijm_i[xyz_axis],
get_diff_mat(
actx,
out_element_group=out_grp,
in_element_group=in_grp
),
actx.tag_axis(1,
DiscretizationDOFAxisTag.from_group(out_grp),
get_diff_mat(
actx,
out_element_group=out_grp,
in_element_group=in_grp
)),
vec_i,
arg_names=("inv_jac_t", "ref_stiffT_mat", "vec", ),
tagged=(FirstAxisIsElementsTag(),))
Expand Down Expand Up @@ -602,10 +607,14 @@ def _apply_mass_operator(
actx,
data=tuple(
actx.einsum("ij,ej,ej->ei",
reference_mass_matrix(
actx,
out_element_group=out_grp,
in_element_group=in_grp
actx.tag_axis(
0,
DiscretizationDOFAxisTag.from_group(out_grp),
reference_mass_matrix(
actx,
out_element_group=out_grp,
in_element_group=in_grp
)
),
ae_i,
vec_i,
Expand Down Expand Up @@ -870,19 +879,28 @@ def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec):
actx,
data=tuple(
actx.einsum("ifj,fej,fej->ei",
reference_face_mass_matrix(
actx,
face_element_group=afgrp,
vol_element_group=vgrp,
dtype=dtype),
surf_ae_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
-1),
vec_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
afgrp.nunit_dofs),
actx.tag_axis(0,
DiscretizationDOFAxisTag.from_group(vgrp),
actx.tag_axis(
2,
DiscretizationDOFAxisTag.from_group(afgrp),
reference_face_mass_matrix(
actx,
face_element_group=afgrp,
vol_element_group=vgrp,
dtype=dtype))),
actx.tag_axis(1,
DiscretizationElementAxisTag.from_group(vgrp),
surf_ae_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
-1)),
actx.tag_axis(0,
DiscretizationFaceAxisTag.from_group(vgrp),
vec_i.reshape(
vgrp.mesh_el_group.nfaces,
vgrp.nelements,
afgrp.nunit_dofs)),
arg_names=("ref_face_mass_mat", "jac_surf", "vec"),
tagged=(FirstAxisIsElementsTag(),))

Expand Down