33#include  " ../template.hpp" 
44#include  " ../edge_cd_asserts.hpp" 
55#include  " ../../../library/trees/edge_cd.hpp" 
6- # include   " ../../../library/math/mod_int.hpp " 
6+ const   int  mod =  998244353 ; 
77int  main () {
88  cin.tie (0 )->sync_with_stdio (0 );
99  int  n;
1010  cin >> n;
1111  vector<int > a (n);
12-   vector<mint > res (n);
12+   vector<int > res (n);
1313  for  (int  i = 0 ; i < n; i++) {
1414    cin >> a[i];
1515    res[i] = a[i];
1616  }
1717  vector<vi> adj (n);
18-   vector<mint > b (n - 1 ), c (n - 1 );
18+   vector<int > b (n - 1 ), c (n - 1 );
1919  vector<pair<int , int >> par (n, {-1 , -1 });
2020  vector<vector<int >> base_adj (n);
2121  {
2222    vector<vector<pair<int , int >>> adj_with_id (n);
2323    for  (int  i = 0 ; i < n - 1 ; i++) {
2424      int  u, v;
25-       cin >> u >> v >> b[i]. x  >> c[i]. x ;
25+       cin >> u >> v >> b[i] >> c[i];
2626      adj[u].push_back (v);
2727      adj[v].push_back (u);
2828      base_adj[u].push_back (v);
2929      base_adj[v].push_back (u);
3030      adj_with_id[u].emplace_back (v, i);
3131      adj_with_id[v].emplace_back (u, i);
32-       res[u] = res[u] + b[i] * a[v] + c[i];
33-       res[v] = res[v] + b[i] * a[u] + c[i];
32+       res[u] = ( res[u] + 1LL  *  b[i] * a[v] + c[i]) % mod ;
33+       res[v] = ( res[v] + 1LL  *  b[i] * a[u] + c[i]) % mod ;
3434    }
3535    auto  dfs = [&](auto && self, int  u) -> void  {
3636      for  (auto  [v, e_id] : adj_with_id[u])
@@ -49,30 +49,35 @@ int main() {
4949  edge_cd (adj,
5050    [&](const  vector<vi>& cd_adj, int  cent,
5151      int  split) -> void  {
52-       array<vector<array<mint , 3 >>, 2 > all_backwards;
53-       array<mint , 2 > sum_forward = {0 , 0 };
52+       array<vector<array<int , 3 >>, 2 > all_backwards;
53+       array<int , 2 > sum_forward = {0 , 0 };
5454      array<int , 2 > cnt_nodes = {0 , 0 };
5555      auto  dfs = [&](auto && self, int  u, int  p,
56-                    array<mint , 2 > forwards,
57-                    array<mint , 2 > backwards,
56+                    array<int , 2 > forwards,
57+                    array<int , 2 > backwards,
5858                   int  side) -> void  {
5959        all_backwards[side].push_back (
6060          {u, backwards[0 ], backwards[1 ]});
61-         sum_forward[side] = sum_forward[side] +
62-           forwards[0 ] * a[u] + forwards[1 ];
61+         sum_forward[side] =
62+           (sum_forward[side] + 1LL  * forwards[0 ] * a[u] +
63+             forwards[1 ]) %
64+           mod;
6365        cnt_nodes[side]++;
6466        for  (int  v : cd_adj[u]) {
6567          if  (v == p) continue ;
6668          int  e_id = edge_id (u, v);
6769          //  f(x) = ax+b
6870          //  g(x) = cx+d
6971          //  f(g(x)) = a(cx+d)+b = acx+ad+b
70-           array<mint, 2 > curr_forw = {
71-             forwards[0 ] * b[e_id],
72-             forwards[0 ] * c[e_id] + forwards[1 ]};
73-           array<mint, 2 > curr_backw = {
74-             backwards[0 ] * b[e_id],
75-             backwards[1 ] * b[e_id] + c[e_id]};
72+           array<int , 2 > curr_forw = {
73+             int (1LL  * forwards[0 ] * b[e_id] % mod),
74+             int (
75+               (1LL  * forwards[0 ] * c[e_id] + forwards[1 ]) %
76+               mod)};
77+           array<int , 2 > curr_backw = {
78+             int (1LL  * backwards[0 ] * b[e_id] % mod),
79+             int ((1LL  * backwards[1 ] * b[e_id] + c[e_id]) %
80+               mod)};
7681          self (self, v, u, curr_forw, curr_backw, side);
7782        }
7883      };
@@ -84,13 +89,14 @@ int main() {
8489      for  (int  side = 0 ; side < 2 ; side++) {
8590        for  (
8691          auto  [u, curr_b, curr_c] : all_backwards[side]) {
87-           res[u.x ] = res[u.x ] +
88-             curr_b * sum_forward[!side] +
89-             curr_c * cnt_nodes[!side];
92+           res[u] =
93+             (res[u] + 1LL  * curr_b * sum_forward[!side] +
94+               1LL  * curr_c * cnt_nodes[!side]) %
95+             mod;
9096        }
9197      }
9298    });
93-   for  (int  i = 0 ; i < n; i++) cout << res[i]. x  << '  '  ;
99+   for  (int  i = 0 ; i < n; i++) cout << res[i] << '  '  ;
94100  cout << ' \n '  ;
95101  return  0 ;
96102}
0 commit comments