Skip to content

Commit 88591ac

Browse files
authored
Handle non homogeneous arrays (#318)
1 parent 59576c3 commit 88591ac

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

ultraplot/internals/inputs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,17 @@ def _to_numpy_array(data, strip_units=False):
150150
data = data.data # support pint quantities that get unit-stripped later
151151
elif isinstance(data, (DataFrame, Series, Index)):
152152
data = data.values
153+
153154
if Quantity is not ndarray and isinstance(data, Quantity):
154155
if strip_units:
155156
return np.atleast_1d(data.magnitude)
156157
else:
157158
return np.atleast_1d(data.magnitude) * data.units
158-
else:
159+
try:
159160
return np.atleast_1d(data) # natively preserves masked arrays
161+
except (TypeError, ValueError):
162+
# handle non-homogeneous data
163+
return np.array(data, dtype=object)
160164

161165

162166
def _to_masked_array(data, *, copy=False):

ultraplot/tests/test_inputs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import ultraplot as uplt, pytest, numpy as np
2+
3+
4+
@pytest.mark.parametrize(
5+
"data, dtype",
6+
[
7+
([1, 2, 3], int),
8+
([[1, 2], [1, 2, 3]], object),
9+
(["hello", 1], np.dtype("<U21")), # will convert 1 to string
10+
([["hello"], 1], object), # non-homogeneous # mixed types
11+
],
12+
)
13+
def test_to_numpy_array(data, dtype):
14+
"""
15+
Test that to_numpy_array works with various data types.
16+
"""
17+
arr = uplt.internals.inputs._to_numpy_array(data)
18+
assert arr.dtype == dtype, f"Expected dtype {dtype}, got {arr.dtype}"

ultraplot/tests/test_plot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,17 @@ def test_color_parsing_for_none():
433433
for artist in ax[0].collections:
434434
assert artist.get_facecolor().shape[0] == 0
435435
uplt.close(fig)
436+
437+
438+
@pytest.mark.mpl_image_compare
439+
def test_inhomogeneous_violin(rng):
440+
"""
441+
Test that inhomogeneous violin plots work correctly.
442+
"""
443+
fig, ax = uplt.subplots()
444+
data = [rng.normal(size=100), np.random.normal(size=200)]
445+
violins = ax.violinplot(data, vert=True, labels=["A", "B"])
446+
assert len(violins) == 2
447+
for violin in violins:
448+
assert violin.get_paths() # Ensure paths are created
449+
return fig

0 commit comments

Comments
 (0)