From b17359efdf97bc60e2a96241dade610b91d449d2 Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Sat, 30 Aug 2025 15:11:08 +0200 Subject: [PATCH] Add OOCMatrixBlockTracker --- .../ooc/AggregateUnaryOOCInstruction.java | 52 +++++++--------- .../ooc/MatrixVectorBinaryOOCInstruction.java | 28 +++------ .../instructions/ooc/OOCInstruction.java | 60 +++++++++++++++++++ 3 files changed, 92 insertions(+), 48 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index a656cd337cd..7bbb5f2e54f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -75,14 +75,11 @@ public void processInstruction( ExecutionContext ec ) { int blen = ConfigurationManager.getBlocksize(); if (aggun.isRowAggregate() || aggun.isColAggregate()) { - // intermediate state per aggregation index - HashMap aggs = new HashMap<>(); // partial aggregates - HashMap corrs = new HashMap<>(); // correction blocks - HashMap cnt = new HashMap<>(); // processed block count per agg idx - DataCharacteristics chars = ec.getDataCharacteristics(input1.getName()); // number of blocks to process per aggregation idx (row or column dim) - long nBlocks = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks(); + long emissionThreshold = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks(); + OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emissionThreshold); + HashMap corrs = new HashMap<>(); // correction blocks LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); @@ -94,9 +91,8 @@ public void processInstruction( ExecutionContext ec ) { while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { long idx = aggun.isRowAggregate() ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); - if(aggs.containsKey(idx)) { - // update existing partial aggregate for this idx - MatrixBlock ret = aggs.get(idx); + MatrixBlock ret = aggTracker.get(idx); + if(ret != null) { MatrixBlock corr = corrs.get(idx); // aggregation @@ -105,9 +101,10 @@ public void processInstruction( ExecutionContext ec ) { OperationsOnMatrixValues.incrementalAggregation(ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - aggs.replace(idx, ret); - corrs.replace(idx, corr); - cnt.replace(idx, cnt.get(idx) + 1); + if (!aggTracker.putAndIncrementCount(idx, ret)){ + corrs.replace(idx, corr); + continue; + } } else { // first block for this idx - init aggregate and correction @@ -115,7 +112,7 @@ public void processInstruction( ExecutionContext ec ) { int rows = tmp.getValue().getNumRows(); int cols = tmp.getValue().getNumColumns(); int extra = _aop.correction.getNumRemovedRowsColumns(); - MatrixBlock ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); + ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); // aggregation @@ -124,25 +121,22 @@ public void processInstruction( ExecutionContext ec ) { OperationsOnMatrixValues.incrementalAggregation(ret, _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - aggs.put(idx, ret); - corrs.put(idx, corr); - cnt.put(idx, 1); + if(emissionThreshold > 1){ + aggTracker.putAndIncrementCount(idx, ret); + corrs.put(idx, corr); + continue; + } } - if(cnt.get(idx) == nBlocks) { - // all input blocks for this idx processed - emit aggregated block - MatrixBlock ret = aggs.get(idx); - // drop correction row/col - ret.dropLastRowsOrColumns(_aop.correction); - MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); - IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); + // all input blocks for this idx processed - emit aggregated block + ret.dropLastRowsOrColumns(_aop.correction); + MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); + IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); - qOut.enqueueTask(tmpOut); - // drop intermediate states - aggs.remove(idx); - corrs.remove(idx); - cnt.remove(idx); - } + qOut.enqueueTask(tmpOut); + // drop intermediate states + aggTracker.remove(idx); + corrs.remove(idx); } qOut.closeInput(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index 5e2d36d9df3..f82b3653acf 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -82,7 +82,8 @@ public void processInstruction( ExecutionContext ec ) { } // number of colBlocks for early block output - long nBlocks = min.getDataCharacteristics().getNumColBlocks(); + long emissionThreshold = min.getDataCharacteristics().getNumColBlocks(); + OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emissionThreshold); LocalTaskQueue qIn = min.getStreamHandle(); LocalTaskQueue qOut = new LocalTaskQueue<>(); @@ -95,8 +96,6 @@ public void processInstruction( ExecutionContext ec ) { pool.submit(() -> { IndexedMatrixValue tmp = null; try { - HashMap partialResults = new HashMap<>(); - HashMap cnt = new HashMap<>(); while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); long rowIndex = tmp.getIndexes().getRowIndex(); @@ -108,31 +107,22 @@ public void processInstruction( ExecutionContext ec ) { matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); // for single column block, no aggregation neeeded - if( min.getNumColumns() <= min.getBlocksize() ) { + if(emissionThreshold == 1) { qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); } else { // aggregation - MatrixBlock currAgg = partialResults.get(rowIndex); + MatrixBlock currAgg = aggTracker.get(rowIndex); if (currAgg == null) { - partialResults.put(rowIndex, partialResult); - cnt.put(rowIndex, 1); + aggTracker.putAndIncrementCount(rowIndex, partialResult); } else { - currAgg.binaryOperationsInPlace(plus, partialResult); - int newCnt = cnt.get(rowIndex) + 1; - - if(newCnt == nBlocks){ + currAgg = currAgg.binaryOperations(plus, partialResult); + if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ // early block output: emit aggregated block MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - MatrixBlock result = partialResults.get(rowIndex); - qOut.enqueueTask(new IndexedMatrixValue(idx, result)); - partialResults.remove(rowIndex); - cnt.remove(rowIndex); - } - else { - // maintain aggregation counts if not output-ready yet - cnt.replace(rowIndex, newCnt); + qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); + aggTracker.remove(rowIndex); } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index d3c2dfcbd77..fa192a9ecde 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -24,8 +24,11 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; +import java.util.HashMap; + public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); @@ -82,4 +85,61 @@ public void postprocessInstruction(ExecutionContext ec) { if(DMLScript.LINEAGE_DEBUGGER) ec.maintainLineageDebuggerInfo(this); } + + /** + * Tracks blocks and their counts to enable early emission + * once all blocks for a given index are processed. + */ + public static class OOCMatrixBlockTracker { + private final long emissionThreshold; + private final HashMap blocks; + private final HashMap cnt; + + public OOCMatrixBlockTracker(long emissionThreshold) { + this.emissionThreshold = emissionThreshold; + this.blocks = new HashMap<>(); + this.cnt = new HashMap<>(); + } + + /** + * Adds or updates a block for the given index and updates its internal count. + * @param idx block index + * @param block MatrixBlock + * @return true if the block count reached the threshold (ready to emit), false otherwise + */ + public boolean putAndIncrementCount(Long idx, MatrixBlock block) { + blocks.put(idx, block); + int newCnt = cnt.getOrDefault(idx, 0) + 1; + if (newCnt == emissionThreshold) { + return true; + } else { + cnt.put(idx, newCnt); + return false; + } + } + + public boolean incrementCount(Long idx) { + int newCnt = cnt.get(idx) + 1; + if (newCnt == emissionThreshold) { + return true; + } else { + cnt.put(idx, newCnt); + return false; + } + } + + public void putAndInitCount(Long idx, MatrixBlock block) { + blocks.put(idx, block); + cnt.put(idx, 0); + } + + public MatrixBlock get(Long idx) { + return blocks.get(idx); + } + + public void remove(Long idx) { + blocks.remove(idx); + cnt.remove(idx); + } + } }