From 1c12b39e9a39cac43b3d311158c79d8e74e2a14b Mon Sep 17 00:00:00 2001 From: Lyn Long Date: Thu, 4 Jun 2026 11:55:21 +1000 Subject: [PATCH 1/3] add estimate CO download with sse service --- .../core/model/enumeration/ProcessIdEnum.java | 1 + .../server/core/service/DasService.java | 54 ++++++++ .../aodn/ogcapi/server/processes/RestApi.java | 12 +- .../ogcapi/server/processes/RestServices.java | 127 ++++++++++++++++++ 4 files changed, 193 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/ProcessIdEnum.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/ProcessIdEnum.java index a7e3966b..48221df2 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/ProcessIdEnum.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/ProcessIdEnum.java @@ -7,6 +7,7 @@ public enum ProcessIdEnum { DOWNLOAD_DATASET("download"), DOWNLOAD_WFS_SSE("downloadWfs"), DOWNLOAD_WFS_ESTIMATE("estimateWfsDownload"), + DOWNLOAD_CO_ESTIMATE("estimateCOdownload"), UNKNOWN(""); private final String value; diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java index 38676654..9664982d 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java @@ -14,6 +14,7 @@ import java.net.URLEncoder; import java.util.HashMap; +import java.util.List; import java.util.Map; @Service("DataAccessService") @@ -79,6 +80,59 @@ public byte[] getLatestWaveBuoySites(){ return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET,httpEntity,byte[].class).getBody(); } + /** + * Call the data-access-service cloud-optimised size estimate endpoint. + * + * POST /api/v1/das/data/{uuid}/estimate_size with a JSON body matching + * EstimateSizeRequest. The endpoint is multi-key and aggregates server-side: + * a null/"*" keys list means "all keys of the uuid" (same as the batch + * download). Returns the raw JSON response body so the SSE layer can forward + * it to the frontend unchanged. + */ + public String estimateCloudOptimisedDownloadSize( + String uuid, + List keys, + String startDate, + String endDate, + Object multiPolygon, + List columns, + String outputFormat) { + + String url = UriComponentsBuilder.fromUriString(dasConfig.host + "/api/v1/das/data/{uuid}/estimate_size") + .encode() + .toUriString(); + + // Body mirrors EstimateSizeRequest. Send the raw frontend date strings + // (or "non-specified" when null) so data-access-service applies the same + // resolve/supply/trim chain the batch download uses. + Map body = new HashMap<>(); + body.put("keys", keys); // null => all keys of the uuid + body.put("start_date", startDate != null ? startDate : "non-specified"); + body.put("end_date", endDate != null ? endDate : "non-specified"); + body.put("output_format", outputFormat); + // multi_polygon is accepted as a GeoJSON object or string; forward as-is. + if (multiPolygon != null) { + body.put("multi_polygon", multiPolygon); + } + // columns is not sent today (frontend doesn't subset columns yet, and the + // batch download grabs all variables), keeping the estimate aligned. + if (columns != null) { + body.put("columns", columns); + } + + HttpHeaders headers = new HttpHeaders(); + headers.set(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE); + headers.set(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE); + headers.set("X-API-KEY", dasConfig.secret); + + HttpEntity> entity = new HttpEntity<>(body, headers); + + Map uriVars = new HashMap<>(); + uriVars.put("uuid", uuid); + + return httpClient.exchange(url, HttpMethod.POST, entity, String.class, uriVars).getBody(); + } + public boolean isCollectionSupported(String collectionId){ final String waveBuoyRealtimeCollectionID = "b299cdcd-3dee-48aa-abdd-e0fcdbb9cadc"; return waveBuoyRealtimeCollectionID.contentEquals(collectionId); diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java index 4fe97baf..27ec97ff 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java @@ -131,17 +131,27 @@ public SseEmitter executeSse( try { String uuid = DatasetDownloadEnums.Parameter.UUID.getStringInput(body); String layerName = DatasetDownloadEnums.Parameter.LAYER_NAME.getStringInput(body); + String key = DatasetDownloadEnums.Parameter.KEY.getStringInput(body); String startDate = DatasetDownloadEnums.Parameter.START_DATE.getStringInput(body); String endDate = DatasetDownloadEnums.Parameter.END_DATE.getStringInput(body); String outputFormat = DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getStringInput(body); Object multiPolygon = DatasetDownloadEnums.Parameter.MULTI_POLYGON.getObjectInput(body); List fields = DatasetDownloadEnums.Parameter.FIELDS.getListInput(body); + ProcessIdEnum id = ProcessIdEnum.fromString(processId); + + // Cloud-optimised (zarr/parquet) size estimate uses uuid + key, not a + // WFS layer name, so handle it before the WFS-specific validation below. + if (id == ProcessIdEnum.DOWNLOAD_CO_ESTIMATE) { + return restServices.estimateCloudOptimisedDownloadWithSse( + uuid, key, startDate, endDate, multiPolygon, null, outputFormat + ); + } + if(uuid == null || layerName == null) { throw new IllegalArgumentException("UUID and LayerName cannot null"); } - ProcessIdEnum id = ProcessIdEnum.fromString(processId); SseEmitter emitter; switch (id) { diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java index 80fea273..0710394c 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java @@ -2,6 +2,7 @@ import au.org.aodn.ogcapi.server.core.exception.wfs.WfsErrorHandler; import au.org.aodn.ogcapi.server.core.model.enumeration.DatasetDownloadEnums; +import au.org.aodn.ogcapi.server.core.service.DasService; import au.org.aodn.ogcapi.server.core.service.geoserver.wfs.DownloadWfsDataService; import au.org.aodn.ogcapi.server.core.util.EmailUtils; import com.fasterxml.jackson.core.JsonProcessingException; @@ -20,6 +21,7 @@ import java.io.IOException; import java.math.BigInteger; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -36,6 +38,9 @@ public class RestServices { @Autowired private DownloadWfsDataService downloadWfsDataService; + @Autowired + private DasService dasService; + public RestServices(BatchClient batchClient, ObjectMapper objectMapper) { this.batchClient = batchClient; this.objectMapper = objectMapper; @@ -340,4 +345,126 @@ public SseEmitter downloadWfsDataWithSse(String uuid, }); return emitter; } + + /** + * Estimate the download size of a cloud-optimised (zarr/parquet) subset over SSE. + *

+ * Opens an SSE stream, calls the data-access-service estimate endpoint with the + * same parameters that drive the batch download (so the estimate matches what + * the real download would produce), forwards the size estimate as a single + * event, then closes the stream. + */ + public SseEmitter estimateCloudOptimisedDownloadWithSse(String uuid, + String key, + String startDate, + String endDate, + Object multiPolygon, + List columns, + String outputFormat) { + + final SseEmitter emitter = new SseEmitter(0L); + + // Set up references for resources that need to be cleaned up + AtomicReference> keepAliveTaskRef = new AtomicReference<>(); + AtomicReference keepAliveExecutorRef = new AtomicReference<>(); + + // Set up cleanup function to clear up resources + Runnable cleanupResources = () -> { + try { + ScheduledFuture keepAliveTask = keepAliveTaskRef.get(); + if (keepAliveTask != null && !keepAliveTask.isCancelled()) { + keepAliveTask.cancel(false); + } + + ScheduledExecutorService keepAliveExecutor = keepAliveExecutorRef.get(); + if (keepAliveExecutor != null && !keepAliveExecutor.isShutdown()) { + keepAliveExecutor.shutdown(); + } + } catch (Exception e) { + log.error("Error during cleanup for CO estimate UUID: {}", uuid, e); + } + }; + + emitter.onCompletion(() -> { + log.info("CO estimate SSE stream completion"); + cleanupResources.run(); + }); + + emitter.onTimeout(() -> { + log.warn("CO estimate SSE stream timed out"); + cleanupResources.run(); + }); + + emitter.onError(throwable -> WfsErrorHandler.handleError((Exception) throwable, uuid, emitter, cleanupResources)); + + // Validate parameters + if (uuid == null) { + IllegalArgumentException exception = new IllegalArgumentException("Uuid is required"); + WfsErrorHandler.handleError(exception, uuid, emitter, cleanupResources); + return emitter; + } + + // batch-style key handling: comma-separated string -> list; null/blank/"*" + // -> null, which the data-access-service expands to all keys of the uuid. + final List keys = (key == null || key.isBlank() || key.equals("*")) + ? null + : Arrays.stream(key.split(",")).map(String::trim).toList(); + + // Start async estimate with SSE progress updates + CompletableFuture.runAsync(() -> { + try { + // STEP 1: Send connection established event + emitter.send(SseEmitter.event() + .name("connection-established") + .data(Map.of( + "status", "connected", + "message", "Starting cloud-optimised size estimate for UUID: " + uuid, + "timestamp", System.currentTimeMillis() + ))); + + // STEP 2: Start keep-alive mechanism while data-access-service computes the estimate + ScheduledExecutorService keepAliveExecutor = Executors.newSingleThreadScheduledExecutor(); + ScheduledFuture keepAliveTask = keepAliveExecutor.scheduleAtFixedRate(() -> { + try { + emitter.send(SseEmitter.event() + .name("keep-alive") + .data(Map.of( + "status", "estimating", + "message", "Estimating download size...", + "timestamp", System.currentTimeMillis() + ))); + } catch (Exception e) { + WfsErrorHandler.handleError(e, uuid, emitter, cleanupResources); + } + }, 20, 20, TimeUnit.SECONDS); + + keepAliveTaskRef.set(keepAliveTask); + keepAliveExecutorRef.set(keepAliveExecutor); + + // STEP 3: Call the data-access-service estimate endpoint and forward the result + try { + String estimateJson = dasService.estimateCloudOptimisedDownloadSize( + uuid, keys, startDate, endDate, multiPolygon, columns, outputFormat + ); + emitter.send(SseEmitter.event() + .name("estimate-complete") + .data(estimateJson)); + } catch (Exception e) { + log.warn("Cloud-optimised size estimation failed for UUID {}: {}", uuid, e.getMessage()); + emitter.send(SseEmitter.event() + .name("estimate-failed") + .data(Map.of( + "message", "Size estimation failed: " + e.getMessage(), + "timestamp", System.currentTimeMillis() + ))); + } finally { + emitter.complete(); + } + } catch (Exception e) { + WfsErrorHandler.handleError(e, uuid, emitter, cleanupResources); + } + }); + + return emitter; + } } From f53cac2828574669eb72ffa8737420a449987df8 Mon Sep 17 00:00:00 2001 From: Lyn Long Date: Fri, 12 Jun 2026 11:37:34 +1000 Subject: [PATCH 2/3] refactor to extract SSE wrapper --- .../core/exception/wfs/WfsErrorHandler.java | 9 +- .../core/model/enumeration/SseEventName.java | 28 ++ .../server/core/service/DasService.java | 45 +-- .../geoserver/wfs/DownloadWfsDataService.java | 13 +- .../server/core/service/sse/SseSession.java | 91 +++++ .../core/service/sse/SseStreamHandler.java | 71 ++++ .../aodn/ogcapi/server/processes/RestApi.java | 73 ++-- .../ogcapi/server/processes/RestServices.java | 365 ++++++------------ .../wfs/DownloadWfsDataServiceTest.java | 4 +- 9 files changed, 388 insertions(+), 311 deletions(-) create mode 100644 server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/SseEventName.java create mode 100644 server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseSession.java create mode 100644 server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseStreamHandler.java diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/exception/wfs/WfsErrorHandler.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/exception/wfs/WfsErrorHandler.java index 7af55e9d..fac2ec7f 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/exception/wfs/WfsErrorHandler.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/exception/wfs/WfsErrorHandler.java @@ -2,6 +2,7 @@ import au.org.aodn.ogcapi.server.core.exception.GeoserverFieldsNotFoundException; import au.org.aodn.ogcapi.server.core.exception.UnauthorizedServerException; +import au.org.aodn.ogcapi.server.core.model.enumeration.SseEventName; import lombok.extern.slf4j.Slf4j; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -50,7 +51,7 @@ public static void handleError(Exception e, String uuid, SseEmitter emitter, Run case VALIDATION_ERROR -> { log.warn("Invalid parameter error for UUID {}: {}", uuid, e.getMessage()); emitter.send(SseEmitter.event() - .name("error") + .name(SseEventName.ERROR.getValue()) .data(Map.of( "message", "Invalid parameter error", "timestamp", System.currentTimeMillis() @@ -61,7 +62,7 @@ public static void handleError(Exception e, String uuid, SseEmitter emitter, Run case UNAUTHORIZED_SERVER_ERROR -> { log.warn("Unauthorized wfs server for UUID {}", uuid, e); emitter.send(SseEmitter.event() - .name("error") + .name(SseEventName.ERROR.getValue()) .data(Map.of( "message", "Unauthorized wfs server", "timestamp", System.currentTimeMillis() @@ -72,7 +73,7 @@ public static void handleError(Exception e, String uuid, SseEmitter emitter, Run case DOWNLOADABLE_FIELDS_ERROR -> { log.warn("No downloadable fields found for UUID {}", uuid, e); emitter.send(SseEmitter.event() - .name("error") + .name(SseEventName.ERROR.getValue()) .data(Map.of( "message", "No downloadable fields found", "timestamp", System.currentTimeMillis() @@ -83,7 +84,7 @@ public static void handleError(Exception e, String uuid, SseEmitter emitter, Run case UNKNOWN_ERROR -> { log.warn("Unknown error for UUID {}", uuid, e); emitter.send(SseEmitter.event() - .name("error") + .name(SseEventName.ERROR.getValue()) .data(Map.of( "message", "Unknown error: " + e.getMessage(), "timestamp", System.currentTimeMillis() diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/SseEventName.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/SseEventName.java new file mode 100644 index 00000000..5879f17a --- /dev/null +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/model/enumeration/SseEventName.java @@ -0,0 +1,28 @@ +package au.org.aodn.ogcapi.server.core.model.enumeration; + +import lombok.Getter; + +/** + * Names of the SSE events sent by the download / size-estimate streams. + *

+ * The frontend identifies events by these literal names, so they form a wire + * contract — change a value only together with the portal (and keep them + * aligned with the data-access-service sse_wrapper vocabulary). + */ +@Getter +public enum SseEventName { + CONNECTION_ESTABLISHED("connection-established"), + KEEP_ALIVE("keep-alive"), + DOWNLOAD_STARTED("download-started"), + FILE_CHUNK("file-chunk"), + DOWNLOAD_COMPLETE("download-complete"), + ESTIMATE_COMPLETE("estimate-complete"), + ESTIMATE_FAILED("estimate-failed"), + ERROR("error"); + + private final String value; + + SseEventName(String value) { + this.value = value; + } +} diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java index 9664982d..0615da6b 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/DasService.java @@ -36,57 +36,53 @@ public void init() { httpEntity = new HttpEntity<>(headers); } - public byte[] getWaveBuoys(String from, String to){ + public byte[] getWaveBuoys(String from, String to) { String waveBuoysUrlTemplate = UriComponentsBuilder.fromUriString(dasConfig.host + "/api/v1/das/data/feature-collection/wave-buoy") - .queryParam("start_date","{start_date}") - .queryParam("end_date","{end_date}") + .queryParam("start_date", "{start_date}") + .queryParam("end_date", "{end_date}") .encode() .toUriString(); - Map params = new HashMap<>(); + Map params = new HashMap<>(); params.put("start_date", from); - params.put("end_date",to); + params.put("end_date", to); - return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET,httpEntity,byte[].class,params).getBody(); + return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET, httpEntity, byte[].class, params).getBody(); } - public byte[] getWaveBuoysLatestDate(){ + public byte[] getWaveBuoysLatestDate() { String waveBuoysUrlTemplate = UriComponentsBuilder.fromUriString(dasConfig.host + "/api/v1/das/data/feature-collection/wave-buoy/latest") .encode() .toUriString(); - return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET,httpEntity,byte[].class).getBody(); + return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET, httpEntity, byte[].class).getBody(); } - public byte[] getWaveBuoyData(String from, String to, String buoy){ + public byte[] getWaveBuoyData(String from, String to, String buoy) { String encodedBuoy = URLEncoder.encode(buoy, java.nio.charset.StandardCharsets.UTF_8); String waveBuoyDataUrlTemplate = UriComponentsBuilder.fromUriString(dasConfig.host + "/api/v1/das/data/feature-collection/wave-buoy/" + encodedBuoy) - .queryParam("start_date","{start_date}") - .queryParam("end_date","{end_date}") + .queryParam("start_date", "{start_date}") + .queryParam("end_date", "{end_date}") .encode() .toUriString(); - Map params = new HashMap<>(); + Map params = new HashMap<>(); params.put("start_date", from); - params.put("end_date",to); + params.put("end_date", to); - return httpClient.exchange(waveBuoyDataUrlTemplate, HttpMethod.GET,httpEntity,byte[].class,params).getBody(); + return httpClient.exchange(waveBuoyDataUrlTemplate, HttpMethod.GET, httpEntity, byte[].class, params).getBody(); } - public byte[] getLatestWaveBuoySites(){ + public byte[] getLatestWaveBuoySites() { String waveBuoysUrlTemplate = UriComponentsBuilder.fromUriString(dasConfig.host + "/api/v1/das/data/feature-collection/wave-buoy/all") .encode() .toUriString(); - return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET,httpEntity,byte[].class).getBody(); + return httpClient.exchange(waveBuoysUrlTemplate, HttpMethod.GET, httpEntity, byte[].class).getBody(); } /** * Call the data-access-service cloud-optimised size estimate endpoint. - * - * POST /api/v1/das/data/{uuid}/estimate_size with a JSON body matching - * EstimateSizeRequest. The endpoint is multi-key and aggregates server-side: - * a null/"*" keys list means "all keys of the uuid" (same as the batch - * download). Returns the raw JSON response body so the SSE layer can forward + * Returns the raw JSON response body so the SSE layer can forward * it to the frontend unchanged. */ public String estimateCloudOptimisedDownloadSize( @@ -102,11 +98,8 @@ public String estimateCloudOptimisedDownloadSize( .encode() .toUriString(); - // Body mirrors EstimateSizeRequest. Send the raw frontend date strings - // (or "non-specified" when null) so data-access-service applies the same - // resolve/supply/trim chain the batch download uses. Map body = new HashMap<>(); - body.put("keys", keys); // null => all keys of the uuid + body.put("keys", keys); body.put("start_date", startDate != null ? startDate : "non-specified"); body.put("end_date", endDate != null ? endDate : "non-specified"); body.put("output_format", outputFormat); @@ -133,7 +126,7 @@ public String estimateCloudOptimisedDownloadSize( return httpClient.exchange(url, HttpMethod.POST, entity, String.class, uriVars).getBody(); } - public boolean isCollectionSupported(String collectionId){ + public boolean isCollectionSupported(String collectionId) { final String waveBuoyRealtimeCollectionID = "b299cdcd-3dee-48aa-abdd-e0fcdbb9cadc"; return waveBuoyRealtimeCollectionID.contentEquals(collectionId); } diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java index 08485203..a7779349 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java @@ -1,6 +1,7 @@ package au.org.aodn.ogcapi.server.core.service.geoserver.wfs; import au.org.aodn.ogcapi.server.core.configuration.CacheConfig; +import au.org.aodn.ogcapi.server.core.model.enumeration.SseEventName; import au.org.aodn.ogcapi.server.core.model.ogc.FeatureRequest; import au.org.aodn.ogcapi.server.core.util.DatetimeUtils; import com.fasterxml.jackson.databind.JsonNode; @@ -191,9 +192,9 @@ public BigInteger estimateDownloadSize( } /** - * Execute WFS request with SSE support + * Call the WFS server and stream the downloaded data to the client over SSE */ - public void executeWfsRequestWithSse( + public void streamWfsDataWithSse( String wfsRequestUrl, String uuid, String layerName, @@ -213,7 +214,7 @@ public void executeWfsRequestWithSse( // Send download started confirmation emitter.send(SseEmitter.event() - .name("download-started") + .name(SseEventName.DOWNLOAD_STARTED.getValue()) .data(Map.of( "message", "WFS server responded, starting data stream...", "timestamp", System.currentTimeMillis() @@ -236,7 +237,7 @@ public void executeWfsRequestWithSse( String encodedData = Base64.getEncoder().encodeToString(chunkBytes); emitter.send(SseEmitter.event() - .name("file-chunk") + .name(SseEventName.FILE_CHUNK.getValue()) .data(Map.of( "chunkNumber", ++chunkNumber, "data", encodedData, @@ -254,7 +255,7 @@ public void executeWfsRequestWithSse( if (chunkBuffer.size() > 0) { String encodedData = Base64.getEncoder().encodeToString(chunkBuffer.toByteArray()); emitter.send(SseEmitter.event() - .name("file-chunk") + .name(SseEventName.FILE_CHUNK.getValue()) .data(Map.of( "chunkNumber", ++chunkNumber, "data", encodedData, @@ -266,7 +267,7 @@ public void executeWfsRequestWithSse( // Send completion event emitter.send(SseEmitter.event() - .name("download-complete") + .name(SseEventName.DOWNLOAD_COMPLETE.getValue()) .data(Map.of( "totalBytes", totalBytes, "totalChunks", chunkNumber, diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseSession.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseSession.java new file mode 100644 index 00000000..e2b52d06 --- /dev/null +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseSession.java @@ -0,0 +1,91 @@ +package au.org.aodn.ogcapi.server.core.service.sse; + +import au.org.aodn.ogcapi.server.core.exception.wfs.WfsErrorHandler; +import au.org.aodn.ogcapi.server.core.model.enumeration.SseEventName; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +/** + * A single SSE stream's runtime state: the underlying {@link SseEmitter}, an + * optional keep-alive ticker, and the cleanup of those resources. + *

+ * Created and managed by {@link SseStreamHandler}; the work lambda receives one + * to send events and (optionally) start a keep-alive. + */ +@Slf4j +public class SseSession { + + private final String contextId; + + @Getter + private final SseEmitter emitter; + + private final AtomicReference> keepAliveTaskRef = new AtomicReference<>(); + private final AtomicReference keepAliveExecutorRef = new AtomicReference<>(); + + public SseSession(String contextId, SseEmitter emitter) { + this.contextId = contextId; + this.emitter = emitter; + } + + /** + * Send a named SSE event with the given payload. + */ + public void send(SseEventName eventName, Object data) throws IOException { + emitter.send(SseEmitter.event().name(eventName.getValue()).data(data)); + } + + /** + * Start sending a {@code keep-alive} event every {@code intervalSeconds}. The + * payload is recomputed each tick by {@code payloadSupplier} so callers can + * reflect changing state (e.g. whether an upstream server has responded yet). + */ + public void startKeepAlive(long intervalSeconds, Supplier payloadSupplier) { + ScheduledExecutorService keepAliveExecutor = Executors.newSingleThreadScheduledExecutor(); + ScheduledFuture keepAliveTask = keepAliveExecutor.scheduleAtFixedRate(() -> { + try { + send(SseEventName.KEEP_ALIVE, payloadSupplier.get()); + } catch (Exception e) { + WfsErrorHandler.handleError(e, contextId, emitter, this::cleanup); + } + }, intervalSeconds, intervalSeconds, TimeUnit.SECONDS); + + keepAliveTaskRef.set(keepAliveTask); + keepAliveExecutorRef.set(keepAliveExecutor); + } + + /** + * Complete the stream, closing the connection to the client. + */ + public void complete() { + emitter.complete(); + } + + /** + * Cancel the keep-alive task and shut down its executor. Idempotent. + */ + public void cleanup() { + try { + ScheduledFuture keepAliveTask = keepAliveTaskRef.get(); + if (keepAliveTask != null && !keepAliveTask.isCancelled()) { + keepAliveTask.cancel(false); + } + + ScheduledExecutorService keepAliveExecutor = keepAliveExecutorRef.get(); + if (keepAliveExecutor != null && !keepAliveExecutor.isShutdown()) { + keepAliveExecutor.shutdown(); + } + } catch (Exception e) { + log.error("Error during cleanup for SSE stream: {}", contextId, e); + } + } +} diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseStreamHandler.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseStreamHandler.java new file mode 100644 index 00000000..d341871b --- /dev/null +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/sse/SseStreamHandler.java @@ -0,0 +1,71 @@ +package au.org.aodn.ogcapi.server.core.service.sse; + +import au.org.aodn.ogcapi.server.core.exception.wfs.WfsErrorHandler; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.concurrent.CompletableFuture; + +/** + * Shared scaffolding for the long-running SSE endpoints (WFS download / estimate, + * cloud-optimised estimate). It owns the boilerplate that every stream needs — + * emitter creation, lifecycle callbacks, resource cleanup, and error handling — + * so callers only supply the actual work. + */ +@Slf4j +public class SseStreamHandler { + + private SseStreamHandler() { + } + + /** + * Work executed against an {@link SseSession}. Allowed to throw so callers can + * let {@code emitter.send(...)} (which throws {@link java.io.IOException}) + * propagate to the shared error handler. + */ + @FunctionalInterface + public interface SseWork { + void run(SseSession session) throws Exception; + } + + /** + * Create an SSE stream and run {@code work} asynchronously against it. + *

+ * A never-timing-out {@link SseEmitter} is created, lifecycle callbacks are + * wired to clean up the keep-alive resources, and any exception from the work + * (including validation errors thrown at the start) is routed through + * {@link WfsErrorHandler}. The work is responsible for completing the stream + * once its result has been sent. + * + * @param contextId identifier (e.g. uuid) used for logging and error handling + * @param work the per-stream logic: send events, optionally start keep-alive + * @return the emitter to return from the controller + */ + public static SseEmitter stream(String contextId, SseWork work) { + final SseEmitter emitter = new SseEmitter(0L); + final SseSession session = new SseSession(contextId, emitter); + + emitter.onCompletion(() -> { + log.info("SSE stream completion for {}", contextId); + session.cleanup(); + }); + + emitter.onTimeout(() -> { + log.warn("SSE stream timed out for {}", contextId); + session.cleanup(); + }); + + emitter.onError(throwable -> + WfsErrorHandler.handleError((Exception) throwable, contextId, emitter, session::cleanup)); + + CompletableFuture.runAsync(() -> { + try { + work.run(session); + } catch (Exception e) { + WfsErrorHandler.handleError(e, contextId, emitter, session::cleanup); + } + }); + + return emitter; + } +} diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java index 27ec97ff..13968895 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestApi.java @@ -9,7 +9,6 @@ import au.org.aodn.ogcapi.server.core.model.enumeration.DatasetDownloadEnums; import au.org.aodn.ogcapi.server.core.model.enumeration.InlineResponseKeyEnum; import au.org.aodn.ogcapi.server.core.model.enumeration.ProcessIdEnum; -import au.org.aodn.ogcapi.server.core.model.ogc.FeatureRequest; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.enums.ParameterIn; import io.swagger.v3.oas.annotations.media.Schema; @@ -38,8 +37,8 @@ public class RestApi implements ProcessesApi { // cause exception thrown sometimes. So i re-declared the produces value here @RequestMapping( value = "/processes/{processID}/execution", - produces = { MediaType.APPLICATION_JSON_VALUE, MediaType.TEXT_HTML_VALUE }, - consumes = { MediaType.APPLICATION_JSON_VALUE }, + produces = {MediaType.APPLICATION_JSON_VALUE, MediaType.TEXT_HTML_VALUE}, + consumes = {MediaType.APPLICATION_JSON_VALUE}, method = RequestMethod.POST ) public ResponseEntity execute( @@ -118,8 +117,8 @@ public ResponseEntity getProcesses() { */ @RequestMapping( value = "/processes/{processID}/execution", - produces = { MediaType.TEXT_EVENT_STREAM_VALUE }, - consumes = { MediaType.APPLICATION_JSON_VALUE }, + produces = {MediaType.TEXT_EVENT_STREAM_VALUE}, + consumes = {MediaType.APPLICATION_JSON_VALUE}, method = RequestMethod.POST ) public SseEmitter executeSse( @@ -140,41 +139,43 @@ public SseEmitter executeSse( ProcessIdEnum id = ProcessIdEnum.fromString(processId); - // Cloud-optimised (zarr/parquet) size estimate uses uuid + key, not a - // WFS layer name, so handle it before the WFS-specific validation below. - if (id == ProcessIdEnum.DOWNLOAD_CO_ESTIMATE) { - return restServices.estimateCloudOptimisedDownloadWithSse( - uuid, key, startDate, endDate, multiPolygon, null, outputFormat - ); - } - - if(uuid == null || layerName == null) { - throw new IllegalArgumentException("UUID and LayerName cannot null"); - } - SseEmitter emitter; switch (id) { - case DOWNLOAD_WFS_SSE: + case DOWNLOAD_WFS_SSE: { + emitter = restServices.downloadWfsDataWithSse( + uuid, + startDate, + endDate, + multiPolygon, + fields, + layerName, + outputFormat + ); + break; + } case DOWNLOAD_WFS_ESTIMATE: { - if(FeatureRequest.GeoServerOutputFormat.fromString(outputFormat) == FeatureRequest.GeoServerOutputFormat.UNKNOWN) { - emitter = new SseEmitter(0L); - emitter.completeWithError(new BadRequestException( - String.format("Missing output format [%s]", processId) - )); - } - else { - emitter = restServices.downloadWfsDataWithSse( - uuid, - startDate, - endDate, - multiPolygon, - fields, - layerName, - outputFormat, - id == ProcessIdEnum.DOWNLOAD_WFS_ESTIMATE - ); - } + emitter = restServices.estimateWfsDownloadWithSse( + uuid, + startDate, + endDate, + multiPolygon, + fields, + layerName, + outputFormat + ); + break; + } + case DOWNLOAD_CO_ESTIMATE: { + emitter = restServices.estimateCloudOptimisedDownloadWithSse( + uuid, + key, + startDate, + endDate, + multiPolygon, + null, + outputFormat + ); break; } default: { diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java index 0710394c..d10b71be 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/processes/RestServices.java @@ -1,9 +1,11 @@ package au.org.aodn.ogcapi.server.processes; -import au.org.aodn.ogcapi.server.core.exception.wfs.WfsErrorHandler; import au.org.aodn.ogcapi.server.core.model.enumeration.DatasetDownloadEnums; +import au.org.aodn.ogcapi.server.core.model.enumeration.SseEventName; +import au.org.aodn.ogcapi.server.core.model.ogc.FeatureRequest; import au.org.aodn.ogcapi.server.core.service.DasService; import au.org.aodn.ogcapi.server.core.service.geoserver.wfs.DownloadWfsDataService; +import au.org.aodn.ogcapi.server.core.service.sse.SseStreamHandler; import au.org.aodn.ogcapi.server.core.util.EmailUtils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; @@ -25,9 +27,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; @Slf4j public class RestServices { @@ -104,7 +104,7 @@ public ResponseEntity downloadData( if (polygons == null || polygons.toString().isEmpty()) { throw new IllegalArgumentException("Polygons parameter should now be null. If users didn't specify polygons, a 'non-specified' should be sent."); - // String (e.g. "non-specified") is working weird with function ObjectMapper.writeValueAsString(), so handle it separately + // String (e.g. "non-specified") is working weird with function ObjectMapper.writeValueAsString(), so handle it separately } else if (polygons.toString().equals("non-specified")) { parameters.put(DatasetDownloadEnums.Parameter.MULTI_POLYGON.getValue(), polygons.toString()); } else { @@ -203,156 +203,100 @@ public SseEmitter downloadWfsDataWithSse(String uuid, Object multiPolygon, List fields, String layerName, - String outputFormat, - boolean estimateSizeOnly) { - - final SseEmitter emitter = new SseEmitter(0L); - - // Set up references for resources that need to be cleaned up - AtomicReference> keepAliveTaskRef = new AtomicReference<>(); - AtomicReference keepAliveExecutorRef = new AtomicReference<>(); - - // Set up cleanup function to clear up resources - Runnable cleanupWfsResources = () -> { - try { - ScheduledFuture keepAliveTask = keepAliveTaskRef.get(); - if (keepAliveTask != null && !keepAliveTask.isCancelled()) { - keepAliveTask.cancel(false); - } - - ScheduledExecutorService keepAliveExecutor = keepAliveExecutorRef.get(); - if (keepAliveExecutor != null && !keepAliveExecutor.isShutdown()) { - keepAliveExecutor.shutdown(); - } - } catch (Exception e) { - log.error("Error during cleanup for UUID: {}", uuid, e); - } - }; - - // Set up emitter callbacks - emitter.onCompletion(() -> { - log.info("WFS SSE stream completion"); - cleanupWfsResources.run(); - }); + String outputFormat) { + + return SseStreamHandler.stream(uuid, session -> { + validateWfsSseInputs(uuid, layerName, outputFormat); + + SseEmitter emitter = session.getEmitter(); + + // STEP 1: Send connection established event + session.send(SseEventName.CONNECTION_ESTABLISHED, Map.of( + "message", "Starting WFS download for UUID: " + uuid, + "timestamp", System.currentTimeMillis() + )); + + // STEP 2: Start keep-alive mechanism for WFS server wait time. The + // payload reflects whether the WFS server has started responding. + AtomicBoolean wfsServerResponded = new AtomicBoolean(false); + session.startKeepAlive(20, () -> Map.of( + "message", wfsServerResponded.get() ? + "WFS data streaming in progress..." : "Waiting for WFS server response...", + "timestamp", System.currentTimeMillis() + )); + + // STEP 3: Do preparation work: Collection lookup from Elasticsearch, WFS validation, Field retrieval, URL building + String wfsRequestUrl = downloadWfsDataService.prepareWfsRequestUrl( + uuid, startDate, endDate, multiPolygon, fields, layerName, outputFormat, -1L, false + ); - emitter.onTimeout(() -> { - log.warn("WFS SSE stream timed out"); - cleanupWfsResources.run(); + // STEP 4: Make the WFS call: Streaming the response directly to client via SSE + downloadWfsDataService.streamWfsDataWithSse( + wfsRequestUrl, + uuid, + layerName, + outputFormat, + emitter, + wfsServerResponded + ); }); + } - emitter.onError(throwable -> WfsErrorHandler.handleError((Exception) throwable, uuid, emitter, cleanupWfsResources)); - - // Validate parameters - if (uuid == null || layerName == null || layerName.trim().isEmpty()) { - IllegalArgumentException exception = new IllegalArgumentException("Layer name and Uuid are required"); - WfsErrorHandler.handleError(exception, uuid, emitter, cleanupWfsResources); - return emitter; - } - - // Start async download with SSE progress updates - CompletableFuture.runAsync(() -> { + /** + * Estimate the download size of a WFS (GeoServer) subset over SSE. + */ + public SseEmitter estimateWfsDownloadWithSse(String uuid, + String startDate, + String endDate, + Object multiPolygon, + List fields, + String layerName, + String outputFormat) { + + return SseStreamHandler.stream(uuid, session -> { + validateWfsSseInputs(uuid, layerName, outputFormat); + + // STEP 1: Send connection established event + session.send(SseEventName.CONNECTION_ESTABLISHED, Map.of( + "message", "Starting WFS download size estimate for UUID: " + uuid, + "timestamp", System.currentTimeMillis() + )); + + // STEP 2: Start keep-alive mechanism while waiting for the WFS server + session.startKeepAlive(20, () -> Map.of( + "message", "Waiting for WFS server response...", + "timestamp", System.currentTimeMillis() + )); + + // STEP 3: Compute the size estimate and forward it as a single event try { - // STEP 1: Send connection established event - emitter.send(SseEmitter.event() - .name("connection-established") - .data(Map.of( - "status", "connected", - "message", "Starting WFS download for UUID: " + uuid, - "timestamp", System.currentTimeMillis() - ))); - - // STEP 2: Start keep-alive mechanism for WFS server wait time - ScheduledExecutorService keepAliveExecutor = Executors.newSingleThreadScheduledExecutor(); - AtomicBoolean wfsServerResponded = new AtomicBoolean(false); - - // Send keep-alive every 20 seconds - ScheduledFuture keepAliveTask = keepAliveExecutor.scheduleAtFixedRate(() -> { - try { - String status = wfsServerResponded.get() ? "streaming" : "waiting-for-wfs-server"; - emitter.send(SseEmitter.event() - .name("keep-alive") - .data(Map.of( - "status", status, - "timestamp", System.currentTimeMillis(), - "message", wfsServerResponded.get() ? - "WFS data streaming in progress..." : "Waiting for WFS server response..." - ))); - } catch (Exception e) { - WfsErrorHandler.handleError(e, uuid, emitter, cleanupWfsResources); - } - }, 20, 20, TimeUnit.SECONDS); - - keepAliveTaskRef.set(keepAliveTask); - keepAliveExecutorRef.set(keepAliveExecutor); - - emitter.send(SseEmitter.event() - .name("wfs-request-ready") - .data(Map.of( - "message", "Connecting to WFS server...", - "timestamp", System.currentTimeMillis() - ))); - - if(!estimateSizeOnly) { - // STEP 3: Do preparation work: Collection lookup from Elasticsearch, WFS validation, Field retrieval, URL building - String wfsRequestUrl = downloadWfsDataService.prepareWfsRequestUrl( - uuid, startDate, endDate, multiPolygon, fields, layerName, outputFormat, -1L, false - ); - - // STEP 4: Make the WFS call: Streaming the response directly to client via SSE - downloadWfsDataService.executeWfsRequestWithSse( - wfsRequestUrl, - uuid, - layerName, - outputFormat, - emitter, - wfsServerResponded - ); - } - else { - try { - BigInteger est = downloadWfsDataService.estimateDownloadSize( - uuid, - layerName, - startDate, - endDate, - multiPolygon, - fields, - outputFormat - ); - emitter.send(SseEmitter.event() - .name(est != null ? "estimate-complete" : "estimate-failed") - .data(Map.of( - "size", est != null ? est : "", - "timestamp", System.currentTimeMillis() - ))); - } - catch(Exception e) { - log.warn("Unexpected error during size estimation for UUID {}: {}", uuid, e.getMessage()); - emitter.send(SseEmitter.event() - .name("estimate-failed") - .data(Map.of( - "message", "Size estimation failed: " + e.getMessage(), - "timestamp", System.currentTimeMillis() - ))); - } - finally { - emitter.complete(); - } - } + BigInteger est = downloadWfsDataService.estimateDownloadSize( + uuid, + layerName, + startDate, + endDate, + multiPolygon, + fields, + outputFormat + ); + session.send(est != null ? SseEventName.ESTIMATE_COMPLETE : SseEventName.ESTIMATE_FAILED, Map.of( + "size", est != null ? est : "", + "timestamp", System.currentTimeMillis() + )); } catch (Exception e) { - WfsErrorHandler.handleError(e, uuid, emitter, cleanupWfsResources); + log.warn("Unexpected error during size estimation for UUID {}: {}", uuid, e.getMessage()); + session.send(SseEventName.ESTIMATE_FAILED, Map.of( + "message", "Size estimation failed: " + e.getMessage(), + "timestamp", System.currentTimeMillis() + )); + } finally { + session.complete(); } }); - return emitter; } /** * Estimate the download size of a cloud-optimised (zarr/parquet) subset over SSE. - *

- * Opens an SSE stream, calls the data-access-service estimate endpoint with the - * same parameters that drive the batch download (so the estimate matches what - * the real download would produce), forwards the size estimate as a single - * event, then closes the stream. */ public SseEmitter estimateCloudOptimisedDownloadWithSse(String uuid, String key, @@ -362,109 +306,56 @@ public SseEmitter estimateCloudOptimisedDownloadWithSse(String uuid, List columns, String outputFormat) { - final SseEmitter emitter = new SseEmitter(0L); - - // Set up references for resources that need to be cleaned up - AtomicReference> keepAliveTaskRef = new AtomicReference<>(); - AtomicReference keepAliveExecutorRef = new AtomicReference<>(); - - // Set up cleanup function to clear up resources - Runnable cleanupResources = () -> { - try { - ScheduledFuture keepAliveTask = keepAliveTaskRef.get(); - if (keepAliveTask != null && !keepAliveTask.isCancelled()) { - keepAliveTask.cancel(false); - } - - ScheduledExecutorService keepAliveExecutor = keepAliveExecutorRef.get(); - if (keepAliveExecutor != null && !keepAliveExecutor.isShutdown()) { - keepAliveExecutor.shutdown(); - } - } catch (Exception e) { - log.error("Error during cleanup for CO estimate UUID: {}", uuid, e); + return SseStreamHandler.stream(uuid, session -> { + // Validate parameters + if (uuid == null || outputFormat == null) { + throw new IllegalArgumentException("Missing uuid or output format"); } - }; - - emitter.onCompletion(() -> { - log.info("CO estimate SSE stream completion"); - cleanupResources.run(); - }); - - emitter.onTimeout(() -> { - log.warn("CO estimate SSE stream timed out"); - cleanupResources.run(); - }); - emitter.onError(throwable -> WfsErrorHandler.handleError((Exception) throwable, uuid, emitter, cleanupResources)); + // batch-style key handling: comma-separated string -> list; null/blank/"*" -> null + List keys = (key == null || key.isBlank() || key.equals("*")) + ? null + : Arrays.stream(key.split(",")).map(String::trim).toList(); - // Validate parameters - if (uuid == null) { - IllegalArgumentException exception = new IllegalArgumentException("Uuid is required"); - WfsErrorHandler.handleError(exception, uuid, emitter, cleanupResources); - return emitter; - } + // STEP 1: Send connection established event + session.send(SseEventName.CONNECTION_ESTABLISHED, Map.of( + "message", "Starting cloud-optimised size estimate for UUID: " + uuid, + "timestamp", System.currentTimeMillis() + )); - // batch-style key handling: comma-separated string -> list; null/blank/"*" - // -> null, which the data-access-service expands to all keys of the uuid. - final List keys = (key == null || key.isBlank() || key.equals("*")) - ? null - : Arrays.stream(key.split(",")).map(String::trim).toList(); + // STEP 2: Start keep-alive mechanism while data-access-service computes the estimate + session.startKeepAlive(20, () -> Map.of( + "message", "Estimating download size...", + "timestamp", System.currentTimeMillis() + )); - // Start async estimate with SSE progress updates - CompletableFuture.runAsync(() -> { + // STEP 3: Call the data-access-service estimate endpoint and forward the result try { - // STEP 1: Send connection established event - emitter.send(SseEmitter.event() - .name("connection-established") - .data(Map.of( - "status", "connected", - "message", "Starting cloud-optimised size estimate for UUID: " + uuid, - "timestamp", System.currentTimeMillis() - ))); - - // STEP 2: Start keep-alive mechanism while data-access-service computes the estimate - ScheduledExecutorService keepAliveExecutor = Executors.newSingleThreadScheduledExecutor(); - ScheduledFuture keepAliveTask = keepAliveExecutor.scheduleAtFixedRate(() -> { - try { - emitter.send(SseEmitter.event() - .name("keep-alive") - .data(Map.of( - "status", "estimating", - "message", "Estimating download size...", - "timestamp", System.currentTimeMillis() - ))); - } catch (Exception e) { - WfsErrorHandler.handleError(e, uuid, emitter, cleanupResources); - } - }, 20, 20, TimeUnit.SECONDS); - - keepAliveTaskRef.set(keepAliveTask); - keepAliveExecutorRef.set(keepAliveExecutor); - - // STEP 3: Call the data-access-service estimate endpoint and forward the result - try { - String estimateJson = dasService.estimateCloudOptimisedDownloadSize( - uuid, keys, startDate, endDate, multiPolygon, columns, outputFormat - ); - emitter.send(SseEmitter.event() - .name("estimate-complete") - .data(estimateJson)); - } catch (Exception e) { - log.warn("Cloud-optimised size estimation failed for UUID {}: {}", uuid, e.getMessage()); - emitter.send(SseEmitter.event() - .name("estimate-failed") - .data(Map.of( - "message", "Size estimation failed: " + e.getMessage(), - "timestamp", System.currentTimeMillis() - ))); - } finally { - emitter.complete(); - } + String estimateJson = dasService.estimateCloudOptimisedDownloadSize( + uuid, keys, startDate, endDate, multiPolygon, columns, outputFormat + ); + session.send(SseEventName.ESTIMATE_COMPLETE, estimateJson); } catch (Exception e) { - WfsErrorHandler.handleError(e, uuid, emitter, cleanupResources); + log.warn("Cloud-optimised size estimation failed for UUID {}: {}", uuid, e.getMessage()); + session.send(SseEventName.ESTIMATE_FAILED, Map.of( + "message", "Size estimation failed: " + e.getMessage(), + "timestamp", System.currentTimeMillis() + )); + } finally { + session.complete(); } }); + } - return emitter; + /** + * Shared input validation for the two WFS SSE flows. + */ + private void validateWfsSseInputs(String uuid, String layerName, String outputFormat) { + if (uuid == null || layerName == null || layerName.trim().isEmpty()) { + throw new IllegalArgumentException("Layer name and Uuid are required"); + } + if (FeatureRequest.GeoServerOutputFormat.fromString(outputFormat) == FeatureRequest.GeoServerOutputFormat.UNKNOWN) { + throw new IllegalArgumentException(String.format("Missing output format [%s]", outputFormat)); + } } } diff --git a/server/src/test/java/au/org/aodn/ogcapi/server/service/wfs/DownloadWfsDataServiceTest.java b/server/src/test/java/au/org/aodn/ogcapi/server/service/wfs/DownloadWfsDataServiceTest.java index 00dedfbe..1dc96d71 100644 --- a/server/src/test/java/au/org/aodn/ogcapi/server/service/wfs/DownloadWfsDataServiceTest.java +++ b/server/src/test/java/au/org/aodn/ogcapi/server/service/wfs/DownloadWfsDataServiceTest.java @@ -84,7 +84,7 @@ void verifyDecodeTextCorrectlyForSSE() throws Exception { }).when(emitter).send(any(SseEmitter.SseEventBuilder.class)); - service.executeWfsRequestWithSse( + service.streamWfsDataWithSse( "http://mock/wfs?...", "uuid-123", "layer:test", @@ -194,7 +194,7 @@ void verifyDecodeBinaryCorrectlyForSSE() throws Exception { }).when(emitter).send(any(SseEmitter.SseEventBuilder.class)); - service.executeWfsRequestWithSse( + service.streamWfsDataWithSse( "http://mock/wfs?...", "uuid-123", "layer:test", From 5b3feee6e725594c348db37e2193d8c7d03a0a06 Mon Sep 17 00:00:00 2001 From: Lyn Long Date: Wed, 17 Jun 2026 10:14:29 +1000 Subject: [PATCH 3/3] add test --- .../server/core/service/DasServiceTest.java | 111 ++++++++ .../server/processes/RestApiSseTest.java | 266 ++++++++++++++++++ 2 files changed, 377 insertions(+) create mode 100644 server/src/test/java/au/org/aodn/ogcapi/server/core/service/DasServiceTest.java create mode 100644 server/src/test/java/au/org/aodn/ogcapi/server/processes/RestApiSseTest.java diff --git a/server/src/test/java/au/org/aodn/ogcapi/server/core/service/DasServiceTest.java b/server/src/test/java/au/org/aodn/ogcapi/server/core/service/DasServiceTest.java new file mode 100644 index 00000000..9808358b --- /dev/null +++ b/server/src/test/java/au/org/aodn/ogcapi/server/core/service/DasServiceTest.java @@ -0,0 +1,111 @@ +package au.org.aodn.ogcapi.server.core.service; + +import au.org.aodn.ogcapi.server.core.configuration.DASConfig; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.HttpClientErrorException; +import org.springframework.web.client.RestTemplate; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class DasServiceTest { + + @Mock + private RestTemplate httpClient; + + private DasService dasService; + + @BeforeEach + public void setUp() { + DASConfig dasConfig = new DASConfig(); + dasConfig.host = "http://das-test-host"; + dasConfig.secret = "test-secret"; + + dasService = new DasService(); + dasService.dasConfig = dasConfig; + dasService.httpClient = httpClient; + } + + private HttpEntity> callAndCaptureEntity( + String uuid, List keys, String startDate, String endDate, + Object multiPolygon, List columns, String outputFormat) { + + when(httpClient.exchange(anyString(), eq(HttpMethod.POST), any(HttpEntity.class), eq(String.class), anyMap())) + .thenReturn(ResponseEntity.ok("{\"estimated_output_bytes\":123}")); + + String result = dasService.estimateCloudOptimisedDownloadSize( + uuid, keys, startDate, endDate, multiPolygon, columns, outputFormat); + assertEquals("{\"estimated_output_bytes\":123}", result, "Raw das JSON should be returned unchanged"); + + var entityCaptor = org.mockito.ArgumentCaptor.forClass(HttpEntity.class); + var urlCaptor = org.mockito.ArgumentCaptor.forClass(String.class); + var uriVarsCaptor = org.mockito.ArgumentCaptor.forClass(Map.class); + verify(httpClient).exchange(urlCaptor.capture(), eq(HttpMethod.POST), entityCaptor.capture(), eq(String.class), uriVarsCaptor.capture()); + + assertEquals("http://das-test-host/api/v1/das/data/{uuid}/estimate_size", urlCaptor.getValue()); + assertEquals(uuid, uriVarsCaptor.getValue().get("uuid")); + + return (HttpEntity>) entityCaptor.getValue(); + } + + @Test + public void testEstimatePostsJsonBodyWithApiKey() { + HttpEntity> entity = callAndCaptureEntity( + "test-uuid", List.of("a.zarr", "b.zarr"), "2023-01-01", "2023-01-31", + "non-specified", null, "netcdf"); + + HttpHeaders headers = entity.getHeaders(); + assertEquals("test-secret", headers.getFirst("X-API-KEY")); + assertEquals(MediaType.APPLICATION_JSON_VALUE, headers.getFirst(HttpHeaders.CONTENT_TYPE)); + + Map body = entity.getBody(); + assertNotNull(body); + assertEquals(List.of("a.zarr", "b.zarr"), body.get("keys")); + assertEquals("2023-01-01", body.get("start_date")); + assertEquals("2023-01-31", body.get("end_date")); + assertEquals("netcdf", body.get("output_format")); + assertEquals("non-specified", body.get("multi_polygon")); + assertFalse(body.containsKey("columns"), "columns must be omitted when not provided"); + } + + @Test + public void testEstimateNullDatesBecomeNonSpecifiedAndOptionalFieldsOmitted() { + HttpEntity> entity = callAndCaptureEntity( + "test-uuid", null, null, null, null, null, "csv"); + + Map body = entity.getBody(); + assertNotNull(body); + assertNull(body.get("keys"), "null keys means all keys of the uuid"); + assertEquals("non-specified", body.get("start_date")); + assertEquals("non-specified", body.get("end_date")); + assertEquals("csv", body.get("output_format")); + assertFalse(body.containsKey("multi_polygon"), "multi_polygon must be omitted when null"); + assertFalse(body.containsKey("columns"), "columns must be omitted when null"); + } + + @Test + public void testEstimateNon2xxPropagates() { + when(httpClient.exchange(anyString(), eq(HttpMethod.POST), any(HttpEntity.class), eq(String.class), anyMap())) + .thenThrow(HttpClientErrorException.create(HttpStatus.NOT_FOUND, "Not Found", HttpHeaders.EMPTY, null, null)); + + assertThrows(HttpClientErrorException.class, () -> + dasService.estimateCloudOptimisedDownloadSize( + "bad-uuid", List.of("missing.zarr"), null, null, null, null, "netcdf")); + } +} diff --git a/server/src/test/java/au/org/aodn/ogcapi/server/processes/RestApiSseTest.java b/server/src/test/java/au/org/aodn/ogcapi/server/processes/RestApiSseTest.java new file mode 100644 index 00000000..333051d7 --- /dev/null +++ b/server/src/test/java/au/org/aodn/ogcapi/server/processes/RestApiSseTest.java @@ -0,0 +1,266 @@ +package au.org.aodn.ogcapi.server.processes; + +import au.org.aodn.ogcapi.server.core.model.enumeration.DatasetDownloadEnums; +import au.org.aodn.ogcapi.server.core.model.enumeration.ProcessIdEnum; +import au.org.aodn.ogcapi.server.core.service.DasService; +import au.org.aodn.ogcapi.server.core.service.geoserver.wfs.DownloadWfsDataService; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.MediaType; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import software.amazon.awssdk.services.batch.BatchClient; + +import java.math.BigInteger; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; + +/** + * Drives the SSE execution endpoint end-to-end (RestApi dispatch + RestServices + * SSE flows) with the downstream services mocked, asserting the events written + * to the stream. + */ +@ExtendWith(MockitoExtension.class) +public class RestApiSseTest { + + @Mock + private BatchClient batchClient; + + @Mock + private DownloadWfsDataService downloadWfsDataService; + + @Mock + private DasService dasService; + + private MockMvc mockMvc; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + public void setUp() { + RestServices restServices = new RestServices(batchClient, objectMapper); + ReflectionTestUtils.setField(restServices, "downloadWfsDataService", downloadWfsDataService); + ReflectionTestUtils.setField(restServices, "dasService", dasService); + + RestApi restApi = new RestApi(); + ReflectionTestUtils.setField(restApi, "restServices", restServices); + + mockMvc = MockMvcBuilders.standaloneSetup(restApi).build(); + } + + private MockHttpServletResponse postSse(String processId, Map inputs) throws Exception { + String body = objectMapper.writeValueAsString(Map.of("inputs", inputs)); + return mockMvc.perform(post("/api/v1/ogc/processes/{processID}/execution", processId) + .contentType(MediaType.APPLICATION_JSON) + .accept(MediaType.TEXT_EVENT_STREAM) + .content(body)) + .andExpect(request().asyncStarted()) + .andReturn() + .getResponse(); + } + + /** + * The SSE work runs on a separate thread, so poll the mock response until the + * expected marker shows up (or time out and let the caller's assert fail with + * the content collected so far). + */ + private String awaitContent(MockHttpServletResponse response, String expectedMarker) throws Exception { + long deadline = System.currentTimeMillis() + 5000; + String content = response.getContentAsString(); + while (System.currentTimeMillis() < deadline && !content.contains(expectedMarker)) { + Thread.sleep(50); + content = response.getContentAsString(); + } + return content; + } + + private static void assertEventOrder(String content, String earlierEvent, String laterEvent) { + int earlier = content.indexOf("event:" + earlierEvent); + int later = content.indexOf("event:" + laterEvent); + assertTrue(earlier >= 0, "Missing event [" + earlierEvent + "] in: " + content); + assertTrue(later > earlier, "Event [" + laterEvent + "] should come after [" + earlierEvent + "] in: " + content); + } + + // ---------- estimateCOdownload ---------- + + @Test + public void testEstimateCODownloadHappyPathSplitsKeyCsv() throws Exception { + String dasJson = "{\"estimated_output_bytes\":12345}"; + when(dasService.estimateCloudOptimisedDownloadSize(any(), any(), any(), any(), any(), any(), any())) + .thenReturn(dasJson); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.KEY.getValue(), "a.zarr, b.zarr"); + inputs.put(DatasetDownloadEnums.Parameter.START_DATE.getValue(), "2023-01-01"); + inputs.put(DatasetDownloadEnums.Parameter.END_DATE.getValue(), "2023-01-31"); + inputs.put(DatasetDownloadEnums.Parameter.MULTI_POLYGON.getValue(), "non-specified"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "netcdf"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_CO_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:estimate-complete"); + + assertEventOrder(content, "connection-established", "estimate-complete"); + assertTrue(content.contains(dasJson), "das JSON should be forwarded unchanged in: " + content); + + @SuppressWarnings("unchecked") + ArgumentCaptor> keysCaptor = ArgumentCaptor.forClass(List.class); + verify(dasService).estimateCloudOptimisedDownloadSize( + eq("test-uuid"), keysCaptor.capture(), eq("2023-01-01"), eq("2023-01-31"), + eq("non-specified"), isNull(), eq("netcdf")); + assertEquals(List.of("a.zarr", "b.zarr"), keysCaptor.getValue(), "key CSV must be split and trimmed"); + } + + @Test + public void testEstimateCODownloadWildcardKeyMeansAllKeys() throws Exception { + when(dasService.estimateCloudOptimisedDownloadSize(any(), any(), any(), any(), any(), any(), any())) + .thenReturn("{}"); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.KEY.getValue(), "*"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "netcdf"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_CO_ESTIMATE.getValue(), inputs); + awaitContent(response, "event:estimate-complete"); + + verify(dasService).estimateCloudOptimisedDownloadSize( + eq("test-uuid"), isNull(), any(), any(), any(), isNull(), eq("netcdf")); + } + + @Test + public void testEstimateCODownloadDasFailureEmitsEstimateFailed() throws Exception { + when(dasService.estimateCloudOptimisedDownloadSize(any(), any(), any(), any(), any(), any(), any())) + .thenThrow(new RuntimeException("das returned 404")); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "netcdf"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_CO_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:estimate-failed"); + + assertEventOrder(content, "connection-established", "estimate-failed"); + assertTrue(content.contains("das returned 404"), "Failure reason should be forwarded in: " + content); + } + + @Test + public void testEstimateCODownloadMissingUuidEmitsError() throws Exception { + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "netcdf"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_CO_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:error"); + + assertTrue(content.contains("event:error"), "Validation failure should emit error event in: " + content); + verifyNoInteractions(dasService); + } + + // ---------- estimateWfsDownload ---------- + + @Test + public void testEstimateWfsDownloadHappyPath() throws Exception { + when(downloadWfsDataService.estimateDownloadSize(any(), any(), any(), any(), any(), any(), any())) + .thenReturn(BigInteger.valueOf(98765)); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.LAYER_NAME.getValue(), "test-layer"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "text/csv"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_WFS_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:estimate-complete"); + + assertEventOrder(content, "connection-established", "estimate-complete"); + assertTrue(content.contains("98765"), "Estimated size should be in the payload: " + content); + verify(downloadWfsDataService).estimateDownloadSize( + eq("test-uuid"), eq("test-layer"), any(), any(), any(), any(), eq("text/csv")); + verify(downloadWfsDataService, never()).streamWfsDataWithSse(any(), any(), any(), any(), any(), any()); + } + + @Test + public void testEstimateWfsDownloadServiceFailureEmitsEstimateFailed() throws Exception { + when(downloadWfsDataService.estimateDownloadSize(any(), any(), any(), any(), any(), any(), any())) + .thenThrow(new RuntimeException("geoserver down")); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.LAYER_NAME.getValue(), "test-layer"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "text/csv"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_WFS_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:estimate-failed"); + + assertEventOrder(content, "connection-established", "estimate-failed"); + } + + @Test + public void testEstimateWfsDownloadUnknownOutputFormatEmitsError() throws Exception { + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.LAYER_NAME.getValue(), "test-layer"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "bogus-format"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_WFS_ESTIMATE.getValue(), inputs); + String content = awaitContent(response, "event:error"); + + assertTrue(content.contains("event:error"), "Unknown output format should emit error event in: " + content); + verifyNoInteractions(downloadWfsDataService); + } + + // ---------- downloadWfs ---------- + + @Test + public void testDownloadWfsRoutesToStreaming() throws Exception { + when(downloadWfsDataService.prepareWfsRequestUrl(any(), any(), any(), any(), any(), any(), any(), anyLong(), anyBoolean())) + .thenReturn("http://geoserver/wfs?request=GetFeature"); + // Complete the stream when the (mocked) streaming call is reached, ending the SSE. + doAnswer(invocation -> { + ((SseEmitter) invocation.getArgument(4)).complete(); + return null; + }).when(downloadWfsDataService).streamWfsDataWithSse(any(), any(), any(), any(), any(), any()); + + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.LAYER_NAME.getValue(), "test-layer"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "text/csv"); + + MockHttpServletResponse response = postSse(ProcessIdEnum.DOWNLOAD_WFS_SSE.getValue(), inputs); + String content = awaitContent(response, "event:connection-established"); + + assertTrue(content.contains("event:connection-established"), "Missing connection event in: " + content); + verify(downloadWfsDataService, timeout(5000)).streamWfsDataWithSse( + eq("http://geoserver/wfs?request=GetFeature"), eq("test-uuid"), eq("test-layer"), + eq("text/csv"), any(SseEmitter.class), any()); + verify(downloadWfsDataService, never()).estimateDownloadSize(any(), any(), any(), any(), any(), any(), any()); + } + + // ---------- dispatch ---------- + + @Test + public void testUnknownProcessIdHitsDefaultBranch() throws Exception { + Map inputs = new HashMap<>(); + inputs.put(DatasetDownloadEnums.Parameter.UUID.getValue(), "test-uuid"); + inputs.put(DatasetDownloadEnums.Parameter.LAYER_NAME.getValue(), "test-layer"); + inputs.put(DatasetDownloadEnums.Parameter.OUTPUT_FORMAT.getValue(), "text/csv"); + + postSse("no-such-process", inputs); + + verifyNoInteractions(downloadWfsDataService, dasService); + } +}