Skip to content

Commit 77c3e43

Browse files
committed
Mimic Julia behaviour
1 parent b79f0e5 commit 77c3e43

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

src/device/intrinsics/math.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,17 @@ using Base.Math: throw_complex_domainerror
99
# - consider emitting LLVM intrinsics and lowering those in the back-end
1010

1111
### Constants
12-
@device_override Core.Float32(::typeof(π), ::RoundingMode) = reinterpret(Float32, 0x40490fdb) # 3.1415927f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000)))
13-
@device_override Core.Float16(::typeof(π), ::RoundingMode) = reinterpret(Float16, 0x4248) # Float16(3.14)
14-
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = reinterpret(Float32, 0x402df854) # 2.7182817f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000)))
15-
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = reinterpret(Float16, 0x4170) # Float16(2.719)
12+
# π
13+
@device_override Core.Float32(::typeof(π), ::RoundingMode) = reinterpret(Float32, 0x40490fdb) # 3.1415927f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000)))
14+
@device_override Core.Float32(::typeof(π), ::RoundingMode{:Down}) = reinterpret(Float32, 0x40490fda) # 3.1415925f0 prevfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000))))
15+
@device_override Core.Float16(::typeof(π), ::RoundingMode{:Up}) = reinterpret(Float16, 0x4249) # Float16(3.143)
16+
@device_override Core.Float16(::typeof(π), ::RoundingMode) = reinterpret(Float16, 0x4248) # Float16(3.14)
17+
18+
#
19+
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode{:Up}) = reinterpret(Float32, 0x402df855) # 2.718282f0 nextfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000))))
20+
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = reinterpret(Float32, 0x402df854) # 2.7182817f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000)))
21+
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = reinterpret(Float16, 0x4170) # Float16(2.719)
22+
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode{:Down}) = reinterpret(Float16, 0x416f) # Float16(2.717)
1623

1724
### Common Intrinsics
1825
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)

test/device/intrinsics/math.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,31 @@ end
312312
@test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir)
313313
end
314314

315-
let # "issue551"
316-
mtl_pi = only(Array(T.(MtlArray([π]), RoundNearest)))
317-
@test mtl_pi == T(π)
318-
mtl_ℯ = only(Array(T.(MtlArray([ℯ]), RoundNearest)))
319-
@test mtl_ℯ == T(ℯ)
315+
# Borrowed from the Julia "Irrationals compared with Rationals and Floats" testset
316+
@testset "Comparisons with $irr" for irr in (π, ℯ)
317+
function convert_test_32(res)
318+
res[1] = Float32(irr,RoundDown) < irr
319+
res[2] = Float32(irr,RoundUp) > irr
320+
res[3] = !(Float32(irr,RoundDown) > irr)
321+
res[4] = !(Float32(irr,RoundUp) < irr)
322+
return nothing
323+
end
324+
325+
res_32 = MtlArray(zeros(Bool,4))
326+
Metal.@sync @metal convert_test_32(res_32)
327+
@test all(Array(res_32))
328+
329+
function convert_test_16(res)
330+
res[1] = Float16(irr,RoundDown) < irr
331+
res[2] = Float16(irr,RoundUp) > irr
332+
res[3] = !(Float16(irr,RoundDown) > irr)
333+
res[4] = !(Float16(irr,RoundUp) < irr)
334+
return nothing
335+
end
336+
337+
res_16 = MtlArray(zeros(Bool,4))
338+
Metal.@sync @metal convert_test_16(res_16)
339+
@test all(Array(res_16))
320340
end
321341
end
322342
end

0 commit comments

Comments
 (0)