diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 1980dd7984d..f7c2bb88f25 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -88,5 +88,5 @@ public enum InstructionType { PMM, MatrixReshape, Write, - Init, + Init, Tee, } diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 64a6c7dd27e..7e096906c0d 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -219,6 +219,7 @@ public enum Opcodes { READ("read", InstructionType.Variable), WRITE("write", InstructionType.Variable, InstructionType.Write), CREATEVAR("createvar", InstructionType.Variable), + TEE("tee", InstructionType.Tee), //Reorg instruction opcodes TRANSPOSE("r'", InstructionType.Reorg), diff --git a/src/main/java/org/apache/sysds/hops/TeeOp.java b/src/main/java/org/apache/sysds/hops/TeeOp.java new file mode 100644 index 00000000000..113d839e2bb --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/TeeOp.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops; + +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.lops.Lop; +import org.apache.sysds.lops.Tee; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.ArrayList; + + +public class TeeOp extends Hop { + + private final ArrayList _outputs = new ArrayList<>(); + + private TeeOp() { + // default constructor + } + + /** + * Takes in a single Hop input and gives two outputs + * + * @param input + */ + public TeeOp(Hop input) { + super(input.getName(), input.getDataType(), input._valueType); + + // add single input for this hop + getInput().add(0, input); + input.getParent().add(this); + + // output variables list to feed tee output into +// for (Hop out: outputs) { +// _outputs.add(out); +// } + + // This characteristics are same as the input + refreshSizeInformation(); + } + + @Override + public boolean allowsAllExecTypes() { + return false; + } + + /** + * Computes the output matrix characteristics (rows, cols, nnz) based on worst-case output + * and/or input estimates. Should return null if dimensions are unknown. + * + * @param memo memory table + * @return output characteristics + */ + @Override + protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) { + return null; + } + + @Override + public Lop constructLops() { + // return already created Lops + if (getLops() != null) { + return getLops(); + } + + Tee teeLop = new Tee(getInput().get(0).constructLops(), + getDataType(), getValueType()); + setOutputDimensions(teeLop); + setLineNumbers(teeLop); + setLops(teeLop); + + return getLops(); + } + + @Override + protected ExecType optFindExecType(boolean transitive) { + return ExecType.OOC; + } + + @Override + public String getOpString() { + return "tee"; + } + + /** + * In memory-based optimizer mode (see OptimizerUtils.isMemoryBasedOptLevel()), + * the exectype is determined by checking this method as well as memory budget of this Hop. + * Please see findExecTypeByMemEstimate for more detail. + *

+ * This method is necessary because not all operator are supported efficiently + * on GPU (for example: operations on frames and scalar as well as operations such as table). + * + * @return true if the Hop is eligible for GPU Exectype. + */ + @Override + public boolean isGPUEnabled() { + return false; + } + + /** + * Computes the hop-specific output memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Computes the hop-specific intermediate memory estimate in bytes. Should be 0 if not + * applicable. + * + * @param dim1 dimension 1 + * @param dim2 dimension 2 + * @param nnz number of non-zeros + * @return memory estimate + */ + @Override + protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { + return 0; + } + + /** + * Update the output size information for this hop. + */ + @Override + public void refreshSizeInformation() { + Hop input1 = getInput().get(0); + setDim1(input1.getDim1()); + setDim2(input1.getDim2()); + setNnz(input1.getNnz()); + setBlocksize(input1.getBlocksize()); + } + + @Override + public Object clone() throws CloneNotSupportedException { + return null; + } + + @Override + public boolean compare(Hop that) { + return false; + } + + public Hop getOutput(int index) { + return _outputs.get(index); + } +} diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index 874ddae0347..c2602dba510 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -77,6 +77,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) //add static HOP DAG rewrite rules _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); + _dagRuleSet.add( new RewriteInjectOOCTee() ); if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java new file mode 100644 index 00000000000..7a3cd95c85c --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteInjectOOCTee.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewrite; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This Rewrite rule injects a Tee Operator for specific Out-Of-Core (OOC) patterns + * where a value or an intermediate result is shared twice. Since for OOC we data streams + * can only be consumed once. + * + *

+ * Pattern identified {@code t(X) %*% X}, where the data {@code X} will be shared by + * {@code t(X)} and {@code %*%} multiplication. + *

+ * + * The rewrite uses a stable two-pass approach: + * 1. Find candidates (Read-Only): Traverse the entire HOP DAG to identify candidates + * the fit the target pattern. + * 2. Apply Rewrites (Modification): Iterate over the collected candidate and put + * {@code TeeOp}, and safely rewire the graph. + */ +public class RewriteInjectOOCTee extends HopRewriteRule { + + private static final Set rewrittenHops = new HashSet<>(); + private static final Map handledHop = new HashMap<>(); + + // Maintain a list of candidates to rewrite in the second pass + private final List rewriteCandidates = new ArrayList<>(); + + /** + * Handle a generic (last-level) hop DAG with multiple roots. + * + * @param roots high-level operator roots + * @param state program rewrite status + * @return list of high-level operators + */ + @Override + public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { + if (roots == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + for (Hop root : roots) { + root.resetVisitStatus(); + findRewriteCandidates(root); + } + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return roots; + } + + /** + * Handle a predicate hop DAG with exactly one root. + * + * @param root high-level operator root + * @param state program rewrite status + * @return high-level operator + */ + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + if (root == null) { + return null; + } + + // Clear candidates for this pass + rewriteCandidates.clear(); + + // PASS 1: Identify candidates without modifying the graph + root.resetVisitStatus(); + findRewriteCandidates(root); + + // PASS 2: Apply rewrites to identified candidates + for (Hop candidate : rewriteCandidates) { + applyTopDownTeeRewrite(candidate); + } + + return root; + } + + /** + * First pass: Find candidates for rewrite without modifying the graph. + * This method traverses the graph and identifies nodes that need to be + * rewritten based on the transpose-matrix multiply pattern. + * + * @param hop current hop being examined + */ + private void findRewriteCandidates(Hop hop) { + if (hop.isVisited()) { + return; + } + + // Mark as visited to avoid processing the same hop multiple times + hop.setVisited(true); + + // Recursively traverse the graph (depth-first) + for (Hop input : hop.getInput()) { + findRewriteCandidates(input); + } + + // Check if this hop is a candidate for OOC Tee injection + if (isRewriteCandidate(hop)) { + rewriteCandidates.add(hop); + } + } + + /** + * Check if a hop should be considered for rewrite. + * + * @param hop the hop to check + * @return true if the hop meets all criteria for rewrite + */ + private boolean isRewriteCandidate(Hop hop) { + // Skip if already handled + if (rewrittenHops.contains(hop.getHopID()) || handledHop.containsKey(hop.getHopID())) { + return false; + } + + boolean multipleConsumers = hop.getParent().size() > 1; + boolean isNotAlreadyTee = isNotAlreadyTee(hop); + boolean isOOCEnabled = DMLScript.USE_OOC; + boolean isTransposeMM = isTranposePattern(hop); + boolean isMatrix = hop.getDataType() == Types.DataType.MATRIX; + + return isOOCEnabled && multipleConsumers && isNotAlreadyTee && isTransposeMM && isMatrix; + } + + /** + * Second pass: Apply the TeeOp transformation to a candidate hop. + * This safely rewires the graph by creating a TeeOp node and placeholders. + * + * @param sharedInput the hop to be rewritten + */ + private void applyTopDownTeeRewrite(Hop sharedInput) { + // Only process if not already handled + if (handledHop.containsKey(sharedInput.getHopID())) { + return; + } + + // Take a defensive copy of consumers before modifying the graph + ArrayList consumers = new ArrayList<>(sharedInput.getParent()); + + // Create the new TeeOp with the original hop as input + TeeOp teeOp = new TeeOp(sharedInput); + + // Rewire the graph: replace original connections with TeeOp outputs + int i = 0; + for (Hop consumer : consumers) { + Hop placeholder = new DataOp("tee_out_" + sharedInput.getName() + "_" + i, + sharedInput.getDataType(), + sharedInput.getValueType(), + Types.OpOpData.TRANSIENTWRITE, + null, + sharedInput.getDim1(), + sharedInput.getDim2(), + sharedInput.getNnz(), + sharedInput.getBlocksize() + ); + + // Copy metadata + placeholder.setBeginLine(sharedInput.getBeginLine()); + placeholder.setBeginColumn(sharedInput.getBeginColumn()); + placeholder.setEndLine(sharedInput.getEndLine()); + placeholder.setEndColumn(sharedInput.getEndColumn()); + + // Connect placeholder to TeeOp and consumer + HopRewriteUtils.addChildReference(placeholder, teeOp); + HopRewriteUtils.replaceChildReference(consumer, sharedInput, placeholder); + + i++; + } + + // Record that we've handled this hop + handledHop.put(sharedInput.getHopID(), teeOp); + rewrittenHops.add(sharedInput.getHopID()); + } + + private boolean isNotAlreadyTee(Hop hop) { + if (hop.getParent().size() > 1) { + for (Hop consumer : hop.getParent()) { + if (consumer instanceof TeeOp) { + return false; + } + } + } + return true; + } + + private boolean isTranposePattern (Hop hop) { + boolean hasTransposeConsumer = false; // t(X) + boolean hasMatrixMultiplyConsumer = false; // %*% + + for (Hop parent: hop.getParent()) { + String opString = parent.getOpString(); + if (parent instanceof ReorgOp) { + if (opString.contains("r'") || opString.contains("transpose")) { + hasTransposeConsumer = true; + } + } + else if (parent instanceof AggBinaryOp) + if (opString.contains("*") || opString.contains("ba+*")) { + hasMatrixMultiplyConsumer = true; + } + } + return hasTransposeConsumer && hasMatrixMultiplyConsumer; + } +} diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 447201a5fd3..bb0ae2309e7 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -38,7 +38,7 @@ public abstract class Lop { protected static final Log LOG = LogFactory.getLog(Lop.class.getName()); - + public enum Type { Data, DataGen, //CP/MR read/write/datagen ReBlock, CSVReBlock, //MR reblock operations @@ -63,7 +63,8 @@ public enum Type { PlusMult, MinusMult, //CP SpoofFused, //CP/SP generated fused operator Sql, //CP sql read - Federated //FED federated read + Federated, //FED federated read + Tee, //OOC Tee operator } diff --git a/src/main/java/org/apache/sysds/lops/Tee.java b/src/main/java/org/apache/sysds/lops/Tee.java new file mode 100644 index 00000000000..a9ce7ff970b --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/Tee.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; + +public class Tee extends Lop { + + public static final String OPCODE = "tee"; + /** + * Constructor to be invoked by base class. + * + * @param input1 lop type + * @param dt data type of the output + * @param vt value type of the output + */ + public Tee(Lop input1, DataType dt, ValueType vt) { + super(Lop.Type.Tee, dt, vt); + this.addInput(input1); + input1.addOutput(this); + lps.setProperties(inputs, Types.ExecType.OOC); + } + + @Override + public String toString() { + return "Operation = Tee"; + } + + @Override + public String getInstructions(String input1, String outputs) { + + String[] out = outputs.split(Lop.OPERAND_DELIMITOR); + String output2 = outputs + "_copy"; + + // This method generates the instruction string: OOC°tee°input°output1°output2... + String ret = InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(out[0]), + prepOutputOperand(out[1]) + ); + + return ret; + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index b26c539e9a8..99c5bbaff58 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -488,7 +488,7 @@ private void generateControlProgramJobs(List execNodes, markedNodes.add(node); continue; } - + // output scalar instructions and mark nodes for deletion if (!node.isDataExecLocation()) { @@ -548,6 +548,18 @@ else if ( node.getType() == Lop.Type.FunctionCallCP ) outputs[count++] = out.getOutputParameters().getLabel(); inst_string = node.getInstructions(inputs, outputs); } + else if ( node.getType() == Type.Tee ) { + String input = node.getInputs().get(0).getOutputParameters().getLabel(); + + ArrayList outputs = new ArrayList<>(); + for( Lop out : node.getOutputs() ) { + outputs.add(out.getOutputParameters().getLabel()); + } + + String packedOutputs = String.join(Lop.OPERAND_DELIMITOR, outputs); + + inst_string = node.getInstructions(input, packedOutputs); + } else if (node.getType() == Lop.Type.Nary) { String[] inputs = new String[node.getInputs().size()]; int count = 0; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 9b1165b819b..94e1a7ff180 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -29,6 +29,8 @@ import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.UnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.MatrixVectorBinaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.TransposeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -60,6 +62,10 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str case AggregateBinary: case MAPMM: return MatrixVectorBinaryOOCInstruction.parseInstruction(str); + case Reorg: + return TransposeOOCInstruction.parseInstruction(str); + case Tee: + return TeeOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index d3c2dfcbd77..c76fd4e4a4b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -30,7 +30,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, AggregateUnary, Binary, Unary, MAPMM, AggregateBinary + Reblock, AggregateUnary, Binary, Unary, MAPMM, Reorg, AggregateBinary, Tee } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java new file mode 100644 index 00000000000..0248ef78b1b --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.util.CommonThreadPool; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; + +public class TeeOOCInstruction extends ComputationOOCInstruction { + + private final List _outputs; + + protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, CPOperand out2, String opcode, String istr) { + super(type, null, in1, out, opcode, istr); + _outputs = Arrays.asList(out, out2); + } + + public static TeeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 3); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + CPOperand out2 = new CPOperand(parts[3]); + + return new TeeOOCInstruction(OOCType.Tee, in1, out, out2, opcode, str); + } + + public void processInstruction( ExecutionContext ec ) { + + // Create thread and process the tee operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); + +// MatrixObject min = ec.getMatrixObject(input1); +// LocalTaskQueue qIn = min.getStreamHandle(); + List> qOuts = new ArrayList<>(); + for (CPOperand out : _outputs) { + MatrixObject mout = ec.createMatrixObject(min.getDataCharacteristics()); + ec.setVariable(out.getName(), mout); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + mout.setStreamHandle(qOut); + qOuts.add(qOut); + } + + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + + for (int i = 0; i < qOuts.size(); i++) { + qOuts.get(i).enqueueTask(new IndexedMatrixValue(tmp)); + } + } + for (LocalTaskQueue qOut : qOuts) { + qOut.closeInput(); + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java new file mode 100644 index 00000000000..3fe9e2439de --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.functionobjects.SwapIndex; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; + +import java.util.concurrent.ExecutorService; + +public class TransposeOOCInstruction extends ComputationOOCInstruction { + + protected TransposeOOCInstruction(OOCType type, ReorgOperator op, CPOperand in1, CPOperand out, String opcode, String istr) { + super(type, op, in1, out, opcode, istr); + + } + + public static TransposeOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 2); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[2]); + + ReorgOperator reorg = new ReorgOperator(SwapIndex.getSwapIndexFnObject()); + return new TransposeOOCInstruction(OOCType.Reorg, reorg, in1, out, opcode, str); + } + + public void processInstruction( ExecutionContext ec ) { + + // Create thread and process the transpose operation + MatrixObject min = ec.getMatrixObject(input1); + LocalTaskQueue qIn = min.getStreamHandle(); + LocalTaskQueue qOut = new LocalTaskQueue<>(); + ec.getMatrixObject(output).setStreamHandle(qOut); + + + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); + long oldRowIdx = tmp.getIndexes().getRowIndex(); + long oldColIdx = tmp.getIndexes().getColumnIndex(); + + MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); + qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }); + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java new file mode 100644 index 00000000000..16e60288538 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TeeTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class TeeTest extends AutomatedTestBase { + + private static final String TEST_NAME = "Tee"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TeeTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + private final static double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testTeeNoRewrite() { + testTeeOperation(false); + } + + @Test + public void testTeeRewrite() { + testTeeOperation(true); + } + + + public void testTeeOperation(boolean rewrite) + { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + boolean oldRewrite = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + int rows = 1000, cols = 1000; + MatrixBlock mb = MatrixBlock.randOperations(rows, cols, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(mb, input(INPUT_NAME), rows, cols, 1000, rows*cols); + HDFSTool.writeMetaDataFile(input(INPUT_NAME+".mtd"), ValueType.FP64, + new MatrixCharacteristics(rows,cols,1000,rows*cols), FileFormat.BINARY); + + runTest(true, false, null, -1); + + + double[][] C1 = readMatrix(output(OUTPUT_NAME), FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < cols; i++) { // verify the results with Java + for(int j = 0; j < cols; j++) { + double expected = 0.0; + for (int k = 0; k < rows; k++) { + expected += mb.get(k, i) * mb.get(k, j); + } + result = C1[i][j]; + Assert.assertEquals( "value mismatch at cell ("+i+","+j+")",expected, result, eps); + } + } + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for TEE", + heavyHittersContainsString(prefix + "tee")); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrite; + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix( String fname, FileFormat fmt, + long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/TransposeTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/TransposeTest.java new file mode 100644 index 00000000000..4ea1f888ee6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/TransposeTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class TransposeTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "Transpose"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransposeTest.class.getSimpleName() + "/"; + private final static double eps = 1e-10; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1000; + private final static int cols_wide = 1000; + private final static int cols_skinny = 500; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testTranspose1() { + runTransposeTest(cols_wide, false); + } + +// @Test +// public void testTranspose2() { +// runTransposeTest(cols_skinny, false); +// } + + private void runTransposeTest(int cols, boolean sparse ) + { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try + { + getAndLoadTestConfiguration(TEST_NAME1); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + // 1. Generate the data as MatrixBlock object + double[][] A_data = getRandomMatrix(rows, cols, 0, 1, sparse?sparsity2:sparsity1, 10); + + // 2. Convert the double arrays to MatrixBlock object + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + boolean exceptionExpected = false; + runTest(true, exceptionExpected, null, -1); + + double[][] C1 = readMatrix(output(OUTPUT_NAME), Types.FileFormat.BINARY, rows, cols, 1000, 1000); + double result = 0.0; + for(int i = 0; i < rows; i++) { // verify the results with Java + double expected = 0.0; + for(int j = 0; j < cols; j++) { + expected = A_mb.get(i, j); + result = C1[j][i]; + Assert.assertEquals(expected, result, eps); + } + + } + + String prefix = Instruction.OOC_INST_PREFIX; + Assert.assertTrue("OOC wasn't used for RBLK", + heavyHittersContainsString(prefix + Opcodes.RBLK)); + Assert.assertTrue("OOC wasn't used for TRANSPOSE", + heavyHittersContainsString(prefix + Opcodes.TRANSPOSE)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } + + private static double[][] readMatrix(String fname, Types.FileFormat fmt, long rows, long cols, int brows, int bcols ) + throws IOException + { + MatrixBlock mb = DataConverter.readMatrixFromHDFS(fname, fmt, rows, cols, brows, bcols); + double[][] C = DataConverter.convertToDoubleMatrix(mb); + return C; + } +} diff --git a/src/test/scripts/functions/ooc/Tee.dml b/src/test/scripts/functions/ooc/Tee.dml new file mode 100644 index 00000000000..e6faabfc7a1 --- /dev/null +++ b/src/test/scripts/functions/ooc/Tee.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = t(X) %*% X; + +write(res, $2, format="binary"); diff --git a/src/test/scripts/functions/ooc/Transpose.dml b/src/test/scripts/functions/ooc/Transpose.dml new file mode 100644 index 00000000000..9b38939a2e1 --- /dev/null +++ b/src/test/scripts/functions/ooc/Transpose.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); + +res = t(X); + +write(res, $2, format="binary");