Skip to content

Commit 694a733

Browse files
hmenkemmalcolms
andcommitted
Pass MPI communicator for flexible parallelization
Co-authored-by: Mario Malcolms de Oliveira <[email protected]>
1 parent 3c5e366 commit 694a733

File tree

4 files changed

+44
-56
lines changed

4 files changed

+44
-56
lines changed

c++/triqs_tprf/lattice/fourier.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace triqs_tprf {
3434
}
3535

3636
template <typename Gf_type>
37-
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
37+
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1, mpi::communicator const &c = {}) {
3838

3939
auto _ = all_t{};
4040
// Get rid of structured binding declarations in this file due to issue #11
@@ -48,7 +48,7 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
4848
auto r0 = *rmesh.begin();
4949
auto p = _fourier_plan<0>(gf_const_view(g_wr[_, r0]), gf_view(g_tr[_, r0]));
5050

51-
auto r_arr = mpi_view(rmesh);
51+
auto r_arr = mpi_view(rmesh, c);
5252

5353
#pragma omp parallel for
5454
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
@@ -63,12 +63,12 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
6363

6464
g_tr[_, r] = g_t;
6565
}
66-
g_tr = mpi::all_reduce(g_tr);
66+
g_tr = mpi::all_reduce(g_tr, c);
6767
return g_tr;
6868
}
6969

7070
template <typename Gf_type>
71-
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
71+
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1, mpi::communicator const &c = {}) {
7272

7373
auto _ = all_t{};
7474
//auto [tmesh, rmesh] = g_tr.mesh();
@@ -81,7 +81,7 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
8181
auto r0 = *rmesh.begin();
8282
auto p = _fourier_plan<0>(gf_const_view(g_tr[_, r0]), gf_view(g_wr[_, r0]));
8383

84-
auto r_arr = mpi_view(rmesh);
84+
auto r_arr = mpi_view(rmesh, c);
8585

8686
#pragma omp parallel for
8787
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
@@ -96,12 +96,12 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
9696

9797
g_wr[_, r] = g_w;
9898
}
99-
g_wr = mpi::all_reduce(g_wr);
99+
g_wr = mpi::all_reduce(g_wr, c);
100100
return g_wr;
101101
}
102102

103103
template <typename Gf_type>
104-
auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
104+
auto fourier_wk_to_wr_general_target(Gf_type g_wk, mpi::communicator const &c = {}) {
105105

106106
auto _ = all_t{};
107107

@@ -116,7 +116,7 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
116116
auto w0 = *wmesh.begin();
117117
auto p = _fourier_plan<0>(gf_const_view(g_wk[w0, _]), gf_view(g_wr[w0, _]));
118118

119-
auto w_arr = mpi_view(wmesh);
119+
auto w_arr = mpi_view(wmesh, c);
120120

121121
#pragma omp parallel for
122122
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
@@ -131,12 +131,12 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
131131

132132
g_wr[w, _] = g_r;
133133
}
134-
g_wr = mpi::all_reduce(g_wr);
134+
g_wr = mpi::all_reduce(g_wr, c);
135135
return g_wr;
136136
}
137137

138138
template <typename Gf_type>
139-
auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
139+
auto fourier_wr_to_wk_general_target(Gf_type g_wr, mpi::communicator const &c = {}) {
140140

141141
auto _ = all_t{};
142142

@@ -150,7 +150,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
150150
auto w0 = *wmesh.begin();
151151
auto p = _fourier_plan<0>(gf_const_view(g_wr[w0, _]), gf_view(g_wk[w0, _]));
152152

153-
auto w_arr = mpi_view(wmesh);
153+
auto w_arr = mpi_view(wmesh, c);
154154

155155
#pragma omp parallel for
156156
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
@@ -165,7 +165,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
165165

166166
g_wk[w, _] = g_k;
167167
}
168-
g_wk = mpi::all_reduce(g_wk);
168+
g_wk = mpi::all_reduce(g_wk, c);
169169
return g_wk;
170170
}
171171

c++/triqs_tprf/lattice/gf.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,28 @@ namespace triqs_tprf {
3838
// ----------------------------------------------------
3939
// g
4040

41-
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh) {
41+
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh, mpi::communicator const &c) {
4242

4343
auto I = nda::eye<ek_vt::scalar_t>(e_k.target_shape()[0]);
4444
g_wk_t g0_wk({mesh, e_k.mesh()}, e_k.target_shape());
4545
g0_wk() = 0.0;
4646

47-
auto arr = mpi_view(g0_wk.mesh());
47+
auto arr = mpi_view(g0_wk.mesh(), c);
4848

4949
#pragma omp parallel for
5050
for (unsigned int idx = 0; idx < arr.size(); idx++) {
5151
auto &[w, k] = arr(idx);
5252
g0_wk[w, k] = inverse((w + mu)*I - e_k(k));
5353
}
5454

55-
g0_wk = mpi::all_reduce(g0_wk);
55+
g0_wk = mpi::all_reduce(g0_wk, c);
5656
return g0_wk;
5757
}
5858

5959
// ----------------------------------------------------
6060

6161
template<typename sigma_t>
62-
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
62+
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma, mpi::communicator const &c){
6363

6464
auto &freqmesh = [&sigma]() -> auto & {
6565
if constexpr (sigma_t::arity == 1) return sigma.mesh();
@@ -72,7 +72,7 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
7272
g_wk_t g_wk({freqmesh, e_k.mesh()}, e_k.target_shape());
7373
g_wk() = 0.0;
7474

75-
auto arr = mpi_view(g_wk.mesh());
75+
auto arr = mpi_view(g_wk.mesh(), c);
7676
#pragma omp parallel for
7777
for (unsigned int idx = 0; idx < arr.size(); idx++) {
7878
auto &[w, k] = arr(idx);
@@ -84,60 +84,60 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
8484
g_wk[w, k] = inverse((w + mu)*I - e_k(k) - sigmaterm);
8585
}
8686

87-
g_wk = mpi::all_reduce(g_wk);
87+
g_wk = mpi::all_reduce(g_wk, c);
8888
return g_wk;
8989
}
9090

9191

92-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk) {
93-
return lattice_dyson_g_generic(mu, e_k, sigma_wk);
92+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c) {
93+
return lattice_dyson_g_generic(mu, e_k, sigma_wk, c);
9494
}
9595

9696

97-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
98-
return lattice_dyson_g_generic(mu, e_k, sigma_w);
97+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {
98+
return lattice_dyson_g_generic(mu, e_k, sigma_w, c);
9999
}
100100

101101

102-
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
102+
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {
103103

104-
auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w);
104+
auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w, c);
105105
auto &[wmesh, kmesh] = g_wk.mesh();
106106

107107
g_w_t g_w(wmesh, e_k.target_shape());
108108
g_w() = 0.0;
109109

110-
for (auto const &[w, k] : mpi_view(g_wk.mesh()))
110+
for (auto const &[w, k] : mpi_view(g_wk.mesh(), c))
111111
g_w[w] += g_wk[w, k];
112112

113-
g_w = mpi::all_reduce(g_w);
113+
g_w = mpi::all_reduce(g_w, c);
114114
g_w /= kmesh.size();
115115
return g_w;
116116
}
117117

118118
// ----------------------------------------------------
119119
// Transformations: real space <-> reciprocal space
120120

121-
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk) {
122-
auto g_wr = fourier_wk_to_wr_general_target(g_wk);
121+
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c) {
122+
auto g_wr = fourier_wk_to_wr_general_target(g_wk, c);
123123
return g_wr;
124124
}
125125

126-
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr) {
127-
auto g_wk = fourier_wr_to_wk_general_target(g_wr);
126+
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c) {
127+
auto g_wk = fourier_wr_to_wk_general_target(g_wr, c);
128128
return g_wk;
129129
}
130130

131131
// ----------------------------------------------------
132132
// Transformations: Matsubara frequency <-> imaginary time
133133

134-
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw) {
135-
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw);
134+
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw, mpi::communicator const &c) {
135+
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw, c);
136136
return g_wr;
137137
}
138138

139-
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt) {
140-
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt);
139+
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt, mpi::communicator const &c) {
140+
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt, c);
141141
return g_tr;
142142
}
143143

c++/triqs_tprf/lattice/gf.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace triqs_tprf {
4141
@param mesh imaginary frequency mesh
4242
@return Matsubara frequency lattice Green's function $G^{(0)}_{a\bar{b}}(i\omega_n, \mathbf{k})$
4343
*/
44-
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh);
44+
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh, mpi::communicator const &c = {});
4545

4646
/** Construct an interacting Matsubara frequency lattice Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
4747
@@ -61,7 +61,7 @@ g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh);
6161
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
6262
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
6363
*/
64-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
64+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});
6565

6666
/** Construct an interacting Matsubara frequency lattice Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
6767
@@ -81,7 +81,7 @@ g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
8181
@param sigma_wk imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n, \mathbf{k})`
8282
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
8383
*/
84-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk);
84+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c = {});
8585

8686
/** Construct an interacting Matsubara frequency local (:math:`\mathbf{r}=\mathbf{0}`) lattice Green's function :math:`G_{a\bar{b}}(i\omega_n)`
8787
@@ -101,7 +101,7 @@ g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk);
101101
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
102102
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
103103
*/
104-
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
104+
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});
105105

106106
/** Inverse fast fourier transform of imaginary frequency Green's function from k-space to real space
107107
@@ -110,7 +110,7 @@ g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
110110
@param g_wk k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
111111
@return real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
112112
*/
113-
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk);
113+
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c = {});
114114

115115
/** Fast fourier transform of imaginary frequency Green's function from real-space to k-space
116116
@@ -119,7 +119,7 @@ g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk);
119119
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
120120
@return k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
121121
*/
122-
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr);
122+
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c = {});
123123

124124
/** Fast fourier transform of real-space Green's function from Matsubara frequency to imaginary time
125125
@@ -128,7 +128,7 @@ g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr);
128128
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
129129
@return real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
130130
*/
131-
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1);
131+
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1, mpi::communicator const &c = {});
132132

133133
/** Fast fourier transform of real-space Green's function from imaginary time to Matsubara frequency
134134
@@ -137,6 +137,6 @@ g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1);
137137
@param g_tr real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
138138
@return real-space Matsubara frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
139139
*/
140-
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw=-1);
140+
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw=-1, mpi::communicator const &c = {});
141141

142142
} // namespace triqs_tprf

c++/triqs_tprf/mpi.hpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
namespace triqs_tprf {
3030

3131
template<class T>
32-
auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
32+
auto mpi_view(const array<T, 1> &arr, mpi::communicator const &c = {}) {
3333

3434
auto slice = itertools::chunk_range(0, arr.shape()[0], c.size(), c.rank());
3535

@@ -42,13 +42,7 @@ auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
4242
}
4343

4444
template<class T>
45-
auto mpi_view(const array<T, 1> &arr) {
46-
mpi::communicator c;
47-
return mpi_view(arr, c);
48-
}
49-
50-
template<class T>
51-
auto mpi_view(const T &mesh, mpi::communicator const & c) {
45+
auto mpi_view(const T &mesh, mpi::communicator const &c = {}) {
5246

5347
auto slice = itertools::chunk_range(0, mesh.size(), c.size(), c.rank());
5448
int size = slice.second - slice.first;
@@ -78,11 +72,5 @@ auto mpi_view(const T &mesh, mpi::communicator const & c) {
7872

7973
return arr;
8074
}
81-
82-
template<class T>
83-
auto mpi_view(const T &mesh) {
84-
mpi::communicator c;
85-
return mpi_view(mesh, c);
86-
}
8775

8876
} // namespace triqs_tprf

0 commit comments

Comments
 (0)