Skip to content

Commit df26e67

Browse files
jessicapriebemboehm7
authored andcommitted
[SYSTEMDS-3918] New out-of-core tmp aggregation util
Closes #2318.
1 parent 186e499 commit df26e67

File tree

3 files changed

+89
-49
lines changed

3 files changed

+89
-49
lines changed

src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,11 @@ public void processInstruction( ExecutionContext ec ) {
7575
int blen = ConfigurationManager.getBlocksize();
7676

7777
if (aggun.isRowAggregate() || aggun.isColAggregate()) {
78-
// intermediate state per aggregation index
79-
HashMap<Long, MatrixBlock> aggs = new HashMap<>(); // partial aggregates
80-
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks
81-
HashMap<Long, Integer> cnt = new HashMap<>(); // processed block count per agg idx
82-
8378
DataCharacteristics chars = ec.getDataCharacteristics(input1.getName());
8479
// number of blocks to process per aggregation idx (row or column dim)
85-
long nBlocks = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks();
80+
long emitThreshold = aggun.isRowAggregate()? chars.getNumColBlocks() : chars.getNumRowBlocks();
81+
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);
82+
HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // correction blocks
8683

8784
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
8885
ec.getMatrixObject(output).setStreamHandle(qOut);
@@ -94,9 +91,8 @@ public void processInstruction( ExecutionContext ec ) {
9491
while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
9592
long idx = aggun.isRowAggregate() ?
9693
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
97-
if(aggs.containsKey(idx)) {
98-
// update existing partial aggregate for this idx
99-
MatrixBlock ret = aggs.get(idx);
94+
MatrixBlock ret = aggTracker.get(idx);
95+
if(ret != null) {
10096
MatrixBlock corr = corrs.get(idx);
10197

10298
// aggregation
@@ -105,17 +101,18 @@ public void processInstruction( ExecutionContext ec ) {
105101
OperationsOnMatrixValues.incrementalAggregation(ret,
106102
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
107103

108-
aggs.replace(idx, ret);
109-
corrs.replace(idx, corr);
110-
cnt.replace(idx, cnt.get(idx) + 1);
104+
if (!aggTracker.putAndIncrementCount(idx, ret)){
105+
corrs.replace(idx, corr);
106+
continue;
107+
}
111108
}
112109
else {
113110
// first block for this idx - init aggregate and correction
114111
// TODO avoid corr block for inplace incremental aggregation
115112
int rows = tmp.getValue().getNumRows();
116113
int cols = tmp.getValue().getNumColumns();
117114
int extra = _aop.correction.getNumRemovedRowsColumns();
118-
MatrixBlock ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
115+
ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
119116
MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false);
120117

121118
// aggregation
@@ -124,25 +121,24 @@ public void processInstruction( ExecutionContext ec ) {
124121
OperationsOnMatrixValues.incrementalAggregation(ret,
125122
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
126123

127-
aggs.put(idx, ret);
128-
corrs.put(idx, corr);
129-
cnt.put(idx, 1);
124+
if(emitThreshold > 1){
125+
aggTracker.putAndIncrementCount(idx, ret);
126+
corrs.put(idx, corr);
127+
continue;
128+
}
130129
}
131130

132-
if(cnt.get(idx) == nBlocks) {
133-
// all input blocks for this idx processed - emit aggregated block
134-
MatrixBlock ret = aggs.get(idx);
135-
// drop correction row/col
136-
ret.dropLastRowsOrColumns(_aop.correction);
137-
MatrixIndexes midx = aggun.isRowAggregate()? new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
138-
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
139-
140-
qOut.enqueueTask(tmpOut);
141-
// drop intermediate states
142-
aggs.remove(idx);
143-
corrs.remove(idx);
144-
cnt.remove(idx);
145-
}
131+
// all input blocks for this idx processed - emit aggregated block
132+
ret.dropLastRowsOrColumns(_aop.correction);
133+
MatrixIndexes midx = aggun.isRowAggregate() ?
134+
new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) :
135+
new MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
136+
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
137+
138+
qOut.enqueueTask(tmpOut);
139+
// drop intermediate states
140+
aggTracker.remove(idx);
141+
corrs.remove(idx);
146142
}
147143
qOut.closeInput();
148144
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ public void processInstruction( ExecutionContext ec ) {
8282
}
8383

8484
// number of colBlocks for early block output
85-
long nBlocks = min.getDataCharacteristics().getNumColBlocks();
85+
long emitThreshold = min.getDataCharacteristics().getNumColBlocks();
86+
OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker(emitThreshold);
8687

8788
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
8889
LocalTaskQueue<IndexedMatrixValue> qOut = new LocalTaskQueue<>();
@@ -95,8 +96,6 @@ public void processInstruction( ExecutionContext ec ) {
9596
pool.submit(() -> {
9697
IndexedMatrixValue tmp = null;
9798
try {
98-
HashMap<Long, MatrixBlock> partialResults = new HashMap<>();
99-
HashMap<Long, Integer> cnt = new HashMap<>();
10099
while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
101100
MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue();
102101
long rowIndex = tmp.getIndexes().getRowIndex();
@@ -108,31 +107,22 @@ public void processInstruction( ExecutionContext ec ) {
108107
matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
109108

110109
// for single column block, no aggregation neeeded
111-
if( min.getNumColumns() <= min.getBlocksize() ) {
110+
if(emitThreshold == 1) {
112111
qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult));
113112
}
114113
else {
115114
// aggregation
116-
MatrixBlock currAgg = partialResults.get(rowIndex);
115+
MatrixBlock currAgg = aggTracker.get(rowIndex);
117116
if (currAgg == null) {
118-
partialResults.put(rowIndex, partialResult);
119-
cnt.put(rowIndex, 1);
117+
aggTracker.putAndIncrementCount(rowIndex, partialResult);
120118
}
121119
else {
122-
currAgg.binaryOperationsInPlace(plus, partialResult);
123-
int newCnt = cnt.get(rowIndex) + 1;
124-
125-
if(newCnt == nBlocks){
120+
currAgg = currAgg.binaryOperations(plus, partialResult);
121+
if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){
126122
// early block output: emit aggregated block
127123
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
128-
MatrixBlock result = partialResults.get(rowIndex);
129-
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
130-
partialResults.remove(rowIndex);
131-
cnt.remove(rowIndex);
132-
}
133-
else {
134-
// maintain aggregation counts if not output-ready yet
135-
cnt.replace(rowIndex, newCnt);
124+
qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg));
125+
aggTracker.remove(rowIndex);
136126
}
137127
}
138128
}

src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
import org.apache.sysds.api.DMLScript;
2525
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2626
import org.apache.sysds.runtime.instructions.Instruction;
27+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
2728
import org.apache.sysds.runtime.matrix.operators.Operator;
2829

30+
import java.util.HashMap;
31+
2932
public abstract class OOCInstruction extends Instruction {
3033
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());
3134

@@ -82,4 +85,55 @@ public void postprocessInstruction(ExecutionContext ec) {
8285
if(DMLScript.LINEAGE_DEBUGGER)
8386
ec.maintainLineageDebuggerInfo(this);
8487
}
88+
89+
/**
90+
* Tracks blocks and their counts to enable early emission
91+
* once all blocks for a given index are processed.
92+
*/
93+
public static class OOCMatrixBlockTracker {
94+
private final long _emitThreshold;
95+
private final HashMap<Long, MatrixBlock> _blocks;
96+
private final HashMap<Long, Integer> _cnts;
97+
98+
public OOCMatrixBlockTracker(long emitThreshold) {
99+
_emitThreshold = emitThreshold;
100+
_blocks = new HashMap<>();
101+
_cnts = new HashMap<>();
102+
}
103+
104+
/**
105+
* Adds or updates a block for the given index and updates its internal count.
106+
* @param idx block index
107+
* @param block MatrixBlock
108+
* @return true if the block count reached the threshold (ready to emit), false otherwise
109+
*/
110+
public boolean putAndIncrementCount(Long idx, MatrixBlock block) {
111+
_blocks.put(idx, block);
112+
int newCnt = _cnts.getOrDefault(idx, 0) + 1;
113+
if( newCnt < _emitThreshold )
114+
_cnts.put(idx, newCnt);
115+
return newCnt == _emitThreshold;
116+
}
117+
118+
public boolean incrementCount(Long idx) {
119+
int newCnt = _cnts.get(idx) + 1;
120+
if( newCnt < _emitThreshold )
121+
_cnts.put(idx, newCnt);
122+
return newCnt == _emitThreshold;
123+
}
124+
125+
public void putAndInitCount(Long idx, MatrixBlock block) {
126+
_blocks.put(idx, block);
127+
_cnts.put(idx, 0);
128+
}
129+
130+
public MatrixBlock get(Long idx) {
131+
return _blocks.get(idx);
132+
}
133+
134+
public void remove(Long idx) {
135+
_blocks.remove(idx);
136+
_cnts.remove(idx);
137+
}
138+
}
85139
}

0 commit comments

Comments
 (0)