Skip to content

Commit

Permalink
Optimize lookup table in join operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang committed Feb 2, 2025
1 parent be26325 commit 79e007a
Show file tree
Hide file tree
Showing 8 changed files with 558 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ protected abstract List<Object[]> buildJoinedRows(TransferableBlock leftBlock)

protected abstract List<Object[]> buildNonMatchRightRows();

// TODO: Optimize this to avoid unnecessary object copy.
protected Object[] joinRow(@Nullable Object[] leftRow, @Nullable Object[] rightRow) {
Object[] resultRow = new Object[_resultColumnSize];
if (leftRow != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,36 +31,64 @@
import org.apache.pinot.query.planner.plannode.JoinNode;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.join.DoubleLookupTable;
import org.apache.pinot.query.runtime.operator.join.FloatLookupTable;
import org.apache.pinot.query.runtime.operator.join.IntLookupTable;
import org.apache.pinot.query.runtime.operator.join.LongLookupTable;
import org.apache.pinot.query.runtime.operator.join.LookupTable;
import org.apache.pinot.query.runtime.operator.join.ObjectLookupTable;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.spi.utils.BooleanUtils;
import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner.JoinOverFlowMode;


/**
* This {@code HashJoinOperator} join algorithm with join keys. Right table is materialized into a hash table.
*/
// TODO: Support memory size based resource limit.
@SuppressWarnings("unchecked")
public class HashJoinOperator extends BaseJoinOperator {
private static final String EXPLAIN_NAME = "HASH_JOIN";
private static final int INITIAL_HEURISTIC_SIZE = 16;

// Placeholder for BitSet in _matchedRightRows when all keys are unique in the right table.
private static final BitSet BIT_SET_PLACEHOLDER = new BitSet(0);

private final KeySelector<?> _leftKeySelector;
private final KeySelector<?> _rightKeySelector;
private final Map<Object, ArrayList<Object[]>> _rightTable;
private final LookupTable _rightTable;
// Track matched right rows for right join and full join to output non-matched right rows.
// TODO: Revisit whether we should use IntList or RoaringBitmap for smaller memory footprint.
// TODO: Optimize this
private final Map<Object, BitSet> _matchedRightRows;

public HashJoinOperator(OpChainExecutionContext context, MultiStageOperator leftInput, DataSchema leftSchema,
MultiStageOperator rightInput, JoinNode node) {
super(context, leftInput, leftSchema, rightInput, node);
Preconditions.checkState(!node.getLeftKeys().isEmpty(), "Hash join operator requires join keys");
_leftKeySelector = KeySelectorFactory.getKeySelector(node.getLeftKeys());
List<Integer> leftKeys = node.getLeftKeys();
Preconditions.checkState(!leftKeys.isEmpty(), "Hash join operator requires join keys");
_leftKeySelector = KeySelectorFactory.getKeySelector(leftKeys);
_rightKeySelector = KeySelectorFactory.getKeySelector(node.getRightKeys());
_rightTable = new HashMap<>();
_rightTable = createLookupTable(leftKeys, leftSchema);
_matchedRightRows = needUnmatchedRightRows() ? new HashMap<>() : null;
}

private static LookupTable createLookupTable(List<Integer> joinKeys, DataSchema schema) {
if (joinKeys.size() > 1) {
return new ObjectLookupTable();
}
switch (schema.getColumnDataType(joinKeys.get(0)).getStoredType()) {
case INT:
return new IntLookupTable();
case LONG:
return new LongLookupTable();
case FLOAT:
return new FloatLookupTable();
case DOUBLE:
return new DoubleLookupTable();
default:
return new ObjectLookupTable();
}
}

@Override
public String toExplainString() {
return EXPLAIN_NAME;
Expand All @@ -71,41 +99,35 @@ protected void buildRightTable()
throws ProcessingException {
LOGGER.trace("Building hash table for join operator");
long startTime = System.currentTimeMillis();
int numRowsInHashTable = 0;
int numRows = 0;
TransferableBlock rightBlock = _rightInput.nextBlock();
while (!TransferableBlockUtils.isEndOfStream(rightBlock)) {
List<Object[]> container = rightBlock.getContainer();
List<Object[]> rows = rightBlock.getContainer();
// Row based overflow check.
if (container.size() + numRowsInHashTable > _maxRowsInJoin) {
if (rows.size() + numRows > _maxRowsInJoin) {
if (_joinOverflowMode == JoinOverFlowMode.THROW) {
throwProcessingExceptionForJoinRowLimitExceeded(
"Cannot build in memory hash table for join operator, reached number of rows limit: " + _maxRowsInJoin);
} else {
// Just fill up the buffer.
int remainingRows = _maxRowsInJoin - numRowsInHashTable;
container = container.subList(0, remainingRows);
int remainingRows = _maxRowsInJoin - numRows;
rows = rows.subList(0, remainingRows);
_statMap.merge(StatKey.MAX_ROWS_IN_JOIN_REACHED, true);
// setting only the rightTableOperator to be early terminated and awaits EOS block next.
_rightInput.earlyTerminate();
}
}
// put all the rows into corresponding hash collections keyed by the key selector function.
for (Object[] row : container) {
ArrayList<Object[]> hashCollection =
_rightTable.computeIfAbsent(_rightKeySelector.getKey(row), k -> new ArrayList<>(INITIAL_HEURISTIC_SIZE));
int size = hashCollection.size();
if ((size & size - 1) == 0 && size < _maxRowsInJoin && size < Integer.MAX_VALUE / 2) { // is power of 2
hashCollection.ensureCapacity(Math.min(size << 1, _maxRowsInJoin));
}
hashCollection.add(row);
for (Object[] row : rows) {
_rightTable.addRow(_rightKeySelector.getKey(row), row);
}
numRowsInHashTable += container.size();
numRows += rows.size();
sampleAndCheckInterruption();
rightBlock = _rightInput.nextBlock();
}
if (rightBlock.isErrorBlock()) {
_upstreamErrorBlock = rightBlock;
} else {
_rightTable.finish();
_isRightTableBuilt = true;
_rightSideStats = rightBlock.getQueryStats();
assert _rightSideStats != null;
Expand All @@ -123,69 +145,99 @@ protected List<Object[]> buildJoinedRows(TransferableBlock leftBlock)
case ANTI:
return buildJoinedDataBlockAnti(leftBlock);
default: { // INNER, LEFT, RIGHT, FULL
return buildJoinedDataBlockDefault(leftBlock);
if (_rightTable.isKeysUnique()) {
return buildJoinedDataBlockUniqueKeys(leftBlock);
} else {
return buildJoinedDataBlockDuplicateKeys(leftBlock);
}
}
}
}

private List<Object[]> buildJoinedDataBlockDefault(TransferableBlock leftBlock)
private List<Object[]> buildJoinedDataBlockUniqueKeys(TransferableBlock leftBlock)
throws ProcessingException {
List<Object[]> container = leftBlock.getContainer();
ArrayList<Object[]> rows = new ArrayList<>(container.size());
List<Object[]> leftRows = leftBlock.getContainer();
ArrayList<Object[]> rows = new ArrayList<>(leftRows.size());

for (Object[] leftRow : container) {
for (Object[] leftRow : leftRows) {
Object key = _leftKeySelector.getKey(leftRow);
// NOTE: Empty key selector will always give same hash code.
List<Object[]> rightRows = _rightTable.get(key);
if (rightRows == null) {
if (needUnmatchedLeftRows()) {
if (isMaxRowsLimitReached(rows.size())) {
break;
}
rows.add(joinRow(leftRow, null));
}
continue;
}
boolean hasMatchForLeftRow = false;
int numRightRows = rightRows.size();
rows.ensureCapacity(rows.size() + numRightRows);
boolean maxRowsLimitReached = false;
for (int i = 0; i < numRightRows; i++) {
Object[] rightRow = rightRows.get(i);
// TODO: Optimize this to avoid unnecessary object copy.
Object[] rightRow = (Object[]) _rightTable.lookup(key);
if (rightRow == null) {
handleUnmatchedLeftRow(leftRow, rows);
} else {
Object[] resultRow = joinRow(leftRow, rightRow);
if (_nonEquiEvaluators.isEmpty() || _nonEquiEvaluators.stream()
.allMatch(evaluator -> BooleanUtils.isTrueInternalValue(evaluator.apply(resultRow)))) {
if (matchNonEquiConditions(resultRow)) {
if (isMaxRowsLimitReached(rows.size())) {
maxRowsLimitReached = true;
break;
}
rows.add(resultRow);
hasMatchForLeftRow = true;
if (_matchedRightRows != null) {
_matchedRightRows.computeIfAbsent(key, k -> new BitSet(numRightRows)).set(i);
_matchedRightRows.put(key, BIT_SET_PLACEHOLDER);
}
} else {
handleUnmatchedLeftRow(leftRow, rows);
}
}
if (maxRowsLimitReached) {
break;
}
if (!hasMatchForLeftRow && needUnmatchedLeftRows()) {
if (isMaxRowsLimitReached(rows.size())) {
}

return rows;
}

private List<Object[]> buildJoinedDataBlockDuplicateKeys(TransferableBlock leftBlock)
throws ProcessingException {
List<Object[]> leftRows = leftBlock.getContainer();
List<Object[]> rows = new ArrayList<>(leftRows.size());

for (Object[] leftRow : leftRows) {
Object key = _leftKeySelector.getKey(leftRow);
List<Object[]> rightRows = (List<Object[]>) _rightTable.lookup(key);
if (rightRows == null) {
handleUnmatchedLeftRow(leftRow, rows);
} else {
boolean maxRowsLimitReached = false;
boolean hasMatchForLeftRow = false;
int numRightRows = rightRows.size();
for (int i = 0; i < numRightRows; i++) {
Object[] resultRow = joinRow(leftRow, rightRows.get(i));
if (matchNonEquiConditions(resultRow)) {
if (isMaxRowsLimitReached(rows.size())) {
maxRowsLimitReached = true;
break;
}
rows.add(resultRow);
hasMatchForLeftRow = true;
if (_matchedRightRows != null) {
_matchedRightRows.computeIfAbsent(key, k -> new BitSet(numRightRows)).set(i);
}
}
}
if (maxRowsLimitReached) {
break;
}
rows.add(joinRow(leftRow, null));
if (!hasMatchForLeftRow) {
handleUnmatchedLeftRow(leftRow, rows);
}
}
}

return rows;
}

private void handleUnmatchedLeftRow(Object[] leftRow, List<Object[]> rows)
throws ProcessingException {
if (needUnmatchedLeftRows()) {
if (isMaxRowsLimitReached(rows.size())) {
return;
}
rows.add(joinRow(leftRow, null));
}
}

private List<Object[]> buildJoinedDataBlockSemi(TransferableBlock leftBlock) {
List<Object[]> container = leftBlock.getContainer();
List<Object[]> rows = new ArrayList<>(container.size());
List<Object[]> leftRows = leftBlock.getContainer();
List<Object[]> rows = new ArrayList<>(leftRows.size());

for (Object[] leftRow : container) {
for (Object[] leftRow : leftRows) {
Object key = _leftKeySelector.getKey(leftRow);
// SEMI-JOIN only checks existence of the key
if (_rightTable.containsKey(key)) {
Expand All @@ -197,10 +249,10 @@ private List<Object[]> buildJoinedDataBlockSemi(TransferableBlock leftBlock) {
}

private List<Object[]> buildJoinedDataBlockAnti(TransferableBlock leftBlock) {
List<Object[]> container = leftBlock.getContainer();
List<Object[]> rows = new ArrayList<>(container.size());
List<Object[]> leftRows = leftBlock.getContainer();
List<Object[]> rows = new ArrayList<>(leftRows.size());

for (Object[] leftRow : container) {
for (Object[] leftRow : leftRows) {
Object key = _leftKeySelector.getKey(leftRow);
// ANTI-JOIN only checks non-existence of the key
if (!_rightTable.containsKey(key)) {
Expand All @@ -214,18 +266,27 @@ private List<Object[]> buildJoinedDataBlockAnti(TransferableBlock leftBlock) {
@Override
protected List<Object[]> buildNonMatchRightRows() {
List<Object[]> rows = new ArrayList<>();
for (Map.Entry<Object, ArrayList<Object[]>> entry : _rightTable.entrySet()) {
List<Object[]> rightRows = entry.getValue();
BitSet matchedIndices = _matchedRightRows.get(entry.getKey());
if (matchedIndices == null) {
for (Object[] rightRow : rightRows) {
if (_rightTable.isKeysUnique()) {
for (Map.Entry<Object, Object[]> entry : _rightTable.entrySet()) {
Object[] rightRow = entry.getValue();
if (!_matchedRightRows.containsKey(entry.getKey())) {
rows.add(joinRow(null, rightRow));
}
} else {
int numRightRows = rightRows.size();
int unmatchedIndex = 0;
while ((unmatchedIndex = matchedIndices.nextClearBit(unmatchedIndex)) < numRightRows) {
rows.add(joinRow(null, rightRows.get(unmatchedIndex++)));
}
} else {
for (Map.Entry<Object, ArrayList<Object[]>> entry : _rightTable.entrySet()) {
List<Object[]> rightRows = entry.getValue();
BitSet matchedIndices = _matchedRightRows.get(entry.getKey());
if (matchedIndices == null) {
for (Object[] rightRow : rightRows) {
rows.add(joinRow(null, rightRow));
}
} else {
int numRightRows = rightRows.size();
int unmatchedIndex = 0;
while ((unmatchedIndex = matchedIndices.nextClearBit(unmatchedIndex)) < numRightRows) {
rows.add(joinRow(null, rightRows.get(unmatchedIndex++)));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.query.runtime.operator.join;

import it.unimi.dsi.fastutil.doubles.Double2ObjectMap;
import it.unimi.dsi.fastutil.doubles.Double2ObjectOpenHashMap;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;


/**
* The {@code DoubleLookupTable} is a lookup table for double keys.
*/
@SuppressWarnings("unchecked")
public class DoubleLookupTable extends LookupTable {
private final Double2ObjectOpenHashMap<Object> _lookupTable = new Double2ObjectOpenHashMap<>(INITIAL_CAPACITY);

@Override
public void addRow(Object key, Object[] row) {
_lookupTable.compute((double) key, (k, v) -> calculateValue(row, v));
}

@Override
public void finish() {
if (!_keysUnique) {
for (Double2ObjectMap.Entry<Object> entry : _lookupTable.double2ObjectEntrySet()) {
convertValueToList(entry);
}
}
}

@Override
public boolean containsKey(Object key) {
return _lookupTable.containsKey((double) key);
}

@Nullable
@Override
public Object[] lookup(Object key) {
return (Object[]) _lookupTable.get((double) key);
}

@SuppressWarnings("rawtypes")
@Override
public Set<Map.Entry> entrySet() {
return (Set) _lookupTable.double2ObjectEntrySet();
}
}
Loading

0 comments on commit 79e007a

Please sign in to comment.