Skip to content

Commit 6620fe2

Browse files
author
Gokulnath Srinivasan
authored
Add LLVM Legalization for tir.erf (#18104)
This PR adds LLVM legalization support for tir.erf using the Abramowitz and Stegun approximation, which avoids the precision issues found in the tanh approximation based implementation.
1 parent 1b9da40 commit 6620fe2

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/target/llvm/intrin_rule_llvm.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,24 @@ TVM_REGISTER_OP("tir.atanh")
242242
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
243243
});
244244

245+
TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246+
using tir::make_const;
247+
const tir::CallNode* call = e.as<tir::CallNode>();
248+
ICHECK(call != nullptr) << "Invalid call node in erf legalization";
249+
const PrimExpr& x = call->args[0];
250+
PrimExpr abs_x = tvm::abs(x);
251+
PrimExpr t = make_const(x.dtype(), 1.0) /
252+
(make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x);
253+
PrimExpr a1 = make_const(x.dtype(), 0.254829592);
254+
PrimExpr a2 = make_const(x.dtype(), -0.284496736);
255+
PrimExpr a3 = make_const(x.dtype(), 1.421413741);
256+
PrimExpr a4 = make_const(x.dtype(), -1.453152027);
257+
PrimExpr a5 = make_const(x.dtype(), 1.061405429);
258+
PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
259+
PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x);
260+
return tvm::tir::Select(x < 0, -approx, approx);
261+
});
262+
245263
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
246264
const tir::CallNode* call = e.as<tir::CallNode>();
247265
ICHECK(call != nullptr);

tests/python/tir-base/test_tir_intrin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import ctypes
2525
import math
26+
import scipy
2627

2728

2829
def test_nearbyint():
@@ -77,6 +78,7 @@ def test_unary_intrin():
7778
(tvm.tir.asinh, lambda x: np.arcsinh(x)),
7879
(tvm.tir.acosh, lambda x: np.arccosh(x)),
7980
(tvm.tir.atanh, lambda x: np.arctanh(x)),
81+
(tvm.tir.erf, lambda x: scipy.special.erf(x)),
8082
]
8183

8284
def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):

0 commit comments

Comments
 (0)