Skip to content

Commit 894533c

Browse files
committed
Add tests/test_heatmap_reshape.py
1 parent 18ba271 commit 894533c

File tree

1 file changed

+44
-178
lines changed

1 file changed

+44
-178
lines changed

tests/test_heatmap_reshape.py

Lines changed: 44 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test reshape_data_for_heatmap() function."""
1+
"""Test reshape_data_for_heatmap() for common use cases."""
22

33
import numpy as np
44
import pandas as pd
@@ -9,204 +9,70 @@
99

1010

1111
@pytest.fixture
12-
def regular_timeseries():
13-
"""Create regular time series data (hourly for 3 days)."""
14-
time = pd.date_range('2024-01-01', periods=72, freq='h', name='time')
15-
data = np.random.rand(72) * 100
12+
def hourly_week_data():
13+
"""Typical use case: hourly data for a week."""
14+
time = pd.date_range('2024-01-01', periods=168, freq='h')
15+
data = np.random.rand(168) * 100
1616
return xr.DataArray(data, dims=['time'], coords={'time': time}, name='power')
1717

1818

19-
@pytest.fixture
20-
def irregular_timeseries():
21-
"""Create irregular time series data with missing timestamps."""
22-
time = pd.date_range('2024-01-01', periods=240, freq='5min', name='time')
23-
data = np.random.rand(240) * 100
24-
da = xr.DataArray(data, dims=['time'], coords={'time': time}, name='temperature')
25-
# Drop random 30% of data points to create irregularity
26-
np.random.seed(42)
27-
keep_indices = np.random.choice(240, int(240 * 0.7), replace=False)
28-
keep_indices.sort()
29-
return da.isel(time=keep_indices)
30-
31-
32-
@pytest.fixture
33-
def multidim_timeseries():
34-
"""Create multi-dimensional time series (time × scenario × period)."""
35-
time = pd.date_range('2024-01-01', periods=48, freq='h', name='time')
36-
scenarios = ['base', 'high', 'low']
37-
periods = [2024, 2030]
38-
data = np.random.rand(48, 3, 2) * 100
39-
return xr.DataArray(
40-
data,
41-
dims=['time', 'scenario', 'period'],
42-
coords={'time': time, 'scenario': scenarios, 'period': periods},
43-
name='demand',
44-
)
45-
46-
47-
class TestBasicReshaping:
48-
"""Test basic reshaping functionality."""
49-
50-
def test_daily_hourly_reshape(self, regular_timeseries):
51-
"""Test reshaping into days × hours."""
52-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=('D', 'h'))
53-
54-
assert result.dims == ('timestep', 'timeframe')
55-
assert result.sizes['timeframe'] == 3 # 3 days
56-
assert result.sizes['timestep'] == 24 # 24 hours per day
57-
assert result.name == 'power'
58-
59-
def test_weekly_daily_reshape(self, regular_timeseries):
60-
"""Test reshaping into weeks × days."""
61-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=('W', 'D'))
62-
63-
assert result.dims == ('timestep', 'timeframe')
64-
assert 'timeframe' in result.dims
65-
assert 'timestep' in result.dims
66-
67-
def test_monthly_daily_reshape(self):
68-
"""Test reshaping into months × days."""
69-
time = pd.date_range('2024-01-01', periods=90, freq='D', name='time')
70-
data = np.random.rand(90) * 100
71-
da = xr.DataArray(data, dims=['time'], coords={'time': time}, name='monthly_data')
72-
73-
result = reshape_data_for_heatmap(da, reshape_time=('MS', 'D'))
74-
75-
assert result.dims == ('timestep', 'timeframe')
76-
assert result.sizes['timeframe'] == 3 # ~3 months
77-
assert result.name == 'monthly_data'
78-
79-
def test_no_reshape(self, regular_timeseries):
80-
"""Test that reshape_time=None returns data unchanged."""
81-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=None)
82-
83-
# Should return the same data
84-
xr.testing.assert_equal(result, regular_timeseries)
85-
86-
87-
class TestFillMethods:
88-
"""Test different fill methods for irregular data."""
89-
90-
def test_forward_fill(self, irregular_timeseries):
91-
"""Test forward fill for missing values."""
92-
result = reshape_data_for_heatmap(irregular_timeseries, reshape_time=('D', 'h'), fill='ffill')
93-
94-
assert result.dims == ('timestep', 'timeframe')
95-
# Should have no NaN values with ffill (except possibly first values)
96-
nan_count = np.isnan(result.values).sum()
97-
total_count = result.values.size
98-
assert nan_count < total_count * 0.1 # Less than 10% NaN
99-
100-
def test_backward_fill(self, irregular_timeseries):
101-
"""Test backward fill for missing values."""
102-
result = reshape_data_for_heatmap(irregular_timeseries, reshape_time=('D', 'h'), fill='bfill')
103-
104-
assert result.dims == ('timestep', 'timeframe')
105-
# Should have no NaN values with bfill (except possibly last values)
106-
nan_count = np.isnan(result.values).sum()
107-
total_count = result.values.size
108-
assert nan_count < total_count * 0.1 # Less than 10% NaN
109-
110-
def test_no_fill(self, irregular_timeseries):
111-
"""Test that fill=None does not automatically fill missing values."""
112-
result = reshape_data_for_heatmap(irregular_timeseries, reshape_time=('D', 'h'), fill=None)
113-
114-
assert result.dims == ('timestep', 'timeframe')
115-
# Note: Whether NaN values appear depends on whether data covers full time range
116-
# Just verify the function completes without error and returns correct dims
117-
assert result.sizes['timestep'] >= 1
118-
assert result.sizes['timeframe'] >= 1
119-
120-
121-
class TestMultidimensionalData:
122-
"""Test handling of multi-dimensional data."""
123-
124-
def test_multidim_basic_reshape(self, multidim_timeseries):
125-
"""Test reshaping multi-dimensional data."""
126-
result = reshape_data_for_heatmap(multidim_timeseries, reshape_time=('D', 'h'))
127-
128-
# Should preserve extra dimensions
129-
assert 'timeframe' in result.dims
130-
assert 'timestep' in result.dims
131-
assert 'scenario' in result.dims
132-
assert 'period' in result.dims
133-
assert result.sizes['scenario'] == 3
134-
assert result.sizes['period'] == 2
19+
def test_daily_hourly_pattern():
20+
"""Most common use case: reshape hourly data into days × hours for daily patterns."""
21+
time = pd.date_range('2024-01-01', periods=72, freq='h')
22+
data = np.random.rand(72) * 100
23+
da = xr.DataArray(data, dims=['time'], coords={'time': time})
13524

136-
def test_multidim_with_selection(self, multidim_timeseries):
137-
"""Test reshaping after selecting from multi-dimensional data."""
138-
# Select single scenario and period
139-
selected = multidim_timeseries.sel(scenario='base', period=2024)
140-
result = reshape_data_for_heatmap(selected, reshape_time=('D', 'h'))
25+
result = reshape_data_for_heatmap(da, reshape_time=('D', 'h'))
14126

142-
# Should only have timeframe and timestep dimensions
143-
assert result.dims == ('timestep', 'timeframe')
144-
assert 'scenario' not in result.dims
145-
assert 'period' not in result.dims
27+
assert 'timeframe' in result.dims and 'timestep' in result.dims
28+
assert result.sizes['timeframe'] == 3 # 3 days
29+
assert result.sizes['timestep'] == 24 # 24 hours
14630

14731

148-
class TestEdgeCases:
149-
"""Test edge cases and error handling."""
32+
def test_weekly_daily_pattern(hourly_week_data):
33+
"""Common use case: reshape hourly data into weeks × days."""
34+
result = reshape_data_for_heatmap(hourly_week_data, reshape_time=('W', 'D'))
15035

151-
def test_single_timeframe(self):
152-
"""Test with data that fits in a single timeframe."""
153-
time = pd.date_range('2024-01-01', periods=12, freq='h', name='time')
154-
data = np.random.rand(12) * 100
155-
da = xr.DataArray(data, dims=['time'], coords={'time': time}, name='short_data')
36+
assert 'timeframe' in result.dims and 'timestep' in result.dims
15637

157-
result = reshape_data_for_heatmap(da, reshape_time=('D', 'h'))
15838

159-
assert result.dims == ('timestep', 'timeframe')
160-
assert result.sizes['timeframe'] == 1 # Only 1 day
161-
assert result.sizes['timestep'] == 12 # 12 hours
39+
def test_with_irregular_data():
40+
"""Real-world use case: data with missing timestamps needs filling."""
41+
time = pd.date_range('2024-01-01', periods=100, freq='15min')
42+
data = np.random.rand(100)
43+
# Randomly drop 30% to simulate real data gaps
44+
keep = np.sort(np.random.choice(100, 70, replace=False)) # Must be sorted
45+
da = xr.DataArray(data[keep], dims=['time'], coords={'time': time[keep]})
16246

163-
def test_preserves_name(self, regular_timeseries):
164-
"""Test that the data name is preserved."""
165-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=('D', 'h'))
47+
result = reshape_data_for_heatmap(da, reshape_time=('h', 'min'), fill='ffill')
16648

167-
assert result.name == regular_timeseries.name
49+
assert 'timeframe' in result.dims and 'timestep' in result.dims
50+
# Should handle irregular data without errors
16851

169-
def test_different_frequencies(self):
170-
"""Test various time frequency combinations."""
171-
time = pd.date_range('2024-01-01', periods=168, freq='h', name='time')
172-
data = np.random.rand(168) * 100
173-
da = xr.DataArray(data, dims=['time'], coords={'time': time}, name='week_data')
17452

175-
# Test week × hour
176-
result = reshape_data_for_heatmap(da, reshape_time=('W', 'h'))
177-
assert result.dims == ('timestep', 'timeframe')
53+
def test_multidimensional_scenarios():
54+
"""Use case: data with scenarios/periods that need to be preserved."""
55+
time = pd.date_range('2024-01-01', periods=48, freq='h')
56+
scenarios = ['base', 'high']
57+
data = np.random.rand(48, 2) * 100
17858

179-
# Test week × day
180-
result = reshape_data_for_heatmap(da, reshape_time=('W', 'D'))
181-
assert result.dims == ('timestep', 'timeframe')
59+
da = xr.DataArray(data, dims=['time', 'scenario'], coords={'time': time, 'scenario': scenarios}, name='demand')
18260

61+
result = reshape_data_for_heatmap(da, reshape_time=('D', 'h'))
18362

184-
class TestDataIntegrity:
185-
"""Test that data values are preserved correctly."""
63+
# Should preserve scenario dimension
64+
assert 'scenario' in result.dims
65+
assert result.sizes['scenario'] == 2
18666

187-
def test_values_preserved(self, regular_timeseries):
188-
"""Test that no data values are lost or corrupted."""
189-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=('D', 'h'))
190-
191-
# Flatten and compare non-NaN values
192-
original_values = regular_timeseries.values
193-
reshaped_values = result.values.flatten()
194-
195-
# All original values should be present (allowing for reordering)
196-
# Compare sums as a simple integrity check
197-
assert np.isclose(np.nansum(original_values), np.nansum(reshaped_values), rtol=1e-10)
19867

199-
def test_coordinate_alignment(self, regular_timeseries):
200-
"""Test that time coordinates are properly aligned."""
201-
result = reshape_data_for_heatmap(regular_timeseries, reshape_time=('D', 'h'))
68+
def test_no_reshape_returns_unchanged():
69+
"""Use case: when reshape_time=None, return data as-is."""
70+
time = pd.date_range('2024-01-01', periods=24, freq='h')
71+
da = xr.DataArray(np.random.rand(24), dims=['time'], coords={'time': time})
20272

203-
# Check that coordinates exist
204-
assert 'timeframe' in result.coords
205-
assert 'timestep' in result.coords
73+
result = reshape_data_for_heatmap(da, reshape_time=None)
20674

207-
# Check coordinate sizes match dimensions
208-
assert len(result.coords['timeframe']) == result.sizes['timeframe']
209-
assert len(result.coords['timestep']) == result.sizes['timestep']
75+
xr.testing.assert_equal(result, da)
21076

21177

21278
if __name__ == '__main__':

0 commit comments

Comments
 (0)