diff --git a/test/test_misc.py b/test/test_misc.py index 3c415e0dc..88c9f3852 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -20,9 +20,13 @@ THE SOFTWARE. """ +import sys +from dataclasses import dataclass +from typing import Any, Callable + import numpy as np import numpy.linalg as la -import sys + import sumpy.toys as t import sumpy.symbolic as sym @@ -31,8 +35,6 @@ from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) -from pytools import Record - from sumpy.kernel import (LaplaceKernel, HelmholtzKernel, BiharmonicKernel, YukawaKernel, StokesletKernel, StressletKernel, ElasticityKernel, LineOfCompressionKernel, ExpressionKernel) @@ -191,39 +193,20 @@ def approx_convergence_factor(orders, errors): return np.exp(poly[0]) -class P2E2E2PTestCase(Record): +@dataclass +class P2E2E2PTestCase: + source: np.ndarray + target: np.ndarray + center1: np.ndarray + center2: np.ndarray + expansion1: Callable[..., Any] + expansion2: Callable[..., Any] + conv_factor: str @property def dim(self): return len(self.source) - @staticmethod - def eval(expr, source, center1, center2, target): - from pymbolic import parse, evaluate - context = { - "s": source, - "c1": center1, - "c2": center2, - "t": target, - "norm": la.norm} - - return evaluate(parse(expr), context) - - def __init__(self, - source, center1, center2, target, expansion1, expansion2, conv_factor): - - if isinstance(conv_factor, str): - conv_factor = self.eval(conv_factor, source, center1, center2, target) - - Record.__init__(self, - source=source, - center1=center1, - center2=center2, - target=target, - expansion1=expansion1, - expansion2=expansion2, - conv_factor=conv_factor) - P2E2E2P_TEST_CASES = ( # local to local, 3D @@ -271,7 +254,16 @@ def test_toy_p2e2e2p(ctx_factory, case): src = case.source.reshape(dim, -1) tgt = case.target.reshape(dim, -1) - if not 0 <= case.conv_factor <= 1: + from pymbolic import parse, evaluate + case_conv_factor = evaluate(parse(case.conv_factor), { + "s": case.source, + "c1": case.center1, + "c2": case.center2, + "t": case.target, + "norm": la.norm, + }) + + if not 0 <= case_conv_factor <= 1: raise ValueError( "convergence factor not in valid range: %e" % case.conv_factor) @@ -296,8 +288,8 @@ def test_toy_p2e2e2p(ctx_factory, case): errors.append(np.abs(pot_actual - pot_p2e2e2p)) conv_factor = approx_convergence_factor(1 + np.array(ORDERS_P2E2E2P), errors) - assert conv_factor <= min(1, case.conv_factor * (1 + RTOL_P2E2E2P)), \ - (conv_factor, case.conv_factor * (1 + RTOL_P2E2E2P)) + assert conv_factor <= min(1, case_conv_factor * (1 + RTOL_P2E2E2P)), \ + (conv_factor, case_conv_factor * (1 + RTOL_P2E2E2P)) def test_cse_matvec():