Skip to content

Commit bb78acf

Browse files
authored
Merge pull request #22 from j-danner/master
fix sanity check for underdetermined systems + add test case
2 parents 40112a1 + f28c78b commit bb78acf

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

m4ri/solve.c

+9-3
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ int mzd_solve_left(mzd_t *A, mzd_t *B, int const cutoff, int const inconsistency
3232
m4ri_die("mzd_solve_left: A ncols (%d) must be smaller than B nrows (%d).\n", A->ncols,
3333
B->nrows);
3434

35-
if (A->nrows != B->nrows)
36-
m4ri_die("mzd_solve_left: A nrows (%d) must be equal to B nrows (%d).\n", A->nrows,
37-
B->nrows);
35+
if (B->nrows != MAX(A->ncols,A->nrows))
36+
m4ri_die("mzd_solve_left: B nrows (%d) must be equal to max of A nrows (%d) and A ncols (%d).\n", B->nrows,
37+
A->nrows, A->ncols);
3838

3939
return _mzd_solve_left(A, B, cutoff, inconsistency_check);
4040
}
@@ -120,6 +120,12 @@ int _mzd_pluq_solve_left(mzd_t const *A, rci_t rank, mzp_t const *P, mzp_t const
120120
}
121121

122122
int _mzd_solve_left(mzd_t *A, mzd_t *B, int const cutoff, int const inconsistency_check) {
123+
if (inconsistency_check && B->nrows > A->nrows) {
124+
mzd_t const *Bpad = mzd_init_window_const(B, A->nrows+1, 0, B->nrows, B->ncols);
125+
if(!mzd_is_zero(Bpad)) return -1;
126+
mzd_free_window((mzd_t *) Bpad);
127+
}
128+
123129
/**
124130
* B is modified in place
125131
* (Bi's in the comments are just modified versions of B)

m4ri/solve.h

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
* \brief Solves A X = B with A and B matrices.
3737
*
3838
* The solution X is stored inplace on B.
39+
* If A->nrows < A->ncols, the matrix A is implicitly padded with zeros to
40+
* match B->nrows.
3941
*
4042
* \param A Input matrix (overwritten).
4143
* \param B Input matrix, being overwritten by the solution matrix X

tests/test_solve.c

+29-7
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,39 @@ int test_solve_left_random(rci_t mA, rci_t nA, rci_t nB, int consistent) {
5959
mzd_t *secret = mzd_init(nA, nB);
6060
mzd_randomize(A);
6161
mzd_randomize(secret);
62-
mzd_t *B;
63-
if (consistent) {
64-
B = mzd_mul(NULL, A, secret, 0);
62+
mzd_t *B;
63+
if(consistent) {
64+
B = mzd_mul(NULL, A, secret, 0);
6565
assert(B->nrows == A->nrows);
6666
assert(B->ncols == nB);
6767
}
6868
else {
69-
B = mzd_init(mB, nB);
69+
B = mzd_init(mA, nB);
7070
mzd_randomize(B);
7171
}
72+
mzd_free(secret);
73+
7274
// copy A & B
7375
mzd_t *Acopy = mzd_copy(NULL, A);
7476
mzd_t *Bcopy = mzd_copy(NULL, B);
77+
78+
// add rows to B, s.t. X fits into B
79+
if(B->nrows < mB) {
80+
mzd_t *padding = mzd_init(nA-mA, nB);
81+
if(!consistent)
82+
mzd_randomize(padding);
83+
mzd_t *B_padded = mzd_stack(NULL, B, padding);
84+
mzd_free(B);
85+
B = B_padded;
86+
mzd_free(padding);
87+
88+
mzd_t const *Bpad = mzd_init_window_const(B, A->nrows+1, 0, B->nrows, B->ncols);
89+
mzd_print(Bpad);
90+
mzd_free_window((mzd_t *) Bpad);
91+
}
92+
assert(B->nrows == mB);
93+
assert(B->ncols == nB);
94+
7595
int consistency = !mzd_solve_left(A, B, 0, 1);
7696

7797
if (consistent && !consistency) {
@@ -83,8 +103,8 @@ int test_solve_left_random(rci_t mA, rci_t nA, rci_t nB, int consistent) {
83103
printf("skipped (OK, no solution found)\n");
84104
return 0;
85105
}
86-
// copy B
87-
mzd_t *X = mzd_submatrix(NULL, B, 0, 0, A->ncols, B->ncols);
106+
107+
mzd_t const *X = mzd_init_window_const(B, 0, 0, A->ncols, B->ncols);
88108
mzd_t *B1 = mzd_mul(NULL, Acopy, X, 0);
89109
mzd_t *Z = mzd_add(NULL, Bcopy, B1);
90110

@@ -99,8 +119,8 @@ int test_solve_left_random(rci_t mA, rci_t nA, rci_t nB, int consistent) {
99119
mzd_free(B1);
100120
mzd_free(Z);
101121
mzd_free(A);
122+
mzd_free((mzd_t*) X);
102123
mzd_free(B);
103-
mzd_free(X);
104124
return status;
105125
}
106126

@@ -114,12 +134,14 @@ int main() {
114134
status += test_solve_left_random(1100, 1100, 1000, 1);
115135
status += test_solve_left_random(1000, 1000, 1100, 1);
116136
status += test_solve_left_random(1100, 1000, 1100, 1);
137+
status += test_solve_left_random(1000, 1100, 1000, 1);
117138

118139
status += test_solve_left_random(1100, 1000, 1000, 0);
119140
status += test_solve_left_random(1000, 1000, 1000, 0);
120141
status += test_solve_left_random(1100, 1100, 1000, 0);
121142
status += test_solve_left_random(1000, 1000, 1100, 0);
122143
status += test_solve_left_random(1100, 1000, 1100, 0);
144+
status += test_solve_left_random(1000, 1100, 1000, 0);
123145

124146

125147
for (size_t i = 0; i < 100; i++) {

0 commit comments

Comments
 (0)