Skip to content

Commit 0e83962

Browse files
committed
[feat] optimize tetrahedron integration for numpy
1 parent c197df3 commit 0e83962

File tree

1 file changed

+223
-126
lines changed

1 file changed

+223
-126
lines changed

python/triqs_dft_tools/converters/plovasp/lintetra.py

Lines changed: 223 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -2,138 +2,235 @@
22

33
_TOL = 1e-8
44

5-
# Ensure strictly positive imaginary part with minimal scale
6-
_regularize = lambda z : 1j * max(float(z), 1.0e-20) / 100.0
5+
def _regularize_scalar(value):
6+
"""Regularize a scalar to avoid numerical underflow or division by zero."""
7+
return max(float(value), 1.0e-20) / 100.0
78

8-
def _F(en, e1, e2, e3, e4):
9-
if abs(e1 - e3) > _TOL and abs(e4 - e2) > _TOL: return (e1 - en) * (en - e2) / ((e1 - e3) * (e4 - e2))
10-
s = _regularize(min(abs(e3 - e1), abs(e4 - e2)))
11-
num = (e1 - en + s) * (en - e2 + s)
9+
def _select(mask, *arrays):
10+
"""Helper to select masked values for multiple arrays efficiently."""
11+
idx = np.nonzero(mask)[0]
12+
return idx, [a[idx] for a in arrays]
13+
14+
# === Auxiliary numerical functions (formerly F, K1, K2) ===
15+
16+
def _stable_fraction_F(e_eval, e1, e2, e3, e4):
17+
"""
18+
Stable evaluation of the F-term used in tetrahedron DOS integration.
19+
20+
Parameters
21+
----------
22+
e_eval : float or ndarray
23+
Evaluation energy.
24+
e1, e2, e3, e4 : ndarray
25+
Corner energies of the tetrahedron.
26+
27+
Returns
28+
-------
29+
ndarray
30+
Evaluated F-term.
31+
"""
32+
e1, e2, e3, e4, e_eval = map(np.asarray, (e1, e2, e3, e4, e_eval))
33+
mask = (np.abs(e1 - e3) > _TOL) & (np.abs(e4 - e2) > _TOL)
34+
safe_val = ((e1 - e_eval) * (e_eval - e2)) / ((e1 - e3) * (e4 - e2))
35+
36+
s_real = np.maximum(np.minimum(np.abs(e3 - e1), np.abs(e4 - e2)), 1.0e-20) / 100.0
37+
s = 1j * s_real
38+
num = (e1 - e_eval + s) * (e_eval - e2 + s)
1239
den = (e1 - e3 + s) * (e4 - e2 + s)
13-
return float(np.real(num / den))
40+
fallback = (num / den).real
41+
return np.where(mask, safe_val, fallback.astype(np.float64))
42+
43+
44+
def _stable_fraction_K1(e_eval, e1, e2):
45+
"""
46+
Stable evaluation of the first K-term used in tetrahedron DOS integration.
47+
48+
Parameters
49+
----------
50+
e_eval : float or ndarray
51+
Evaluation energy.
52+
e1, e2 : ndarray
53+
Corner energies of the tetrahedron.
1454
15-
def _K2(en, e1, e2, e3):
16-
if abs(e1 - e3) > _TOL and abs(e1 - e2) > _TOL: return (en - e1) / ((e2 - e1) * (e3 - e1))
17-
s = _regularize(min(abs(e3 - e1), abs(e1 - e2)))
18-
num = (en - e1 + s)
55+
Returns
56+
-------
57+
ndarray
58+
Evaluated K1-term.
59+
"""
60+
e1, e2, e_eval = map(np.asarray, (e1, e2, e_eval))
61+
mask = np.abs(e1 - e2) > _TOL
62+
safe_val = (e1 - e_eval) / ((e2 - e1) ** 2)
63+
64+
s_real = np.maximum(np.abs(e1 - e2), 1.0e-20) / 100.0
65+
s = 1j * s_real
66+
num = e1 - e_eval + s
67+
den = (e2 - e1 + s) ** 2
68+
fallback = (num / den).real
69+
return np.where(mask, safe_val, fallback.astype(np.float64))
70+
71+
72+
def _stable_fraction_K2(e_eval, e1, e2, e3):
73+
"""
74+
Stable evaluation of the second K-term used in tetrahedron DOS integration.
75+
76+
Parameters
77+
----------
78+
e_eval : float or ndarray
79+
Evaluation energy.
80+
e1, e2, e3 : ndarray
81+
Corner energies of the tetrahedron.
82+
83+
Returns
84+
-------
85+
ndarray
86+
Evaluated K2-term.
87+
"""
88+
e1, e2, e3, e_eval = map(np.asarray, (e1, e2, e3, e_eval))
89+
mask = (np.abs(e1 - e3) > _TOL) & (np.abs(e1 - e2) > _TOL)
90+
safe_val = (e_eval - e1) / ((e2 - e1) * (e3 - e1))
91+
92+
s_real = np.maximum(np.minimum(np.abs(e3 - e1), np.abs(e1 - e2)), 1.0e-20) / 100.0
93+
s = 1j * s_real
94+
num = e_eval - e1 + s
1995
den = (e2 - e1 + s) * (e3 - e1 + s)
20-
return float(np.real(num / den))
21-
22-
def _K1(en, e1, e2):
23-
if abs(e1 - e2) > _TOL: return (e1 - en) / ((e2 - e1) * (e2 - e1))
24-
s = _regularize(abs(e1 - e2))
25-
num = (e1 - en + s)
26-
den = (e2 - e1 + s) * (e2 - e1 + s)
27-
return float(np.real(num / den))
28-
29-
def _dos_reorder(en, e):
30-
# Returns (flag, order, sorted_e)
31-
order = np.argsort(e)
32-
se = e[order].copy()
33-
34-
if (se[0] <= en <= se[3]) and abs(se[3] - se[0]) < _TOL: return 6, order, se
35-
if se[0] <= en <= se[1]: return 1, order, se
36-
if se[1] <= en <= se[2]: return 2, order, se
37-
if se[2] <= en <= se[3]: return 3, order, se
38-
if en < se[0]: return 4, order, se
39-
if se[3] < en: return 5, order, se
40-
41-
return -1, order, se
42-
43-
def _fun_case1(en, e):
44-
e1, e2, e3, e4 = e
45-
ci = np.zeros(4, dtype=float)
46-
ci[0] = _K2(en, e1, e2, e4) * _F(en, e2, e1, e1, e3) \
47-
+ _K2(en, e1, e2, e3) * _F(en, e3, e1, e1, e4) \
48-
+ _K2(en, e1, e3, e4) * _F(en, e4, e1, e1, e2)
49-
ci[1] = -_K1(en, e1, e2) * _F(en, e1, e1, e3, e4)
50-
ci[2] = -_K1(en, e1, e3) * _F(en, e1, e1, e2, e4)
51-
ci[3] = -_K1(en, e1, e4) * _F(en, e1, e1, e2, e3)
52-
return ci
53-
54-
def _fun_case2(en, e):
55-
e1, e2, e3, e4 = e
56-
ci = np.zeros(4, dtype=float)
57-
ci[0] = 0.5 * (_K1(en, e3, e1) * (
58-
_F(en, e3, e2, e2, e4) +
59-
_F(en, e4, e1, e2, e4) +
60-
_F(en, e3, e1, e2, e4)) +
61-
_K1(en, e4, e1) * (
62-
_F(en, e4, e1, e2, e3) +
63-
_F(en, e4, e2, e2, e3) +
64-
_F(en, e3, e1, e2, e3)))
65-
ci[1] = 0.5 * (_K1(en, e3, e2) * (
66-
_F(en, e3, e2, e1, e4) +
67-
_F(en, e4, e2, e1, e4) +
68-
_F(en, e3, e1, e1, e4)) +
69-
_K1(en, e4, e2) * (
70-
_F(en, e3, e2, e1, e3) +
71-
_F(en, e4, e1, e1, e3) +
72-
_F(en, e4, e2, e1, e3)))
73-
ci[2] = 0.5 * (-_K1(en, e2, e3) * (
74-
_F(en, e3, e2, e1, e4) +
75-
_F(en, e4, e2, e1, e4) +
76-
_F(en, e3, e1, e1, e4)) -
77-
_K1(en, e1, e3) * (
78-
_F(en, e3, e2, e2, e4) +
79-
_F(en, e4, e1, e2, e4) +
80-
_F(en, e3, e1, e2, e4)))
81-
ci[3] = 0.5 * (-_K1(en, e2, e4) * (
82-
_F(en, e3, e2, e1, e3) +
83-
_F(en, e4, e1, e1, e3) +
84-
_F(en, e4, e2, e1, e3)) -
85-
_K1(en, e1, e4) * (
86-
_F(en, e4, e1, e2, e3) +
87-
_F(en, e4, e2, e2, e3) +
88-
_F(en, e3, e1, e2, e3)))
89-
return ci
90-
91-
def _fun_case3(en, e):
92-
e1, e2, e3, e4 = e
93-
ci = np.zeros(4, dtype=float)
94-
ci[0] = _K1(en, e4, e1) * _F(en, e4, e4, e2, e3)
95-
ci[1] = _K1(en, e4, e2) * _F(en, e4, e4, e1, e3)
96-
ci[2] = _K1(en, e4, e3) * _F(en, e4, e4, e1, e2)
97-
ci[3] = -_K2(en, e4, e3, e1) * _F(en, e4, e3, e2, e4) \
98-
-_K2(en, e4, e2, e3) * _F(en, e4, e2, e1, e4) \
99-
-_K2(en, e4, e1, e2) * _F(en, e4, e1, e3, e4)
100-
return ci
101-
102-
def _dos_corner_weights(en, e):
103-
flag, order, se = _dos_reorder(en, e)
104-
if flag == 1: ci = _fun_case1(en, se)
105-
elif flag == 2: ci = _fun_case2(en, se)
106-
elif flag == 3: ci = _fun_case3(en, se)
107-
elif flag in (4, 5):
108-
ci = np.zeros(4, dtype=float)
109-
elif flag == 6:
110-
ci = np.full(4, 0.25, dtype=float)
111-
else: raise ValueError("Unexpected flag in tetra reorder")
112-
return flag, order, ci
96+
fallback = (num / den).real
97+
return np.where(mask, safe_val, fallback.astype(np.float64))
98+
99+
100+
# === Main driver ===
113101

114102
def dos_tetra_weights_3d(eigenvalues, energy, k_points):
115103
"""
116-
Pure-Python version of dos_tetra_weights_3d.
117-
Inputs:
118-
- eigenvalues: 1D ndarray, band energies for each k-point (one band)
119-
- energy: float, evaluation energy
120-
- k_points: int ndarray with shape (5, ntet); corners are rows 1..4
121-
Returns:
122-
- cti: float ndarray (4, ntet), corner weights per tetrahedron
104+
Compute tetrahedron corner weights for 3D DOS integration.
105+
106+
This version is fully vectorized in NumPy, operating on all tetrahedra
107+
simultaneously without MPI or Python loops.
108+
109+
Parameters
110+
----------
111+
eigenvalues : (n_kpoints,) array_like of float
112+
Energies at k-points.
113+
energy : float
114+
Target energy for DOS evaluation.
115+
k_points : (5, n_tetra) array_like of int
116+
Tetrahedron connectivity. Only rows [1:5] are used for corner indices.
117+
118+
Returns
119+
-------
120+
corner_weights : (4, n_tetra) ndarray of float
121+
Corner weights for each tetrahedron at the specified energy.
123122
"""
124123
eigk = np.asarray(eigenvalues, dtype=float)
125-
itt = np.asarray(k_points, dtype=np.int64)
126-
if itt.ndim != 2 or itt.shape[0] != 5:
127-
raise ValueError("k_points must have shape (5, ntet)")
128-
ntet = itt.shape[1]
129-
cti = np.zeros((4, ntet), dtype=float)
130-
131-
for it in range(ntet):
132-
# rows 1..4 index the four corners
133-
corners = itt[1:5, it].astype(np.int64)
134-
e = eigk[corners].astype(float).copy()
135-
_, order, ci = _dos_corner_weights(energy, e)
136-
# Map sorted corner weights back to original corner ordering
137-
# order[j] is original corner index 0..3 for sorted position j
138-
cti[order, it] = ci
139-
return cti
124+
tetra = np.asarray(k_points, dtype=np.int64)
125+
if tetra.ndim != 2 or tetra.shape[0] != 5:
126+
raise ValueError("k_points must have shape (5, n_tetra)")
127+
128+
n_tetra = tetra.shape[1]
129+
corners = tetra[1:5, :] # (4, n_tetra)
130+
corner_energies = eigk[corners]
131+
132+
# Sort each tetrahedron's corner energies ascending
133+
order = np.argsort(corner_energies, axis=0)
134+
sorted_energies = np.take_along_axis(corner_energies, order, axis=0)
135+
136+
e1, e2, e3, e4 = sorted_energies
137+
e_eval = float(energy)
138+
139+
# Determine which energy range each tetrahedron falls into
140+
flag_uniform = (e1 <= e_eval) & (e_eval <= e4) & (np.abs(e4 - e1) < _TOL)
141+
flag_case1 = (e1 <= e_eval) & (e_eval <= e2) & (~flag_uniform)
142+
flag_case2 = (e2 <= e_eval) & (e_eval <= e3)
143+
flag_case3 = (e3 <= e_eval) & (e_eval <= e4)
144+
145+
weights_sorted = np.zeros_like(sorted_energies, dtype=float)
146+
idx = lambda mask: np.nonzero(mask)[0]
147+
148+
# === Case 6: uniform tetrahedra (degenerate energies)
149+
if flag_uniform.any():
150+
weights_sorted[:, idx(flag_uniform)] = 0.25
151+
152+
# === Case 1
153+
if flag_case1.any():
154+
i, (ge1, ge2, ge3, ge4) = _select(flag_case1, e1, e2, e3, e4)
155+
ee = e_eval
156+
157+
w0 = (_stable_fraction_K2(ee, ge1, ge2, ge4) * _stable_fraction_F(ee, ge2, ge1, ge1, ge3)
158+
+ _stable_fraction_K2(ee, ge1, ge2, ge3) * _stable_fraction_F(ee, ge3, ge1, ge1, ge4)
159+
+ _stable_fraction_K2(ee, ge1, ge3, ge4) * _stable_fraction_F(ee, ge4, ge1, ge1, ge2))
160+
w1 = -_stable_fraction_K1(ee, ge1, ge2) * _stable_fraction_F(ee, ge1, ge1, ge3, ge4)
161+
w2 = -_stable_fraction_K1(ee, ge1, ge3) * _stable_fraction_F(ee, ge1, ge1, ge2, ge4)
162+
w3 = -_stable_fraction_K1(ee, ge1, ge4) * _stable_fraction_F(ee, ge1, ge1, ge2, ge3)
163+
weights_sorted[:, i] = np.vstack([w0, w1, w2, w3])
164+
165+
# === Case 2
166+
if flag_case2.any():
167+
i, (ge1, ge2, ge3, ge4) = _select(flag_case2, e1, e2, e3, e4)
168+
ee = e_eval
169+
170+
w0 = 0.5 * (
171+
_stable_fraction_K1(ee, ge3, ge1)
172+
* (_stable_fraction_F(ee, ge3, ge2, ge2, ge4)
173+
+ _stable_fraction_F(ee, ge4, ge1, ge2, ge4)
174+
+ _stable_fraction_F(ee, ge3, ge1, ge2, ge4))
175+
+ _stable_fraction_K1(ee, ge4, ge1)
176+
* (_stable_fraction_F(ee, ge4, ge1, ge2, ge3)
177+
+ _stable_fraction_F(ee, ge4, ge2, ge2, ge3)
178+
+ _stable_fraction_F(ee, ge3, ge1, ge2, ge3))
179+
)
180+
181+
w1 = 0.5 * (
182+
_stable_fraction_K1(ee, ge3, ge2)
183+
* (_stable_fraction_F(ee, ge3, ge2, ge1, ge4)
184+
+ _stable_fraction_F(ee, ge4, ge2, ge1, ge4)
185+
+ _stable_fraction_F(ee, ge3, ge1, ge1, ge4))
186+
+ _stable_fraction_K1(ee, ge4, ge2)
187+
* (_stable_fraction_F(ee, ge3, ge2, ge1, ge3)
188+
+ _stable_fraction_F(ee, ge4, ge1, ge1, ge3)
189+
+ _stable_fraction_F(ee, ge4, ge2, ge1, ge3))
190+
)
191+
192+
w2 = 0.5 * (
193+
-_stable_fraction_K1(ee, ge2, ge3)
194+
* (_stable_fraction_F(ee, ge3, ge2, ge1, ge4)
195+
+ _stable_fraction_F(ee, ge4, ge2, ge1, ge4)
196+
+ _stable_fraction_F(ee, ge3, ge1, ge1, ge4))
197+
- _stable_fraction_K1(ee, ge1, ge3)
198+
* (_stable_fraction_F(ee, ge3, ge2, ge2, ge4)
199+
+ _stable_fraction_F(ee, ge4, ge1, ge2, ge4)
200+
+ _stable_fraction_F(ee, ge3, ge1, ge2, ge4))
201+
)
202+
203+
w3 = 0.5 * (
204+
-_stable_fraction_K1(ee, ge2, ge4)
205+
* (_stable_fraction_F(ee, ge3, ge2, ge1, ge3)
206+
+ _stable_fraction_F(ee, ge4, ge1, ge1, ge3)
207+
+ _stable_fraction_F(ee, ge4, ge2, ge1, ge3))
208+
- _stable_fraction_K1(ee, ge1, ge4)
209+
* (_stable_fraction_F(ee, ge4, ge1, ge2, ge3)
210+
+ _stable_fraction_F(ee, ge4, ge2, ge2, ge3)
211+
+ _stable_fraction_F(ee, ge3, ge1, ge2, ge3))
212+
)
213+
214+
weights_sorted[:, i] = np.vstack([w0, w1, w2, w3])
215+
216+
# === Case 3
217+
if flag_case3.any():
218+
i, (ge1, ge2, ge3, ge4) = _select(flag_case3, e1, e2, e3, e4)
219+
ee = e_eval
220+
221+
w0 = _stable_fraction_K1(ee, ge4, ge1) * _stable_fraction_F(ee, ge4, ge4, ge2, ge3)
222+
w1 = _stable_fraction_K1(ee, ge4, ge2) * _stable_fraction_F(ee, ge4, ge4, ge1, ge3)
223+
w2 = _stable_fraction_K1(ee, ge4, ge3) * _stable_fraction_F(ee, ge4, ge4, ge1, ge2)
224+
w3 = (
225+
-_stable_fraction_K2(ee, ge4, ge3, ge1) * _stable_fraction_F(ee, ge4, ge3, ge2, ge4)
226+
- _stable_fraction_K2(ee, ge4, ge2, ge3) * _stable_fraction_F(ee, ge4, ge2, ge1, ge4)
227+
- _stable_fraction_K2(ee, ge4, ge1, ge2) * _stable_fraction_F(ee, ge4, ge1, ge3, ge4)
228+
)
229+
230+
weights_sorted[:, i] = np.vstack([w0, w1, w2, w3])
231+
232+
# === Remap to original corner order
233+
corner_weights = np.zeros_like(corner_energies, dtype=float)
234+
corner_weights[order, np.arange(n_tetra)[None, :]] = weights_sorted
235+
236+
return corner_weights

0 commit comments

Comments
 (0)