Skip to content

Commit b13ef4b

Browse files
Merge pull request #24 from ESMValGroup/add_months_attn
Add months mixing
2 parents efe4169 + 02e21c5 commit b13ef4b

File tree

7 files changed

+523
-1553
lines changed

7 files changed

+523
-1553
lines changed

.github/workflows/deploy-documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install .[docs]
26+
python -m pip install --group docs
2727
- name: Deploy docs
2828
run: mkdocs gh-deploy --force

climanet/dataset.py

Lines changed: 76 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
from .utils import add_month_day_dims
13
import xarray as xr
24
import torch
35
from torch.utils.data import Dataset
@@ -17,38 +19,54 @@ def __init__(
1719
patch_size: Tuple[int, int] = (16, 16),
1820
overlap: int = 0,
1921
):
20-
self.daily_da = daily_da
21-
self.monthly_da = monthly_da
22-
self.land_mask = land_mask
23-
self.time_dim = time_dim
2422
self.spatial_dims = spatial_dims
2523
self.patch_size = patch_size
2624
self.overlap = overlap
2725

28-
# Group daily data
29-
# Create "YYYY-MM" string labels
30-
daily_labels = self.daily_da[time_dim].dt.strftime("%Y-%m")
31-
monthly_labels = self.monthly_da[time_dim].dt.strftime("%Y-%m")
26+
# Check that the input data has the expected dimensions
27+
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
28+
raise ValueError(f"Time dimension '{time_dim}' not found in input data")
29+
for dim in spatial_dims:
30+
if dim not in daily_da.dims or dim not in monthly_da.dims:
31+
raise ValueError(f"Spatial dimension '{dim}' not found in input data")
32+
33+
# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
34+
# and get padded_days_mask → (M, T=31)
35+
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
36+
daily_da, monthly_da, time_dim=time_dim
37+
)
38+
39+
# Convert to numpy once — all __getitem__ calls use these
40+
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
41+
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
42+
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
43+
44+
if land_mask is not None:
45+
lm = land_mask.to_numpy().copy()
46+
if lm.ndim == 3:
47+
lm = lm.squeeze(0) # (1, H, W) → (H, W)
48+
self.land_mask_np = lm
49+
else:
50+
self.land_mask_np = None
51+
52+
# Precompute the NaN mask before filling NaNs
53+
# daily_mask: True where NaN (i.e. missing ocean data, not land)
54+
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)
3255

33-
# Group daily indices by month label
34-
daily_groups = daily_labels.groupby(daily_labels).groups
56+
# Fill NaNs with 0 in-place
57+
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
3558

36-
self.month_to_days = {}
37-
for month_idx, period in enumerate(monthly_labels.values):
38-
self.month_to_days[month_idx] = daily_groups.get(period, [])
39-
if len(self.month_to_days[month_idx]) == 0:
40-
raise ValueError(f"No daily data found for month index {month_idx}")
59+
# Precompute padded_days_mask as a tensor (same for all patches)
60+
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
4161

4262
# Precompute lazy index mapping for patches
43-
dim_y, dim_x = self.spatial_dims
4463
self.stride = self.patch_size[0] - self.overlap
45-
self.n_i = (
46-
self.daily_da.sizes[dim_y] - self.patch_size[0]
47-
) // self.stride + 1 # number of horizontal patches
48-
self.n_j = (
49-
self.daily_da.sizes[dim_x] - self.patch_size[1]
50-
) // self.stride + 1 # number of vertical patches
51-
self.total_len = len(self.monthly_da[time_dim]) * self.n_i * self.n_j
64+
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
65+
self.n_i = (H - self.patch_size[0]) // self.stride + 1
66+
self.n_j = (W - self.patch_size[1]) // self.stride + 1
67+
68+
# Total length is only spatial patches (all months included in each sample)
69+
self.total_len = self.n_i * self.n_j
5270

5371
def __len__(self):
5472
return self.total_len
@@ -58,63 +76,44 @@ def __getitem__(self, idx):
5876
if idx < 0 or idx >= self.total_len:
5977
raise IndexError("Index out of range")
6078

61-
dim_y, dim_x = self.spatial_dims
62-
per_t = self.n_i * self.n_j
63-
t, rem = divmod(idx, per_t)
64-
i_idx, j_idx = divmod(rem, self.n_j)
79+
i_idx, j_idx = divmod(idx, self.n_j)
6580
i = i_idx * self.stride
6681
j = j_idx * self.stride
67-
68-
# Extract spatial patch
69-
y_slice = slice(i, i + self.patch_size[0])
70-
x_slice = slice(j, j + self.patch_size[1])
71-
72-
# Get daily data (all days in month)
73-
# Assuming monthly timestamp corresponds to days in that month
74-
daily_patch = self.daily_da.isel(
75-
{
76-
self.time_dim: self.month_to_days[t],
77-
dim_y: y_slice,
78-
dim_x: x_slice,
79-
}
80-
).to_numpy() # shape: (T, H, W)
81-
82-
# Add channel dim → (C=1, T, H, W)
83-
daily_patch = torch.from_numpy(daily_patch).float().unsqueeze(0)
84-
85-
# Get monthly target
86-
monthly_patch = self.monthly_da.isel(
87-
{
88-
self.time_dim: t,
89-
dim_y: y_slice,
90-
dim_x: x_slice,
91-
}
92-
).to_numpy()
93-
monthly_patch = torch.from_numpy(monthly_patch).float()
94-
95-
if self.land_mask is not None:
96-
land_mask_patch = self.land_mask.isel(
97-
{dim_y: y_slice, dim_x: x_slice}
98-
).to_numpy()
99-
land_mask_patch = torch.from_numpy(land_mask_patch).bool() # (H,W)
82+
ph, pw = self.patch_size
83+
84+
# Extract spatial patch via numpy slicing — faster than xarray indexing
85+
daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W)
86+
monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W)
87+
daily_nan_mask = self.daily_nan_mask[
88+
:, :, i : i + ph, j : j + pw
89+
] # (M, T, H, W)
90+
91+
if self.land_mask_np is not None:
92+
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
93+
land_tensor = torch.from_numpy(land_patch.copy()).bool()
10094
else:
101-
# No land mask → all ocean (False)
102-
land_mask_patch = torch.zeros(
103-
self.patch_size[0], self.patch_size[1], dtype=torch.bool
104-
)
105-
106-
daily_mask_patch = torch.isnan(daily_patch) & (~land_mask_patch)
107-
108-
# Replace NaNs in daily data with zeros (after creating mask)
109-
daily_patch = torch.nan_to_num(daily_patch, nan=0.0)
95+
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)
96+
97+
# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
98+
# (1, M, T, H, W)
99+
daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0)
100+
# (M, H, W)
101+
monthly_tensor = torch.from_numpy(monthly_patch).float()
102+
# (1, M, T, H, W)
103+
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
104+
105+
# daily_mask: NaN locations that are NOT land
106+
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
107+
daily_mask_tensor = daily_nan_mask & (
108+
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
109+
)
110110

111111
# Convert to tensors
112-
sample = {
113-
"daily_patch": daily_patch, # (C=1, T, H, W)
114-
"monthly_patch": monthly_patch, # (H, W)
115-
"daily_mask_patch": daily_mask_patch, # (C=1, T, H, W)
116-
"land_mask_patch": land_mask_patch.squeeze(), # (H,W)
117-
"coords": (t, i, j),
112+
return {
113+
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
114+
"monthly_patch": monthly_tensor, # (M, H, W)
115+
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
116+
"land_mask_patch": land_tensor, # (H,W) True=Land
117+
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
118+
"coords": (i, j),
118119
}
119-
120-
return sample

0 commit comments

Comments
 (0)