Skip to content

Commit 6778f33

Browse files
Copilotdfm
andauthored
Add arviz 1.0.0 compatibility and pandas image test tolerance (#325)
* Initial plan * Fix test failures: update arviz API calls and imports for arviz 1.0.0, add tolerance for pandas image test Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Add docstrings to arviz fallback helper functions Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Bump minimum arviz version to 1.0 and simplify imports Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Fix readthedocs and pre-commit: use Python version markers for arviz deps, restore import fallbacks Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Bump readthedocs Python to 3.12, simplify arviz deps and imports Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Fix readthedocs: update arviz.ipynb from_dict syntax for arviz 1.0, bump OS to ubuntu-22.04 Co-authored-by: dfm <350282+dfm@users.noreply.github.com> * Make arviz compatibility conditional to support both old and new versions Co-authored-by: dfm <350282+dfm@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: dfm <350282+dfm@users.noreply.github.com>
1 parent 5ddb862 commit 6778f33

4 files changed

Lines changed: 66 additions & 15 deletions

File tree

docs/pages/arviz.ipynb

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@
2626
"np.random.seed(11234)\n",
2727
"\n",
2828
"x = np.random.randn(2, 2000)\n",
29-
"data = az.from_dict(\n",
30-
" posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
31-
" sample_stats={\"diverging\": x < -1.2},\n",
32-
")\n",
29+
"try:\n",
30+
" data = az.from_dict(\n",
31+
" posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
32+
" sample_stats={\"diverging\": x < -1.2},\n",
33+
" )\n",
34+
"except TypeError:\n",
35+
" data = az.from_dict(\n",
36+
" {\n",
37+
" \"posterior\": {\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
38+
" \"sample_stats\": {\"diverging\": x < -1.2},\n",
39+
" },\n",
40+
" )\n",
3341
"\n",
3442
"figure = corner.corner(data, divergences=True)"
3543
]

readthedocs.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ submodules:
44
include: all
55

66
build:
7-
os: ubuntu-20.04
7+
os: ubuntu-22.04
88
tools:
9-
python: "3.10"
9+
python: "3.12"
1010

1111
python:
1212
install:

src/corner/arviz_corner.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
__all__ = ["arviz_corner"]
44

55
import logging
6+
import re
67
from collections.abc import Mapping
78

89
import numpy as np
@@ -12,10 +13,38 @@
1213
except ImportError:
1314
from arviz import convert_to_dataset
1415

15-
from arviz.utils import _var_names, get_coords
16-
1716
# Support multiple versions of arviz
1817
try:
18+
# arviz < 1.0
19+
from arviz.utils import _var_names, get_coords
20+
except ImportError:
21+
# arviz >= 1.0: these functions were removed
22+
23+
def _var_names(var_names, dataset, filter_vars=None):
24+
if var_names is None:
25+
return None
26+
if filter_vars == "like":
27+
return [
28+
v
29+
for v in dataset.data_vars
30+
if any(vn in v for vn in var_names)
31+
]
32+
elif filter_vars == "regex":
33+
return [
34+
v
35+
for v in dataset.data_vars
36+
if any(re.search(vn, v) for vn in var_names)
37+
]
38+
return list(var_names)
39+
40+
def get_coords(dataset, coords):
41+
if not coords:
42+
return dataset
43+
return dataset.sel(coords)
44+
45+
46+
try:
47+
# Very old arviz
1948
from arviz.plots.plot_utils import (
2049
make_label,
2150
xarray_to_ndarray,
@@ -29,8 +58,14 @@ def _get_labels(plotters, labeller=None):
2958
]
3059

3160
except ImportError:
32-
from arviz.labels import BaseLabeller
33-
from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter
61+
try:
62+
# Medium arviz (< 1.0)
63+
from arviz.labels import BaseLabeller
64+
from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter
65+
except ImportError:
66+
# arviz >= 1.0
67+
from arviz import xarray_to_ndarray, xarray_var_iter
68+
from arviz_base.labels import BaseLabeller
3469

3570
def _get_labels(plotters, labeller=None):
3671
if labeller is None:

tests/test_corner.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,18 @@ def _run_corner(
3939
)
4040
elif arviz:
4141
az = pytest.importorskip("arviz")
42-
data = az.from_dict(
43-
posterior={"x": data[None]},
44-
sample_stats={"diverging": data[None, :, 0] < 0.0},
45-
)
42+
try:
43+
data = az.from_dict(
44+
posterior={"x": data[None]},
45+
sample_stats={"diverging": data[None, :, 0] < 0.0},
46+
)
47+
except TypeError:
48+
data = az.from_dict(
49+
{
50+
"posterior": {"x": data[None]},
51+
"sample_stats": {"diverging": data[None, :, 0] < 0.0},
52+
},
53+
)
4654
kwargs["truths"] = {"x": np.random.randn(ndim)}
4755
elif arviz_preview:
4856
az = pytest.importorskip("arviz.preview")
@@ -293,7 +301,7 @@ def test_top_ticks():
293301
_run_corner(top_ticks=True)
294302

295303

296-
@image_comparison(baseline_images=["pandas"], extensions=["png"])
304+
@image_comparison(baseline_images=["pandas"], extensions=["png"], tol=7)
297305
def test_pandas():
298306
_run_corner(pandas=True)
299307

0 commit comments

Comments
 (0)