From 648faf146542e7d067d9531976b883397c325f61 Mon Sep 17 00:00:00 2001 From: Jesse Tuglu Date: Tue, 21 Jan 2025 22:04:10 -0800 Subject: [PATCH] Avoid OOM-killing query if large result-level cache population fails for query Currently, result-level caching which attempts to allocate a large enough buffer to store query results will overflow the Integer.MAX_INT capacity. ByteArrayOutputStream materializes this case as an OutOfMemoryError, which is not caught and terminates the node. This limits the allocated buffer for storing query results to whatever is set in `CacheConfig.getResultLevelCacheLimit()`. --- .../apache/druid/io/LimitedOutputStream.java | 7 +- .../query/ResultLevelCachingQueryRunner.java | 25 ++++--- .../ResultLevelCachingQueryRunnerTest.java | 66 ++++++++++++++++--- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java index 6d27abb42739..fd5691c1bb3c 100644 --- a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java +++ b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java @@ -28,7 +28,7 @@ /** * An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit - * is exceeded. + * is exceeded. *Not* thread-safe. */ public class LimitedOutputStream extends OutputStream { @@ -88,6 +88,11 @@ public void close() throws IOException out.close(); } + public OutputStream get() + { + return out; + } + private void plus(final int n) throws IOException { written += n; diff --git a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java index dedfb0028b77..aa531f93c922 100644 --- a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java +++ b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java @@ -30,6 +30,7 @@ import org.apache.druid.client.cache.Cache; import org.apache.druid.client.cache.Cache.NamedKey; import org.apache.druid.client.cache.CacheConfig; +import org.apache.druid.io.LimitedOutputStream; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Sequence; @@ -152,6 +153,8 @@ public void after(boolean isDone, Throwable thrown) // The resultset identifier and its length is cached along with the resultset resultLevelCachePopulator.populateResults(); log.debug("Cache population complete for query %s", query.getId()); + } else { // thrown == null && !resultLevelCachePopulator.isShouldPopulate() + log.error("Failed (gracefully) to populate result level cache for query %s", query.getId()); } resultLevelCachePopulator.stopPopulating(); } @@ -233,8 +236,8 @@ private ResultLevelCachePopulator createResultLevelCachePopulator( try { // Save the resultSetId and its length resultLevelCachePopulator.cacheObjectStream.write(ByteBuffer.allocate(Integer.BYTES) - .putInt(resultSetId.length()) - .array()); + .putInt(resultSetId.length()) + .array()); resultLevelCachePopulator.cacheObjectStream.write(StringUtils.toUtf8(resultSetId)); } catch (IOException ioe) { @@ -255,7 +258,7 @@ private class ResultLevelCachePopulator private final Cache.NamedKey key; private final CacheConfig cacheConfig; @Nullable - private ByteArrayOutputStream cacheObjectStream; + private LimitedOutputStream cacheObjectStream; private ResultLevelCachePopulator( Cache cache, @@ -270,7 +273,14 @@ private ResultLevelCachePopulator( this.serialiers = mapper.getSerializerProviderInstance(); this.key = key; this.cacheConfig = cacheConfig; - this.cacheObjectStream = shouldPopulate ? new ByteArrayOutputStream() : null; + this.cacheObjectStream = shouldPopulate ? new LimitedOutputStream( + new ByteArrayOutputStream(), + cacheConfig.getResultLevelCacheLimit(), limit -> StringUtils.format( + "resultLevelCacheLimit[%,d] exceeded. " + + "Max ResultLevelCacheLimit for cache exceeded. Result caching failed.", + limit + ) + ) : null; } boolean isShouldPopulate() @@ -289,12 +299,8 @@ private void cacheResultEntry( ) { Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream"); - int cacheLimit = cacheConfig.getResultLevelCacheLimit(); try (JsonGenerator gen = mapper.getFactory().createGenerator(cacheObjectStream)) { JacksonUtils.writeObjectUsingSerializerProvider(gen, serialiers, cacheFn.apply(resultEntry)); - if (cacheLimit > 0 && cacheObjectStream.size() > cacheLimit) { - stopPopulating(); - } } catch (IOException ex) { log.error(ex, "Failed to retrieve entry to be cached. Result Level caching will not be performed!"); @@ -304,7 +310,8 @@ private void cacheResultEntry( public void populateResults() { - byte[] cachedResults = Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream").toByteArray(); + byte[] cachedResults = ((ByteArrayOutputStream) Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream") + .get()).toByteArray(); cache.put(key, cachedResults); } } diff --git a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java index 6245509465c1..3cb4ae528e67 100644 --- a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java +++ b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java @@ -39,6 +39,7 @@ public class ResultLevelCachingQueryRunnerTest extends QueryRunnerBasedOnClusteredClientTestBase { private Cache cache; + private static final int DEFAULT_CACHE_ENTRY_MAX_SIZE = Integer.MAX_VALUE; @Before public void setup() @@ -58,7 +59,7 @@ public void testNotPopulateAndNotUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -72,7 +73,7 @@ public void testNotPopulateAndNotUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -93,7 +94,7 @@ public void testPopulateAndNotUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -107,7 +108,7 @@ public void testPopulateAndNotUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -128,7 +129,7 @@ public void testNotPopulateAndUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -142,7 +143,7 @@ public void testNotPopulateAndUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(false, true), + newCacheConfig(false, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -163,7 +164,7 @@ public void testPopulateAndUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(true, true), + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -177,7 +178,7 @@ public void testPopulateAndUse() Assert.assertEquals(1, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(true, true), + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -192,6 +193,41 @@ public void testPopulateAndUse() Assert.assertEquals(1, cache.getStats().getNumMisses()); } + @Test + public void testNoPopulateIfEntrySizeExceedsMaximum() + { + prepareCluster(10); + final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); + final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( + newCacheConfig(true, true, 128), + query + ); + + final Sequence> sequence1 = queryRunner1.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results1 = sequence1.toList(); + Assert.assertEquals(0, cache.getStats().getNumHits()); + Assert.assertEquals(0, cache.getStats().getNumEntries()); + Assert.assertEquals(1, cache.getStats().getNumMisses()); + + final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), + query + ); + + final Sequence> sequence2 = queryRunner2.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results2 = sequence2.toList(); + Assert.assertEquals(results1, results2); + Assert.assertEquals(0, cache.getStats().getNumHits()); + Assert.assertEquals(1, cache.getStats().getNumEntries()); + Assert.assertEquals(2, cache.getStats().getNumMisses()); + } + @Test public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() { @@ -206,7 +242,7 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -249,7 +285,11 @@ private ResultLevelCachingQueryRunner createQueryRunner( ); } - private CacheConfig newCacheConfig(boolean populateResultLevelCache, boolean useResultLevelCache) + private CacheConfig newCacheConfig( + boolean populateResultLevelCache, + boolean useResultLevelCache, + int resultLevelCacheLimit + ) { return new CacheConfig() { @@ -264,6 +304,12 @@ public boolean isUseResultLevelCache() { return useResultLevelCache; } + + @Override + public int getResultLevelCacheLimit() + { + return resultLevelCacheLimit; + } }; } }