Skip to content

Commit 878dd71

Browse files
committed
more fixes
1 parent 7875040 commit 878dd71

File tree

4 files changed

+28
-107
lines changed

4 files changed

+28
-107
lines changed

tests/library_checker_aizu_tests/math/mod_int_gcd_convolution.test.cpp

Lines changed: 0 additions & 28 deletions
This file was deleted.

tests/library_checker_aizu_tests/math/mod_int_n_choose_k.test.cpp

Lines changed: 0 additions & 28 deletions
This file was deleted.

tests/library_checker_aizu_tests/math/mod_int_tetration.test.cpp

Lines changed: 0 additions & 29 deletions
This file was deleted.

tests/library_checker_aizu_tests/trees/edge_cd_reroot_dp.test.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,34 @@
33
#include "../template.hpp"
44
#include "../edge_cd_asserts.hpp"
55
#include "../../../library/trees/edge_cd.hpp"
6-
#include "../../../library/math/mod_int.hpp"
6+
const int mod = 998244353;
77
int main() {
88
cin.tie(0)->sync_with_stdio(0);
99
int n;
1010
cin >> n;
1111
vector<int> a(n);
12-
vector<mint> res(n);
12+
vector<int> res(n);
1313
for (int i = 0; i < n; i++) {
1414
cin >> a[i];
1515
res[i] = a[i];
1616
}
1717
vector<vi> adj(n);
18-
vector<mint> b(n - 1), c(n - 1);
18+
vector<int> b(n - 1), c(n - 1);
1919
vector<pair<int, int>> par(n, {-1, -1});
2020
vector<vector<int>> base_adj(n);
2121
{
2222
vector<vector<pair<int, int>>> adj_with_id(n);
2323
for (int i = 0; i < n - 1; i++) {
2424
int u, v;
25-
cin >> u >> v >> b[i].x >> c[i].x;
25+
cin >> u >> v >> b[i] >> c[i];
2626
adj[u].push_back(v);
2727
adj[v].push_back(u);
2828
base_adj[u].push_back(v);
2929
base_adj[v].push_back(u);
3030
adj_with_id[u].emplace_back(v, i);
3131
adj_with_id[v].emplace_back(u, i);
32-
res[u] = res[u] + b[i] * a[v] + c[i];
33-
res[v] = res[v] + b[i] * a[u] + c[i];
32+
res[u] = (res[u] + 1LL * b[i] * a[v] + c[i]) % mod;
33+
res[v] = (res[v] + 1LL * b[i] * a[u] + c[i]) % mod;
3434
}
3535
auto dfs = [&](auto&& self, int u) -> void {
3636
for (auto [v, e_id] : adj_with_id[u])
@@ -49,30 +49,35 @@ int main() {
4949
edge_cd(adj,
5050
[&](const vector<vi>& cd_adj, int cent,
5151
int split) -> void {
52-
array<vector<array<mint, 3>>, 2> all_backwards;
53-
array<mint, 2> sum_forward = {0, 0};
52+
array<vector<array<int, 3>>, 2> all_backwards;
53+
array<int, 2> sum_forward = {0, 0};
5454
array<int, 2> cnt_nodes = {0, 0};
5555
auto dfs = [&](auto&& self, int u, int p,
56-
array<mint, 2> forwards,
57-
array<mint, 2> backwards,
56+
array<int, 2> forwards,
57+
array<int, 2> backwards,
5858
int side) -> void {
5959
all_backwards[side].push_back(
6060
{u, backwards[0], backwards[1]});
61-
sum_forward[side] = sum_forward[side] +
62-
forwards[0] * a[u] + forwards[1];
61+
sum_forward[side] =
62+
(sum_forward[side] + 1LL * forwards[0] * a[u] +
63+
forwards[1]) %
64+
mod;
6365
cnt_nodes[side]++;
6466
for (int v : cd_adj[u]) {
6567
if (v == p) continue;
6668
int e_id = edge_id(u, v);
6769
// f(x) = ax+b
6870
// g(x) = cx+d
6971
// f(g(x)) = a(cx+d)+b = acx+ad+b
70-
array<mint, 2> curr_forw = {
71-
forwards[0] * b[e_id],
72-
forwards[0] * c[e_id] + forwards[1]};
73-
array<mint, 2> curr_backw = {
74-
backwards[0] * b[e_id],
75-
backwards[1] * b[e_id] + c[e_id]};
72+
array<int, 2> curr_forw = {
73+
int(1LL * forwards[0] * b[e_id] % mod),
74+
int(
75+
(1LL * forwards[0] * c[e_id] + forwards[1]) %
76+
mod)};
77+
array<int, 2> curr_backw = {
78+
int(1LL * backwards[0] * b[e_id] % mod),
79+
int((1LL * backwards[1] * b[e_id] + c[e_id]) %
80+
mod)};
7681
self(self, v, u, curr_forw, curr_backw, side);
7782
}
7883
};
@@ -84,13 +89,14 @@ int main() {
8489
for (int side = 0; side < 2; side++) {
8590
for (
8691
auto [u, curr_b, curr_c] : all_backwards[side]) {
87-
res[u.x] = res[u.x] +
88-
curr_b * sum_forward[!side] +
89-
curr_c * cnt_nodes[!side];
92+
res[u] =
93+
(res[u] + 1LL * curr_b * sum_forward[!side] +
94+
1LL * curr_c * cnt_nodes[!side]) %
95+
mod;
9096
}
9197
}
9298
});
93-
for (int i = 0; i < n; i++) cout << res[i].x << ' ';
99+
for (int i = 0; i < n; i++) cout << res[i] << ' ';
94100
cout << '\n';
95101
return 0;
96102
}

0 commit comments

Comments
 (0)