Skip to content

Commit e51db27

Browse files
j143mboehm7
authored andcommitted
[SYSTEMDS-3907] New out-of-core TEE operator (resettable stream)
Closes #2321. Closes #2317. Closes #2315.
1 parent 7c504d0 commit e51db27

File tree

15 files changed

+566
-24
lines changed

15 files changed

+566
-24
lines changed

src/main/java/org/apache/sysds/common/InstructionType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,5 @@ public enum InstructionType {
8989
PMM,
9090
MatrixReshape,
9191
Write,
92-
Init,
92+
Init, Tee,
9393
}

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ public enum Opcodes {
220220
READ("read", InstructionType.Variable),
221221
WRITE("write", InstructionType.Variable, InstructionType.Write),
222222
CREATEVAR("createvar", InstructionType.Variable),
223+
TEE("tee", InstructionType.Tee),
223224

224225
//Reorg instruction opcodes
225226
TRANSPOSE("r'", InstructionType.Reorg),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,8 @@ public enum OpOpData {
831831
PERSISTENTREAD, PERSISTENTWRITE,
832832
TRANSIENTREAD, TRANSIENTWRITE,
833833
FUNCTIONOUTPUT,
834-
SQLREAD, FEDERATED;
834+
SQLREAD, FEDERATED,
835+
TEE;
835836

836837
public boolean isTransient() {
837838
return this == TRANSIENTREAD || this == TRANSIENTWRITE;
@@ -856,6 +857,7 @@ public String toString() {
856857
case FUNCTIONOUTPUT: return "FunOut";
857858
case SQLREAD: return Opcodes.SQL.toString();
858859
case FEDERATED: return "Fed";
860+
case TEE: return "Tee";
859861
default: return "Invalid";
860862
}
861863
}

src/main/java/org/apache/sysds/hops/DataOp.java

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.apache.sysds.common.Types.ExecType;
3939
import org.apache.sysds.lops.LopsException;
4040
import org.apache.sysds.lops.Sql;
41+
import org.apache.sysds.lops.Tee;
4142
import org.apache.sysds.parser.DataExpression;
4243
import static org.apache.sysds.parser.DataExpression.FED_RANGES;
4344
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
@@ -73,7 +74,7 @@ public class DataOp extends Hop {
7374
private DataOp() {
7475
//default constructor for clone
7576
}
76-
77+
7778
/**
7879
* READ operation for Matrix w/ dim1, dim2.
7980
* This constructor does not support any expression in parameters
@@ -251,62 +252,63 @@ public Lop constructLops()
251252

252253
ExecType et = optFindExecType();
253254
Lop l = null;
254-
255+
255256
// construct lops for all input parameters
256257
HashMap<String, Lop> inputLops = new HashMap<>();
257258
for (Entry<String, Integer> cur : _paramIndexMap.entrySet()) {
258259
inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops());
259260
}
260261

261262
// Create the lop
262-
switch(_op)
263-
{
263+
switch (_op) {
264264
case TRANSIENTREAD:
265-
l = new Data(_op, null, inputLops, getName(), null,
266-
getDataType(), getValueType(), getFileFormat());
265+
l = new Data(_op, null, inputLops, getName(), null,
266+
getDataType(), getValueType(), getFileFormat());
267267
setOutputDimensions(l);
268268
break;
269-
269+
270270
case PERSISTENTREAD:
271-
l = new Data(_op, null, inputLops, getName(), null,
272-
getDataType(), getValueType(), getFileFormat());
271+
l = new Data(_op, null, inputLops, getName(), null,
272+
getDataType(), getValueType(), getFileFormat());
273273
l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inBlocksize, getNnz(), getUpdateType());
274274
break;
275-
275+
276276
case PERSISTENTWRITE:
277277
case FUNCTIONOUTPUT:
278-
l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null,
279-
getDataType(), getValueType(), getFileFormat());
280-
((Data)l).setExecType(et);
278+
l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null,
279+
getDataType(), getValueType(), getFileFormat());
280+
((Data) l).setExecType(et);
281281
setOutputDimensions(l);
282282
break;
283-
283+
284284
case TRANSIENTWRITE:
285285
l = new Data(_op, getInput().get(0).constructLops(), inputLops, getName(), null,
286-
getDataType(), getValueType(), getFileFormat());
286+
getDataType(), getValueType(), getFileFormat());
287287
setOutputDimensions(l);
288288
break;
289-
289+
290290
case SQLREAD:
291291
l = new Sql(inputLops, getDataType(), getValueType());
292292
break;
293-
293+
294294
case FEDERATED:
295295
l = new Federated(inputLops, getDataType(), getValueType());
296296
break;
297-
297+
298+
case TEE:
299+
l = new Tee(getInput(0).constructLops(), getDataType(), getValueType());
300+
break;
301+
298302
default:
299303
throw new LopsException("Invalid operation type for Data LOP: " + _op);
300304
}
301-
302305
setLineNumbers(l);
303306
setLops(l);
304307

305308
//add reblock/checkpoint lops if necessary
306309
constructAndSetLopsDataFlowProperties();
307310

308311
return getLops();
309-
310312
}
311313

312314
public void setFileFormat(FileFormat ft) {

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
7777
//add static HOP DAG rewrite rules
7878
_dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize
7979
_dagRuleSet.add( new RewriteBlockSizeAndReblock() );
80+
_dagRuleSet.add( new RewriteInjectOOCTee() );
8081
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
8182
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
8283
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.hops.rewrite;
21+
22+
import org.apache.sysds.api.DMLScript;
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.common.Types.OpOpData;
25+
import org.apache.sysds.hops.DataOp;
26+
import org.apache.sysds.hops.Hop;
27+
28+
import java.util.ArrayList;
29+
import java.util.HashMap;
30+
import java.util.HashSet;
31+
import java.util.List;
32+
import java.util.Map;
33+
import java.util.Set;
34+
35+
/**
36+
* This Rewrite rule injects a Tee Operator for specific Out-Of-Core (OOC) patterns
37+
* where a value or an intermediate result is shared twice. Since for OOC we data streams
38+
* can only be consumed once.
39+
*
40+
* <p>
41+
* Pattern identified {@code t(X) %*% X}, where the data {@code X} will be shared by
42+
* {@code t(X)} and {@code %*%} multiplication.
43+
* </p>
44+
*
45+
* The rewrite uses a stable two-pass approach:
46+
* 1. <b>Find candidates (Read-Only):</b> Traverse the entire HOP DAG to identify candidates
47+
* the fit the target pattern.
48+
* 2. <b>Apply Rewrites (Modification):</b> Iterate over the collected candidate and put
49+
* {@code TeeOp}, and safely rewire the graph.
50+
*/
51+
public class RewriteInjectOOCTee extends HopRewriteRule {
52+
53+
private static final Set<Long> rewrittenHops = new HashSet<>();
54+
private static final Map<Long, Hop> handledHop = new HashMap<>();
55+
56+
// Maintain a list of candidates to rewrite in the second pass
57+
private final List<Hop> rewriteCandidates = new ArrayList<>();
58+
59+
/**
60+
* Handle a generic (last-level) hop DAG with multiple roots.
61+
*
62+
* @param roots high-level operator roots
63+
* @param state program rewrite status
64+
* @return list of high-level operators
65+
*/
66+
@Override
67+
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
68+
if (roots == null) {
69+
return null;
70+
}
71+
72+
// Clear candidates for this pass
73+
rewriteCandidates.clear();
74+
75+
// PASS 1: Identify candidates without modifying the graph
76+
for (Hop root : roots) {
77+
root.resetVisitStatus();
78+
findRewriteCandidates(root);
79+
}
80+
81+
// PASS 2: Apply rewrites to identified candidates
82+
for (Hop candidate : rewriteCandidates) {
83+
applyTopDownTeeRewrite(candidate);
84+
}
85+
86+
return roots;
87+
}
88+
89+
/**
90+
* Handle a predicate hop DAG with exactly one root.
91+
*
92+
* @param root high-level operator root
93+
* @param state program rewrite status
94+
* @return high-level operator
95+
*/
96+
@Override
97+
public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
98+
if (root == null) {
99+
return null;
100+
}
101+
102+
// Clear candidates for this pass
103+
rewriteCandidates.clear();
104+
105+
// PASS 1: Identify candidates without modifying the graph
106+
root.resetVisitStatus();
107+
findRewriteCandidates(root);
108+
109+
// PASS 2: Apply rewrites to identified candidates
110+
for (Hop candidate : rewriteCandidates) {
111+
applyTopDownTeeRewrite(candidate);
112+
}
113+
114+
return root;
115+
}
116+
117+
/**
118+
* First pass: Find candidates for rewrite without modifying the graph.
119+
* This method traverses the graph and identifies nodes that need to be
120+
* rewritten based on the transpose-matrix multiply pattern.
121+
*
122+
* @param hop current hop being examined
123+
*/
124+
private void findRewriteCandidates(Hop hop) {
125+
if (hop.isVisited()) {
126+
return;
127+
}
128+
129+
// Mark as visited to avoid processing the same hop multiple times
130+
hop.setVisited(true);
131+
132+
// Recursively traverse the graph (depth-first)
133+
for (Hop input : hop.getInput()) {
134+
findRewriteCandidates(input);
135+
}
136+
137+
// Check if this hop is a candidate for OOC Tee injection
138+
if (DMLScript.USE_OOC
139+
&& hop.getDataType().isMatrix()
140+
&& !HopRewriteUtils.isData(hop, OpOpData.TEE)
141+
&& hop.getParent().size() > 1)
142+
{
143+
rewriteCandidates.add(hop);
144+
}
145+
}
146+
147+
/**
148+
* Second pass: Apply the TeeOp transformation to a candidate hop.
149+
* This safely rewires the graph by creating a TeeOp node and placeholders.
150+
*
151+
* @param sharedInput the hop to be rewritten
152+
*/
153+
private void applyTopDownTeeRewrite(Hop sharedInput) {
154+
// Only process if not already handled
155+
if (handledHop.containsKey(sharedInput.getHopID())) {
156+
return;
157+
}
158+
159+
// Take a defensive copy of consumers before modifying the graph
160+
ArrayList<Hop> consumers = new ArrayList<>(sharedInput.getParent());
161+
162+
// Create the new TeeOp with the original hop as input
163+
DataOp teeOp = new DataOp("tee_out_" + sharedInput.getName(),
164+
sharedInput.getDataType(), sharedInput.getValueType(), Types.OpOpData.TEE, null,
165+
sharedInput.getDim1(), sharedInput.getDim2(), sharedInput.getNnz(), sharedInput.getBlocksize());
166+
HopRewriteUtils.addChildReference(teeOp, sharedInput);
167+
168+
// Rewire the graph: replace original connections with TeeOp outputs
169+
for (Hop consumer : consumers) {
170+
HopRewriteUtils.replaceChildReference(consumer, sharedInput, teeOp);
171+
}
172+
173+
// Record that we've handled this hop
174+
handledHop.put(sharedInput.getHopID(), teeOp);
175+
rewrittenHops.add(sharedInput.getHopID());
176+
}
177+
}

src/main/java/org/apache/sysds/lops/Lop.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ public enum Type {
6363
PlusMult, MinusMult, //CP
6464
SpoofFused, //CP/SP generated fused operator
6565
Sql, //CP sql read
66-
Federated //FED federated read
66+
Federated, //FED federated read
67+
Tee, //OOC Tee operator
6768
}
6869

6970

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.lops;
21+
22+
import org.apache.sysds.common.Types;
23+
import org.apache.sysds.common.Types.DataType;
24+
import org.apache.sysds.common.Types.ValueType;
25+
import org.apache.sysds.runtime.instructions.InstructionUtils;
26+
27+
public class Tee extends Lop {
28+
29+
public static final String OPCODE = "tee";
30+
/**
31+
* Constructor to be invoked by base class.
32+
*
33+
* @param input1 lop type
34+
* @param dt data type of the output
35+
* @param vt value type of the output
36+
*/
37+
public Tee(Lop input1, DataType dt, ValueType vt) {
38+
super(Lop.Type.Tee, dt, vt);
39+
this.addInput(input1);
40+
input1.addOutput(this);
41+
lps.setProperties(inputs, Types.ExecType.OOC);
42+
}
43+
44+
@Override
45+
public String toString() {
46+
return "Operation = Tee";
47+
}
48+
49+
@Override
50+
public String getInstructions(String input1, String output) {
51+
// This method generates the instruction string: OOC°tee°input°output1°output2...
52+
String ret = InstructionUtils.concatOperands(
53+
getExecType().name(), OPCODE,
54+
getInputs().get(0).prepInputOperand(input1),
55+
prepOutputOperand(output)
56+
);
57+
58+
return ret;
59+
}
60+
}

0 commit comments

Comments
 (0)