From 3fc3c03503d52d52040d47381e69cbb7fa1809e9 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 18:12:33 +0100 Subject: [PATCH 1/6] Improved exception handling for OOC instructions using failure propagation in LocalTaskQueue --- .../java/org/apache/sysds/api/DMLScript.java | 4 + .../controlprogram/parfor/LocalTaskQueue.java | 32 ++++- .../ooc/AggregateUnaryOOCInstruction.java | 122 +++++++++--------- .../ooc/BinaryOOCInstruction.java | 34 ++--- .../ooc/MatrixVectorBinaryOOCInstruction.java | 81 ++++++------ .../instructions/ooc/OOCInstruction.java | 34 +++++ .../ooc/ReblockOOCInstruction.java | 8 +- .../ooc/TransposeOOCInstruction.java | 38 +++--- .../instructions/ooc/UnaryOOCInstruction.java | 35 ++--- .../ooc/OOCExceptionHandlingTest.java | 87 +++++++++++++ .../functions/ooc/OOCExceptionHandling.dml | 28 ++++ 11 files changed, 320 insertions(+), 183 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java create mode 100644 src/test/scripts/functions/ooc/OOCExceptionHandling.dml diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 65805b5c2ed..3a2cf1a16e2 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -69,6 +69,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker; import org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMonitoringServer; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysds.runtime.io.IOUtilFunctions; @@ -443,6 +444,9 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map public static final int MAX_SIZE = 100000; //main memory constraint public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS + private static volatile DMLRuntimeException FAILURE = null; private LinkedList _data = null; - private boolean _closedInput = false; + private boolean _closedInput = false; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); public LocalTaskQueue() @@ -61,11 +63,14 @@ public LocalTaskQueue() public synchronized void enqueueTask( T t ) throws InterruptedException { - while( _data.size() + 1 > MAX_SIZE ) + while( _data.size() + 1 > MAX_SIZE && FAILURE == null ) { LOG.warn("MAX_SIZE of task queue reached."); wait(); //max constraint reached, wait for read } + + if ( FAILURE != null ) + throw FAILURE; _data.addLast( t ); @@ -82,13 +87,16 @@ public synchronized void enqueueTask( T t ) public synchronized T dequeueTask() throws InterruptedException { - while( _data.isEmpty() ) + while( _data.isEmpty() && FAILURE == null ) { if( !_closedInput ) wait(); // wait for writers else return (T)NO_MORE_TASKS; } + + if ( FAILURE != null ) + throw FAILURE; T t = _data.removeFirst(); @@ -111,6 +119,10 @@ public synchronized boolean isProcessed() { return _closedInput && _data.isEmpty(); } + public synchronized void notifyFailure() { + notifyAll(); + } + @Override public synchronized String toString() { @@ -135,4 +147,18 @@ public synchronized String toString() return sb.toString(); } + + public static boolean failGlobally(DMLRuntimeException ex) { + // Only register the first failure + if (FAILURE == null) { + FAILURE = ex; + return true; + } + + return false; + } + + public static void resetFailures() { + FAILURE = null; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index c01fb3fa376..221023b43aa 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -90,74 +90,68 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - long idx = aggun.isRowAggregate() ? - tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); - MatrixBlock ret = aggTracker.get(idx); - if(ret != null) { - MatrixBlock corr = corrs.get(idx); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) - .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if (!aggTracker.putAndIncrementCount(idx, ret)){ - corrs.replace(idx, corr); - continue; - } + + submitOOCTask(() -> { + IndexedMatrixValue tmp = null; + try { + while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + long idx = aggun.isRowAggregate() ? + tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); + MatrixBlock ret = aggTracker.get(idx); + if(ret != null) { + MatrixBlock corr = corrs.get(idx); + + // aggregation + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) + .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); + OperationsOnMatrixValues.incrementalAggregation(ret, + _aop.existsCorrection() ? corr : null, ltmp, _aop, true); + + if (!aggTracker.putAndIncrementCount(idx, ret)){ + corrs.replace(idx, corr); + continue; } - else { - // first block for this idx - init aggregate and correction - // TODO avoid corr block for inplace incremental aggregation - int rows = tmp.getValue().getNumRows(); - int cols = tmp.getValue().getNumColumns(); - int extra = _aop.correction.getNumRemovedRowsColumns(); - ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations( - aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if(emitThreshold > 1){ - aggTracker.putAndIncrementCount(idx, ret); - corrs.put(idx, corr); - continue; - } + } + else { + // first block for this idx - init aggregate and correction + // TODO avoid corr block for inplace incremental aggregation + int rows = tmp.getValue().getNumRows(); + int cols = tmp.getValue().getNumColumns(); + int extra = _aop.correction.getNumRemovedRowsColumns(); + ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); + MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); + + // aggregation + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations( + aggun, new MatrixBlock(), blen, tmp.getIndexes()); + OperationsOnMatrixValues.incrementalAggregation(ret, + _aop.existsCorrection() ? corr : null, ltmp, _aop, true); + + if(emitThreshold > 1){ + aggTracker.putAndIncrementCount(idx, ret); + corrs.put(idx, corr); + continue; } - - // all input blocks for this idx processed - emit aggregated block - ret.dropLastRowsOrColumns(_aop.correction); - MatrixIndexes midx = aggun.isRowAggregate() ? - new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : - new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); - IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); - - qOut.enqueueTask(tmpOut); - // drop intermediate states - aggTracker.remove(idx); - corrs.remove(idx); } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + + // all input blocks for this idx processed - emit aggregated block + ret.dropLastRowsOrColumns(_aop.correction); + MatrixIndexes midx = aggun.isRowAggregate() ? + new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : + new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); + IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); + + qOut.enqueueTask(tmpOut); + // drop intermediate states + aggTracker.remove(idx); + corrs.remove(idx); } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }, q, qOut); } // full aggregation else { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index fe76e60b9eb..248c1206b18 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -70,26 +70,20 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); - qOut.enqueueTask(tmpOut); - } - qOut.closeInput(); + submitOOCTask(() -> { + IndexedMatrixValue tmp = null; + try { + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); + qOut.enqueueTask(tmpOut); } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }); - } - finally { - pool.shutdown(); - } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index ae84e4b5419..953bc9f2430 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -90,56 +90,47 @@ public void processInstruction( ExecutionContext ec ) { BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - // Core logic: background thread - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); - long rowIndex = tmp.getIndexes().getRowIndex(); - long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); - - // Now, call the operation with the correct, specific operator. - MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( - matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); - - // for single column block, no aggregation neeeded - if(emitThreshold == 1) { - qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + submitOOCTask(() -> { + IndexedMatrixValue tmp = null; + try { + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); + long rowIndex = tmp.getIndexes().getRowIndex(); + long colIndex = tmp.getIndexes().getColumnIndex(); + MatrixBlock vectorSlice = partitionedVector.get(colIndex); + + // Now, call the operation with the correct, specific operator. + MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( + matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); + + // for single column block, no aggregation neeeded + if(emitThreshold == 1) { + qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); + } + else { + // aggregation + MatrixBlock currAgg = aggTracker.get(rowIndex); + if (currAgg == null) { + aggTracker.putAndIncrementCount(rowIndex, partialResult); } else { - // aggregation - MatrixBlock currAgg = aggTracker.get(rowIndex); - if (currAgg == null) { - aggTracker.putAndIncrementCount(rowIndex, partialResult); - } - else { - currAgg = currAgg.binaryOperations(plus, partialResult); - if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ - // early block output: emit aggregated block - MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); - aggTracker.remove(rowIndex); - } + currAgg = currAgg.binaryOperations(plus, partialResult); + if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ + // early block output: emit aggregated block + MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); + qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); + aggTracker.remove(rowIndex); } } } } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - finally { - qOut.closeInput(); - } - }); - } catch (Exception e) { - throw new DMLRuntimeException(e); - } - finally { - pool.shutdown(); - } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + qOut.closeInput(); + } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index d55d1ee5948..f9f9e631ce6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -22,12 +22,16 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; import java.util.HashMap; +import java.util.concurrent.ExecutorService; public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); @@ -86,6 +90,36 @@ public void postprocessInstruction(ExecutionContext ec) { ec.maintainLineageDebuggerInfo(this); } + protected void submitOOCTask(Runnable r, LocalTaskQueue... queues) { + ExecutorService pool = CommonThreadPool.get(); + try { + pool.submit(oocTask(r, queues)); + } + catch (Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + pool.shutdown(); + } + } + + private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { + return () -> { + try { + r.run(); + } + catch (Exception ex) { + DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); + + LocalTaskQueue.failGlobally(re); + + for (LocalTaskQueue q : queues) { + q.notifyFailure(); + } + } + }; + } + /** * Tracks blocks and their counts to enable early emission * once all blocks for a given index are processed. diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java index 9a7059be513..06386c5d66c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ReblockOOCInstruction.java @@ -79,13 +79,7 @@ public void processInstruction(ExecutionContext ec) { //create queue, spawn thread for asynchronous reading, and return LocalTaskQueue q = new LocalTaskQueue(); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> readBinaryBlock(q, min.getFileName())); - } - finally { - pool.shutdown(); - } + submitOOCTask(() -> readBinaryBlock(q, min.getFileName()), q); MatrixObject mout = ec.getMatrixObject(output); mout.setStreamHandle(q); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java index 212d0d5c56a..e9cf3616f20 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java @@ -60,30 +60,22 @@ public void processInstruction( ExecutionContext ec ) { LocalTaskQueue qOut = new LocalTaskQueue<>(); ec.getMatrixObject(output).setStreamHandle(qOut); + submitOOCTask(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); + long oldRowIdx = tmp.getIndexes().getRowIndex(); + long oldColIdx = tmp.getIndexes().getColumnIndex(); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); - long oldRowIdx = tmp.getIndexes().getRowIndex(); - long oldColIdx = tmp.getIndexes().getColumnIndex(); - - MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); - qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); - } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); + qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index 13cd5463ed4..51ff3925b93 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -61,27 +61,20 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); - ExecutorService pool = CommonThreadPool.get(); - try { - pool.submit(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().unaryOperations(uop, new MatrixBlock())); - qOut.enqueueTask(tmpOut); - } - qOut.closeInput(); + submitOOCTask(() -> { + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().unaryOperations(uop, new MatrixBlock())); + qOut.enqueueTask(tmpOut); } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - }); - } catch (Exception ex) { - throw new DMLRuntimeException(ex); - } finally { - pool.shutdown(); - } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + }, qIn, qOut); } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java new file mode 100644 index 00000000000..a94faf87449 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java @@ -0,0 +1,87 @@ +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.io.IOException; + +public class OOCExceptionHandlingTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "OOCExceptionHandling"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + OOCExceptionHandlingTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME_2 = "Y"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1000; + private final static int cols = 1000; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void runOOCExceptionHandlingTest1() { + runOOCExceptionHandlingTest(500); + } + + @Test + public void runOOCExceptionHandlingTest2() { + runOOCExceptionHandlingTest(750); + } + + + private void runOOCExceptionHandlingTest(int misalignVals) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, cols, 1, 2, 1, 7); + double[][] B_data = getRandomMatrix(rows, 1, 1, 2, 1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + MatrixBlock B_mb = DataConverter.convertToMatrixBlock(B_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + + // Here, we write two faulty matrices which will only be recognized at runtime + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, cols, misalignVals, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + writer.writeMatrixToHDFS(B_mb, input(INPUT_NAME_2), rows, 1, 1000, B_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, B_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, true, null, -1); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/OOCExceptionHandling.dml b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml new file mode 100644 index 00000000000..6b7dc6038e2 --- /dev/null +++ b/src/test/scripts/functions/ooc/OOCExceptionHandling.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the input matrix as a stream +X = read($1); +b = read($2); + +OOC = colSums(X %*% b); + +write(OOC, $3, format="binary"); From 3cc44c782311160d568719154d2bb73c8739c6fb Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 18:30:38 +0100 Subject: [PATCH 2/6] Improve diff readability by adjusting tabs --- .../ooc/AggregateUnaryOOCInstruction.java | 110 +++++++++--------- .../ooc/BinaryOOCInstruction.java | 24 ++-- .../ooc/MatrixVectorBinaryOOCInstruction.java | 68 +++++------ .../ooc/TransposeOOCInstruction.java | 26 ++--- .../instructions/ooc/UnaryOOCInstruction.java | 24 ++-- 5 files changed, 126 insertions(+), 126 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index 221023b43aa..8c8a64b0225 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -92,65 +92,65 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - long idx = aggun.isRowAggregate() ? - tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); - MatrixBlock ret = aggTracker.get(idx); - if(ret != null) { - MatrixBlock corr = corrs.get(idx); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) - .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if (!aggTracker.putAndIncrementCount(idx, ret)){ - corrs.replace(idx, corr); - continue; + IndexedMatrixValue tmp = null; + try { + while((tmp = q.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + long idx = aggun.isRowAggregate() ? + tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(); + MatrixBlock ret = aggTracker.get(idx); + if(ret != null) { + MatrixBlock corr = corrs.get(idx); + + // aggregation + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()) + .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()); + OperationsOnMatrixValues.incrementalAggregation(ret, + _aop.existsCorrection() ? corr : null, ltmp, _aop, true); + + if (!aggTracker.putAndIncrementCount(idx, ret)){ + corrs.replace(idx, corr); + continue; + } } - } - else { - // first block for this idx - init aggregate and correction - // TODO avoid corr block for inplace incremental aggregation - int rows = tmp.getValue().getNumRows(); - int cols = tmp.getValue().getNumColumns(); - int extra = _aop.correction.getNumRemovedRowsColumns(); - ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); - - // aggregation - MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations( - aggun, new MatrixBlock(), blen, tmp.getIndexes()); - OperationsOnMatrixValues.incrementalAggregation(ret, - _aop.existsCorrection() ? corr : null, ltmp, _aop, true); - - if(emitThreshold > 1){ - aggTracker.putAndIncrementCount(idx, ret); - corrs.put(idx, corr); - continue; + else { + // first block for this idx - init aggregate and correction + // TODO avoid corr block for inplace incremental aggregation + int rows = tmp.getValue().getNumRows(); + int cols = tmp.getValue().getNumColumns(); + int extra = _aop.correction.getNumRemovedRowsColumns(); + ret = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); + MatrixBlock corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new MatrixBlock(1 + extra, cols, false); + + // aggregation + MatrixBlock ltmp = (MatrixBlock) ((MatrixBlock) tmp.getValue()).aggregateUnaryOperations( + aggun, new MatrixBlock(), blen, tmp.getIndexes()); + OperationsOnMatrixValues.incrementalAggregation(ret, + _aop.existsCorrection() ? corr : null, ltmp, _aop, true); + + if(emitThreshold > 1){ + aggTracker.putAndIncrementCount(idx, ret); + corrs.put(idx, corr); + continue; + } } - } - // all input blocks for this idx processed - emit aggregated block - ret.dropLastRowsOrColumns(_aop.correction); - MatrixIndexes midx = aggun.isRowAggregate() ? - new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : - new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); - IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); - - qOut.enqueueTask(tmpOut); - // drop intermediate states - aggTracker.remove(idx); - corrs.remove(idx); + // all input blocks for this idx processed - emit aggregated block + ret.dropLastRowsOrColumns(_aop.correction); + MatrixIndexes midx = aggun.isRowAggregate() ? + new MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : + new MatrixIndexes(1, tmp.getIndexes().getColumnIndex()); + IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret); + + qOut.enqueueTask(tmpOut); + // drop intermediate states + aggTracker.remove(idx); + corrs.remove(idx); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } }, q, qOut); } // full aggregation diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index 248c1206b18..82ad12ae554 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -71,19 +71,19 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); - qOut.enqueueTask(tmpOut); + IndexedMatrixValue tmp = null; + try { + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().scalarOperations(sc_op, new MatrixBlock())); + qOut.enqueueTask(tmpOut); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java index 953bc9f2430..c1d1ed6ace7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java @@ -91,46 +91,46 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); - long rowIndex = tmp.getIndexes().getRowIndex(); - long colIndex = tmp.getIndexes().getColumnIndex(); - MatrixBlock vectorSlice = partitionedVector.get(colIndex); - - // Now, call the operation with the correct, specific operator. - MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( - matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); - - // for single column block, no aggregation neeeded - if(emitThreshold == 1) { - qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); - } - else { - // aggregation - MatrixBlock currAgg = aggTracker.get(rowIndex); - if (currAgg == null) { - aggTracker.putAndIncrementCount(rowIndex, partialResult); + IndexedMatrixValue tmp = null; + try { + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock matrixBlock = (MatrixBlock) tmp.getValue(); + long rowIndex = tmp.getIndexes().getRowIndex(); + long colIndex = tmp.getIndexes().getColumnIndex(); + MatrixBlock vectorSlice = partitionedVector.get(colIndex); + + // Now, call the operation with the correct, specific operator. + MatrixBlock partialResult = matrixBlock.aggregateBinaryOperations( + matrixBlock, vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr); + + // for single column block, no aggregation neeeded + if(emitThreshold == 1) { + qOut.enqueueTask(new IndexedMatrixValue(tmp.getIndexes(), partialResult)); } else { - currAgg = currAgg.binaryOperations(plus, partialResult); - if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ - // early block output: emit aggregated block - MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); - qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); - aggTracker.remove(rowIndex); + // aggregation + MatrixBlock currAgg = aggTracker.get(rowIndex); + if (currAgg == null) { + aggTracker.putAndIncrementCount(rowIndex, partialResult); + } + else { + currAgg = currAgg.binaryOperations(plus, partialResult); + if (aggTracker.putAndIncrementCount(rowIndex, currAgg)){ + // early block output: emit aggregated block + MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L); + qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg)); + aggTracker.remove(rowIndex); + } } } } } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - finally { - qOut.closeInput(); - } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { + qOut.closeInput(); + } }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java index e9cf3616f20..fce5408960e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TransposeOOCInstruction.java @@ -61,21 +61,21 @@ public void processInstruction( ExecutionContext ec ) { ec.getMatrixObject(output).setStreamHandle(qOut); submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); - long oldRowIdx = tmp.getIndexes().getRowIndex(); - long oldColIdx = tmp.getIndexes().getColumnIndex(); + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock inBlock = (MatrixBlock)tmp.getValue(); + long oldRowIdx = tmp.getIndexes().getRowIndex(); + long oldColIdx = tmp.getIndexes().getColumnIndex(); - MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); - qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); + MatrixBlock outBlock = inBlock.reorgOperations((ReorgOperator) _optr, new MatrixBlock(), -1, -1, -1); + qOut.enqueueTask(new IndexedMatrixValue(new MatrixIndexes(oldColIdx, oldRowIdx), outBlock)); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } }, qIn, qOut); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java index 51ff3925b93..63f42f5bf15 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/UnaryOOCInstruction.java @@ -62,19 +62,19 @@ public void processInstruction( ExecutionContext ec ) { submitOOCTask(() -> { - IndexedMatrixValue tmp = null; - try { - while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { - IndexedMatrixValue tmpOut = new IndexedMatrixValue(); - tmpOut.set(tmp.getIndexes(), - tmp.getValue().unaryOperations(uop, new MatrixBlock())); - qOut.enqueueTask(tmpOut); + IndexedMatrixValue tmp = null; + try { + while ((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + IndexedMatrixValue tmpOut = new IndexedMatrixValue(); + tmpOut.set(tmp.getIndexes(), + tmp.getValue().unaryOperations(uop, new MatrixBlock())); + qOut.enqueueTask(tmpOut); + } + qOut.closeInput(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); } - qOut.closeInput(); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } }, qIn, qOut); } } From 0324985bda68478c7df702a3111bdef828c355f6 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 19:34:22 +0100 Subject: [PATCH 3/6] Add license --- .../ooc/OOCExceptionHandlingTest.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java index a94faf87449..3bd32d7eff7 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/OOCExceptionHandlingTest.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.test.functions.ooc; import org.apache.sysds.common.Types; From 88bc2b97127f2c97f115ca2ed161653fb9f41429 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 19:49:08 +0100 Subject: [PATCH 4/6] Scope failures locally --- .../java/org/apache/sysds/api/DMLScript.java | 3 -- .../controlprogram/parfor/LocalTaskQueue.java | 35 +++++++------------ .../instructions/ooc/OOCInstruction.java | 4 +-- 3 files changed, 13 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 3a2cf1a16e2..a51646e0e78 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -444,9 +444,6 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map public static final int MAX_SIZE = 100000; //main memory constraint public static final Object NO_MORE_TASKS = null; //object to signal NO_MORE_TASKS - private static volatile DMLRuntimeException FAILURE = null; private LinkedList _data = null; private boolean _closedInput = false; + private DMLRuntimeException _failure = null; private static final Log LOG = LogFactory.getLog(LocalTaskQueue.class.getName()); public LocalTaskQueue() @@ -63,14 +63,14 @@ public LocalTaskQueue() public synchronized void enqueueTask( T t ) throws InterruptedException { - while( _data.size() + 1 > MAX_SIZE && FAILURE == null ) + while( _data.size() + 1 > MAX_SIZE && _failure == null ) { LOG.warn("MAX_SIZE of task queue reached."); wait(); //max constraint reached, wait for read } - if ( FAILURE != null ) - throw FAILURE; + if ( _failure != null ) + throw _failure; _data.addLast( t ); @@ -87,7 +87,7 @@ public synchronized void enqueueTask( T t ) public synchronized T dequeueTask() throws InterruptedException { - while( _data.isEmpty() && FAILURE == null ) + while( _data.isEmpty() && _failure == null ) { if( !_closedInput ) wait(); // wait for writers @@ -95,8 +95,8 @@ public synchronized T dequeueTask() return (T)NO_MORE_TASKS; } - if ( FAILURE != null ) - throw FAILURE; + if ( _failure != null ) + throw _failure; T t = _data.removeFirst(); @@ -119,8 +119,11 @@ public synchronized boolean isProcessed() { return _closedInput && _data.isEmpty(); } - public synchronized void notifyFailure() { - notifyAll(); + public synchronized void propagateFailure(DMLRuntimeException failure) { + if (_failure == null) { + _failure = failure; + notifyAll(); + } } @Override @@ -147,18 +150,4 @@ public synchronized String toString() return sb.toString(); } - - public static boolean failGlobally(DMLRuntimeException ex) { - // Only register the first failure - if (FAILURE == null) { - FAILURE = ex; - return true; - } - - return false; - } - - public static void resetFailures() { - FAILURE = null; - } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index f9f9e631ce6..139b4f0517d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -111,10 +111,8 @@ private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { catch (Exception ex) { DMLRuntimeException re = ex instanceof DMLRuntimeException ? (DMLRuntimeException) ex : new DMLRuntimeException(ex); - LocalTaskQueue.failGlobally(re); - for (LocalTaskQueue q : queues) { - q.notifyFailure(); + q.propagateFailure(re); } } }; From 2aa51767fba1d5501eaeb68410c0741b3de620a0 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 19:57:05 +0100 Subject: [PATCH 5/6] Remove unused import --- src/main/java/org/apache/sysds/api/DMLScript.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index a51646e0e78..65805b5c2ed 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -69,7 +69,6 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker; import org.apache.sysds.runtime.controlprogram.federated.monitoring.FederatedMonitoringServer; import org.apache.sysds.runtime.controlprogram.federated.monitoring.models.CoordinatorModel; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysds.runtime.io.IOUtilFunctions; From 8eef61534f276e69ea484a1d41466c524b937fd7 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Mon, 27 Oct 2025 20:11:13 +0100 Subject: [PATCH 6/6] Rethrow exception to ensure proper future handling --- .../apache/sysds/runtime/instructions/ooc/OOCInstruction.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 139b4f0517d..0d159492891 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -114,6 +114,9 @@ private Runnable oocTask(Runnable r, LocalTaskQueue... queues) { for (LocalTaskQueue q : queues) { q.propagateFailure(re); } + + // Rethrow to ensure proper future handling + throw re; } }; }