Skip to content

Commit 70329c2

Browse files
committed
jdyfag
1 parent 77c3e43 commit 70329c2

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

src/device/intrinsics/math.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,45 @@
22

33
using Base: FastMath
44
using Base.Math: throw_complex_domainerror
5+
import Core: Float32, Float16
56

67
# TODO:
78
# - wrap all intrinsics from include/metal/metal_math
89
# - add support for vector types
910
# - consider emitting LLVM intrinsics and lowering those in the back-end
1011

12+
# Precompute irrationals for use on GPU
13+
# macro _const_convert(irr,T,r)
14+
# :($T($irr, $r))
15+
# end
16+
for T in (:Float32, :Float16), irr in (, :ℯ), r in (:RoundUp, :RoundDown)
17+
# newT = Symbol(:new,T)
18+
@eval begin
19+
# $newT(::typeof($irr), ::typeof($r)) = @_const_convert($irr, $T, $r)
20+
# @device_override $T(::typeof($irr), ::typeof($r)) = @_const_convert($irr, $T, $r)
21+
@device_override $T(::typeof($irr), ::typeof($r)) = Base.Rounding._convert_rounding($T, $irr, $r)
22+
end
23+
end
24+
25+
for T in (:Float32, :Float16), irr in (, :ℯ), r in (:RoundUp, :RoundDown)
26+
@eval begin
27+
@device_override $T(::typeof($irr), ::typeof($r)) = Base.Rounding._convert_rounding($T, $irr, $r)
28+
end
29+
end
30+
1131
### Constants
1232
# π
13-
@device_override Core.Float32(::typeof(π), ::RoundingMode) = reinterpret(Float32, 0x40490fdb) # 3.1415927f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000)))
14-
@device_override Core.Float32(::typeof(π), ::RoundingMode{:Down}) = reinterpret(Float32, 0x40490fda) # 3.1415925f0 prevfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000))))
15-
@device_override Core.Float16(::typeof(π), ::RoundingMode{:Up}) = reinterpret(Float16, 0x4249) # Float16(3.143)
16-
@device_override Core.Float16(::typeof(π), ::RoundingMode) = reinterpret(Float16, 0x4248) # Float16(3.14)
33+
# @device_override Core.Float32(::typeof(π), ::RoundingMode) = reinterpret(Float32, 0x40490fdb) # 3.1415927f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000)))
34+
# @device_override Core.Float32(::typeof(π), ::RoundingMode{:Down}) = reinterpret(Float32, 0x40490fda) # 3.1415925f0 prevfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000))))
35+
36+
# @device_override Core.Float16(::typeof(π), ::RoundingMode{:Up}) = reinterpret(Float16, 0x4249) # Float16(3.143)
37+
# @device_override Core.Float16(::typeof(π), ::RoundingMode) = reinterpret(Float16, 0x4248) # Float16(3.14)
1738

1839
#
19-
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode{:Up}) = reinterpret(Float32, 0x402df855) # 2.718282f0 nextfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000))))
20-
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = reinterpret(Float32, 0x402df854) # 2.7182817f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000)))
21-
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = reinterpret(Float16, 0x4170) # Float16(2.719)
22-
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode{:Down}) = reinterpret(Float16, 0x416f) # Float16(2.717)
40+
# @device_override Core.Float32(::typeof(ℯ), ::RoundingMode{:Up}) = reinterpret(Float32, 0x402df855) # 2.718282f0 nextfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000))))
41+
# @device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = reinterpret(Float32, 0x402df854) # 2.7182817f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000)))
42+
# @device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = reinterpret(Float16, 0x4170) # Float16(2.719)
43+
# @device_override Core.Float16(::typeof(ℯ), ::RoundingMode{:Down}) = reinterpret(Float16, 0x416f) # Float16(2.717)
2344

2445
### Common Intrinsics
2546
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)

0 commit comments

Comments
 (0)