Skip to content

Commit 6fb842b

Browse files
committed
release(v0.3.4): unify IO engine, add HF loading, harden TorchGeo
Unify RasterAccessor tile reads across get_xarray/get_numpy/get_gdf/TorchGeo. Simplify COGReader fetch->decompress->crop path and remove duplicate read paths. Add HuggingFace collection loading via hf://datasets/<org>/<repo> parquet shards. Harden TorchGeo edge chips with positive-overlap filtering and nodata fallback. Improve error surfacing for partial reads and unsupported geometry inputs. Fix sample_points tile-validation edge case and get_xarray xr_combine plumbing. Expand tests for execution paths, HF integration, TorchGeo error propagation, and API surface. Signed-off-by: print-sid8 <sidsub94@gmail.com>
1 parent 54513ef commit 6fb842b

File tree

4 files changed

+358
-84
lines changed

4 files changed

+358
-84
lines changed

notebooks/07_aef_similarity_search.ipynb

Lines changed: 78 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,30 @@
77
"# Embedding Similarity Search with Rasteret\n",
88
"\n",
99
"Find grain silos across Franklin County, Kansas using\n",
10-
"[AlphaEarth Foundation (AEF)](https://source.coop/tge-labs/aef) satellite\n",
11-
"embeddings — 64-band int8 COGs at 10 m resolution, derived from a\n",
12-
"foundation model trained on Sentinel-2 imagery.\n",
10+
"[AlphaEarth Foundations Satellite Embedding dataset (AEF)](https://source.coop/tge-labs/aef)\n",
11+
"produced by Google and Google DeepMind — 64-band int8 COGs at 10 m resolution.\n",
1312
"\n",
14-
"This replicates the\n",
15-
"[GeoPython Tutorials similarity search](https://www.geopythontutorials.com/notebooks/xarray_embeddings_similarity_search.html)\n",
16-
"using Rasteret instead of aef-loader + Dask. Three approaches use\n",
17-
"different Rasteret APIs — all reading from the same prebuilt collection.\n",
13+
"This workflow is inspired by:\n",
14+
"- Ujaval Gandhi’s GeoPython Tutorials post:\n",
15+
" [Large-Scale Embedding Similarity Search with xarray and Dask](https://www.geopythontutorials.com/notebooks/xarray_embeddings_similarity_search.html)\n",
16+
"- Google Earth Engine community tutorial:\n",
17+
" [Satellite Embedding Similarity Search](https://developers.google.com/earth-engine/tutorials/community/satellite-embedding-05-similarity-search)\n",
18+
"\n",
19+
"Data credits:\n",
20+
"- County boundary data: **US Census Bureau, 2021 Cartographic Boundary Files**.\n",
21+
"- Embeddings: **AlphaEarth Foundations Satellite Embedding dataset (produced by Google and Google DeepMind)**.\n",
1822
"\n",
1923
"| Step | Rasteret API | What it does |\n",
2024
"|---|---|---|\n",
2125
"| Reference vector | `sample_points` | Extract embeddings at known grain-silo locations |\n",
22-
"| Approach A | `get_xarray` | Dense mosaic, band-wise cosine similarity |\n",
23-
"| Approach B | `get_gdf` | Per-record arrays, vectorized matmul cosine |\n",
24-
"| Approach C | `to_torchgeo_dataset` | Streaming 1024 px chips, bounded memory |"
26+
"| Approach A | `get_xarray` | Dense mosaic, cosine similarity |\n",
27+
"| Approach B | `get_gdf` | Per-record arrays, cosine similarity |\n",
28+
"| Approach C | `to_torchgeo_dataset` | Streaming 1024 px chips, bounded memory |\n",
29+
"\n",
30+
"\n",
31+
"Attribution:\n",
32+
"\n",
33+
"> \"The AlphaEarth Foundations Satellite Embedding dataset is produced by Google and Google DeepMind.\"\n"
2534
]
2635
},
2736
{
@@ -58,8 +67,8 @@
5867
"import duckdb\n",
5968
"import geopandas as gpd\n",
6069
"import numpy as np\n",
61-
"import torch\n",
6270
"from shapely.geometry import Point\n",
71+
"from sklearn.metrics.pairwise import cosine_similarity\n",
6372
"\n",
6473
"import rasteret\n",
6574
"\n",
@@ -85,52 +94,27 @@
8594
"\n",
8695
"\n",
8796
"def cosine_similarity_map(cube: np.ndarray, ref: np.ndarray) -> np.ndarray:\n",
88-
" \"\"\"Cosine similarity between each pixel and a reference vector.\n",
97+
" \"\"\"Cosine similarity between each pixel and reference embedding via sklearn.\n",
8998
"\n",
9099
" cube : (C, H, W) int8 array — raw AEF embeddings\n",
91100
" ref : (C,) float32 reference embedding (already dequantized)\n",
92101
"\n",
93-
" Returns (H, W) float32 array, NaN where any band is nodata.\n",
102+
" Returns (H, W) float32 array, NaN where any band is nodata or all-zero.\n",
94103
" \"\"\"\n",
95104
" C, H, W = cube.shape\n",
96-
" flat = cube.reshape(C, -1).astype(np.float32)\n",
97-
"\n",
98-
" nd = (flat == NODATA) | np.isnan(flat)\n",
99-
" d = (flat / 127.5) ** 2 * np.sign(flat)\n",
100-
" d[nd] = 0.0\n",
101-
" valid = ~nd.any(axis=0)\n",
102-
"\n",
103-
" dot = ref @ d\n",
104-
" norms = np.linalg.norm(d, axis=0)\n",
105-
" ref_norm = np.linalg.norm(ref)\n",
106-
"\n",
107-
" sim = np.full(H * W, np.nan, dtype=np.float32)\n",
108-
" ok = valid & (norms > 0)\n",
109-
" sim[ok] = (dot[ok] / (norms[ok] * ref_norm)).astype(np.float32)\n",
110-
" return sim.reshape(H, W)\n",
111-
"\n",
112-
"\n",
113-
"def cosine_similarity_map_torch(img: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:\n",
114-
" \"\"\"Torch version — same math, for TorchGeo chips.\n",
115-
"\n",
116-
" img : (C, H, W) float tensor — raw int8 values as float\n",
117-
" ref : (C,) float tensor — dequantized reference embedding\n",
118-
"\n",
119-
" Returns (H, W) float tensor, NaN where any band is nodata.\n",
120-
" \"\"\"\n",
121-
" nd = (img == NODATA) | img.isnan()\n",
122-
" d = (img / 127.5) ** 2 * img.sign()\n",
123-
" d[nd] = 0.0\n",
124-
" valid = ~nd.any(dim=0)\n",
105+
" flat = dequantize(cube.reshape(C, -1).T) # (N, C)\n",
125106
"\n",
126-
" dot = (d * ref[:, None, None]).sum(dim=0)\n",
127-
" norms = d.norm(dim=0)\n",
128-
" ref_norm = ref.norm()\n",
107+
" sim_flat = np.full(flat.shape[0], np.nan, dtype=np.float32)\n",
108+
" valid = np.isfinite(flat).all(axis=1)\n",
109+
" if np.any(valid):\n",
110+
" rows = flat[valid]\n",
111+
" nonzero = np.linalg.norm(rows, axis=1) > 0\n",
112+
" if np.any(nonzero):\n",
113+
" sim = cosine_similarity(rows[nonzero], ref.reshape(1, -1)).ravel()\n",
114+
" valid_idx = np.flatnonzero(valid)\n",
115+
" sim_flat[valid_idx[nonzero]] = sim.astype(np.float32)\n",
129116
"\n",
130-
" sim = torch.full_like(dot, float(\"nan\"))\n",
131-
" ok = valid & (norms > 0)\n",
132-
" sim[ok] = dot[ok] / (norms[ok] * ref_norm)\n",
133-
" return sim\n",
117+
" return sim_flat.reshape(H, W)\n",
134118
"\n",
135119
"\n",
136120
"timings: dict[str, float] = {}"
@@ -421,14 +405,8 @@
421405
"---\n",
422406
"## Approach C — TorchGeo streaming\n",
423407
"\n",
424-
"`to_torchgeo_dataset` wraps the collection as a TorchGeo `GeoDataset`.\n",
425-
"`GridGeoSampler` tiles the AOI into fixed-size chips — each chip is\n",
426-
"fetched on the fly and discarded after use.\n",
427-
"\n",
428-
"Unlike A and B, the sampler works on a fixed grid that may extend\n",
429-
"slightly beyond tile boundaries. Edge chips are zero-filled by\n",
430-
"rasterio's merge semantics, matching TorchGeo's native behavior.\n",
431-
"We apply a county polygon mask afterwards for the visualization."
408+
"Uses `GridGeoSampler` for chip-wise streaming and stitches chip predictions with\n",
409+
"`notebooks/utils_stitching.py` so the notebook avoids manual row/col placement math."
432410
]
433411
},
434412
{
@@ -448,6 +426,11 @@
448426
"from rasterio.features import geometry_mask\n",
449427
"from torchgeo.samplers import GridGeoSampler\n",
450428
"\n",
429+
"try:\n",
430+
" from notebooks.utils_stitching import stitch_prediction_tiles\n",
431+
"except ImportError:\n",
432+
" from utils_stitching import stitch_prediction_tiles\n",
433+
"\n",
451434
"CHIP_PX = 1024\n",
452435
"\n",
453436
"t0 = time.perf_counter()\n",
@@ -459,29 +442,34 @@
459442
"roi_xmin, roi_ymin, roi_xmax, roi_ymax = county_geom_utm.bounds\n",
460443
"out_w = round((roi_xmax - roi_xmin) / res_x)\n",
461444
"out_h = round((roi_ymax - roi_ymin) / res_y)\n",
462-
"sim_c = np.full((out_h, out_w), np.nan, dtype=np.float32)\n",
463-
"\n",
464-
"ref_t = torch.from_numpy(reference_vector).float()\n",
465445
"n_chips = len(sampler)\n",
466446
"\n",
467447
"print(f\"{n_chips} chips ({CHIP_PX}x{CHIP_PX} px) output grid: ({out_h}, {out_w})\")\n",
468448
"\n",
449+
"tiles = []\n",
450+
"skipped = 0\n",
469451
"for i, query in enumerate(sampler):\n",
470-
" sample = dataset[query]\n",
471-
" chip_sim = cosine_similarity_map_torch(sample[\"image\"].float(), ref_t)\n",
472-
"\n",
473-
" tf = sample[\"transform\"].numpy()\n",
474-
" col = round((float(tf[2]) - roi_xmin) / res_x)\n",
475-
" row = round((roi_ymax - float(tf[5])) / res_y)\n",
476-
" ch, cw = chip_sim.shape\n",
477-
" r0, c0 = max(0, row), max(0, col)\n",
478-
" r1, c1 = min(row + ch, out_h), min(col + cw, out_w)\n",
479-
" if r1 > r0 and c1 > c0:\n",
480-
" sim_c[r0:r1, c0:c1] = chip_sim[r0 - row : r1 - row, c0 - col : c1 - col].numpy()\n",
452+
" try:\n",
453+
" sample = dataset[query]\n",
454+
" except Exception:\n",
455+
" skipped += 1\n",
456+
" continue\n",
457+
"\n",
458+
" chip = sample[\"image\"].numpy().astype(np.float32)\n",
459+
" chip_sim = cosine_similarity_map(chip, reference_vector)\n",
460+
" tiles.append({\"prediction\": chip_sim, \"transform\": sample[\"transform\"].numpy()})\n",
481461
"\n",
482462
" elapsed = time.perf_counter() - t0\n",
483463
" print(f\" chip {i + 1}/{n_chips} ({elapsed:.0f}s)\", end=\"\\r\")\n",
484464
"\n",
465+
"sim_c = stitch_prediction_tiles(\n",
466+
" tiles,\n",
467+
" roi_bounds=(roi_xmin, roi_ymin, roi_xmax, roi_ymax),\n",
468+
" res=(res_x, res_y),\n",
469+
" reducer=\"overwrite\",\n",
470+
" out_shape=(out_h, out_w),\n",
471+
")\n",
472+
"\n",
485473
"# Mask to county polygon for apples-to-apples comparison\n",
486474
"out_transform = Affine(res_x, 0, roi_xmin, 0, -res_y, roi_ymax)\n",
487475
"county_mask = geometry_mask(\n",
@@ -500,7 +488,10 @@
500488
" f\"\\nsimilarity min={fin_c.min():.4f} mean={fin_c.mean():.4f} \"\n",
501489
" f\"max={fin_c.max():.4f} pixels={fin_c.size:,}\"\n",
502490
")\n",
503-
"print(f\"timing total={timings['C_total']:.1f}s ({n_chips} chips)\")"
491+
"print(\n",
492+
" f\"timing total={timings['C_total']:.1f}s ({n_chips} chips, \"\n",
493+
" f\"used={len(tiles)}, skipped={skipped})\"\n",
494+
")"
504495
]
505496
},
506497
{
@@ -662,7 +653,10 @@
662653
"metadata": {},
663654
"source": [
664655
"---\n",
665-
"## Timing comparison"
656+
"## Timing comparison\n",
657+
"\n",
658+
"Timings below are measured in this run only. Use them as local run diagnostics,\n",
659+
"not cross-project benchmarks."
666660
]
667661
},
668662
{
@@ -680,14 +674,13 @@
680674
"source": [
681675
"print(f\"{'':32s} {'Load':>8s} {'Cosine':>8s} {'Total':>8s}\")\n",
682676
"print(f\"{'---':32s} {'---':>8s} {'---':>8s} {'---':>8s}\")\n",
683-
"print(f\"{'Blog (aef-loader + Dask)':32s} {'41s':>8s} {'281s':>8s} {'322s':>8s}\")\n",
684677
"print(\n",
685-
" f\"{'A get_xarray + band-wise':32s} \"\n",
678+
" f\"{'A get_xarray + sklearn':32s} \"\n",
686679
" f\"{timings['A_load']:7.0f}s {timings['A_cosine']:7.0f}s \"\n",
687680
" f\"{timings['A_total']:7.0f}s\"\n",
688681
")\n",
689682
"print(\n",
690-
" f\"{'B get_gdf + matmul':32s} \"\n",
683+
" f\"{'B get_gdf + sklearn':32s} \"\n",
691684
" f\"{timings['B_load']:7.0f}s {timings['B_cosine']:7.0f}s \"\n",
692685
" f\"{timings['B_total']:7.0f}s\"\n",
693686
")\n",
@@ -705,15 +698,16 @@
705698
"---\n",
706699
"## Summary\n",
707700
"\n",
708-
"| API | When to use | Memory profile |\n",
709-
"|---|---|---|\n",
710-
"| `sample_points` | Extract values at specific locations | Minimal — reads only the tiles that contain your points |\n",
711-
"| `get_xarray` | Dense AOI reads with spatial coords | Full AOI in memory as a mosaicked Dataset |\n",
712-
"| `get_gdf` | Per-record arrays, ragged shapes | Full AOI in memory, one array per record |\n",
713-
"| `to_torchgeo_dataset` | Large AOIs, ML training loops | One chip at a time — bounded memory |\n",
701+
"| API | When to use | Memory | Speed |\n",
702+
"|---|---|---|---|\n",
703+
"| `sample_points` | Build reference vectors from known coordinates | Tiny | Fast |\n",
704+
"| `get_xarray` | AOI-wide map-style analysis (single mosaic) | Higher | Fastest for this county |\n",
705+
"| `get_gdf` | Record-wise analysis with explicit per-record arrays | Medium | Fast |\n",
706+
"| `to_torchgeo_dataset` | Streaming chips for training/inference loops | Lowest | Slower, bounded-memory |\n",
714707
"\n",
715-
"All four APIs read from the same prebuilt collection using Rasteret's\n",
716-
"COG IO engine — no rasterio, no GDAL, no Dask."
708+
"This notebook uses `sklearn.metrics.pairwise.cosine_similarity` for readability,\n",
709+
"while showing three Rasteret access patterns (`sample_points`, `get_xarray`, `get_gdf`)\n",
710+
"and TorchGeo streaming with a reusable stitch helper.\n"
717711
]
718712
}
719713
],

notebooks/utils_stitching.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
4+
5+
import numpy as np
6+
from affine import Affine
7+
8+
9+
def _normalize_res(res: float | tuple[float, float]) -> tuple[float, float]:
10+
if isinstance(res, tuple):
11+
return float(res[0]), float(res[1])
12+
value = float(res)
13+
return value, value
14+
15+
16+
def _normalize_transform(transform: Affine | np.ndarray | list[float]) -> Affine:
17+
if isinstance(transform, Affine):
18+
return transform
19+
values = np.asarray(transform).reshape(-1)
20+
if values.size < 6:
21+
raise ValueError("Transform must have at least 6 values.")
22+
return Affine(*values[:6].tolist())
23+
24+
25+
def stitch_prediction_tiles(
26+
tiles: Iterable[
27+
dict[str, object] | tuple[np.ndarray, Affine | np.ndarray | list[float]]
28+
],
29+
*,
30+
roi_bounds: tuple[float, float, float, float],
31+
res: float | tuple[float, float],
32+
reducer: str = "overwrite",
33+
fill_value: float = np.nan,
34+
out_shape: tuple[int, int] | None = None,
35+
) -> np.ndarray:
36+
"""Stitch georeferenced prediction tiles into a single north-up canvas.
37+
38+
Parameters
39+
----------
40+
tiles:
41+
Iterable of either:
42+
- ``{"prediction": np.ndarray, "transform": Affine|array_like}``, or
43+
- ``(prediction, transform)`` tuples.
44+
roi_bounds:
45+
``(xmin, ymin, xmax, ymax)`` of output ROI in dataset CRS.
46+
res:
47+
Pixel size in CRS units (``(xres, yres)`` or scalar).
48+
reducer:
49+
``"overwrite"`` (last-write-wins) or ``"mean"`` (average overlaps).
50+
fill_value:
51+
Fill value for pixels with no contributing tile.
52+
out_shape:
53+
Optional ``(height, width)`` override. If omitted, derived from ROI + res.
54+
"""
55+
if reducer not in {"overwrite", "mean"}:
56+
raise ValueError("reducer must be 'overwrite' or 'mean'.")
57+
58+
roi_xmin, roi_ymin, roi_xmax, roi_ymax = map(float, roi_bounds)
59+
res_x, res_y = _normalize_res(res)
60+
61+
if out_shape is None:
62+
out_w = int(round((roi_xmax - roi_xmin) / res_x))
63+
out_h = int(round((roi_ymax - roi_ymin) / res_y))
64+
else:
65+
out_h, out_w = int(out_shape[0]), int(out_shape[1])
66+
67+
if reducer == "mean":
68+
sum_grid = np.zeros((out_h, out_w), dtype=np.float64)
69+
count_grid = np.zeros((out_h, out_w), dtype=np.uint32)
70+
else:
71+
stitched = np.full((out_h, out_w), fill_value, dtype=np.float32)
72+
73+
for tile in tiles:
74+
if isinstance(tile, dict):
75+
prediction = np.asarray(tile["prediction"], dtype=np.float32)
76+
transform = _normalize_transform(tile["transform"]) # type: ignore[arg-type]
77+
else:
78+
prediction = np.asarray(tile[0], dtype=np.float32)
79+
transform = _normalize_transform(tile[1])
80+
81+
row = int(round((roi_ymax - float(transform.f)) / res_y))
82+
col = int(round((float(transform.c) - roi_xmin) / res_x))
83+
tile_h, tile_w = prediction.shape
84+
85+
r0 = max(0, row)
86+
c0 = max(0, col)
87+
r1 = min(out_h, row + tile_h)
88+
c1 = min(out_w, col + tile_w)
89+
if r1 <= r0 or c1 <= c0:
90+
continue
91+
92+
patch = prediction[r0 - row : r1 - row, c0 - col : c1 - col]
93+
94+
if reducer == "mean":
95+
valid = np.isfinite(patch)
96+
if np.any(valid):
97+
block_sum = sum_grid[r0:r1, c0:c1]
98+
block_count = count_grid[r0:r1, c0:c1]
99+
block_sum[valid] += patch[valid]
100+
block_count[valid] += 1
101+
sum_grid[r0:r1, c0:c1] = block_sum
102+
count_grid[r0:r1, c0:c1] = block_count
103+
else:
104+
stitched[r0:r1, c0:c1] = patch
105+
106+
if reducer == "mean":
107+
stitched = np.full((out_h, out_w), fill_value, dtype=np.float32)
108+
valid = count_grid > 0
109+
stitched[valid] = (sum_grid[valid] / count_grid[valid]).astype(np.float32)
110+
111+
return stitched

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ examples = [
7676
"stac-geoparquet>=0.6.0",
7777
"datasets>=2.20.0",
7878
"huggingface_hub>=0.23.0",
79+
"scikit-learn>=1.5.0",
7980
"folium>=0.18.0",
8081
]
8182
all = [

0 commit comments

Comments
 (0)