Skip to content

Commit 25d64d3

Browse files
committed
another fix
1 parent 878dd71 commit 25d64d3

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

library/math/matrix_related/row_reduce.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "../mod_int.hpp"
2+
#include "../mod_division.hpp"
33
//! @code
44
//! auto [rank, det] = row_reduce(mat, cols);
55
//! @endcode
@@ -8,28 +8,31 @@
88
//! affected by row operations
99
//! @time O(n * m * min(cols, n))
1010
//! @space O(1)
11-
pair<int, mint> row_reduce(vector<vector<mint>>& mat,
11+
pair<int, int> row_reduce(vector<vector<int>>& mat,
1212
int cols) {
1313
int n = sz(mat), m = sz(mat[0]), rank = 0;
14-
mint det = 1;
14+
int det = 1;
1515
for (int col = 0; col < cols && rank < n; col++) {
1616
auto it = find_if(rank + all(mat),
17-
[&](auto& v) { return v[col].x; });
17+
[&](auto& v) { return v[col]; });
1818
if (it == end(mat)) {
1919
det = 0;
2020
continue;
2121
}
2222
if (it != begin(mat) + rank) {
23-
det = mint(0) - det;
23+
det = (mod - det) % mod;
2424
iter_swap(begin(mat) + rank, it);
2525
}
26-
det = det * mat[rank][col];
27-
mint a_inv = mint(1) / mat[rank][col];
28-
for (mint& num : mat[rank]) num = num * a_inv;
29-
rep(i, 0, n) if (i != rank && mat[i][col].x != 0) {
30-
mint num = mat[i][col];
26+
det = 1LL * det * mat[rank][col] % mod;
27+
int a_inv = mod_div(1, mat[rank][col]);
28+
for (int& num : mat[rank])
29+
num = 1LL * num * a_inv % mod;
30+
rep(i, 0, n) if (i != rank && mat[i][col] != 0) {
31+
int num = mat[i][col];
3132
rep(j, 0, m) mat[i][j] =
32-
mat[i][j] - mat[rank][j] * num;
33+
((mat[i][j] - 1LL * mat[rank][j] * num) % mod +
34+
mod) %
35+
mod;
3336
}
3437
rank++;
3538
}

library/math/matrix_related/solve_linear_mod.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,21 @@
1111
//! @time O(n * m * min(n, m))
1212
//! @space O(m)
1313
struct solve_linear_mod {
14-
int rank;
15-
mint det;
16-
vector<mint> sol;
17-
solve_linear_mod(vector<vector<mint>>& mat,
18-
const vector<mint>& rhs) {
14+
int rank, det;
15+
vector<int> sol;
16+
solve_linear_mod(vector<vector<int>>& mat,
17+
const vector<int>& rhs) {
1918
int n = sz(mat), m = sz(mat[0]);
2019
rep(i, 0, n) mat[i].push_back(rhs[i]);
2120
tie(rank, det) = row_reduce(mat, m);
2221
if (any_of(rank + all(mat),
23-
[](auto& v) { return v.back().x; })) {
22+
[](vi& v) { return v.back(); })) {
2423
return;
2524
}
2625
sol.resize(m);
2726
int j = 0;
28-
for_each(begin(mat), begin(mat) + rank, [&](auto& v) {
29-
while (v[j].x == 0) j++;
27+
for_each(begin(mat), begin(mat) + rank, [&](vi& v) {
28+
while (!v[j]) j++;
3029
sol[j] = v.back();
3130
});
3231
}

tests/library_checker_aizu_tests/math/solve_linear_mod.test.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,32 @@ int main() {
66
cin.tie(0)->sync_with_stdio(0);
77
int n, m;
88
cin >> n >> m;
9-
vector<vector<mint>> mat(n, vector<mint>(m));
9+
vector<vector<int>> mat(n, vector<int>(m));
1010
for (int i = 0; i < n; i++)
11-
for (int j = 0; j < m; j++) cin >> mat[i][j].x;
12-
vector<mint> b(n);
13-
for (int i = 0; i < n; i++) cin >> b[i].x;
11+
for (int j = 0; j < m; j++) cin >> mat[i][j];
12+
vector<int> b(n);
13+
for (int i = 0; i < n; i++) cin >> b[i];
1414
solve_linear_mod info(mat, b);
1515
assert(info.rank <= min(n, m));
1616
if (empty(info.sol)) {
1717
cout << -1 << '\n';
1818
return 0;
1919
}
2020
cout << m - info.rank << '\n';
21-
for (auto val : info.sol) cout << val.x << " ";
21+
for (int val : info.sol) cout << val << " ";
2222
cout << '\n';
2323
vector<int> pivot(m, -1);
2424
for (int i = 0, j = 0; i < info.rank; i++) {
25-
while (mat[i][j].x == 0) j++;
25+
while (mat[i][j] == 0) j++;
2626
pivot[j] = i;
2727
}
2828
for (int j = 0; j < m; j++)
2929
if (pivot[j] == -1) {
30-
vector<mint> x(m, 0);
31-
x[j] = -1;
32-
assert(0 <= x[j].x && x[j].x < mod);
30+
vector<int> x(m, 0);
31+
x[j] = mod - 1;
3332
for (int k = 0; k < j; k++)
3433
if (pivot[k] != -1) x[k] = mat[pivot[k]][j];
35-
for (int k = 0; k < m; k++) cout << x[k].x << " ";
34+
for (int k = 0; k < m; k++) cout << x[k] << " ";
3635
cout << '\n';
3736
}
3837
return 0;

0 commit comments

Comments
 (0)