diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStream.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStream.java new file mode 100644 index 000000000000..8d20bea534d4 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStream.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae; + +import javax.annotation.Nullable; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; + +/** + * OutputStream that starts buffering in a heap byte array and switches to a disk file via + * {@link LimitedTemporaryStorage} once the written bytes exceed the threshold. This avoids + * the createFile/delete round-trip for small spills while bounding peak extra heap to the + * threshold size. + */ +public class SpillOutputStream extends OutputStream +{ + private static final int INITIAL_BUFFER_SIZE = 4096; + + private final LimitedTemporaryStorage temporaryStorage; + private final long threshold; + @Nullable + private ByteArrayOutputStream memoryBuffer; + private LimitedTemporaryStorage.LimitedOutputStream fileOut; + private boolean thresholdExceeded; + + SpillOutputStream(LimitedTemporaryStorage temporaryStorage, long threshold) + { + this.temporaryStorage = temporaryStorage; + this.threshold = threshold; + this.memoryBuffer = new ByteArrayOutputStream((int) Math.min(threshold, INITIAL_BUFFER_SIZE)); + } + + @Override + public void write(int b) throws IOException + { + checkThreshold(1); + if (fileOut != null) { + fileOut.write(b); + } else { + memoryBuffer.write(b); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + checkThreshold(len); + if (fileOut != null) { + fileOut.write(b, off, len); + } else { + memoryBuffer.write(b, off, len); + } + } + + @Override + public void flush() throws IOException + { + if (fileOut != null) { + fileOut.flush(); + } + } + + @Override + public void close() throws IOException + { + if (fileOut != null) { + fileOut.close(); + } + } + + boolean isInMemory() + { + return fileOut == null; + } + + byte[] toByteArray() + { + return memoryBuffer.toByteArray(); + } + + File getFile() + { + return fileOut.getFile(); + } + + private void checkThreshold(int count) throws IOException + { + if (!thresholdExceeded && memoryBuffer.size() + count > threshold) { + thresholdExceeded = true; + switchToDisk(); + } + } + + private void switchToDisk() throws IOException + { + final LimitedTemporaryStorage.LimitedOutputStream out = temporaryStorage.createFile(); + try { + memoryBuffer.writeTo(out); + } + catch (IOException e) { + out.close(); + throw e; + } + fileOut = out; + memoryBuffer = null; + } +} diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouper.java index 7af71f926737..96e55907b21f 100644 --- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouper.java +++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouper.java @@ -47,6 +47,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.file.Files; import java.util.ArrayList; @@ -374,28 +375,24 @@ public CloseableIterator> iterator(final boolean sorted) private void spill() throws IOException { - // Stream directly to a temp file first, then check the file size. If the file is small - // (serialized size much smaller than the pre-allocated buffer, e.g. HLL sketches in List mode), - // read it back into memory for batching to avoid creating thousands of tiny disk files. - // If the file is already large enough, keep it on disk as-is. - final File file; + final SpillOutputStream spillOut = new SpillOutputStream(temporaryStorage, minSpillFileSize); try (CloseableIterator> iterator = grouper.iterator(true)) { - file = spill(iterator); + serializeToStream(iterator, spillOut); } + pendingDictionaryEntries.addAll(keySerde.getDictionary()); grouper.reset(); - final long fileSize = file.length(); - if (fileSize < minSpillFileSize) { - pendingSpillRuns.add(Files.readAllBytes(file.toPath())); - pendingSpillBytes += fileSize; - temporaryStorage.delete(file); + if (spillOut.isInMemory()) { + final byte[] bytes = spillOut.toByteArray(); + pendingSpillRuns.add(bytes); + pendingSpillBytes += bytes.length; if (pendingSpillBytes >= minSpillFileSize) { flushPendingRunsToDisk(); } } else { - files.add(file); + files.add(spillOut.getFile()); dictionaryFiles.add(spill(pendingDictionaryEntries.iterator())); pendingDictionaryEntries.clear(); } @@ -483,20 +480,24 @@ public Entry apply(Entry entry) ); } - private File spill(Iterator iterator) throws IOException + private void serializeToStream(Iterator iterator, OutputStream out) throws IOException { try ( - final LimitedTemporaryStorage.LimitedOutputStream out = temporaryStorage.createFile(); final LZ4BlockOutputStream compressedOut = new LZ4BlockOutputStream(out); final JsonGenerator jsonGenerator = spillMapper.getFactory().createGenerator(compressedOut) ) { final SerializerProvider serializers = spillMapper.getSerializerProviderInstance(); - while (iterator.hasNext()) { BaseQuery.checkInterrupted(); JacksonUtils.writeObjectUsingSerializerProvider(jsonGenerator, serializers, iterator.next()); } + } + } + private File spill(Iterator iterator) throws IOException + { + try (final LimitedTemporaryStorage.LimitedOutputStream out = temporaryStorage.createFile()) { + serializeToStream(iterator, out); return out.getFile(); } } diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStreamTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStreamTest.java new file mode 100644 index 000000000000..66a19b0a8a08 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillOutputStreamTest.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.query.groupby.epinephelinae; + +import org.apache.druid.query.groupby.GroupByStatsProvider; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; + +public class SpillOutputStreamTest +{ + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + @Test + public void testSmallWriteStaysInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + out.write(new byte[]{1, 2, 3}); + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[]{1, 2, 3}, out.toByteArray()); + } + } + + @Test + public void testExactlyAtThresholdStaysInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(4)) { + out.write(new byte[]{1, 2, 3, 4}); + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[]{1, 2, 3, 4}, out.toByteArray()); + } + } + + @Test + public void testExceedingThresholdSwitchesToDisk() throws IOException + { + try (SpillOutputStream out = makeStream(4)) { + out.write(new byte[]{1, 2, 3, 4, 5}); + Assert.assertFalse(out.isInMemory()); + Assert.assertTrue(out.getFile().exists()); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1, 2, 3, 4, 5}, fileContent); + } + } + + @Test + public void testSwitchesToDiskOnSecondWrite() throws IOException + { + try (SpillOutputStream out = makeStream(4)) { + out.write(new byte[]{1, 2}); + Assert.assertTrue(out.isInMemory()); + + out.write(new byte[]{3, 4, 5}); + Assert.assertFalse(out.isInMemory()); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1, 2, 3, 4, 5}, fileContent); + } + } + + @Test + public void testSingleByteWriteStaysInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + out.write(42); + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[]{42}, out.toByteArray()); + } + } + + @Test + public void testSingleByteWriteTriggersSwitch() throws IOException + { + try (SpillOutputStream out = makeStream(2)) { + out.write(1); + out.write(2); + Assert.assertTrue(out.isInMemory()); + + out.write(3); + Assert.assertFalse(out.isInMemory()); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1, 2, 3}, fileContent); + } + } + + @Test + public void testDataIntegrityAcrossSwitch() throws IOException + { + try (SpillOutputStream out = makeStream(10)) { + byte[] beforeSwitch = new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + byte[] afterSwitch = new byte[]{11, 12, 13, 14, 15}; + out.write(beforeSwitch); + Assert.assertTrue(out.isInMemory()); + + out.write(afterSwitch); + Assert.assertFalse(out.isInMemory()); + out.flush(); + + byte[] expected = new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(expected, fileContent); + } + } + + @Test + public void testWriteWithOffsetAndLength() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + byte[] data = new byte[]{0, 0, 1, 2, 3, 0, 0}; + out.write(data, 2, 3); + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[]{1, 2, 3}, out.toByteArray()); + } + } + + @Test + public void testWriteWithOffsetAndLengthTriggersDiskSwitch() throws IOException + { + try (SpillOutputStream out = makeStream(2)) { + byte[] data = new byte[]{0, 1, 2, 3, 0}; + out.write(data, 1, 3); + Assert.assertFalse(out.isInMemory()); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1, 2, 3}, fileContent); + } + } + + @Test + public void testLargeWrite() throws IOException + { + try (SpillOutputStream out = makeStream(100)) { + byte[] data = new byte[10_000]; + Arrays.fill(data, (byte) 0xAB); + out.write(data); + Assert.assertFalse(out.isInMemory()); + out.flush(); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(data, fileContent); + } + } + + @Test + public void testZeroThresholdAlwaysGoesToDisk() throws IOException + { + try (SpillOutputStream out = makeStream(0)) { + out.write(new byte[]{1}); + Assert.assertFalse(out.isInMemory()); + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1}, fileContent); + } + } + + @Test + public void testEmptyStreamIsInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[0], out.toByteArray()); + } + } + + @Test + public void testMultipleWritesAccumulateInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + out.write(new byte[]{1, 2}); + out.write(new byte[]{3, 4}); + out.write(5); + Assert.assertTrue(out.isInMemory()); + Assert.assertArrayEquals(new byte[]{1, 2, 3, 4, 5}, out.toByteArray()); + } + } + + @Test + public void testMultipleWritesAfterDiskSwitch() throws IOException + { + try (SpillOutputStream out = makeStream(4)) { + out.write(new byte[]{1, 2, 3, 4, 5}); + Assert.assertFalse(out.isInMemory()); + + out.write(new byte[]{6, 7}); + out.write(8); + out.flush(); + + byte[] fileContent = Files.readAllBytes(out.getFile().toPath()); + Assert.assertArrayEquals(new byte[]{1, 2, 3, 4, 5, 6, 7, 8}, fileContent); + } + } + + @Test + public void testDiskStorageBytesTracked() throws IOException + { + LimitedTemporaryStorage storage = makeStorage(1024 * 1024); + + try (SpillOutputStream out = new SpillOutputStream(storage, 4)) { + out.write(new byte[]{1, 2, 3, 4, 5}); + Assert.assertFalse(out.isInMemory()); + out.flush(); + Assert.assertTrue(storage.currentSize() > 0); + } + } + + @Test(expected = NullPointerException.class) + public void testToByteArrayThrowsAfterDiskSwitch() throws IOException + { + try (SpillOutputStream out = makeStream(4)) { + out.write(new byte[]{1, 2, 3, 4, 5}); + Assert.assertFalse(out.isInMemory()); + out.toByteArray(); + } + } + + @Test(expected = NullPointerException.class) + public void testGetFileThrowsWhenInMemory() throws IOException + { + try (SpillOutputStream out = makeStream(1024)) { + out.write(new byte[]{1, 2, 3}); + Assert.assertTrue(out.isInMemory()); + out.getFile(); + } + } + + @Test(expected = TemporaryStorageFullException.class) + public void testDiskStorageLimitEnforced() throws IOException + { + LimitedTemporaryStorage storage = makeStorage(10); + + try (SpillOutputStream out = new SpillOutputStream(storage, 4)) { + byte[] data = new byte[100]; + Arrays.fill(data, (byte) 1); + out.write(data); + } + } + + private SpillOutputStream makeStream(long threshold) throws IOException + { + return new SpillOutputStream(makeStorage(1024 * 1024), threshold); + } + + private LimitedTemporaryStorage makeStorage(long maxBytes) throws IOException + { + return new LimitedTemporaryStorage( + temporaryFolder.newFolder(), + maxBytes, + 100, + new GroupByStatsProvider.PerQueryStats() + ); + } +} diff --git a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouperTest.java b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouperTest.java index 0fa77f988ada..3ec121568234 100644 --- a/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouperTest.java +++ b/processing/src/test/java/org/apache/druid/query/groupby/epinephelinae/SpillingGrouperTest.java @@ -286,10 +286,43 @@ public void testResetClearsPendingState() throws IOException } } + @Test + public void testSmallSpillsStayInMemoryUntilFlush() throws IOException + { + final File storageDir = temporaryFolder.newFolder(); + final LimitedTemporaryStorage temporaryStorage = + new LimitedTemporaryStorage(storageDir, 1024 * 1024, 100, new GroupByStatsProvider.PerQueryStats()); + + // Use a large minSpillFileSize so individual spills (which are tiny with a 50-byte buffer) + // stay in memory via SpillOutputStream and never create temp files. + final long largeMinSpillFileSize = 1024 * 1024L; + try (SpillingGrouper grouper = makeGrouper(50, temporaryStorage, largeMinSpillFileSize)) { + // Aggregate enough keys to trigger multiple spills, but not enough to exceed + // minSpillFileSize in total pending bytes. + for (int i = 0; i < 20; i++) { + Assert.assertTrue(grouper.aggregate(new IntKey(i)).isOk()); + } + + // No files should have been created — all spills are below the threshold and + // pending bytes haven't reached minSpillFileSize yet. + Assert.assertEquals( + "small spills should stay in memory without creating any temp files", + 0, + temporaryStorage.currentFileCount() + ); + Assert.assertEquals(0, storageDir.listFiles().length); + + // Results should still be correct when iterated. + assertResultsCorrect(grouper, 20, 1); + } + } + @Test public void testDiskFull() throws IOException { - try (SpillingGrouper grouper = makeGrouper(50, temporaryFolder.newFolder(), 10, 100)) { + // Use a small minSpillFileSize so pending runs flush to disk frequently, where the + // 500-byte maxStorageBytes limit will be hit. + try (SpillingGrouper grouper = makeGrouper(50, temporaryFolder.newFolder(), 500, 100, 100)) { AggregateResult lastResult = AggregateResult.ok(); for (int i = 0; i < 10000 && lastResult.isOk(); i++) { lastResult = grouper.aggregate(new IntKey(i)); @@ -345,10 +378,22 @@ private SpillingGrouper makeGrouper( long maxStorageBytes, int maxFileCount ) + { + return makeGrouper(bufferSize, storageDir, maxStorageBytes, maxFileCount, 1024 * 1024L); + } + + private SpillingGrouper makeGrouper( + int bufferSize, + File storageDir, + long maxStorageBytes, + int maxFileCount, + long minSpillFileSize + ) { return makeGrouper( bufferSize, - new LimitedTemporaryStorage(storageDir, maxStorageBytes, maxFileCount, new GroupByStatsProvider.PerQueryStats()) + new LimitedTemporaryStorage(storageDir, maxStorageBytes, maxFileCount, new GroupByStatsProvider.PerQueryStats()), + minSpillFileSize ); } @@ -356,6 +401,15 @@ private SpillingGrouper makeGrouper( int bufferSize, LimitedTemporaryStorage temporaryStorage ) + { + return makeGrouper(bufferSize, temporaryStorage, 1024 * 1024L); + } + + private SpillingGrouper makeGrouper( + int bufferSize, + LimitedTemporaryStorage temporaryStorage, + long minSpillFileSize + ) { final GroupByTestColumnSelectorFactory columnSelectorFactory = GrouperTestUtil.newColumnSelectorFactory(); columnSelectorFactory.setRow(new MapBasedRow(0, ImmutableMap.of("value", 1L))); @@ -374,7 +428,7 @@ private SpillingGrouper makeGrouper( null, false, bufferSize, - 1024 * 1024L, + minSpillFileSize, new GroupByStatsProvider.PerQueryStats() ); grouper.init();