From 56dfb127cc6774f893108b8bba2e987b38c2a49e Mon Sep 17 00:00:00 2001 From: Jesse Tuglu Date: Tue, 21 Jan 2025 22:04:10 -0800 Subject: [PATCH] Avoid OOM-killing query if large result-level cache population fails for query Currently, result-level caching which attempts to allocate a large enough buffer to store query results will overflow the Integer.MAX_INT capacity. ByteArrayOutputStream materializes this case as an OutOfMemoryError, which is not caught and terminates the node. This limits the allocated buffer for storing query results to whatever is set in `CacheConfig.getResultLevelCacheLimit()`. --- .../apache/druid/io/LimitedOutputStream.java | 14 ++-- .../query/ResultLevelCachingQueryRunner.java | 25 ++++--- .../ResultLevelCachingQueryRunnerTest.java | 66 ++++++++++++++++--- 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java index 6d27abb427394..f2d7112a68d7b 100644 --- a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java +++ b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java @@ -24,18 +24,19 @@ import java.io.IOException; import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; /** * An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit - * is exceeded. + * is exceeded. *Not* thread-safe. */ public class LimitedOutputStream extends OutputStream { private final OutputStream out; private final long limit; private final Function exceptionMessageFn; - long written; + AtomicLong written; /** * Create a bytes-limited output stream. @@ -48,6 +49,7 @@ public LimitedOutputStream(OutputStream out, long limit, Function { this.out = out; this.limit = limit; + this.written = new AtomicLong(0); this.exceptionMessageFn = exceptionMessageFn; if (limit < 0) { @@ -88,10 +90,14 @@ public void close() throws IOException out.close(); } + public OutputStream get() + { + return out; + } + private void plus(final int n) throws IOException { - written += n; - if (written > limit) { + if (written.addAndGet(n) > limit) { throw new IOE(exceptionMessageFn.apply(limit)); } } diff --git a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java index dedfb0028b777..6fed344dccc02 100644 --- a/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java +++ b/server/src/main/java/org/apache/druid/query/ResultLevelCachingQueryRunner.java @@ -30,6 +30,7 @@ import org.apache.druid.client.cache.Cache; import org.apache.druid.client.cache.Cache.NamedKey; import org.apache.druid.client.cache.CacheConfig; +import org.apache.druid.io.LimitedOutputStream; import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.guava.Sequence; @@ -152,6 +153,8 @@ public void after(boolean isDone, Throwable thrown) // The resultset identifier and its length is cached along with the resultset resultLevelCachePopulator.populateResults(); log.debug("Cache population complete for query %s", query.getId()); + } else { // thrown == null && !resultLevelCachePopulator.isShouldPopulate() + log.error("Failed (and recovered) to populate result level cache for query %s", query.getId()); } resultLevelCachePopulator.stopPopulating(); } @@ -233,8 +236,8 @@ private ResultLevelCachePopulator createResultLevelCachePopulator( try { // Save the resultSetId and its length resultLevelCachePopulator.cacheObjectStream.write(ByteBuffer.allocate(Integer.BYTES) - .putInt(resultSetId.length()) - .array()); + .putInt(resultSetId.length()) + .array()); resultLevelCachePopulator.cacheObjectStream.write(StringUtils.toUtf8(resultSetId)); } catch (IOException ioe) { @@ -255,7 +258,7 @@ private class ResultLevelCachePopulator private final Cache.NamedKey key; private final CacheConfig cacheConfig; @Nullable - private ByteArrayOutputStream cacheObjectStream; + private LimitedOutputStream cacheObjectStream; private ResultLevelCachePopulator( Cache cache, @@ -270,7 +273,14 @@ private ResultLevelCachePopulator( this.serialiers = mapper.getSerializerProviderInstance(); this.key = key; this.cacheConfig = cacheConfig; - this.cacheObjectStream = shouldPopulate ? new ByteArrayOutputStream() : null; + this.cacheObjectStream = shouldPopulate ? new LimitedOutputStream( + new ByteArrayOutputStream(), + cacheConfig.getResultLevelCacheLimit(), limit -> StringUtils.format( + "resultLevelCacheLimit[%,d] exceeded. " + + "Max ResultLevelCacheLimit for cache exceeded. Result caching failed.", + limit + ) + ) : null; } boolean isShouldPopulate() @@ -289,12 +299,8 @@ private void cacheResultEntry( ) { Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream"); - int cacheLimit = cacheConfig.getResultLevelCacheLimit(); try (JsonGenerator gen = mapper.getFactory().createGenerator(cacheObjectStream)) { JacksonUtils.writeObjectUsingSerializerProvider(gen, serialiers, cacheFn.apply(resultEntry)); - if (cacheLimit > 0 && cacheObjectStream.size() > cacheLimit) { - stopPopulating(); - } } catch (IOException ex) { log.error(ex, "Failed to retrieve entry to be cached. Result Level caching will not be performed!"); @@ -304,7 +310,8 @@ private void cacheResultEntry( public void populateResults() { - byte[] cachedResults = Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream").toByteArray(); + byte[] cachedResults = ((ByteArrayOutputStream) Preconditions.checkNotNull(cacheObjectStream, "cacheObjectStream") + .get()).toByteArray(); cache.put(key, cachedResults); } } diff --git a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java index 6245509465c10..3cb4ae528e67a 100644 --- a/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java +++ b/server/src/test/java/org/apache/druid/query/ResultLevelCachingQueryRunnerTest.java @@ -39,6 +39,7 @@ public class ResultLevelCachingQueryRunnerTest extends QueryRunnerBasedOnClusteredClientTestBase { private Cache cache; + private static final int DEFAULT_CACHE_ENTRY_MAX_SIZE = Integer.MAX_VALUE; @Before public void setup() @@ -58,7 +59,7 @@ public void testNotPopulateAndNotUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -72,7 +73,7 @@ public void testNotPopulateAndNotUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -93,7 +94,7 @@ public void testPopulateAndNotUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -107,7 +108,7 @@ public void testPopulateAndNotUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -128,7 +129,7 @@ public void testNotPopulateAndUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(false, false), + newCacheConfig(false, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -142,7 +143,7 @@ public void testNotPopulateAndUse() Assert.assertEquals(0, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(false, true), + newCacheConfig(false, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -163,7 +164,7 @@ public void testPopulateAndUse() prepareCluster(10); final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( - newCacheConfig(true, true), + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -177,7 +178,7 @@ public void testPopulateAndUse() Assert.assertEquals(1, cache.getStats().getNumMisses()); final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( - newCacheConfig(true, true), + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -192,6 +193,41 @@ public void testPopulateAndUse() Assert.assertEquals(1, cache.getStats().getNumMisses()); } + @Test + public void testNoPopulateIfEntrySizeExceedsMaximum() + { + prepareCluster(10); + final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); + final ResultLevelCachingQueryRunner> queryRunner1 = createQueryRunner( + newCacheConfig(true, true, 128), + query + ); + + final Sequence> sequence1 = queryRunner1.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results1 = sequence1.toList(); + Assert.assertEquals(0, cache.getStats().getNumHits()); + Assert.assertEquals(0, cache.getStats().getNumEntries()); + Assert.assertEquals(1, cache.getStats().getNumMisses()); + + final ResultLevelCachingQueryRunner> queryRunner2 = createQueryRunner( + newCacheConfig(true, true, DEFAULT_CACHE_ENTRY_MAX_SIZE), + query + ); + + final Sequence> sequence2 = queryRunner2.run( + QueryPlus.wrap(query), + responseContext() + ); + final List> results2 = sequence2.toList(); + Assert.assertEquals(results1, results2); + Assert.assertEquals(0, cache.getStats().getNumHits()); + Assert.assertEquals(1, cache.getStats().getNumEntries()); + Assert.assertEquals(2, cache.getStats().getNumMisses()); + } + @Test public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() { @@ -206,7 +242,7 @@ public void testPopulateCacheWhenQueryThrowExceptionShouldNotCache() final Query> query = timeseriesQuery(BASE_SCHEMA_INFO.getDataInterval()); final ResultLevelCachingQueryRunner> queryRunner = createQueryRunner( - newCacheConfig(true, false), + newCacheConfig(true, false, DEFAULT_CACHE_ENTRY_MAX_SIZE), query ); @@ -249,7 +285,11 @@ private ResultLevelCachingQueryRunner createQueryRunner( ); } - private CacheConfig newCacheConfig(boolean populateResultLevelCache, boolean useResultLevelCache) + private CacheConfig newCacheConfig( + boolean populateResultLevelCache, + boolean useResultLevelCache, + int resultLevelCacheLimit + ) { return new CacheConfig() { @@ -264,6 +304,12 @@ public boolean isUseResultLevelCache() { return useResultLevelCache; } + + @Override + public int getResultLevelCacheLimit() + { + return resultLevelCacheLimit; + } }; } }