Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,11 @@ public void processInstruction( ExecutionContext ec ) {
int blen = ConfigurationManager.getBlocksize();

if (aggun.isRowAggregate() || aggun.isColAggregate()) {
// intermediate state per aggregation index
HashMap<Long, MatrixBlock> aggs = new HashMap<>(); // partial aggregates
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks
HashMap<Long, Integer> cnt = new HashMap<>(); // processed block count per agg idx

DataCharacteristics chars = ec.getDataCharacteristics(input1.getName());
// number of blocks to process per aggregation idx (row or column dim)
long nBlocks = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks();
long emissionThreshold = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks();
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emissionThreshold);
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks

LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
Expand All @@ -94,9 +91,8 @@ public void processInstruction( ExecutionContext ec ) {
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
long idx = aggun.isRowAggregate() ?
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
if(aggs.containsKey(idx)) {
// update existing partial aggregate for this idx
MatrixBlock ret = aggs.get(idx);
MatrixBlock ret = aggTracker.get(idx);
if(ret != null) {
MatrixBlock corr = corrs.get(idx);

// aggregation
Expand All @@ -105,17 +101,18 @@ public void processInstruction( ExecutionContext ec ) {
OperationsOnMatrixValues.incrementalAggregation(ret,
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);

aggs.replace(idx, ret);
corrs.replace(idx, corr);
cnt.replace(idx, cnt.get(idx) + 1);
if (!aggTracker.putAndIncrementCount(idx, ret)){
corrs.replace(idx, corr);
continue;
}
}
else {
// first block for this idx - init aggregate and correction
// TODO avoid corr block for inplace incremental aggregation
int rows = tmp.getValue().getNumRows();
int cols = tmp.getValue().getNumColumns();
int extra = _aop.correction.getNumRemovedRowsColumns();
MatrixBlock ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);

// aggregation
Expand All @@ -124,25 +121,22 @@ public void processInstruction( ExecutionContext ec ) {
OperationsOnMatrixValues.incrementalAggregation(ret,
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);

aggs.put(idx, ret);
corrs.put(idx, corr);
cnt.put(idx, 1);
if(emissionThreshold > 1){
aggTracker.putAndIncrementCount(idx, ret);
corrs.put(idx, corr);
continue;
}
}

if(cnt.get(idx) == nBlocks) {
// all input blocks for this idx processed - emit aggregated block
MatrixBlock ret = aggs.get(idx);
// drop correction row/col
ret.dropLastRowsOrColumns(_aop.correction);
MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
// all input blocks for this idx processed - emit aggregated block
ret.dropLastRowsOrColumns(_aop.correction);
MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);

qOut.enqueueTask(tmpOut);
// drop intermediate states
aggs.remove(idx);
corrs.remove(idx);
cnt.remove(idx);
}
qOut.enqueueTask(tmpOut);
// drop intermediate states
aggTracker.remove(idx);
corrs.remove(idx);
}
qOut.closeInput();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public void processInstruction( ExecutionContext ec ) {
}

// number of colBlocks for early block output
long nBlocks = min.getDataCharacteristics().getNumColBlocks();
long emissionThreshold = min.getDataCharacteristics().getNumColBlocks();
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emissionThreshold);

LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
Expand All @@ -95,8 +96,6 @@ public void processInstruction( ExecutionContext ec ) {
pool.submit(() -> {
IndexedMatrixValue tmp = null;
try {
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
HashMap<Long, Integer> cnt = new HashMap<>();
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
long rowIndex = tmp.getIndexes().getRowIndex();
Expand All @@ -108,31 +107,22 @@ public void processInstruction( ExecutionContext ec ) {
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);

// for single column block, no aggregation neeeded
if( min.getNumColumns() <= min.getBlocksize() ) {
if(emissionThreshold == 1) {
qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
}
else {
// aggregation
MatrixBlock currAgg = partialResults.get(rowIndex);
MatrixBlock currAgg = aggTracker.get(rowIndex);
if (currAgg == null) {
partialResults.put(rowIndex, partialResult);
cnt.put(rowIndex, 1);
aggTracker.putAndIncrementCount(rowIndex, partialResult);
}
else {
currAgg.binaryOperationsInPlace(plus, partialResult);
int newCnt = cnt.get(rowIndex) + 1;

if(newCnt == nBlocks){
currAgg = currAgg.binaryOperations(plus, partialResult);
if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){
// early block output: emit aggregated block
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
MatrixBlock result = partialResults.get(rowIndex);
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
partialResults.remove(rowIndex);
cnt.remove(rowIndex);
}
else {
// maintain aggregation counts if not output-ready yet
cnt.replace(rowIndex, newCnt);
qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg));
aggTracker.remove(rowIndex);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;

import java.util.HashMap;

public abstract class OOCInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());

Expand Down Expand Up @@ -82,4 +85,61 @@ public void postprocessInstruction(ExecutionContext ec) {
if(DMLScript.LINEAGE_DEBUGGER)
ec.maintainLineageDebuggerInfo(this);
}

/**
* Tracks blocks and their counts to enable early emission
* once all blocks for a given index are processed.
*/
public static class OOCMatrixBlockTracker {
private final long emissionThreshold;
private final HashMap<Long, MatrixBlock> blocks;
private final HashMap<Long, Integer> cnt;

public OOCMatrixBlockTracker(long emissionThreshold) {
this.emissionThreshold = emissionThreshold;
this.blocks = new HashMap<>();
this.cnt = new HashMap<>();
}

/**
* Adds or updates a block for the given index and updates its internal count.
* @param idx block index
* @param block MatrixBlock
* @return true if the block count reached the threshold (ready to emit), false otherwise
*/
public boolean putAndIncrementCount(Long idx, MatrixBlock block) {
blocks.put(idx, block);
int newCnt = cnt.getOrDefault(idx, 0) + 1;
if (newCnt == emissionThreshold) {
return true;
} else {
cnt.put(idx, newCnt);
return false;
}
}

public boolean incrementCount(Long idx) {
int newCnt = cnt.get(idx) + 1;
if (newCnt == emissionThreshold) {
return true;
} else {
cnt.put(idx, newCnt);
return false;
}
}

public void putAndInitCount(Long idx, MatrixBlock block) {
blocks.put(idx, block);
cnt.put(idx, 0);
}

public MatrixBlock get(Long idx) {
return blocks.get(idx);
}

public void remove(Long idx) {
blocks.remove(idx);
cnt.remove(idx);
}
}
}
Loading