diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java index e1099f715b7..350fc8de3b6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/LocalTaskQueue.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; /** * This class provides a way of dynamic task distribution to multiple workers @@ -43,7 +44,8 @@ public class LocalTaskQueue public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS private LinkedList _data = null; - private boolean _closedInput = false; + private boolean _closedInput = false; + private DMLRuntimeException _failure = null; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); public LocalTaskQueue() @@ -61,11 +63,14 @@ public LocalTaskQueue() public synchronized void enqueueTask( T t ) throws InterruptedException { - while( _data.size() + 1 > MAX_SIZE ) + while( _data.size() + 1 > MAX_SIZE && _failure == null ) { LOG.warn("MAX_SIZE of task queue reached."); wait(); //max constraint reached, wait for read } + + if ( _failure != null ) + throw _failure; _data.addLast( t ); @@ -82,13 +87,16 @@ public synchronized void enqueueTask( T t ) public synchronized T dequeueTask() throws InterruptedException { - while( _data.isEmpty() ) + while( _data.isEmpty() && _failure == null ) { if( !_closedInput ) wait(); // wait for writers else return (T)NO_MORE_TASKS; } + + if ( _failure != null ) + throw _failure; T t = _data.removeFirst(); @@ -111,6 +119,13 @@ public synchronized boolean isProcessed() { return _closedInput && _data.isEmpty(); } + public synchronized void propagateFailure(DMLRuntimeException failure) { + if (_failure == null) { + _failure = failure; + notifyAll(); + } + } + @Override public synchronized String toString() { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index c01fb3fa376..8c8a64b0225 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -90,9 +90,8 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { + + submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -152,12 +151,7 @@ public void processInstruction( ExecutionContext ec ) { catch(Exception ex) { throw new DMLRuntimeException(ex); } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + }, q, qOut); } // full aggregation else { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index fe76e60b9eb..82ad12ae554 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -70,9 +70,7 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { + submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -86,10 +84,6 @@ public void processInstruction( ExecutionContext ec ) { catch(Exception ex) { throw new DMLRuntimeException(ex); } - }); - } - finally { - pool.shutdown(); - } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index ae84e4b5419..c1d1ed6ace7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -90,10 +90,7 @@ public void processInstruction( ExecutionContext ec ) { BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - // Core logic: background thread - pool.submit(() -> { + submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -134,12 +131,6 @@ public void processInstruction( ExecutionContext ec ) { finally { qOut.closeInput(); } - }); - } catch (Exception e) { - throw new DMLRuntimeException(e); - } - finally { - pool.shutdown(); - } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index d55d1ee5948..0d159492891 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -22,12 +22,16 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; import java.util.HashMap; +import java.util.concurrent.ExecutorService; public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); @@ -86,6 +90,37 @@ public void postprocessInstruction(ExecutionContext ec) { ec.maintainLineageDebuggerInfo(this); } + protected void submitOOCTask(Runnable r, LocalTaskQueue... queues) { + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(oocTask(r, queues)); + } + catch (Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + pool.shutdown(); + } + } + + private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { + return () -> { + try { + r.run(); + } + catch (Exception ex) { + DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + + for (LocalTaskQueue q : queues) { + q.propagateFailure(re); + } + + // Rethrow to ensure proper future handling + throw re; + } + }; + } + /** * Tracks blocks and their counts to enable early emission * once all blocks for a given index are processed. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java index 9a7059be513..06386c5d66c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -79,13 +79,7 @@ public void processInstruction(ExecutionContext ec) { //create queue, spawn thread for asynchronous reading, and return LocalTaskQueue q = new LocalTaskQueue(); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> readBinaryBlock(q, min.getFileName())); - } - finally { - pool.shutdown(); - } + submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q); MatrixObject mout = ec.getMatrixObject(output); mout.setStreamHandle(q); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java index 212d0d5c56a..fce5408960e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java @@ -60,10 +60,7 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { + submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -79,11 +76,6 @@ public void processInstruction( ExecutionContext ec ) { catch(Exception ex) { throw new DMLRuntimeException(ex); } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index 13cd5463ed4..63f42f5bf15 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -61,9 +61,7 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { + submitOOCTask(() -> { IndexedMatrixValue tmp = null; try { while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { @@ -77,11 +75,6 @@ public void processInstruction( ExecutionContext ec ) { catch(Exception ex) { throw new DMLRuntimeException(ex); } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + }, qIn, qOut); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java new file mode 100644 index 00000000000..3bd32d7eff7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class OOCExceptionHandlingTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "OOCExceptionHandling"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + OOCExceptionHandlingTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME_2 = "Y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1000; + private final static int cols = 1000; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void runOOCExceptionHandlingTest1() { + runOOCExceptionHandlingTest(500); + } + + @Test + public void runOOCExceptionHandlingTest2() { + runOOCExceptionHandlingTest(750); + } + + + private void runOOCExceptionHandlingTest(int misalignVals) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 1, 2, 1, 7); + double[][] B_data = getRandomMatrix(rows, 1, 1, 2, 1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + MatrixBlock B_mb = DataConverter.convertToMatrixBlock(B_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + + // Here, we write two faulty matrices which will only be recognized at runtime + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, misalignVals, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME_2), rows, 1, 1000, B_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, B_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, true, null, -1); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/OOCExceptionHandling.dml b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml new file mode 100644 index 00000000000..6b7dc6038e2 --- /dev/null +++ b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); +b = read($2); + +OOC = colSums(X %*% b); + +write(OOC, $3, format="binary");