diff --git a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java index d3793c65..08485203 100644 --- a/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java +++ b/server/src/main/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataService.java @@ -6,7 +6,9 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.cache.annotation.Cacheable; +import org.springframework.context.annotation.Lazy; import org.springframework.http.*; import org.springframework.web.client.RestTemplate; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -27,6 +29,10 @@ public class DownloadWfsDataService { protected final ObjectMapper objectMapper; protected static final int SAMPLES_SIZE = 500; // A not too small sample for download size estimation + @Autowired + @Lazy + protected DownloadWfsDataService self; + public DownloadWfsDataService( WfsServer wfsServer, RestTemplate restTemplate, @@ -88,13 +94,61 @@ public String prepareWfsRequestUrl( } /** - * We just need to estimate the download size, the way we do it is issue two query: - * a. Issue a query and get the number or record hit - * b. Issue a query with data download but then limit the records size, and do a liner interpolation + * Unfiltered total feature count for a layer + * Cached per (uuid, layerName) + */ + @Cacheable(CacheConfig.DOWNLOADABLE_SIZE) + public BigInteger getUnfilteredRecordCount(String uuid, String layerName) { + String countUrl = prepareWfsRequestUrl( + uuid, null, null, null, null, layerName, "application/json", 1L, false + ); + + ResponseEntity response = restTemplate.exchange(countUrl, HttpMethod.GET, pretendUserEntity, String.class); + if (response.getStatusCode().is2xxSuccessful() && response.getBody() != null) { + try { + JsonNode root = objectMapper.readTree(response.getBody()); + if (!root.has("totalFeatures")) { + throw new RuntimeException("GeoServer GeoJSON response missing totalFeatures field"); + } + return BigInteger.valueOf(root.get("totalFeatures").asLong()); + } catch (IOException e) { + log.error("Failed to parse unfiltered count response for {}/{}", uuid, layerName, e); + } + } + return null; + } + + /** + * Average bytes per record for the layer in the requested output format. + * Issues an unfiltered sample download so the result can be reused across calls + * Cached per (uuid, layerName, outputFormat). + */ + @Cacheable(CacheConfig.DOWNLOADABLE_SIZE) + public BigInteger getBytesPerRecord(String uuid, String layerName, String outputFormat) { + BigInteger totalCount = self.getUnfilteredRecordCount(uuid, layerName); + if (totalCount == null || totalCount.equals(BigInteger.ZERO)) { + return BigInteger.ZERO; + } + + long sampleSize = totalCount.longValue() < SAMPLES_SIZE ? totalCount.longValue() : SAMPLES_SIZE; + + String sampleUrl = prepareWfsRequestUrl( + uuid, null, null, null, null, layerName, outputFormat, sampleSize, false + ); + + ResponseEntity bytes = restTemplate.exchange(sampleUrl, HttpMethod.GET, pretendUserEntity, byte[].class); + if (bytes.getStatusCode().is2xxSuccessful() && bytes.getBody() != null) { + return BigInteger.valueOf(bytes.getBody().length).divide(BigInteger.valueOf(sampleSize)); + } + return null; + } + + /** + * Estimate download size for the user's subset. Runs the inherently subset-dependent + * count query, then multiplies by the cached bytes-per-record sample. * * @return The estimated file size */ - @Cacheable(CacheConfig.DOWNLOADABLE_SIZE) public BigInteger estimateDownloadSize( String uuid, String layerName, @@ -104,7 +158,7 @@ public BigInteger estimateDownloadSize( List fields, String outputFormat) throws IllegalArgumentException { - // Get total feature count via GeoJSON response + // Subset-filtered count — not cacheable here because the subset would explode the key space. String countUrl = prepareWfsRequestUrl( uuid, startDate, endDate, multiPolygon, fields, layerName, "application/json", 1L, false ); @@ -118,26 +172,17 @@ public BigInteger estimateDownloadSize( throw new RuntimeException("GeoServer GeoJSON response missing totalFeatures field"); } BigInteger featureCount = BigInteger.valueOf(root.get("totalFeatures").asLong()); - log.debug("Total record hits {}", featureCount); + log.debug("Subset record hits {}", featureCount); if (featureCount.equals(BigInteger.ZERO)) { return BigInteger.ZERO; } - // In case the records we have is smaller than our predefined SAMPLES_SIZE, we use smaller one. - long sampleSize = featureCount.longValue() < SAMPLES_SIZE ? featureCount.longValue() : SAMPLES_SIZE; - - // Download a small sample to measure bytes per record in the requested output format - String sampleUrl = prepareWfsRequestUrl( - uuid, startDate, endDate, multiPolygon, fields, layerName, outputFormat, sampleSize, false - ); - - ResponseEntity bytes = restTemplate.exchange(sampleUrl, HttpMethod.GET, pretendUserEntity, byte[].class); - if (bytes.getStatusCode().is2xxSuccessful() && bytes.getBody() != null) { - return featureCount - .multiply(BigInteger.valueOf(bytes.getBody().length)) - .divide(BigInteger.valueOf(sampleSize)); + BigInteger bytesPerRecord = self.getBytesPerRecord(uuid, layerName, outputFormat); + if (bytesPerRecord == null) { + return null; } + return featureCount.multiply(bytesPerRecord); } catch (IOException e) { log.error("Fail to get feature count for estimate", e); } diff --git a/server/src/test/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataServiceTest.java b/server/src/test/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataServiceTest.java index ba6b2d21..505dfe6a 100644 --- a/server/src/test/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataServiceTest.java +++ b/server/src/test/java/au/org/aodn/ogcapi/server/core/service/geoserver/wfs/DownloadWfsDataServiceTest.java @@ -72,6 +72,8 @@ public void setUp() { downloadWfsDataService = new DownloadWfsDataService( wfsServer, restTemplate, pretendUserEntity, 16384, new ObjectMapper() ); + + downloadWfsDataService.self = downloadWfsDataService; } /** @@ -268,6 +270,7 @@ public void verifyRequestUrlGenerateCorrect() { ); assertEquals("https://test.com/geoserver/wfs?VERSION=1.0.0&typeName=test:layer&SERVICE=WFS&REQUEST=GetFeature&outputFormat=shape-zip&cql_filter=((timestamp DURING 2024-01-01T00:00:00Z/2024-12-31T23:59:59Z))", result, "Correct url 1"); } + /** * Make sure the url generated contains the correct polygon * @@ -302,6 +305,7 @@ public void verifyRequestUrlGenerateCorrectWithPolygon() throws JsonProcessingEx result, "Correct url 1"); } + /** * Verify estimate size on success request */ @@ -315,25 +319,28 @@ void shouldReturnEstimatedSizeWhenBothRequestsSucceed() { List fields = List.of("name", "area"); String format = "application/json"; - // 1. Count response: GeoJSON with totalFeatures (1 record requested, but totalFeatures = full count) + // 1. Count response: GeoJSON with totalFeatures (1 record requested, but totalFeatures = full count). + // Returned for BOTH the subset-filtered count and the unfiltered count probe inside + // getBytesPerRecord — both URLs use maxFeatures=1. String countJson = "{\"totalFeatures\": 227193, \"features\": []}"; ResponseEntity countResponse = new ResponseEntity<>(countJson, HttpStatus.OK); - // 2. Sample response (small payload in requested format) - byte[] sampleBytes = "fake data".getBytes(); + // 2. Sample response. Use a payload >= SAMPLES_SIZE so bytesPerRecord = sampleBytes / 500 + // yields a non-zero integer (10 bytes/record here). + byte[] sampleBytes = new byte[DownloadWfsDataService.SAMPLES_SIZE * 10]; ResponseEntity sampleResponse = new ResponseEntity<>(sampleBytes, HttpStatus.OK); doReturn(countResponse) .when(restTemplate).exchange( - argThat((String url) -> url != null && url.contains("maxFeatures=1")), - eq(HttpMethod.GET), - any(HttpEntity.class), - eq(String.class)); + argThat((String url) -> url != null && url.contains("maxFeatures=1")), + eq(HttpMethod.GET), + any(HttpEntity.class), + eq(String.class)); doReturn(sampleResponse) .when(restTemplate).exchange( - argThat((String url) -> url != null && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE)), - eq(HttpMethod.GET), any(), eq(byte[].class)); + argThat((String url) -> url != null && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE)), + eq(HttpMethod.GET), any(), eq(byte[].class)); doReturn(Optional.of("http://dummy.com/wfs")) .when(wfsServer).getFeatureServerUrl(eq(uuid), anyString()); @@ -350,26 +357,45 @@ void shouldReturnEstimatedSizeWhenBothRequestsSucceed() { BigInteger size = downloadWfsDataService.estimateDownloadSize( uuid, layer, start, end, multiPolygon, fields, format); - // Should call with maxFeatures=1 to get totalFeatures count via GeoJSON + // Subset-filtered count (carries the cql_filter built from start/end dates) verify(restTemplate).exchange( - argThat((String url) -> url != null && url.contains("maxFeatures=1") && url.contains("outputFormat=application")), + argThat((String url) -> url != null + && url.contains("maxFeatures=1") + && url.contains("outputFormat=application") + && url.contains("cql_filter")), eq(HttpMethod.GET), any(), eq(String.class) ); - // Should also call with maxFeatures=500 to sample bytes for size interpolation + // Unfiltered count probe issued inside getBytesPerRecord — same maxFeatures=1 + // pattern but without cql_filter. Acceptance criterion: sample/count path ignores subsetting. verify(restTemplate).exchange( - argThat((String url) -> url != null && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE)), + argThat((String url) -> url != null + && url.contains("maxFeatures=1") + && url.contains("outputFormat=application") + && !url.contains("cql_filter")), + eq(HttpMethod.GET), + any(), + eq(String.class) + ); + + // Sample download with maxFeatures=500, also without subset params. + verify(restTemplate).exchange( + argThat((String url) -> url != null + && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE) + && !url.contains("cql_filter")), eq(HttpMethod.GET), any(), eq(byte[].class) ); - // totalFeatures=227193, sampleBytes=9 bytes, SAMPLES_SIZE=500 - long expected = 227193L * sampleBytes.length / DownloadWfsDataService.SAMPLES_SIZE; + // bytesPerRecord = sampleBytes.length / SAMPLES_SIZE; total = featureCount * bytesPerRecord + long bytesPerRecord = sampleBytes.length / DownloadWfsDataService.SAMPLES_SIZE; + long expected = 227193L * bytesPerRecord; assertEquals(BigInteger.valueOf(expected), size, "Size match"); } + @Test void shouldReturnZeroWhenTotalFeaturesIsZero() { String uuid = "lyr-123"; @@ -385,10 +411,10 @@ void shouldReturnZeroWhenTotalFeaturesIsZero() { doReturn(countResponse) .when(restTemplate).exchange( - argThat((String url) -> url != null && url.contains("maxFeatures=1")), - eq(HttpMethod.GET), - any(HttpEntity.class), - eq(String.class)); + argThat((String url) -> url != null && url.contains("maxFeatures=1")), + eq(HttpMethod.GET), + any(HttpEntity.class), + eq(String.class)); doReturn(Optional.of("http://dummy.com/wfs")) .when(wfsServer).getFeatureServerUrl(eq(uuid), anyString()); @@ -490,8 +516,53 @@ void returnsNullWhenParserThrowsException() { .when(wfsServer).getDownloadableFields(eq(uuid), any(WfsServer.WfsFeatureRequest.class)); BigInteger size = downloadWfsDataService.estimateDownloadSize( - uuid, layer, start, end, multiPolygon, fields, format); + uuid, layer, start, end, multiPolygon, fields, format); assertNull(size, "Size should be null when JSON parsing fails"); } + + @Test + void sampleRequestIgnoresSubsetFilter() throws JsonProcessingException { + String uuid = "lyr-123"; + String layer = "test:layer"; + String start = "2024-01-01"; + String end = "2024-12-31"; + Object multiPolygon = new ObjectMapper().readValue( + "{ \"type\": \"MultiPolygon\", \"coordinates\": [[[[0,0],[1,0],[1,1],[0,1],[0,0]]]] }", + HashMap.class + ); + List fields = List.of("name", "area"); + String format = "text/csv"; + + String countJson = "{\"totalFeatures\": 1000, \"features\": []}"; + ResponseEntity countResponse = new ResponseEntity<>(countJson, HttpStatus.OK); + byte[] sampleBytes = new byte[DownloadWfsDataService.SAMPLES_SIZE * 4]; + ResponseEntity sampleResponse = new ResponseEntity<>(sampleBytes, HttpStatus.OK); + + doReturn(countResponse) + .when(restTemplate).exchange( + argThat((String url) -> url != null && url.contains("maxFeatures=1")), + eq(HttpMethod.GET), any(HttpEntity.class), eq(String.class)); + doReturn(sampleResponse) + .when(restTemplate).exchange( + argThat((String url) -> url != null + && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE)), + eq(HttpMethod.GET), any(), eq(byte[].class)); + + doReturn(Optional.of("http://dummy.com/wfs")) + .when(wfsServer).getFeatureServerUrl(eq(uuid), anyString()); + doReturn(createTestWFSFieldModel()) + .when(wfsServer).getDownloadableFields(eq(uuid), any(WfsServer.WfsFeatureRequest.class)); + + downloadWfsDataService.estimateDownloadSize(uuid, layer, start, end, multiPolygon, fields, format); + + // Sample URL must NOT carry the subset filter (no cql_filter, no DURING, no INTERSECTS). + verify(restTemplate).exchange( + argThat((String url) -> url != null + && url.contains("maxFeatures=" + DownloadWfsDataService.SAMPLES_SIZE) + && !url.contains("cql_filter") + && !url.contains("DURING") + && !url.contains("INTERSECTS")), + eq(HttpMethod.GET), any(), eq(byte[].class)); + } }