1+ import numpy as np
2+ from .utils import add_month_day_dims
13import xarray as xr
24import torch
35from torch .utils .data import Dataset
@@ -17,38 +19,54 @@ def __init__(
1719 patch_size : Tuple [int , int ] = (16 , 16 ),
1820 overlap : int = 0 ,
1921 ):
20- self .daily_da = daily_da
21- self .monthly_da = monthly_da
22- self .land_mask = land_mask
23- self .time_dim = time_dim
2422 self .spatial_dims = spatial_dims
2523 self .patch_size = patch_size
2624 self .overlap = overlap
2725
28- # Group daily data
29- # Create "YYYY-MM" string labels
30- daily_labels = self .daily_da [time_dim ].dt .strftime ("%Y-%m" )
31- monthly_labels = self .monthly_da [time_dim ].dt .strftime ("%Y-%m" )
26+ # Check that the input data has the expected dimensions
27+ if time_dim not in daily_da .dims or time_dim not in monthly_da .dims :
28+ raise ValueError (f"Time dimension '{ time_dim } ' not found in input data" )
29+ for dim in spatial_dims :
30+ if dim not in daily_da .dims or dim not in monthly_da .dims :
31+ raise ValueError (f"Spatial dimension '{ dim } ' not found in input data" )
32+
33+ # Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
34+ # and get padded_days_mask → (M, T=31)
35+ daily_mt , monthly_m , padded_days_mask = add_month_day_dims (
36+ daily_da , monthly_da , time_dim = time_dim
37+ )
38+
39+ # Convert to numpy once — all __getitem__ calls use these
40+ self .daily_np = daily_mt .to_numpy ().copy () # (M, T=31, H, W) float
41+ self .monthly_np = monthly_m .to_numpy ().copy () # (M, H, W) float
42+ self .padded_mask_np = padded_days_mask .to_numpy ().copy () # (M, T=31) bool
43+
44+ if land_mask is not None :
45+ lm = land_mask .to_numpy ().copy ()
46+ if lm .ndim == 3 :
47+ lm = lm .squeeze (0 ) # (1, H, W) → (H, W)
48+ self .land_mask_np = lm
49+ else :
50+ self .land_mask_np = None
51+
52+ # Precompute the NaN mask before filling NaNs
53+ # daily_mask: True where NaN (i.e. missing ocean data, not land)
54+ self .daily_nan_mask = np .isnan (self .daily_np ) # (M, T=31, H, W)
3255
33- # Group daily indices by month label
34- daily_groups = daily_labels . groupby ( daily_labels ). groups
56+ # Fill NaNs with 0 in-place
57+ np . nan_to_num ( self . daily_np , copy = False , nan = 0.0 )
3558
36- self .month_to_days = {}
37- for month_idx , period in enumerate (monthly_labels .values ):
38- self .month_to_days [month_idx ] = daily_groups .get (period , [])
39- if len (self .month_to_days [month_idx ]) == 0 :
40- raise ValueError (f"No daily data found for month index { month_idx } " )
59+ # Precompute padded_days_mask as a tensor (same for all patches)
60+ self .padded_days_tensor = torch .from_numpy (self .padded_mask_np ).bool ()
4161
4262 # Precompute lazy index mapping for patches
43- dim_y , dim_x = self .spatial_dims
4463 self .stride = self .patch_size [0 ] - self .overlap
45- self .n_i = (
46- self .daily_da .sizes [dim_y ] - self .patch_size [0 ]
47- ) // self .stride + 1 # number of horizontal patches
48- self .n_j = (
49- self .daily_da .sizes [dim_x ] - self .patch_size [1 ]
50- ) // self .stride + 1 # number of vertical patches
51- self .total_len = len (self .monthly_da [time_dim ]) * self .n_i * self .n_j
64+ H , W = self .daily_np .shape [2 ], self .daily_np .shape [3 ]
65+ self .n_i = (H - self .patch_size [0 ]) // self .stride + 1
66+ self .n_j = (W - self .patch_size [1 ]) // self .stride + 1
67+
68+ # Total length is only spatial patches (all months included in each sample)
69+ self .total_len = self .n_i * self .n_j
5270
5371 def __len__ (self ):
5472 return self .total_len
@@ -58,63 +76,44 @@ def __getitem__(self, idx):
5876 if idx < 0 or idx >= self .total_len :
5977 raise IndexError ("Index out of range" )
6078
61- dim_y , dim_x = self .spatial_dims
62- per_t = self .n_i * self .n_j
63- t , rem = divmod (idx , per_t )
64- i_idx , j_idx = divmod (rem , self .n_j )
79+ i_idx , j_idx = divmod (idx , self .n_j )
6580 i = i_idx * self .stride
6681 j = j_idx * self .stride
67-
68- # Extract spatial patch
69- y_slice = slice (i , i + self .patch_size [0 ])
70- x_slice = slice (j , j + self .patch_size [1 ])
71-
72- # Get daily data (all days in month)
73- # Assuming monthly timestamp corresponds to days in that month
74- daily_patch = self .daily_da .isel (
75- {
76- self .time_dim : self .month_to_days [t ],
77- dim_y : y_slice ,
78- dim_x : x_slice ,
79- }
80- ).to_numpy () # shape: (T, H, W)
81-
82- # Add channel dim → (C=1, T, H, W)
83- daily_patch = torch .from_numpy (daily_patch ).float ().unsqueeze (0 )
84-
85- # Get monthly target
86- monthly_patch = self .monthly_da .isel (
87- {
88- self .time_dim : t ,
89- dim_y : y_slice ,
90- dim_x : x_slice ,
91- }
92- ).to_numpy ()
93- monthly_patch = torch .from_numpy (monthly_patch ).float ()
94-
95- if self .land_mask is not None :
96- land_mask_patch = self .land_mask .isel (
97- {dim_y : y_slice , dim_x : x_slice }
98- ).to_numpy ()
99- land_mask_patch = torch .from_numpy (land_mask_patch ).bool () # (H,W)
82+ ph , pw = self .patch_size
83+
84+ # Extract spatial patch via numpy slicing — faster than xarray indexing
85+ daily_patch = self .daily_np [:, :, i : i + ph , j : j + pw ] # (M, T, H, W)
86+ monthly_patch = self .monthly_np [:, i : i + ph , j : j + pw ] # (M, H, W)
87+ daily_nan_mask = self .daily_nan_mask [
88+ :, :, i : i + ph , j : j + pw
89+ ] # (M, T, H, W)
90+
91+ if self .land_mask_np is not None :
92+ land_patch = self .land_mask_np [i : i + ph , j : j + pw ] # (H, W)
93+ land_tensor = torch .from_numpy (land_patch .copy ()).bool ()
10094 else :
101- # No land mask → all ocean (False)
102- land_mask_patch = torch .zeros (
103- self .patch_size [0 ], self .patch_size [1 ], dtype = torch .bool
104- )
105-
106- daily_mask_patch = torch .isnan (daily_patch ) & (~ land_mask_patch )
107-
108- # Replace NaNs in daily data with zeros (after creating mask)
109- daily_patch = torch .nan_to_num (daily_patch , nan = 0.0 )
95+ land_tensor = torch .zeros (ph , pw , dtype = torch .bool )
96+
97+ # Convert to tensors (from_numpy is zero-copy on contiguous arrays)
98+ # (1, M, T, H, W)
99+ daily_tensor = torch .from_numpy (daily_patch ).float ().unsqueeze (0 )
100+ # (M, H, W)
101+ monthly_tensor = torch .from_numpy (monthly_patch ).float ()
102+ # (1, M, T, H, W)
103+ daily_nan_mask = torch .from_numpy (daily_nan_mask ).unsqueeze (0 )
104+
105+ # daily_mask: NaN locations that are NOT land
106+ # Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
107+ daily_mask_tensor = daily_nan_mask & (
108+ ~ land_tensor .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 )
109+ )
110110
111111 # Convert to tensors
112- sample = {
113- "daily_patch" : daily_patch , # (C=1, T, H, W)
114- "monthly_patch" : monthly_patch , # (H, W)
115- "daily_mask_patch" : daily_mask_patch , # (C=1, T, H, W)
116- "land_mask_patch" : land_mask_patch .squeeze (), # (H,W)
117- "coords" : (t , i , j ),
112+ return {
113+ "daily_patch" : daily_tensor , # (C=1, M, T=31, H, W)
114+ "monthly_patch" : monthly_tensor , # (M, H, W)
115+ "daily_mask_patch" : daily_mask_tensor , # (C=1, M, T=31, H, W)
116+ "land_mask_patch" : land_tensor , # (H,W) True=Land
117+ "padded_days_mask" : self .padded_days_tensor , # (M, T=31) True=padded
118+ "coords" : (i , j ),
118119 }
119-
120- return sample
0 commit comments