Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 26 additions & 34 deletions test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
THE SOFTWARE.
"""

import sys
from dataclasses import dataclass
from typing import Any, Callable

import numpy as np
import numpy.linalg as la
import sys

import sumpy.toys as t
import sumpy.symbolic as sym

Expand All @@ -31,8 +35,6 @@
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl as pytest_generate_tests)

from pytools import Record

from sumpy.kernel import (LaplaceKernel, HelmholtzKernel,
BiharmonicKernel, YukawaKernel, StokesletKernel, StressletKernel,
ElasticityKernel, LineOfCompressionKernel, ExpressionKernel)
Expand Down Expand Up @@ -191,39 +193,20 @@ def approx_convergence_factor(orders, errors):
return np.exp(poly[0])


class P2E2E2PTestCase(Record):
@dataclass
class P2E2E2PTestCase:
source: np.ndarray
target: np.ndarray
center1: np.ndarray
center2: np.ndarray
expansion1: Callable[..., Any]
expansion2: Callable[..., Any]
conv_factor: str

@property
def dim(self):
return len(self.source)

@staticmethod
def eval(expr, source, center1, center2, target):
from pymbolic import parse, evaluate
context = {
"s": source,
"c1": center1,
"c2": center2,
"t": target,
"norm": la.norm}

return evaluate(parse(expr), context)

def __init__(self,
source, center1, center2, target, expansion1, expansion2, conv_factor):

if isinstance(conv_factor, str):
conv_factor = self.eval(conv_factor, source, center1, center2, target)

Record.__init__(self,
source=source,
center1=center1,
center2=center2,
target=target,
expansion1=expansion1,
expansion2=expansion2,
conv_factor=conv_factor)


P2E2E2P_TEST_CASES = (
# local to local, 3D
Expand Down Expand Up @@ -271,7 +254,16 @@ def test_toy_p2e2e2p(ctx_factory, case):
src = case.source.reshape(dim, -1)
tgt = case.target.reshape(dim, -1)

if not 0 <= case.conv_factor <= 1:
from pymbolic import parse, evaluate
case_conv_factor = evaluate(parse(case.conv_factor), {
"s": case.source,
"c1": case.center1,
"c2": case.center2,
"t": case.target,
"norm": la.norm,
})

if not 0 <= case_conv_factor <= 1:
raise ValueError(
"convergence factor not in valid range: %e" % case.conv_factor)

Expand All @@ -296,8 +288,8 @@ def test_toy_p2e2e2p(ctx_factory, case):
errors.append(np.abs(pot_actual - pot_p2e2e2p))

conv_factor = approx_convergence_factor(1 + np.array(ORDERS_P2E2E2P), errors)
assert conv_factor <= min(1, case.conv_factor * (1 + RTOL_P2E2E2P)), \
(conv_factor, case.conv_factor * (1 + RTOL_P2E2E2P))
assert conv_factor <= min(1, case_conv_factor * (1 + RTOL_P2E2E2P)), \
(conv_factor, case_conv_factor * (1 + RTOL_P2E2E2P))


def test_cse_matvec():
Expand Down