Skip to content

Commit 884eee7

Browse files
hmenkemmalcolms
andcommitted
Pass MPI communicator for flexible parallelization
Co-authored-by: Mario Malcolms de Oliveira <[email protected]>
1 parent 707a04f commit 884eee7

File tree

4 files changed

+44
-56
lines changed

4 files changed

+44
-56
lines changed

c++/triqs_tprf/lattice/fourier.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ namespace triqs_tprf {
3434
}
3535

3636
template <typename Gf_type>
37-
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
37+
auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1, mpi::communicator const &c = {}) {
3838

3939
auto _ = all_t{};
4040
// Get rid of structured binding declarations in this file due to issue #11
@@ -48,7 +48,7 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
4848
auto r0 = *rmesh.begin();
4949
auto p = _fourier_plan<0>(gf_const_view(g_wr[_, r0]), gf_view(g_tr[_, r0]));
5050

51-
auto r_arr = mpi_view(rmesh);
51+
auto r_arr = mpi_view(rmesh, c);
5252

5353
#pragma omp parallel for
5454
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
@@ -63,12 +63,12 @@ auto fourier_wr_to_tr_general_target(Gf_type g_wr, int n_tau = -1) {
6363

6464
g_tr[_, r] = g_t;
6565
}
66-
g_tr = mpi::all_reduce(g_tr);
66+
g_tr = mpi::all_reduce(g_tr, c);
6767
return g_tr;
6868
}
6969

7070
template <typename Gf_type>
71-
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
71+
auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1, mpi::communicator const &c = {}) {
7272

7373
auto _ = all_t{};
7474
//auto [tmesh, rmesh] = g_tr.mesh();
@@ -81,7 +81,7 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
8181
auto r0 = *rmesh.begin();
8282
auto p = _fourier_plan<0>(gf_const_view(g_tr[_, r0]), gf_view(g_wr[_, r0]));
8383

84-
auto r_arr = mpi_view(rmesh);
84+
auto r_arr = mpi_view(rmesh, c);
8585

8686
#pragma omp parallel for
8787
for (unsigned int idx = 0; idx < r_arr.size(); idx++) {
@@ -96,12 +96,12 @@ auto fourier_tr_to_wr_general_target(Gf_type g_tr, int n_w = -1) {
9696

9797
g_wr[_, r] = g_w;
9898
}
99-
g_wr = mpi::all_reduce(g_wr);
99+
g_wr = mpi::all_reduce(g_wr, c);
100100
return g_wr;
101101
}
102102

103103
template <typename Gf_type>
104-
auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
104+
auto fourier_wk_to_wr_general_target(Gf_type g_wk, mpi::communicator const &c = {}) {
105105

106106
auto _ = all_t{};
107107

@@ -116,7 +116,7 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
116116
auto w0 = *wmesh.begin();
117117
auto p = _fourier_plan<0>(gf_const_view(g_wk[w0, _]), gf_view(g_wr[w0, _]));
118118

119-
auto w_arr = mpi_view(wmesh);
119+
auto w_arr = mpi_view(wmesh, c);
120120

121121
#pragma omp parallel for
122122
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
@@ -131,12 +131,12 @@ auto fourier_wk_to_wr_general_target(Gf_type g_wk) {
131131

132132
g_wr[w, _] = g_r;
133133
}
134-
g_wr = mpi::all_reduce(g_wr);
134+
g_wr = mpi::all_reduce(g_wr, c);
135135
return g_wr;
136136
}
137137

138138
template <typename Gf_type>
139-
auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
139+
auto fourier_wr_to_wk_general_target(Gf_type g_wr, mpi::communicator const &c = {}) {
140140

141141
auto _ = all_t{};
142142

@@ -150,7 +150,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
150150
auto w0 = *wmesh.begin();
151151
auto p = _fourier_plan<0>(gf_const_view(g_wr[w0, _]), gf_view(g_wk[w0, _]));
152152

153-
auto w_arr = mpi_view(wmesh);
153+
auto w_arr = mpi_view(wmesh, c);
154154

155155
#pragma omp parallel for
156156
for (unsigned int idx = 0; idx < w_arr.size(); idx++) {
@@ -165,7 +165,7 @@ auto fourier_wr_to_wk_general_target(Gf_type g_wr) {
165165

166166
g_wk[w, _] = g_k;
167167
}
168-
g_wk = mpi::all_reduce(g_wk);
168+
g_wk = mpi::all_reduce(g_wk, c);
169169
return g_wk;
170170
}
171171

c++/triqs_tprf/lattice/gf.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,21 @@ namespace triqs_tprf {
3838
// ----------------------------------------------------
3939
// g
4040

41-
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh) {
41+
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, gf_mesh<imfreq> mesh, mpi::communicator const &c) {
4242

4343
auto I = nda::eye<ek_vt::scalar_t>(e_k.target_shape()[0]);
4444
g_wk_t g0_wk({mesh, e_k.mesh()}, e_k.target_shape());
4545
g0_wk() = 0.0;
4646

47-
auto arr = mpi_view(g0_wk.mesh());
47+
auto arr = mpi_view(g0_wk.mesh(), c);
4848

4949
#pragma omp parallel for
5050
for (unsigned int idx = 0; idx < arr.size(); idx++) {
5151
auto &[w, k] = arr(idx);
5252
g0_wk[w, k] = inverse((w + mu)*I - e_k(k));
5353
}
5454

55-
g0_wk = mpi::all_reduce(g0_wk);
55+
g0_wk = mpi::all_reduce(g0_wk, c);
5656
return g0_wk;
5757
}
5858

@@ -80,7 +80,7 @@ g_fk_t lattice_dyson_g0_fk(double mu, e_k_cvt e_k, gf_mesh<refreq> mesh, double
8080
// ----------------------------------------------------
8181

8282
template<typename sigma_t>
83-
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
83+
auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma, mpi::communicator const &c){
8484

8585
auto &freqmesh = [&sigma]() -> auto & {
8686
if constexpr (sigma_t::arity == 1) return sigma.mesh();
@@ -93,7 +93,7 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
9393
g_wk_t g_wk({freqmesh, e_k.mesh()}, e_k.target_shape());
9494
g_wk() = 0.0;
9595

96-
auto arr = mpi_view(g_wk.mesh());
96+
auto arr = mpi_view(g_wk.mesh(), c);
9797
#pragma omp parallel for
9898
for (unsigned int idx = 0; idx < arr.size(); idx++) {
9999
auto &[w, k] = arr(idx);
@@ -105,18 +105,18 @@ auto lattice_dyson_g_generic(double mu, e_k_cvt e_k, sigma_t sigma){
105105
g_wk[w, k] = inverse((w + mu)*I - e_k(k) - sigmaterm);
106106
}
107107

108-
g_wk = mpi::all_reduce(g_wk);
108+
g_wk = mpi::all_reduce(g_wk, c);
109109
return g_wk;
110110
}
111111

112112

113-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk) {
114-
return lattice_dyson_g_generic(mu, e_k, sigma_wk);
113+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c) {
114+
return lattice_dyson_g_generic(mu, e_k, sigma_wk, c);
115115
}
116116

117117

118-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
119-
return lattice_dyson_g_generic(mu, e_k, sigma_w);
118+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {
119+
return lattice_dyson_g_generic(mu, e_k, sigma_w, c);
120120
}
121121

122122

@@ -144,45 +144,45 @@ g_fk_t lattice_dyson_g_fk(double mu, e_k_cvt e_k, g_fk_cvt sigma_fk, double delt
144144

145145
// ----------------------------------------------------
146146

147-
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w) {
147+
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c) {
148148

149-
auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w);
149+
auto g_wk = lattice_dyson_g_generic(mu, e_k, sigma_w, c);
150150
auto &[wmesh, kmesh] = g_wk.mesh();
151151

152152
g_w_t g_w(wmesh, e_k.target_shape());
153153
g_w() = 0.0;
154154

155-
for (auto const &[w, k] : mpi_view(g_wk.mesh()))
155+
for (auto const &[w, k] : mpi_view(g_wk.mesh(), c))
156156
g_w[w] += g_wk[w, k];
157157

158-
g_w = mpi::all_reduce(g_w);
158+
g_w = mpi::all_reduce(g_w, c);
159159
g_w /= kmesh.size();
160160
return g_w;
161161
}
162162

163163
// ----------------------------------------------------
164164
// Transformations: real space <-> reciprocal space
165165

166-
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk) {
167-
auto g_wr = fourier_wk_to_wr_general_target(g_wk);
166+
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c) {
167+
auto g_wr = fourier_wk_to_wr_general_target(g_wk, c);
168168
return g_wr;
169169
}
170170

171-
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr) {
172-
auto g_wk = fourier_wr_to_wk_general_target(g_wr);
171+
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c) {
172+
auto g_wk = fourier_wr_to_wk_general_target(g_wr, c);
173173
return g_wk;
174174
}
175175

176176
// ----------------------------------------------------
177177
// Transformations: Matsubara frequency <-> imaginary time
178178

179-
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw) {
180-
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw);
179+
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw, mpi::communicator const &c) {
180+
auto g_wr = fourier_tr_to_wr_general_target(g_tr, nw, c);
181181
return g_wr;
182182
}
183183

184-
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt) {
185-
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt);
184+
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt, mpi::communicator const &c) {
185+
auto g_tr = fourier_wr_to_tr_general_target(g_wr, nt, c);
186186
return g_tr;
187187
}
188188

c++/triqs_tprf/lattice/gf.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace triqs_tprf {
4141
@param mesh imaginary frequency mesh
4242
@return Matsubara frequency lattice Green's function $G^{(0)}_{a\bar{b}}(i\omega_n, \mathbf{k})$
4343
*/
44-
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh);
44+
g_wk_t lattice_dyson_g0_wk(double mu, e_k_cvt e_k, mesh::imfreq mesh, mpi::communicator const &c = {});
4545

4646
/** Construct an interacting Matsubara frequency lattice Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
4747
@@ -61,7 +61,7 @@ namespace triqs_tprf {
6161
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
6262
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
6363
*/
64-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
64+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});
6565

6666
/** Construct a non-interacting real frequency lattice Green's function :math:`G^{(0)}_{a\bar{b}}(\omega, \mathbf{k})`
6767
@@ -101,7 +101,7 @@ namespace triqs_tprf {
101101
@param sigma_wk imaginary frequency self-energy $\Sigma_{\bar{a}b}(i\omega_n, \mathbf{k})$
102102
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
103103
*/
104-
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk);
104+
g_wk_t lattice_dyson_g_wk(double mu, e_k_cvt e_k, g_wk_cvt sigma_wk, mpi::communicator const &c = {});
105105

106106
/** Construct an interacting real frequency lattice Green's function :math:`G_{a\bar{b}}(\omega, \mathbf{k})`
107107
@@ -142,7 +142,7 @@ namespace triqs_tprf {
142142
@param sigma_w imaginary frequency self-energy :math:`\Sigma_{\bar{a}b}(i\omega_n)`
143143
@return Matsubara frequency lattice Green's function $G_{a\bar{b}}(i\omega_n, \mathbf{k})$
144144
*/
145-
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w);
145+
g_w_t lattice_dyson_g_w(double mu, e_k_cvt e_k, g_w_cvt sigma_w, mpi::communicator const &c = {});
146146

147147
/** Inverse fast fourier transform of imaginary frequency Green's function from k-space to real space
148148
@@ -151,7 +151,7 @@ namespace triqs_tprf {
151151
@param g_wk k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
152152
@return real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
153153
*/
154-
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk);
154+
g_wr_t fourier_wk_to_wr(g_wk_cvt g_wk, mpi::communicator const &c = {});
155155

156156
/** Fast fourier transform of imaginary frequency Green's function from real-space to k-space
157157
@@ -160,7 +160,7 @@ namespace triqs_tprf {
160160
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
161161
@return k-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{k})`
162162
*/
163-
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr);
163+
g_wk_t fourier_wr_to_wk(g_wr_cvt g_wr, mpi::communicator const &c = {});
164164

165165
/** Fast fourier transform of real-space Green's function from Matsubara frequency to imaginary time
166166
@@ -169,7 +169,7 @@ namespace triqs_tprf {
169169
@param g_wr real-space imaginary frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
170170
@return real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
171171
*/
172-
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt = -1);
172+
g_tr_t fourier_wr_to_tr(g_wr_cvt g_wr, int nt=-1, mpi::communicator const &c = {});
173173

174174
/** Fast fourier transform of real-space Green's function from imaginary time to Matsubara frequency
175175
@@ -178,6 +178,6 @@ namespace triqs_tprf {
178178
@param g_tr real-space imaginary time Green's function :math:`G_{a\bar{b}}(\tau, \mathbf{r})`
179179
@return real-space Matsubara frequency Green's function :math:`G_{a\bar{b}}(i\omega_n, \mathbf{r})`
180180
*/
181-
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw = -1);
181+
g_wr_t fourier_tr_to_wr(g_tr_cvt g_tr, int nw=-1, mpi::communicator const &c = {});
182182

183183
} // namespace triqs_tprf

c++/triqs_tprf/mpi.hpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
namespace triqs_tprf {
3030

3131
template<class T>
32-
auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
32+
auto mpi_view(const array<T, 1> &arr, mpi::communicator const &c = {}) {
3333

3434
auto slice = itertools::chunk_range(0, arr.shape()[0], c.size(), c.rank());
3535

@@ -42,13 +42,7 @@ auto mpi_view(const array<T, 1> &arr, mpi::communicator const & c) {
4242
}
4343

4444
template<class T>
45-
auto mpi_view(const array<T, 1> &arr) {
46-
mpi::communicator c;
47-
return mpi_view(arr, c);
48-
}
49-
50-
template<class T>
51-
auto mpi_view(const T &mesh, mpi::communicator const & c) {
45+
auto mpi_view(const T &mesh, mpi::communicator const &c = {}) {
5246

5347
auto slice = itertools::chunk_range(0, mesh.size(), c.size(), c.rank());
5448
int size = slice.second - slice.first;
@@ -78,11 +72,5 @@ auto mpi_view(const T &mesh, mpi::communicator const & c) {
7872

7973
return arr;
8074
}
81-
82-
template<class T>
83-
auto mpi_view(const T &mesh) {
84-
mpi::communicator c;
85-
return mpi_view(mesh, c);
86-
}
8775

8876
} // namespace triqs_tprf

0 commit comments

Comments
 (0)