2
2
"""
3
3
This is the main module for PyErf.
4
4
"""
5
- # ---------------------------------------------------------------------------
6
- ### Imports
7
- # ---------------------------------------------------------------------------
8
- # Standard Library
5
+
9
6
import math
10
7
11
8
12
- # ---------------------------------------------------------------------------
13
- ### Constants
14
- # ---------------------------------------------------------------------------
15
9
# While some of these are used only in _ndtri, we don't want to
16
10
# calculate them each time a user calls erfinv. So we define them at the
17
11
# module level and they'll only be calculated once.
23
17
try :
24
18
from math import inf
25
19
except ImportError :
26
- inf = float (' inf' )
20
+ inf = float (" inf" )
27
21
28
22
29
- #: Inputs above this value are considered infinity.
23
+ # Inputs above this value are considered infinity.
30
24
MAXVAL = 1e50
31
25
32
26
33
- # ---------------------------------------------------------------------------
34
- ### Functions
35
- # ---------------------------------------------------------------------------
36
27
def _erf (x ):
37
28
"""
38
29
Port of cephes ``ndtr.c`` ``erf`` function.
39
30
40
31
See https://github.com/jeremybarnes/cephes/blob/master/cprob/ndtr.c
41
32
"""
42
33
T = [
43
- 9.60497373987051638749E0 ,
44
- 9.00260197203842689217E1 ,
45
- 2.23200534594684319226E3 ,
46
- 7.00332514112805075473E3 ,
47
- 5.55923013010394962768E4 ,
34
+ 9.60497373987051638749e0 ,
35
+ 9.00260197203842689217e1 ,
36
+ 2.23200534594684319226e3 ,
37
+ 7.00332514112805075473e3 ,
38
+ 5.55923013010394962768e4 ,
48
39
]
49
40
50
41
U = [
51
- 3.35617141647503099647E1 ,
52
- 5.21357949780152679795E2 ,
53
- 4.59432382970980127987E3 ,
54
- 2.26290000613890934246E4 ,
55
- 4.92673942608635921086E4 ,
42
+ 3.35617141647503099647e1 ,
43
+ 5.21357949780152679795e2 ,
44
+ 4.59432382970980127987e3 ,
45
+ 2.26290000613890934246e4 ,
46
+ 4.92673942608635921086e4 ,
56
47
]
57
48
58
49
# Shorcut special cases
@@ -78,45 +69,45 @@ def _erfc(a):
78
69
"""
79
70
# approximation for abs(a) < 8 and abs(a) >= 1
80
71
P = [
81
- 2.46196981473530512524E -10 ,
82
- 5.64189564831068821977E -1 ,
83
- 7.46321056442269912687E0 ,
84
- 4.86371970985681366614E1 ,
85
- 1.96520832956077098242E2 ,
86
- 5.26445194995477358631E2 ,
87
- 9.34528527171957607540E2 ,
88
- 1.02755188689515710272E3 ,
89
- 5.57535335369399327526E2 ,
72
+ 2.46196981473530512524e -10 ,
73
+ 5.64189564831068821977e -1 ,
74
+ 7.46321056442269912687e0 ,
75
+ 4.86371970985681366614e1 ,
76
+ 1.96520832956077098242e2 ,
77
+ 5.26445194995477358631e2 ,
78
+ 9.34528527171957607540e2 ,
79
+ 1.02755188689515710272e3 ,
80
+ 5.57535335369399327526e2 ,
90
81
]
91
82
92
83
Q = [
93
- 1.32281951154744992508E1 ,
94
- 8.67072140885989742329E1 ,
95
- 3.54937778887819891062E2 ,
96
- 9.75708501743205489753E2 ,
97
- 1.82390916687909736289E3 ,
98
- 2.24633760818710981792E3 ,
99
- 1.65666309194161350182E3 ,
100
- 5.57535340817727675546E2 ,
84
+ 1.32281951154744992508e1 ,
85
+ 8.67072140885989742329e1 ,
86
+ 3.54937778887819891062e2 ,
87
+ 9.75708501743205489753e2 ,
88
+ 1.82390916687909736289e3 ,
89
+ 2.24633760818710981792e3 ,
90
+ 1.65666309194161350182e3 ,
91
+ 5.57535340817727675546e2 ,
101
92
]
102
93
103
94
# approximation for abs(a) >= 8
104
95
R = [
105
- 5.64189583547755073984E -1 ,
106
- 1.27536670759978104416E0 ,
107
- 5.01905042251180477414E0 ,
108
- 6.16021097993053585195E0 ,
109
- 7.40974269950448939160E0 ,
110
- 2.97886665372100240670E0 ,
96
+ 5.64189583547755073984e -1 ,
97
+ 1.27536670759978104416e0 ,
98
+ 5.01905042251180477414e0 ,
99
+ 6.16021097993053585195e0 ,
100
+ 7.40974269950448939160e0 ,
101
+ 2.97886665372100240670e0 ,
111
102
]
112
103
113
104
S = [
114
- 2.26052863220117276590E0 ,
115
- 9.39603524938001434673E0 ,
116
- 1.20489539808096656605E1 ,
117
- 1.70814450747565897222E1 ,
118
- 9.60896809063285878198E0 ,
119
- 3.36907645100081516050E0 ,
105
+ 2.26052863220117276590e0 ,
106
+ 9.39603524938001434673e0 ,
107
+ 1.20489539808096656605e1 ,
108
+ 1.70814450747565897222e1 ,
109
+ 9.60896809063285878198e0 ,
110
+ 3.36907645100081516050e0 ,
120
111
]
121
112
122
113
# Shortcut special cases
@@ -188,72 +179,72 @@ def _ndtri(y):
188
179
"""
189
180
# approximation for 0 <= abs(z - 0.5) <= 3/8
190
181
P0 = [
191
- - 5.99633501014107895267E1 ,
192
- 9.80010754185999661536E1 ,
193
- - 5.66762857469070293439E1 ,
194
- 1.39312609387279679503E1 ,
195
- - 1.23916583867381258016E0 ,
182
+ - 5.99633501014107895267e1 ,
183
+ 9.80010754185999661536e1 ,
184
+ - 5.66762857469070293439e1 ,
185
+ 1.39312609387279679503e1 ,
186
+ - 1.23916583867381258016e0 ,
196
187
]
197
188
198
189
Q0 = [
199
- 1.95448858338141759834E0 ,
200
- 4.67627912898881538453E0 ,
201
- 8.63602421390890590575E1 ,
202
- - 2.25462687854119370527E2 ,
203
- 2.00260212380060660359E2 ,
204
- - 8.20372256168333339912E1 ,
205
- 1.59056225126211695515E1 ,
206
- - 1.18331621121330003142E0 ,
190
+ 1.95448858338141759834e0 ,
191
+ 4.67627912898881538453e0 ,
192
+ 8.63602421390890590575e1 ,
193
+ - 2.25462687854119370527e2 ,
194
+ 2.00260212380060660359e2 ,
195
+ - 8.20372256168333339912e1 ,
196
+ 1.59056225126211695515e1 ,
197
+ - 1.18331621121330003142e0 ,
207
198
]
208
199
209
200
# Approximation for interval z = sqrt(-2 log y ) between 2 and 8
210
201
# i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
211
202
P1 = [
212
- 4.05544892305962419923E0 ,
213
- 3.15251094599893866154E1 ,
214
- 5.71628192246421288162E1 ,
215
- 4.40805073893200834700E1 ,
216
- 1.46849561928858024014E1 ,
217
- 2.18663306850790267539E0 ,
218
- - 1.40256079171354495875E -1 ,
219
- - 3.50424626827848203418E -2 ,
220
- - 8.57456785154685413611E -4 ,
203
+ 4.05544892305962419923e0 ,
204
+ 3.15251094599893866154e1 ,
205
+ 5.71628192246421288162e1 ,
206
+ 4.40805073893200834700e1 ,
207
+ 1.46849561928858024014e1 ,
208
+ 2.18663306850790267539e0 ,
209
+ - 1.40256079171354495875e -1 ,
210
+ - 3.50424626827848203418e -2 ,
211
+ - 8.57456785154685413611e -4 ,
221
212
]
222
213
223
214
Q1 = [
224
- 1.57799883256466749731E1 ,
225
- 4.53907635128879210584E1 ,
226
- 4.13172038254672030440E1 ,
227
- 1.50425385692907503408E1 ,
228
- 2.50464946208309415979E0 ,
229
- - 1.42182922854787788574E -1 ,
230
- - 3.80806407691578277194E -2 ,
231
- - 9.33259480895457427372E -4 ,
215
+ 1.57799883256466749731e1 ,
216
+ 4.53907635128879210584e1 ,
217
+ 4.13172038254672030440e1 ,
218
+ 1.50425385692907503408e1 ,
219
+ 2.50464946208309415979e0 ,
220
+ - 1.42182922854787788574e -1 ,
221
+ - 3.80806407691578277194e -2 ,
222
+ - 9.33259480895457427372e -4 ,
232
223
]
233
224
234
225
# Approximation for interval z = sqrt(-2 log y ) between 8 and 64
235
226
# i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
236
227
P2 = [
237
- 3.23774891776946035970E0 ,
238
- 6.91522889068984211695E0 ,
239
- 3.93881025292474443415E0 ,
240
- 1.33303460815807542389E0 ,
241
- 2.01485389549179081538E -1 ,
242
- 1.23716634817820021358E -2 ,
243
- 3.01581553508235416007E -4 ,
244
- 2.65806974686737550832E -6 ,
245
- 6.23974539184983293730E -9 ,
228
+ 3.23774891776946035970e0 ,
229
+ 6.91522889068984211695e0 ,
230
+ 3.93881025292474443415e0 ,
231
+ 1.33303460815807542389e0 ,
232
+ 2.01485389549179081538e -1 ,
233
+ 1.23716634817820021358e -2 ,
234
+ 3.01581553508235416007e -4 ,
235
+ 2.65806974686737550832e -6 ,
236
+ 6.23974539184983293730e -9 ,
246
237
]
247
238
248
239
Q2 = [
249
- 6.02427039364742014255E0 ,
250
- 3.67983563856160859403E0 ,
251
- 1.37702099489081330271E0 ,
252
- 2.16236993594496635890E -1 ,
253
- 1.34204006088543189037E -2 ,
254
- 3.28014464682127739104E -4 ,
255
- 2.89247864745380683936E -6 ,
256
- 6.79019408009981274425E -9 ,
240
+ 6.02427039364742014255e0 ,
241
+ 3.67983563856160859403e0 ,
242
+ 1.37702099489081330271e0 ,
243
+ 2.16236993594496635890e -1 ,
244
+ 1.34204006088543189037e -2 ,
245
+ 3.28014464682127739104e -4 ,
246
+ 2.89247864745380683936e -6 ,
247
+ 6.79019408009981274425e -9 ,
257
248
]
258
249
259
250
sign_flag = 1
@@ -266,7 +257,7 @@ def _ndtri(y):
266
257
# between -0.135 and 0.135
267
258
if y > EXP_NEG2 :
268
259
y -= 0.5
269
- y2 = y ** 2
260
+ y2 = y ** 2
270
261
x = y + y * (y2 * _polevl (y2 , P0 , 4 ) / _p1evl (y2 , Q0 , 8 ))
271
262
x = x * ROOT_2PI
272
263
return x
@@ -275,7 +266,7 @@ def _ndtri(y):
275
266
x0 = x - math .log (x ) / x
276
267
277
268
z = 1.0 / x
278
- if x < 8.0 : # y > exp(-32) = 1.2664165549e-14
269
+ if x < 8.0 : # y > exp(-32) = 1.2664165549e-14
279
270
x1 = z * _polevl (z , P1 , 8 ) / _p1evl (z , Q1 , 8 )
280
271
else :
281
272
x1 = z * _polevl (z , P2 , 8 ) / _p1evl (z , Q2 , 8 )
0 commit comments