|
7 | 7 | "# Embedding Similarity Search with Rasteret\n", |
8 | 8 | "\n", |
9 | 9 | "Find grain silos across Franklin County, Kansas using\n", |
10 | | - "[AlphaEarth Foundation (AEF)](https://source.coop/tge-labs/aef) satellite\n", |
11 | | - "embeddings — 64-band int8 COGs at 10 m resolution, derived from a\n", |
12 | | - "foundation model trained on Sentinel-2 imagery.\n", |
| 10 | + "[AlphaEarth Foundations Satellite Embedding dataset (AEF)](https://source.coop/tge-labs/aef)\n", |
| 11 | + "produced by Google and Google DeepMind — 64-band int8 COGs at 10 m resolution.\n", |
13 | 12 | "\n", |
14 | | - "This replicates the\n", |
15 | | - "[GeoPython Tutorials similarity search](https://www.geopythontutorials.com/notebooks/xarray_embeddings_similarity_search.html)\n", |
16 | | - "using Rasteret instead of aef-loader + Dask. Three approaches use\n", |
17 | | - "different Rasteret APIs — all reading from the same prebuilt collection.\n", |
| 13 | + "This workflow is inspired by:\n", |
| 14 | + "- Ujaval Gandhi’s GeoPython Tutorials post:\n", |
| 15 | + " [Large-Scale Embedding Similarity Search with xarray and Dask](https://www.geopythontutorials.com/notebooks/xarray_embeddings_similarity_search.html)\n", |
| 16 | + "- Google Earth Engine community tutorial:\n", |
| 17 | + " [Satellite Embedding Similarity Search](https://developers.google.com/earth-engine/tutorials/community/satellite-embedding-05-similarity-search)\n", |
| 18 | + "\n", |
| 19 | + "Data credits:\n", |
| 20 | + "- County boundary data: **US Census Bureau, 2021 Cartographic Boundary Files**.\n", |
| 21 | + "- Embeddings: **AlphaEarth Foundations Satellite Embedding dataset (produced by Google and Google DeepMind)**.\n", |
18 | 22 | "\n", |
19 | 23 | "| Step | Rasteret API | What it does |\n", |
20 | 24 | "|---|---|---|\n", |
21 | 25 | "| Reference vector | `sample_points` | Extract embeddings at known grain-silo locations |\n", |
22 | | - "| Approach A | `get_xarray` | Dense mosaic, band-wise cosine similarity |\n", |
23 | | - "| Approach B | `get_gdf` | Per-record arrays, vectorized matmul cosine |\n", |
24 | | - "| Approach C | `to_torchgeo_dataset` | Streaming 1024 px chips, bounded memory |" |
| 26 | + "| Approach A | `get_xarray` | Dense mosaic, cosine similarity |\n", |
| 27 | + "| Approach B | `get_gdf` | Per-record arrays, cosine similarity |\n", |
| 28 | + "| Approach C | `to_torchgeo_dataset` | Streaming 1024 px chips, bounded memory |\n", |
| 29 | + "\n", |
| 30 | + "\n", |
| 31 | + "Attribution:\n", |
| 32 | + "\n", |
| 33 | + "> \"The AlphaEarth Foundations Satellite Embedding dataset is produced by Google and Google DeepMind.\"\n" |
25 | 34 | ] |
26 | 35 | }, |
27 | 36 | { |
|
58 | 67 | "import duckdb\n", |
59 | 68 | "import geopandas as gpd\n", |
60 | 69 | "import numpy as np\n", |
61 | | - "import torch\n", |
62 | 70 | "from shapely.geometry import Point\n", |
| 71 | + "from sklearn.metrics.pairwise import cosine_similarity\n", |
63 | 72 | "\n", |
64 | 73 | "import rasteret\n", |
65 | 74 | "\n", |
|
85 | 94 | "\n", |
86 | 95 | "\n", |
87 | 96 | "def cosine_similarity_map(cube: np.ndarray, ref: np.ndarray) -> np.ndarray:\n", |
88 | | - " \"\"\"Cosine similarity between each pixel and a reference vector.\n", |
| 97 | + " \"\"\"Cosine similarity between each pixel and reference embedding via sklearn.\n", |
89 | 98 | "\n", |
90 | 99 | " cube : (C, H, W) int8 array — raw AEF embeddings\n", |
91 | 100 | " ref : (C,) float32 reference embedding (already dequantized)\n", |
92 | 101 | "\n", |
93 | | - " Returns (H, W) float32 array, NaN where any band is nodata.\n", |
| 102 | + " Returns (H, W) float32 array, NaN where any band is nodata or all-zero.\n", |
94 | 103 | " \"\"\"\n", |
95 | 104 | " C, H, W = cube.shape\n", |
96 | | - " flat = cube.reshape(C, -1).astype(np.float32)\n", |
97 | | - "\n", |
98 | | - " nd = (flat == NODATA) | np.isnan(flat)\n", |
99 | | - " d = (flat / 127.5) ** 2 * np.sign(flat)\n", |
100 | | - " d[nd] = 0.0\n", |
101 | | - " valid = ~nd.any(axis=0)\n", |
102 | | - "\n", |
103 | | - " dot = ref @ d\n", |
104 | | - " norms = np.linalg.norm(d, axis=0)\n", |
105 | | - " ref_norm = np.linalg.norm(ref)\n", |
106 | | - "\n", |
107 | | - " sim = np.full(H * W, np.nan, dtype=np.float32)\n", |
108 | | - " ok = valid & (norms > 0)\n", |
109 | | - " sim[ok] = (dot[ok] / (norms[ok] * ref_norm)).astype(np.float32)\n", |
110 | | - " return sim.reshape(H, W)\n", |
111 | | - "\n", |
112 | | - "\n", |
113 | | - "def cosine_similarity_map_torch(img: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:\n", |
114 | | - " \"\"\"Torch version — same math, for TorchGeo chips.\n", |
115 | | - "\n", |
116 | | - " img : (C, H, W) float tensor — raw int8 values as float\n", |
117 | | - " ref : (C,) float tensor — dequantized reference embedding\n", |
118 | | - "\n", |
119 | | - " Returns (H, W) float tensor, NaN where any band is nodata.\n", |
120 | | - " \"\"\"\n", |
121 | | - " nd = (img == NODATA) | img.isnan()\n", |
122 | | - " d = (img / 127.5) ** 2 * img.sign()\n", |
123 | | - " d[nd] = 0.0\n", |
124 | | - " valid = ~nd.any(dim=0)\n", |
| 105 | + " flat = dequantize(cube.reshape(C, -1).T) # (N, C)\n", |
125 | 106 | "\n", |
126 | | - " dot = (d * ref[:, None, None]).sum(dim=0)\n", |
127 | | - " norms = d.norm(dim=0)\n", |
128 | | - " ref_norm = ref.norm()\n", |
| 107 | + " sim_flat = np.full(flat.shape[0], np.nan, dtype=np.float32)\n", |
| 108 | + " valid = np.isfinite(flat).all(axis=1)\n", |
| 109 | + " if np.any(valid):\n", |
| 110 | + " rows = flat[valid]\n", |
| 111 | + " nonzero = np.linalg.norm(rows, axis=1) > 0\n", |
| 112 | + " if np.any(nonzero):\n", |
| 113 | + " sim = cosine_similarity(rows[nonzero], ref.reshape(1, -1)).ravel()\n", |
| 114 | + " valid_idx = np.flatnonzero(valid)\n", |
| 115 | + " sim_flat[valid_idx[nonzero]] = sim.astype(np.float32)\n", |
129 | 116 | "\n", |
130 | | - " sim = torch.full_like(dot, float(\"nan\"))\n", |
131 | | - " ok = valid & (norms > 0)\n", |
132 | | - " sim[ok] = dot[ok] / (norms[ok] * ref_norm)\n", |
133 | | - " return sim\n", |
| 117 | + " return sim_flat.reshape(H, W)\n", |
134 | 118 | "\n", |
135 | 119 | "\n", |
136 | 120 | "timings: dict[str, float] = {}" |
|
421 | 405 | "---\n", |
422 | 406 | "## Approach C — TorchGeo streaming\n", |
423 | 407 | "\n", |
424 | | - "`to_torchgeo_dataset` wraps the collection as a TorchGeo `GeoDataset`.\n", |
425 | | - "`GridGeoSampler` tiles the AOI into fixed-size chips — each chip is\n", |
426 | | - "fetched on the fly and discarded after use.\n", |
427 | | - "\n", |
428 | | - "Unlike A and B, the sampler works on a fixed grid that may extend\n", |
429 | | - "slightly beyond tile boundaries. Edge chips are zero-filled by\n", |
430 | | - "rasterio's merge semantics, matching TorchGeo's native behavior.\n", |
431 | | - "We apply a county polygon mask afterwards for the visualization." |
| 408 | + "Uses `GridGeoSampler` for chip-wise streaming and stitches chip predictions with\n", |
| 409 | + "`notebooks/utils_stitching.py` so the notebook avoids manual row/col placement math." |
432 | 410 | ] |
433 | 411 | }, |
434 | 412 | { |
|
448 | 426 | "from rasterio.features import geometry_mask\n", |
449 | 427 | "from torchgeo.samplers import GridGeoSampler\n", |
450 | 428 | "\n", |
| 429 | + "try:\n", |
| 430 | + " from notebooks.utils_stitching import stitch_prediction_tiles\n", |
| 431 | + "except ImportError:\n", |
| 432 | + " from utils_stitching import stitch_prediction_tiles\n", |
| 433 | + "\n", |
451 | 434 | "CHIP_PX = 1024\n", |
452 | 435 | "\n", |
453 | 436 | "t0 = time.perf_counter()\n", |
|
459 | 442 | "roi_xmin, roi_ymin, roi_xmax, roi_ymax = county_geom_utm.bounds\n", |
460 | 443 | "out_w = round((roi_xmax - roi_xmin) / res_x)\n", |
461 | 444 | "out_h = round((roi_ymax - roi_ymin) / res_y)\n", |
462 | | - "sim_c = np.full((out_h, out_w), np.nan, dtype=np.float32)\n", |
463 | | - "\n", |
464 | | - "ref_t = torch.from_numpy(reference_vector).float()\n", |
465 | 445 | "n_chips = len(sampler)\n", |
466 | 446 | "\n", |
467 | 447 | "print(f\"{n_chips} chips ({CHIP_PX}x{CHIP_PX} px) output grid: ({out_h}, {out_w})\")\n", |
468 | 448 | "\n", |
| 449 | + "tiles = []\n", |
| 450 | + "skipped = 0\n", |
469 | 451 | "for i, query in enumerate(sampler):\n", |
470 | | - " sample = dataset[query]\n", |
471 | | - " chip_sim = cosine_similarity_map_torch(sample[\"image\"].float(), ref_t)\n", |
472 | | - "\n", |
473 | | - " tf = sample[\"transform\"].numpy()\n", |
474 | | - " col = round((float(tf[2]) - roi_xmin) / res_x)\n", |
475 | | - " row = round((roi_ymax - float(tf[5])) / res_y)\n", |
476 | | - " ch, cw = chip_sim.shape\n", |
477 | | - " r0, c0 = max(0, row), max(0, col)\n", |
478 | | - " r1, c1 = min(row + ch, out_h), min(col + cw, out_w)\n", |
479 | | - " if r1 > r0 and c1 > c0:\n", |
480 | | - " sim_c[r0:r1, c0:c1] = chip_sim[r0 - row : r1 - row, c0 - col : c1 - col].numpy()\n", |
| 452 | + " try:\n", |
| 453 | + " sample = dataset[query]\n", |
| 454 | + " except Exception:\n", |
| 455 | + " skipped += 1\n", |
| 456 | + " continue\n", |
| 457 | + "\n", |
| 458 | + " chip = sample[\"image\"].numpy().astype(np.float32)\n", |
| 459 | + " chip_sim = cosine_similarity_map(chip, reference_vector)\n", |
| 460 | + " tiles.append({\"prediction\": chip_sim, \"transform\": sample[\"transform\"].numpy()})\n", |
481 | 461 | "\n", |
482 | 462 | " elapsed = time.perf_counter() - t0\n", |
483 | 463 | " print(f\" chip {i + 1}/{n_chips} ({elapsed:.0f}s)\", end=\"\\r\")\n", |
484 | 464 | "\n", |
| 465 | + "sim_c = stitch_prediction_tiles(\n", |
| 466 | + " tiles,\n", |
| 467 | + " roi_bounds=(roi_xmin, roi_ymin, roi_xmax, roi_ymax),\n", |
| 468 | + " res=(res_x, res_y),\n", |
| 469 | + " reducer=\"overwrite\",\n", |
| 470 | + " out_shape=(out_h, out_w),\n", |
| 471 | + ")\n", |
| 472 | + "\n", |
485 | 473 | "# Mask to county polygon for apples-to-apples comparison\n", |
486 | 474 | "out_transform = Affine(res_x, 0, roi_xmin, 0, -res_y, roi_ymax)\n", |
487 | 475 | "county_mask = geometry_mask(\n", |
|
500 | 488 | " f\"\\nsimilarity min={fin_c.min():.4f} mean={fin_c.mean():.4f} \"\n", |
501 | 489 | " f\"max={fin_c.max():.4f} pixels={fin_c.size:,}\"\n", |
502 | 490 | ")\n", |
503 | | - "print(f\"timing total={timings['C_total']:.1f}s ({n_chips} chips)\")" |
| 491 | + "print(\n", |
| 492 | + " f\"timing total={timings['C_total']:.1f}s ({n_chips} chips, \"\n", |
| 493 | + " f\"used={len(tiles)}, skipped={skipped})\"\n", |
| 494 | + ")" |
504 | 495 | ] |
505 | 496 | }, |
506 | 497 | { |
|
662 | 653 | "metadata": {}, |
663 | 654 | "source": [ |
664 | 655 | "---\n", |
665 | | - "## Timing comparison" |
| 656 | + "## Timing comparison\n", |
| 657 | + "\n", |
| 658 | + "Timings below are measured in this run only. Use them as local run diagnostics,\n", |
| 659 | + "not cross-project benchmarks." |
666 | 660 | ] |
667 | 661 | }, |
668 | 662 | { |
|
680 | 674 | "source": [ |
681 | 675 | "print(f\"{'':32s} {'Load':>8s} {'Cosine':>8s} {'Total':>8s}\")\n", |
682 | 676 | "print(f\"{'---':32s} {'---':>8s} {'---':>8s} {'---':>8s}\")\n", |
683 | | - "print(f\"{'Blog (aef-loader + Dask)':32s} {'41s':>8s} {'281s':>8s} {'322s':>8s}\")\n", |
684 | 677 | "print(\n", |
685 | | - " f\"{'A get_xarray + band-wise':32s} \"\n", |
| 678 | + " f\"{'A get_xarray + sklearn':32s} \"\n", |
686 | 679 | " f\"{timings['A_load']:7.0f}s {timings['A_cosine']:7.0f}s \"\n", |
687 | 680 | " f\"{timings['A_total']:7.0f}s\"\n", |
688 | 681 | ")\n", |
689 | 682 | "print(\n", |
690 | | - " f\"{'B get_gdf + matmul':32s} \"\n", |
| 683 | + " f\"{'B get_gdf + sklearn':32s} \"\n", |
691 | 684 | " f\"{timings['B_load']:7.0f}s {timings['B_cosine']:7.0f}s \"\n", |
692 | 685 | " f\"{timings['B_total']:7.0f}s\"\n", |
693 | 686 | ")\n", |
|
705 | 698 | "---\n", |
706 | 699 | "## Summary\n", |
707 | 700 | "\n", |
708 | | - "| API | When to use | Memory profile |\n", |
709 | | - "|---|---|---|\n", |
710 | | - "| `sample_points` | Extract values at specific locations | Minimal — reads only the tiles that contain your points |\n", |
711 | | - "| `get_xarray` | Dense AOI reads with spatial coords | Full AOI in memory as a mosaicked Dataset |\n", |
712 | | - "| `get_gdf` | Per-record arrays, ragged shapes | Full AOI in memory, one array per record |\n", |
713 | | - "| `to_torchgeo_dataset` | Large AOIs, ML training loops | One chip at a time — bounded memory |\n", |
| 701 | + "| API | When to use | Memory | Speed |\n", |
| 702 | + "|---|---|---|---|\n", |
| 703 | + "| `sample_points` | Build reference vectors from known coordinates | Tiny | Fast |\n", |
| 704 | + "| `get_xarray` | AOI-wide map-style analysis (single mosaic) | Higher | Fastest for this county |\n", |
| 705 | + "| `get_gdf` | Record-wise analysis with explicit per-record arrays | Medium | Fast |\n", |
| 706 | + "| `to_torchgeo_dataset` | Streaming chips for training/inference loops | Lowest | Slower, bounded-memory |\n", |
714 | 707 | "\n", |
715 | | - "All four APIs read from the same prebuilt collection using Rasteret's\n", |
716 | | - "COG IO engine — no rasterio, no GDAL, no Dask." |
| 708 | + "This notebook uses `sklearn.metrics.pairwise.cosine_similarity` for readability,\n", |
| 709 | + "while showing three Rasteret access patterns (`sample_points`, `get_xarray`, `get_gdf`)\n", |
| 710 | + "and TorchGeo streaming with a reusable stitch helper.\n" |
717 | 711 | ] |
718 | 712 | } |
719 | 713 | ], |
|
0 commit comments