From bc870f3a8f5f45e994675ee538e5c587aa9acf5d Mon Sep 17 00:00:00 2001 From: Adrian Trapletti Date: Sun, 1 Sep 2024 18:33:46 +0200 Subject: [PATCH] Improve setup precondition checks and documentation (#60) * Add error message * Rewrite condition similar to ajc * Fix bug * Improve javadoc; Backward compatible change in the method signature: if A and b are provided, they should be non-null, if they are not provided, they should be null; Add more precondition checks and improve existing * Fix tests * Improve javadocs --- .../java/com/ustermetrics/ecos4j/Model.java | 74 ++++++++++++------- .../com/ustermetrics/ecos4j/ModelTest.java | 6 +- 2 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/main/java/com/ustermetrics/ecos4j/Model.java b/src/main/java/com/ustermetrics/ecos4j/Model.java index 7abc580..6f19222 100644 --- a/src/main/java/com/ustermetrics/ecos4j/Model.java +++ b/src/main/java/com/ustermetrics/ecos4j/Model.java @@ -58,50 +58,74 @@ public static String version() { * exponential cone. * * @param l the dimension of the positive orthant. - * @param q the dimensions of each cone. + * @param q the dimensions of the second-order cones. * @param nExC the number of exponential cones. * @param gpr the sparse G matrix data (Column Compressed Storage CCS). * @param gjc the sparse G matrix column index (CCS). * @param gir the sparse G matrix row index (CCS). * @param c the cost function weights. * @param h the right-hand-side of the cone constraints. - * @param apr the sparse A matrix data (CCS). - * @param ajc the sparse A matrix column index (CCS). - * @param air the sparse A matrix row index (CCS). - * @param b the right-hand-side of the equalities. + * @param apr the (optional) sparse A matrix data (CCS). + * @param ajc the (optional) sparse A matrix column index (CCS). + * @param air the (optional) sparse A matrix row index (CCS). + * @param b the (optional) right-hand-side of the equalities. * @see ECOS */ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, long @NonNull [] gjc, - long @NonNull [] gir, double @NonNull [] c, double @NonNull [] h, double @NonNull [] apr, - long @NonNull [] ajc, long @NonNull [] air, double @NonNull [] b) { + long @NonNull [] gir, double @NonNull [] c, double @NonNull [] h, double[] apr, long[] ajc, + long[] air, double[] b) { checkState(stage == Stage.NEW, "Model must be in stage new"); - checkArgument(apr.length == 0 && ajc.length == 0 && air.length == 0 && b.length == 0 - || apr.length > 0 && ajc.length > 0 && air.length > 0 && b.length > 0); - val nonNegErrMsg = "%s must be non-negative"; - checkArgument(l >= 0, nonNegErrMsg, "l"); - checkArgument(nExC >= 0, nonNegErrMsg, "nExC"); + checkArgument(l >= 0, "dimension of the positive orthant l must be non-negative"); + val nCones = q.length; + checkArgument(nCones == 0 || Arrays.stream(q).allMatch(d -> d > 0), + "second-order cone dimensions q must be empty or each dimension q[i] must be positive"); + checkArgument(nExC >= 0, "number of exponential cones nExC must be non-negative"); + val nnzG = gpr.length; + checkArgument(nnzG > 0, "number of non-zero elements in G (gpr.length) must be positive"); + checkArgument(nnzG == gir.length, + "number of non-zero elements in G (gpr.length) must be equal to the number of elements in the row " + + "index of G (gir.length)"); + val nColsG = gjc.length - 1; + checkArgument(nColsG > 0, "number of columns of G (gjc.length - 1) must be positive"); n = c.length; + checkArgument(n > 0, "number of variables x (c.length) must be positive"); m = h.length; - p = b.length; - val nCones = q.length; - + checkArgument(m > 0, "dimension of all cones (h.length) must be positive"); checkArgument(m == l + Arrays.stream(q).sum() + 3 * nExC, - "Length of h must be equal to the sum of l, q, and 3*nExC"); - checkArgument(n == gjc.length - 1, "Length of c must be equal to the length of gjc minus one"); - checkArgument(ajc.length == 0 || ajc.length == n + 1, - "ajc has zero length or must be equal to the length of c plus one"); + "dimension of all cones (h.length) must be equal to the sum of the positive orthant dimension l, the " + + "second-order cone dimensions q[i], and three times the number of exponential cones 3 * nExC"); + checkArgument(nColsG == n, "number of columns of G (gjc.length - 1) must be equal to the number of variables " + + "x (c.length)"); + + checkArgument(apr != null && ajc != null && air != null && b != null || apr == null && ajc == null && air == null && b == null, + "A (apr, ajc, air) and b must be supplied (all non-null) or omitted (all null) together"); + if (apr != null) { + val nnzA = apr.length; + checkArgument(nnzA > 0, "number of non-zero elements in A (apr.length) must be positive"); + checkArgument(nnzA == air.length, + "number of non-zero elements in A (apr.length) must be equal to the number of elements in the row" + + " index of A (air.length)"); + val nColsA = ajc.length - 1; + checkArgument(nColsA > 0, "number of columns of A (ajc.length - 1) must be positive"); + p = b.length; + checkArgument(p > 0, "number of equalities (b.length) must be positive"); + checkArgument(nColsA == n, "number of columns of A (ajc.length - 1) must be equal to the number of " + + "variables x (c.length)"); + } else { + p = 0; + } val qSeg = arena.allocateFrom(C_LONG_LONG, q); val gprSeg = arena.allocateFrom(C_DOUBLE, gpr); val gjcSeg = arena.allocateFrom(C_LONG_LONG, gjc); val girSeg = arena.allocateFrom(C_LONG_LONG, gir); - val aprSeg = apr.length > 0 ? arena.allocateFrom(C_DOUBLE, apr) : NULL; - val ajcSeg = ajc.length > 0 ? arena.allocateFrom(C_LONG_LONG, ajc) : NULL; - val airSeg = air.length > 0 ? arena.allocateFrom(C_LONG_LONG, air) : NULL; + val aprSeg = apr != null ? arena.allocateFrom(C_DOUBLE, apr) : NULL; + val ajcSeg = ajc != null ? arena.allocateFrom(C_LONG_LONG, ajc) : NULL; + val airSeg = air != null ? arena.allocateFrom(C_LONG_LONG, air) : NULL; val cSeg = arena.allocateFrom(C_DOUBLE, c); val hSeg = arena.allocateFrom(C_DOUBLE, h); - val bSeg = b.length > 0 ? arena.allocateFrom(C_DOUBLE, b) : NULL; + val bSeg = b != null ? arena.allocateFrom(C_DOUBLE, b) : NULL; workSeg = ECOS_setup(n, m, p, l, nCones, qSeg, nExC, gprSeg, gjcSeg, girSeg, aprSeg, ajcSeg, airSeg, cSeg, hSeg, bSeg).reinterpret(pwork.sizeof(), arena, null); @@ -119,7 +143,7 @@ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, * without equality constraint, i.e. {@code apr}, {@code ajc}, {@code air}, and {@code b} are empty arrays. * * @param l the dimension of the positive orthant. - * @param q the dimensions of each cone. + * @param q the dimensions of the second-order cones. * @param nExC the number of exponential cones. * @param gpr the sparse G matrix data (Column Compressed Storage CCS). * @param gjc the sparse G matrix column index (CCS). @@ -129,7 +153,7 @@ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, */ public void setup(long l, long @NonNull [] q, long nExC, double @NonNull [] gpr, long @NonNull [] gjc, long @NonNull [] gir, double @NonNull [] c, double @NonNull [] h) { - setup(l, q, nExC, gpr, gjc, gir, c, h, new double[]{}, new long[]{}, new long[]{}, new double[]{}); + setup(l, q, nExC, gpr, gjc, gir, c, h, null, null, null, null); } /** diff --git a/src/test/java/com/ustermetrics/ecos4j/ModelTest.java b/src/test/java/com/ustermetrics/ecos4j/ModelTest.java index ddb1387..d31db73 100644 --- a/src/test/java/com/ustermetrics/ecos4j/ModelTest.java +++ b/src/test/java/com/ustermetrics/ecos4j/ModelTest.java @@ -33,7 +33,7 @@ static void setup() { -1.}; gjc = new long[]{0, 2, 5, 9, 14, 16}; gir = new long[]{0, 6, 1, 6, 7, 2, 6, 7, 8, 3, 6, 7, 8, 9, 4, 5}; - apr = new double[]{1., 1., 1., 1., 0.}; + apr = new double[]{1., 1., 1., 1.}; ajc = new long[]{0, 1, 2, 3, 4, 4}; air = new long[]{0, 0, 0, 0}; c = new double[]{-0.05, -0.06, -0.08, -0.06, 0.}; @@ -160,7 +160,7 @@ void setupWithInvalidPositiveOrthantDimensionThrowsException() { } }); - assertEquals("l must be non-negative", exception.getMessage()); + assertEquals("dimension of the positive orthant l must be non-negative", exception.getMessage()); } @Test @@ -171,7 +171,7 @@ void setupWithInvalidNumberOfExponentialConesThrowsException() { } }); - assertEquals("nExC must be non-negative", exception.getMessage()); + assertEquals("number of exponential cones nExC must be non-negative", exception.getMessage()); } @Test