From 6f8eb5bda10d3346789faab7531e29cafd00a87c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesse=20Tu=C4=9Flu?= Date: Sun, 3 May 2026 22:52:08 -0700 Subject: [PATCH] fix: ensure query timeout is respected at router --- .../server/AsyncQueryForwardingServlet.java | 50 +++++++- .../AsyncQueryForwardingServletTest.java | 112 ++++++++++++++++++ 2 files changed, 159 insertions(+), 3 deletions(-) diff --git a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java index 4054e6249b1e..866960d5083d 100644 --- a/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java +++ b/services/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java @@ -44,6 +44,7 @@ import org.apache.druid.query.DruidMetrics; import org.apache.druid.query.GenericQueryMetricsFactory; import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryMetrics; import org.apache.druid.query.QueryToolChestWarehouse; @@ -435,9 +436,6 @@ protected void sendProxyRequest( Request proxyRequest ) { - proxyRequest.timeout(httpClientConfig.getReadTimeout().getMillis(), TimeUnit.MILLISECONDS); - proxyRequest.idleTimeout(httpClientConfig.getReadTimeout().getMillis(), TimeUnit.MILLISECONDS); - byte[] avaticaQuery = (byte[]) clientRequest.getAttribute(AVATICA_QUERY_ATTRIBUTE); if (avaticaQuery != null) { proxyRequest.body(new BytesRequestContent(avaticaQuery)); @@ -451,6 +449,10 @@ protected void sendProxyRequest( setProxyRequestContent(proxyRequest, clientRequest, sqlQuery); } + final long proxyTimeoutMillis = resolveProxyTimeoutMillis(query, sqlQuery); + proxyRequest.timeout(proxyTimeoutMillis, TimeUnit.MILLISECONDS); + proxyRequest.idleTimeout(proxyTimeoutMillis, TimeUnit.MILLISECONDS); + // Since we can't see the request object on the remote side, we can't check whether the remote side actually // performed an authorization check here, so always set this to true for the proxy servlet. // If the remote node failed to perform an authorization check, PreResponseAuthorizationCheckFilter @@ -480,6 +482,48 @@ protected void sendProxyRequest( ); } + /** + * Resolves the proxy request timeout as min(query timeout, druid.router.http.readTimeout). + * Falls back to readTimeout when no per-query timeout is set or the value is unusable + * (Avatica requests, missing/invalid context value, or {@link QueryContexts#NO_TIMEOUT}). + */ + @VisibleForTesting + long resolveProxyTimeoutMillis(@Nullable Query query, @Nullable SqlQuery sqlQuery) + { + final long readTimeoutMillis = httpClientConfig.getReadTimeout().getMillis(); + final Long queryTimeoutMillis; + if (query != null) { + long t = query.context().getTimeout(QueryContexts.NO_TIMEOUT); + queryTimeoutMillis = t > 0 ? t : null; + } else if (sqlQuery != null) { + queryTimeoutMillis = extractSqlTimeoutMillis(sqlQuery); + } else { + queryTimeoutMillis = null; + } + return queryTimeoutMillis == null ? readTimeoutMillis : Math.min(queryTimeoutMillis, readTimeoutMillis); + } + + @Nullable + private static Long extractSqlTimeoutMillis(SqlQuery sqlQuery) + { + Object raw = sqlQuery.getContext().get(QueryContexts.TIMEOUT_KEY); + if (raw == null) { + return null; + } + try { + final long t; + if (raw instanceof Number) { + t = ((Number) raw).longValue(); + } else { + t = Long.parseLong(raw.toString()); + } + return t > 0 ? t : null; + } + catch (NumberFormatException ignored) { + return null; + } + } + private void setProxyRequestContent(Request proxyRequest, HttpServletRequest clientRequest, Object content) { final ObjectMapper objectMapper = (ObjectMapper) clientRequest.getAttribute(OBJECTMAPPER_ATTRIBUTE); diff --git a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java index c3b50514e826..1fbcc8f6ea8d 100644 --- a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java +++ b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java @@ -1206,6 +1206,118 @@ private static Map asMap(String json, ObjectMapper mapper) throw return mapper.readValue(json, JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT); } + @Test + public void testResolveProxyTimeoutMillis() + { + final long readTimeoutMillis = 900_000L; + final DruidHttpClientConfig httpClientConfig = Mockito.mock(DruidHttpClientConfig.class); + Mockito.when(httpClientConfig.getReadTimeout()).thenReturn(org.joda.time.Duration.millis(readTimeoutMillis)); + + final AsyncQueryForwardingServlet servlet = new AsyncQueryForwardingServlet( + new MapQueryToolChestWarehouse(ImmutableMap.of()), + TestHelper.makeJsonMapper(), + TestHelper.makeSmileMapper(), + null, + null, + httpClientConfig, + NoopServiceEmitter.instance(), + NoopRequestLogger.instance(), + new DefaultGenericQueryMetricsFactory(), + new AuthenticatorMapper(ImmutableMap.of()), + new Properties(), + new ServerConfig() + ); + + // No query, no sqlQuery -> readTimeout + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(null, null)); + + // Native query with shorter timeout -> query timeout wins + final TimeseriesQuery shortQuery = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2000/3000") + .granularity(Granularities.ALL) + .context(ImmutableMap.of("timeout", 30_000)) + .build(); + Assert.assertEquals(30_000L, servlet.resolveProxyTimeoutMillis(shortQuery, null)); + + // Native query with longer timeout -> readTimeout wins + final TimeseriesQuery longQuery = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2000/3000") + .granularity(Granularities.ALL) + .context(ImmutableMap.of("timeout", 1_800_000)) + .build(); + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(longQuery, null)); + + // Native query with no timeout context -> readTimeout + final TimeseriesQuery noTimeoutQuery = Druids.newTimeseriesQueryBuilder() + .dataSource("test") + .intervals("2000/3000") + .granularity(Granularities.ALL) + .build(); + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(noTimeoutQuery, null)); + + // SQL query with shorter timeout (Number) -> query timeout wins + final SqlQuery shortSql = new SqlQuery( + "SELECT 1", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of("timeout", 45_000), + null + ); + Assert.assertEquals(45_000L, servlet.resolveProxyTimeoutMillis(null, shortSql)); + + // SQL query with timeout as String -> parsed + final SqlQuery stringSql = new SqlQuery( + "SELECT 1", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of("timeout", "60000"), + null + ); + Assert.assertEquals(60_000L, servlet.resolveProxyTimeoutMillis(null, stringSql)); + + // SQL query with longer timeout -> readTimeout wins + final SqlQuery longSql = new SqlQuery( + "SELECT 1", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of("timeout", 1_800_000), + null + ); + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(null, longSql)); + + // SQL query with invalid timeout -> readTimeout + final SqlQuery invalidSql = new SqlQuery( + "SELECT 1", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of("timeout", "not-a-number"), + null + ); + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(null, invalidSql)); + + // SQL query with no context -> readTimeout + final SqlQuery noCtxSql = new SqlQuery( + "SELECT 1", + ResultFormat.OBJECT, + false, + false, + false, + ImmutableMap.of(), + null + ); + Assert.assertEquals(readTimeoutMillis, servlet.resolveProxyTimeoutMillis(null, noCtxSql)); + } + private static class TestServer implements org.apache.druid.client.selector.Server {