diff --git a/docs/changelog/130510.yaml b/docs/changelog/130510.yaml new file mode 100644 index 0000000000000..01426b6b8e4e3 --- /dev/null +++ b/docs/changelog/130510.yaml @@ -0,0 +1,5 @@ +pr: 130510 +summary: Add fast path for single value in VALUES aggregator +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index 51195578ac363..8a52433fc0629 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongLongHash; import org.elasticsearch.compute.aggregation.blockhash.BlockHash; import org.elasticsearch.compute.ann.Aggregator; @@ -76,7 +77,7 @@ public static GroupingAggregatorFunction.AddInput wrapAddInput( } public static void combine(GroupingState state, int groupId, BytesRef v) { - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(v)))); } public static void combineIntermediate(GroupingState state, int groupId, BytesRefBlock values, int valuesPosition) { @@ -90,6 +91,14 @@ public static void combineIntermediate(GroupingState state, int groupId, BytesRe public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { BytesRef scratch = new BytesRef(); + if (statePosition >= state.firstValues.size()) { + return; + } + int firstValue = state.firstValues.get(statePosition) - 1; + if (firstValue < 0) { + return; + } + combine(current, currentGroupId, state.bytes.get(firstValue, scratch)); for (int id = 0; id < state.values.size(); id++) { if (state.values.getKey1(id) == statePosition) { long value = state.values.getKey2(id); @@ -146,23 +155,29 @@ public void close() { * collector operation. But at least it's fairly simple. */ public static class GroupingState implements GroupingAggregatorState { - final LongLongHash values; + private final BigArrays bigArrays; + private final LongLongHash values; + private IntArray firstValues; // the first value ordinal+1 collected in each group, 0 means no value BytesRefHash bytes; private GroupingState(BigArrays bigArrays) { + this.bigArrays = bigArrays; LongLongHash _values = null; BytesRefHash _bytes = null; + IntArray _firstValues = null; try { _values = new LongLongHash(1, bigArrays); _bytes = new BytesRefHash(1, bigArrays); - + _firstValues = bigArrays.newIntArray(1); values = _values; bytes = _bytes; + firstValues = _firstValues; _values = null; _bytes = null; + _firstValues = null; } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException(_values, _bytes, _firstValues); } } @@ -176,7 +191,7 @@ public void toIntermediate(Block[] blocks, int offset, IntVector selected, Drive * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { + if (bytes.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -220,11 +235,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* @@ -256,7 +273,7 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { ids[selectedCounts[group]++] = id; } } - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(firstValues.size() + values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); } else { return buildOutputBlock(blockFactory, selected, selectedCounts, ids); @@ -266,6 +283,20 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } + void addValue(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + final int curr = firstValues.get(groupId) - 1; + if (curr == -1) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (curr != valueOrdinal) { + values.add(groupId, valueOrdinal); + } + } else { + firstValues = bigArrays.grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } + } + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. @@ -275,20 +306,26 @@ Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] sele int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> append(builder, ids[start], scratch); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); for (int i = start; i < end; i++) { append(builder, ids[i], scratch); } builder.endPositionEntry(); } + start = end; } - start = end; + } return builder.build(); } @@ -304,20 +341,25 @@ Block buildOrdinalOutputBlock(BlockFactory blockFactory, IntVector selected, int int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendInt(firstValue); + } else { builder.beginPositionEntry(); + builder.appendInt(firstValue); for (int i = start; i < end; i++) { builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); } builder.endPositionEntry(); } + start = end; } - start = end; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -343,7 +385,7 @@ public void enableGroupIdTracking(SeenGroupIds seen) { @Override public void close() { - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, firstValues); } } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index f5b0d519dd890..62f729b16e656 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -180,11 +180,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index 4cfbf329a895d..7bb7444e000e7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -186,11 +186,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 38e5ad99cf581..1428b05a53ac8 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -186,11 +186,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 4bfc230d7e1f7..eb9f7beb19908 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -180,11 +180,13 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java index 78a083b8daac7..c8094b348ca8f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregators.java @@ -55,7 +55,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -77,7 +77,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -93,7 +93,7 @@ public void add(int positionOffset, IntVector groupIds) { int valuesStart = ordinalIds.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + ordinalIds.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(v))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(v))); } } } @@ -135,7 +135,7 @@ public void add(int positionOffset, IntArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -150,7 +150,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { int groupEnd = groupStart + groupIds.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groupIds.getInt(g); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } } @@ -159,7 +159,7 @@ public void add(int positionOffset, IntBigArrayBlock groupIds) { public void add(int positionOffset, IntVector groupIds) { for (int groupPosition = 0; groupPosition < groupIds.getPositionCount(); groupPosition++) { int groupId = groupIds.getInt(groupPosition); - state.values.add(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); + state.addValue(groupId, hashIds.getInt(ordinalIds.getInt(groupPosition + positionOffset))); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 67f32fc4a4d4e..2c899105516b9 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -14,6 +14,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.util.BigArrays; $if(BytesRef)$ import org.elasticsearch.common.util.BytesRefHash; +import org.elasticsearch.common.util.IntArray; $else$ import org.elasticsearch.common.util.LongHash; $endif$ @@ -118,7 +119,7 @@ $if(long)$ $elseif(double)$ state.values.add(groupId, Double.doubleToLongBits(v)); $elseif(BytesRef)$ - state.values.add(groupId, BlockHash.hashOrdToGroup(state.bytes.add(v))); + state.addValue(groupId, Math.toIntExact(BlockHash.hashOrdToGroup(state.bytes.add(v)))); $elseif(int)$ /* * Encode the groupId and value into a single long - @@ -152,6 +153,14 @@ $endif$ public static void combineStates(GroupingState current, int currentGroupId, GroupingState state, int statePosition) { $if(BytesRef)$ BytesRef scratch = new BytesRef(); + if (statePosition >= state.firstValues.size()) { + return; + } + int firstValue = state.firstValues.get(statePosition) - 1; + if (firstValue < 0) { + return; + } + combine(current, currentGroupId, state.bytes.get(firstValue, scratch)); $endif$ for (int id = 0; id < state.values.size(); id++) { $if(long||BytesRef)$ @@ -259,7 +268,9 @@ $if(long||double)$ private final LongLongHash values; $elseif(BytesRef)$ - final LongLongHash values; + private final BigArrays bigArrays; + private final LongLongHash values; + private IntArray firstValues; // the first value ordinal+1 collected in each group, 0 means no value BytesRefHash bytes; $elseif(int||float)$ @@ -270,19 +281,23 @@ $endif$ $if(long||double)$ values = new LongLongHash(1, bigArrays); $elseif(BytesRef)$ + this.bigArrays = bigArrays; LongLongHash _values = null; BytesRefHash _bytes = null; + IntArray _firstValues = null; try { _values = new LongLongHash(1, bigArrays); _bytes = new BytesRefHash(1, bigArrays); - + _firstValues = bigArrays.newIntArray(1); values = _values; bytes = _bytes; + firstValues = _firstValues; _values = null; _bytes = null; + _firstValues = null; } finally { - Releasables.closeExpectNoException(_values, _bytes); + Releasables.closeExpectNoException(_values, _bytes, _firstValues); } $elseif(int||float)$ values = new LongHash(1, bigArrays); @@ -299,7 +314,7 @@ $endif$ * groups. This is the implementation of the final and intermediate results of the agg. */ Block toBlock(BlockFactory blockFactory, IntVector selected) { - if (values.size() == 0) { + if ($if(BytesRef)$bytes$else$values$endif$.size() == 0) { return blockFactory.newConstantNullBlock(selected.getPositionCount()); } @@ -348,11 +363,13 @@ $endif$ * Then the total is 9 and the counts array will contain 0, 3, -2, 4, 5 */ int total = 0; - for (int s = 0; s < selected.getPositionCount(); s++) { - int group = selected.getInt(s); - int count = -selectedCounts[group]; - selectedCounts[group] = total; - total += count; + if (values.size() > 0) { + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int count = -selectedCounts[group]; + selectedCounts[group] = total; + total += count; + } } /* @@ -390,7 +407,7 @@ $endif$ } } $if(BytesRef)$ - if (OrdinalBytesRefBlock.isDense(selected.getPositionCount(), Math.toIntExact(values.size()))) { + if (OrdinalBytesRefBlock.isDense(firstValues.size() + values.size(), bytes.size())) { return buildOrdinalOutputBlock(blockFactory, selected, selectedCounts, ids); } else { return buildOutputBlock(blockFactory, selected, selectedCounts, ids); @@ -403,13 +420,56 @@ $endif$ } } +$if(BytesRef)$ + void addValue(int groupId, int valueOrdinal) { + if (groupId < firstValues.size()) { + final int curr = firstValues.get(groupId) - 1; + if (curr == -1) { + firstValues.set(groupId, valueOrdinal + 1); + } else if (curr != valueOrdinal) { + values.add(groupId, valueOrdinal); + } + } else { + firstValues = bigArrays.grow(firstValues, groupId + 1); + firstValues.set(groupId, valueOrdinal + 1); + } + } +$endif$ + Block buildOutputBlock(BlockFactory blockFactory, IntVector selected, int[] selectedCounts, int[] ids) { /* * Insert the ids in order. */ $if(BytesRef)$ BytesRef scratch = new BytesRef(); -$endif$ + try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(selected.getPositionCount())) { + int start = 0; + for (int s = 0; s < selected.getPositionCount(); s++) { + int group = selected.getInt(s); + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendBytesRef(bytes.get(firstValue, scratch)); + } else { + builder.beginPositionEntry(); + builder.appendBytesRef(bytes.get(firstValue, scratch)); + for (int i = start; i < end; i++) { + append(builder, ids[i], scratch); + } + builder.endPositionEntry(); + } + start = end; + } + + } + return builder.build(); + } +$else$ try ($Type$Block.Builder builder = blockFactory.new$Type$BlockBuilder(selected.getPositionCount())) { int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { @@ -431,6 +491,7 @@ $endif$ } return builder.build(); } +$endif$ } $if(BytesRef)$ @@ -444,20 +505,25 @@ $if(BytesRef)$ int start = 0; for (int s = 0; s < selected.getPositionCount(); s++) { int group = selected.getInt(s); - int end = selectedCounts[group]; - int count = end - start; - switch (count) { - case 0 -> builder.appendNull(); - case 1 -> builder.appendInt(Math.toIntExact(values.getKey2(ids[start]))); - default -> { + int firstValue = group < firstValues.size() ? firstValues.get(group) - 1 : -1; + if (firstValue == -1) { + assert selectedCounts[group] == start : selectedCounts[group] + " != " + start; + builder.appendNull(); + } else { + int end = selectedCounts[group]; + int count = end - start; + if (count == 0) { + builder.appendInt(firstValue); + } else { builder.beginPositionEntry(); + builder.appendInt(firstValue); for (int i = start; i < end; i++) { builder.appendInt(Math.toIntExact(values.getKey2(ids[i]))); } builder.endPositionEntry(); } + start = end; } - start = end; } ordinals = builder.build(); dict = blockFactory.newBytesRefArrayVector(dictArray, Math.toIntExact(dictArray.size())); @@ -501,7 +567,7 @@ $endif$ @Override public void close() { $if(BytesRef)$ - Releasables.closeExpectNoException(values, bytes); + Releasables.closeExpectNoException(values, bytes, firstValues); $else$ values.close(); $endif$