Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ultraplot/internals/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ def _to_numpy_array(data, strip_units=False):
data = data.data # support pint quantities that get unit-stripped later
elif isinstance(data, (DataFrame, Series, Index)):
data = data.values

if Quantity is not ndarray and isinstance(data, Quantity):
if strip_units:
return np.atleast_1d(data.magnitude)
else:
return np.atleast_1d(data.magnitude) * data.units
else:
try:
return np.atleast_1d(data) # natively preserves masked arrays
except (TypeError, ValueError):
# handle non-homogeneous data
return np.array(data, dtype=object)


def _to_masked_array(data, *, copy=False):
Expand Down
18 changes: 18 additions & 0 deletions ultraplot/tests/test_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import ultraplot as uplt, pytest, numpy as np


@pytest.mark.parametrize(
"data, dtype",
[
([1, 2, 3], int),
([[1, 2], [1, 2, 3]], object),
(["hello", 1], np.dtype("<U21")), # will convert 1 to string
([["hello"], 1], object), # non-homogeneous # mixed types
],
)
def test_to_numpy_array(data, dtype):
"""
Test that to_numpy_array works with various data types.
"""
arr = uplt.internals.inputs._to_numpy_array(data)
assert arr.dtype == dtype, f"Expected dtype {dtype}, got {arr.dtype}"
14 changes: 14 additions & 0 deletions ultraplot/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,17 @@ def test_color_parsing_for_none():
for artist in ax[0].collections:
assert artist.get_facecolor().shape[0] == 0
uplt.close(fig)


@pytest.mark.mpl_image_compare
def test_inhomogeneous_violin(rng):
"""
Test that inhomogeneous violin plots work correctly.
"""
fig, ax = uplt.subplots()
data = [rng.normal(size=100), np.random.normal(size=200)]
violins = ax.violinplot(data, vert=True, labels=["A", "B"])
assert len(violins) == 2
for violin in violins:
assert violin.get_paths() # Ensure paths are created
return fig