diff --git a/src/utils.jl b/src/EAUtils.jl similarity index 95% rename from src/utils.jl rename to src/EAUtils.jl index 58b6d20..f78464e 100644 --- a/src/utils.jl +++ b/src/EAUtils.jl @@ -1,19 +1,23 @@ +const EA_AS_PKG = Symbol(@__MODULE__) !== :Base # develop EA as an external package + module EAUtils -import ..EscapeAnalysis: EscapeAnalysis +import ..EA_AS_PKG +if EA_AS_PKG + import ..EscapeAnalysis +else + import Core.Compiler.EscapeAnalysis: EscapeAnalysis + Base.getindex(estate::EscapeAnalysis.EscapeState, @nospecialize(x)) = + Core.Compiler.getindex(estate, x) +end const EA = EscapeAnalysis const CC = Core.Compiler -let - README = normpath(dirname(@__DIR__), "README.md") - include_dependency(README) - @doc read(README, String) EA -end - # entries # ------- -using InteractiveUtils +@static if EA_AS_PKG +import InteractiveUtils: gen_call_with_extracted_types_and_kwargs """ @code_escapes [options...] f(args...) @@ -24,8 +28,9 @@ As with `@code_typed` and its family, any of `code_escapes` keyword arguments ca as the optional arguments like `@code_escpase interp=myinterp myfunc(myargs...)`. """ macro code_escapes(ex0...) - return InteractiveUtils.gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) end +end # @static if EA_AS_PKG """ code_escapes(f, argtypes=Tuple{}; [world], [interp]) -> result::EscapeResult @@ -402,9 +407,10 @@ end end # module EAUtils -using .EAUtils: - code_escapes, - @code_escapes -export - code_escapes, - @code_escapes +if EA_AS_PKG + using .EAUtils: code_escapes, @code_escapes + export code_escapes, @code_escapes +else + using .EAUtils: code_escapes + export code_escapes +end diff --git a/src/EscapeAnalysis.jl b/src/EscapeAnalysis.jl index 0bd1e82..d892678 100644 --- a/src/EscapeAnalysis.jl +++ b/src/EscapeAnalysis.jl @@ -37,7 +37,7 @@ import Core.Compiler: # Core.Compiler specific definitions if _TOP_MOD !== Core.Compiler include(@__MODULE__, "disjoint_set.jl") else - include(@__MODULE__, "compiler/EscapeAnalysis/disjoint_set.jl") + include(@__MODULE__, "compiler/ssair/EscapeAnalysis/disjoint_set.jl") end # XXX better to be IdSet{Int}? @@ -481,7 +481,7 @@ end """ cache_escapes!(linfo::MethodInstance, estate::EscapeState, _::IRCode) -Transforms escape information of `estate` for interprocedural propagation, +Transforms escape information of `estate` for interprocedural propagation, and caches it in a global cache that can then be looked up later when `linfo` callsite is seen again. """ @@ -981,10 +981,15 @@ function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}) ft = argextype(first(args), ir, ir.sptypes, ir.argtypes) f = singleton_type(ft) if isa(f, Core.IntrinsicFunction) - # COMBAK we may break soundness and need to account for some aliasing here, e.g. `pointerref` - argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + # XXX somehow `:call` expression can creep in here, ideally we should be able to do: + # argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + argtypes = Any[] + for i = 2:length(args) + arg = args[i] + push!(argtypes, isexpr(arg, :call) ? Any : argextype(arg, ir)) + end intrinsic_nothrow(f, argtypes) || add_thrown_escapes!(astate, pc, args, 2) - return + return # TODO accounts for pointer operations end result = escape_builtin!(f, astate, pc, args) if result === missing @@ -1481,7 +1486,7 @@ end # if isdefined(Core, :arrayfreeze) && isdefined(Core, :arraythaw) && isdefi # NOTE define fancy package utilities when developing EA as an external package if _TOP_MOD !== Core.Compiler - include(@__MODULE__, "utils.jl") + include(@__MODULE__, "EAUtils.jl") end end # baremodule EscapeAnalysis diff --git a/test/EscapeAnalysis.jl b/test/EscapeAnalysis.jl new file mode 100644 index 0000000..95370dd --- /dev/null +++ b/test/EscapeAnalysis.jl @@ -0,0 +1,1869 @@ +@isdefined(EA_AS_PKG) || include(normpath(@__DIR__, "setup.jl")) + +@testset "basics" begin + let # simplest + result = code_escapes((Any,)) do a # return to caller + return nothing + end + @test has_return_escape(result.state[Argument(2)]) + end + let # return + result = code_escapes((Any,)) do a + return a + end + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(1)], 0) # self + @test !has_return_escape(result.state[Argument(1)], i) # self + @test has_return_escape(result.state[Argument(2)], 0) # a + @test has_return_escape(result.state[Argument(2)], i) # a + end + let # global store + result = code_escapes((Any,)) do a + global aa = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + let # global load + result = code_escapes() do + global gr + return gr + end + i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) + result = code_escapes((Any,)) do s + global v + v = s + return v + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + end + let # :gc_preserve_begin / :gc_preserve_end + result = code_escapes((String,)) do s + m = SafeRef(s) + GC.@preserve m begin + return nothing + end + end + i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # :isdefined + result = code_escapes((String, Bool, )) do a, b + if b + s = Ref(a) + end + return @isdefined(s) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # ϕ-node + result = code_escapes((Bool,Any,Any)) do cond, a, b + c = cond ? a : b # ϕ(a, b) + return c + end + @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], i) # a + @test has_return_escape(result.state[Argument(4)], i) # b + end + let # π-node + result = code_escapes((Any,)) do a + if isa(a, Regex) # a::π(Regex) + return a + end + return nothing + end + @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) + @test any(findall(isreturn, result.ir.stmts.inst)) do i + has_return_escape(result.state[Argument(2)], i) + end + end + let # φᶜ-node / ϒ-node + result = code_escapes((Any,String)) do a, b + local x::String + try + x = a + catch err + x = b + end + return x + end + @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) + @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], i) + @test has_return_escape(result.state[Argument(3)], i) + end + let # branching + result = code_escapes((Any,Bool,)) do a, c + if c + return nothing # a doesn't escape in this branch + else + return a # a escapes to a caller + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # loop + result = code_escapes((Int,)) do n + c = SafeRef{Bool}(false) + while n > 0 + rand(Bool) && return c + end + nothing + end + i = only(findall(isT(SafeRef{Bool}), result.ir.stmts.type)) + @test has_return_escape(result.state[SSAValue(i)]) + end + let # try/catch + result = code_escapes((Any,)) do a + try + nothing + catch err + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + try + nothing + finally + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end +end + +let # simple allocation + result = code_escapes((Bool,)) do c + mm = SafeRef{Bool}(c) # just allocated, never escapes + return mm[] ? nothing : 1 + end + + i = only(findall(isT(SafeRef{Bool}), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +@testset "inter-procedural" begin + # FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` + # because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 + # and thus it leads to the following two broken tests + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes() do + broadcast_NoEscape(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape2(b) = broadcast(identity, b) + $code_escapes() do + broadcast_NoEscape2(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline f_GlobalEscape_a(a) = (global globala = a) # obvious escape + $code_escapes() do + f_GlobalEscape_a(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) && has_thrown_escape(result.state[SSAValue(i)]) + end + # if we can't determine the matching method statically, we should be conservative + let result = @eval Module() $code_escapes((Ref{Any},)) do a + may_exist(a) + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes((Ref{Any},)) do a + Base.@invokelatest broadcast_NoEscape(a) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # handling of simple union-split (just exploit the inliner's effort) + let T = Union{Int,Nothing} + result = @eval Module() begin + @noinline unionsplit_NoEscape_a(a) = string(nothing) + @noinline unionsplit_NoEscape_a(a::Int) = a + 10 + $code_escapes(($T,)) do x + s = $SafeRef{$T}(x) + unionsplit_NoEscape_a(s[]) + return nothing + end + end + inds = findall(isT(SafeRef{T}), result.ir.stmts.type) # find allocation statement + @assert !isempty(inds) + for i in inds + @test has_no_escape(result.state[SSAValue(i)]) + end + end + + # appropriate conversion of inter-procedural context + # https://github.com/aviatesk/EscapeAnalysis.jl/issues/7 + let M = Module() + @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); Base.inferencebarrier(nothing)) + + result = @eval M $code_escapes() do + a = Ref("foo") # shouldn't be "return escape" + b = f_NoEscape_a(a) + nothing + end + i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + + result = @eval M $code_escapes() do + a = Ref("foo") # still should be "return escape" + b = f_NoEscape_a(a) + return a + end + i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + + # should propagate escape information imposed on return value to the aliased call argument + let result = @eval Module() begin + @noinline f_ReturnEscape_a(a) = (println("prevent inlining"); a) + $code_escapes() do + obj = Ref("foo") # should be "return escape" + ret = f_ReturnEscape_a(obj) + return ret # alias of `obj` + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + @noinline f_NoReturnEscape_a(a) = (println("prevent inlining"); identity("hi")) + $code_escapes() do + obj = Ref("foo") # better to not be "return escape" + ret = f_NoReturnEscape_a(obj) + return ret # must not alias to `obj` + end + end + i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "builtins" begin + let # throw + r = code_escapes((Any,)) do a + throw(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # implicit throws + r = code_escapes((Any,)) do a + getfield(a, :may_not_field) + end + @test has_thrown_escape(r.state[Argument(2)]) + + r = code_escapes((Any,)) do a + sizeof(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # :=== + result = code_escapes((Bool, String)) do cond, s + m = cond ? SafeRef(s) : nothing + c = m === nothing + return c + end + i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # sizeof + ary = [0,1,2] + result = @eval code_escapes() do + ary = $(QuoteNode(ary)) + sizeof(ary) + end + i = only(findall(isT(Core.Const(ary)), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # ifelse + result = code_escapes((Bool,)) do c + r = ifelse(c, Ref("yes"), Ref("no")) + return r + end + inds = findall(isT(Base.RefValue{String}), result.ir.stmts.type) + @assert !isempty(inds) + for i in inds + @test has_return_escape(result.state[SSAValue(i)]) + end + end + let # ifelse (with constant condition) + result = code_escapes() do + r = ifelse(true, Ref("yes"), Ref(nothing)) + return r + end + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)]) + elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) + @test has_no_escape(result.state[SSAValue(i)]) + end + end + end + + let # typeassert + result = code_escapes((Any,)) do x + y = x::String + return y + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + end + + let # isdefined + result = code_escapes((Any,)) do x + isdefined(x, :foo) ? x : throw("undefined") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + + result = code_escapes((Module,)) do m + isdefined(m, 10) # throws + end + @test has_thrown_escape(result.state[Argument(2)]) + end +end + +@testset "flow-sensitivity" begin + # ReturnEscape + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + if cond + return cond + end + return r + end + i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) + rts = findall(isreturn, result.ir.stmts.inst) + @assert length(rts) == 2 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 + end + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + cnt = 0 + while rand(Bool) + cnt += 1 + rand(Bool) && return r + end + rand(Bool) && return r + return cnt + end + i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) + rts = findall(isreturn, result.ir.stmts.inst) # return statement + @assert length(rts) == 3 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 + end +end + +@testset "escape through exceptions" begin + M = @eval Module() begin + unsafeget(x) = isassigned(x) ? x[] : throw(x) + @noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end + end + @noinline function current_exceptions_escape!() + excs = Base.current_exceptions() + Gx[] = excs + end + const Gx = Ref{Any}() + @__MODULE__ + end + + let # simple: return escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err + ret = err + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + + let # simple: global escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret # prevent DCE + try + s = unsafeget(r) + ret = sizeof(s) + catch err + global g = err + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # account for possible escapes via nested throws + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + throw(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + rethrow(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + rethrow_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + local t + try + r = Ref{String}() + t = unsafeget(r) + catch err + t = typeof(err) + rethrow_escape!() + end + return t + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + Gx[] = Base.current_exceptions() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + current_exceptions_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err + global g = err + end + s2 = unsafeget(r2) + return s2, r2 + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test !has_all_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + end + + # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` + + let # limited propagation: exception is caught within a frame => doesn't escape to a caller + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end + let # sequential: escape information imposed on `err1` and `err2 should propagate separately + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err1 + global g = err1 + end + try + s2 = unsafeget(r2) + ret = sizeof(s2) + catch err2 + ret = err2 + end + return ret + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test_broken !has_all_escape(result.state[SSAValue(i2)]) + end + let # nested: escape information imposed on `inner` shouldn't propagate to `s` + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + try + ret = sizeof(s) + catch inner + return inner + end + catch outer + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + end + let # merge: escape information imposed on `err1` and `err2 should be merged + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err1 + return err1 + end + try + s = unsafeget(r) + ret = sizeof(s) + catch err2 + return err2 + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + rs = findall(isreturn, result.ir.stmts.inst) + @test_broken !has_all_escape(result.state[SSAValue(i)]) + for r in rs + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let # no exception handling: should keep propagating the escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + finally + if !@isdefined(ret) + ret = 42 + end + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "field analysis / alias analysis" begin + # escaped allocations + # ------------------- + + # escaped object should escape its fields as well + let result = code_escapes((Any,)) do a + global g = SafeRef{Any}(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + global g = (a,) + nothing + end + i = only(findall(issubT(Tuple), result.ir.stmts.type)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + o0 = SafeRef{Any}(a) + global g = SafeRef(o0) + nothing + end + i0 = only(findall(isT(SafeRef{Any}), result.ir.stmts.type)) + i1 = only(findall(isT(SafeRef{SafeRef{Any}}), result.ir.stmts.type)) + @test has_all_escape(result.state[SSAValue(i0)]) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + t0 = (a,) + global g = (t0,) + nothing + end + inds = findall(issubT(Tuple), result.ir.stmts.type) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end + @test has_all_escape(result.state[Argument(2)]) + end + # global escape through `setfield!` + let result = code_escapes((Any,)) do a + r = SafeRef{Any}(:init) + global g = r + r[] = a + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + global g = r + r[] = b + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) # a + @test has_all_escape(result.state[Argument(3)]) # b + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + Rx[] = s + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let M = EATModule() + @eval M module ___xxx___ + import ..SafeRef + const Rx = SafeRef("Rx") + end + result = @eval M begin + $code_escapes((String,)) do s + rx = getfield(___xxx___, :Rx) + rx[] = s + nothing + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # field escape + # ------------ + + # field escape should propagate to :new arguments + let result = code_escapes((String,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String,)) do a + t = (a,) + f = t[1] + return f + end + i = only(findall(t->t<:Tuple, result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String, String)) do a, b + obj = SafeRefs(a, b) + fld1 = obj[1] + fld2 = obj[2] + return (fld1, fld2) + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # field escape should propagate to `setfield!` argument + let result = code_escapes((String,)) do a + o = SafeRef("foo") + o[] = a + f = o[] + return f + end + i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # propagate escape information imposed on return value of `setfield!` call + let result = code_escapes((String,)) do a + obj = SafeRef("foo") + return (obj[] = a) + end + i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # nested allocations + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + return o2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = (a,) + o2 = (o1,) + return o2[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + o1′ = o2[] + a′ = o1′[] + return a′ + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2 = SafeRef(o1) + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isnew, result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2′ = SafeRef(nothing) + o2 = SafeRef{SafeRef}(o2′) + o2[] = o1 + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + findall(1:length(result.ir.stmts)) do i + if isnew(result.ir.stmts[i][:inst]) + t = result.ir.stmts[i][:type] + return t === SafeRef{String} || # o1 + t === SafeRef{SafeRef} # o2 + end + return false + end |> x->foreach(x) do i + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # ϕ-node allocations + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ = SafeRef{Any}(x) + else + ϕ = SafeRef{Any}(y) + end + return ϕ[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + i = only(findall(isϕ, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = SafeRef{Any}(x) + else + ϕ2 = ϕ1 = SafeRef{Any}(y) + end + return ϕ1[], ϕ2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # when ϕ-node merges values with different types + let result = code_escapes((Bool,String,String,String)) do cond, x, y, z + local out + if cond + ϕ = SafeRef(x) + out = ϕ[] + else + ϕ = SafeRefs(z, y) + end + return @isdefined(out) ? out : throw(ϕ) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test !has_return_escape(result.state[Argument(4)], r) # y + @test has_return_escape(result.state[Argument(5)], r) # z + @test has_thrown_escape(result.state[SSAValue(ϕ)], t) + end + + # alias analysis + # -------------- + + # alias via getfield & Expr(:new) + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((String,)) do s + r = SafeRef(Rx) + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(2)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via getfield & setfield! + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((SafeRef{String}, String,)) do _rx, s + r = SafeRef(_rx) + r[] = Rx + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(3)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via typeassert + let result = code_escapes((Any,)) do a + global g + (g::SafeRef{Any})[] = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + # alias via ifelse + let result = @eval EATModule() begin + const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") + $code_escapes((Bool,String,)) do c, a + r = ifelse(c, Lx, Rx) + r[] = a + nothing + end + end + @test has_all_escape(result.state[Argument(3)]) # a + end + # alias via ϕ-node + let result = code_escapes((Bool,String)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x + if cond1 + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + cond2 && (ϕ2[] = x) + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(4)], r) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # alias via π-node + let result = code_escapes((String,)) do x + global g + l = g + if isa(l, SafeRef{String}) + l[] = x + end + nothing + end + @test has_all_escape(result.state[Argument(2)]) # x + end + + # dynamic semantics + # ----------------- + + # conservatively handle untyped objects + let result = @eval code_escapes((Any,Any,)) do T, x + obj = $(Expr(:new, :T, :x)) + end + t = only(findall(isnew, result.ir.stmts.inst)) + @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T + @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x, :y)) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x)) + setfield!(obj, :x, y) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + + # conservatively handle unknown field: + # all fields should be escaped, but the allocation itself doesn't need to be escaped + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRef(a) + return getfield(obj, fld) + end + i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs(a, b) + return getfield(obj, fld) # should escape both `a` and `b` + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs(a, b) + return obj[idx] # should escape both `a` and `b` + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[2] # should escape `a` + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[1] # this should escape `a` + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs("a", "b") + obj[idx] = a + return obj[2] # should escape `a` + end + i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) # obj + end + + # interprocedural + # --------------- + + let result = @eval EATModule() begin + @noinline getx(obj) = obj[] + $code_escapes((String,)) do a + obj = SafeRef(a) + fld = getx(obj) + return fld + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + end + + # TODO interprocedural field analysis + let result = code_escapes((SafeRef{String},)) do s + s[] = "bar" + global g = s[] + nothing + end + @test_broken !has_all_escape(result.state[Argument(2)]) + end + + # TODO flow-sensitivity? + # ---------------------- + + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(:init) + r[] = a + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any,Bool)) do a, b, cond + r = SafeRef{Any}(:init) + if cond + r[] = a + return r[] + else + r[] = b + return nothing + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + r = only(findall(result.ir.stmts.inst) do @nospecialize x + isreturn(x) && isa(x.val, Core.SSAValue) + end) + @test has_return_escape(result.state[Argument(2)], r) # a + @test_broken !has_return_escape(result.state[Argument(3)], r) # b + end + + # handle conflicting field information correctly + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRef("foo") + else + o = SafeRefs("bar", baz) + r = getfield(o, 2) + end + if cnd + o = o::SafeRef + setfield!(o, 1, qux) + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + for new in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(new)]) + end + end + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRefs("foo", "bar") + r = setfield!(o, 2, baz) + else + o = SafeRef(qux) + end + if !cnd + o = o::SafeRef + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + end +end + +# demonstrate the power of our field / alias analysis with a realistic end to end example +abstract type AbstractPoint{T} end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute(T, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(100000000-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute(a, b) + for i in 0:(100000000-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute!(a, b) + for i in 0:(100000000-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end +let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end + +@testset "array primitives" begin + # arrayref + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(true, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test_broken !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Bool)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((AbstractVector{String},Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{String},Any)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # arrayset + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((String,String,String,)) do s, t, u + xs = Vector{String}(undef, 3) + Base.arrayset(true, xs, s, 1) + Base.arrayset(true, xs, t, 2) + Base.arrayset(true, xs, u, 3) + return xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + for i in 2:result.state.nargs + @test has_return_escape(result.state[Argument(i)], r) + end + end + let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((String,String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs::String + @test has_thrown_escape(result.state[Argument(3)], t) # x::String + end + let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + + # arrayref and arrayset + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test !has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j + x = SafeRef(s) + xs[i] = x + xs[j] # potential error + end + i = only(findall(isnew, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(3)], t) # s + @test has_thrown_escape(result.state[SSAValue(i)], t) # x + end + + # arraysize + let result = code_escapes((Vector{Any},)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Vector{Any},Int,)) do xs, dim + Core.arraysize(xs, dim) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Any,)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) + end + + # arraylen + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs, 1) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # array resizing + # without BoundsErrors + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + # with possible BoundsErrors + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[3] = x + @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[1] = x + @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + + # array copy + let result = code_escapes((Vector{Any},)) do xs + return copy(xs) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test_broken !has_return_escape(result.state[Argument(2)], r) + end + let result = code_escapes((String,)) do s + xs = String[s] + xs′ = copy(xs) + return xs′[1] + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i1)]) + @test !has_return_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Vector{Any},)) do xs + xs′ = copy(xs) + return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + ret = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) + @test_broken !has_return_escape(result.state[SSAValue(i)], ret) + @test has_thrown_escape(result.state[Argument(2)], ref) + @test has_return_escape(result.state[Argument(2)], ret) + end + let result = code_escapes((String,)) do s + xs = Vector{String}(undef, 1) + xs[1] = s + xs′ = copy(xs) + length(xs′) > 2 && throw(xs′) + return xs′ + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) + @test_broken !has_return_escape(result.state[SSAValue(i1)], r) + @test has_thrown_escape(result.state[SSAValue(i2)], t) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test has_thrown_escape(result.state[Argument(2)], t) + @test has_return_escape(result.state[Argument(2)], r) + end + + # isassigned + let result = code_escapes((Vector{Any},Int)) do xs, i + return isassigned(xs, i) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test !has_thrown_escape(result.state[Argument(2)]) + end +end + +# demonstrate array primitive support with a realistic end to end example +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + push!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test !has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + @test !has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + pushfirst!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test !has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + @test !has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((String,String,String)) do s, t, u + xs = String[] + resize!(xs, 3) + xs[1] = s + xs[1] = t + xs[1] = u + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test !has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + @test has_return_escape(result.state[Argument(4)], r) # u +end + +@static if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "ImmutableArray" begin + # arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # mutating_arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + mutating_arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # arraythaw + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray,)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector{Any},)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = ImmutableArray(Any[]) + arraythaw(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end +end + +# demonstrate some arrayfreeze optimizations +# has_no_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization +let result = code_escapes((Int,)) do n + xs = collect(1:n) + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Float64},)) do xs + ys = sin.(xs) + ImmutableArray(ys) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Pair{Int,String}},)) do xs + n = maximum(first, xs) + ys = Vector{String}(undef, n) + for (i, s) in xs + ys[i] = s + end + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +end # @static if isdefined(Core, :ImmutableArray) + +# demonstrate a simple type level analysis can sometimes improve the analysis accuracy +# by compensating the lack of yet unimplemented analyses +@testset "special-casing bitstype" begin + let result = code_escapes((Nothing,)) do a + global bb = a + end + @test !(has_all_escape(result.state[Argument(2)])) + end + + let result = code_escapes((Int,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isT(SafeRef{Int}), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end + + # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) + let result = code_escapes((Int,Any,)) do a, b + t = tuple(a, b) + return t + end + i = only(findall(issubT(Tuple), result.ir.stmts.type)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test has_return_escape(result.state[Argument(3)], r) + end +end + +@testset "finalizer elision" begin + @test can_elide_finalizer(EscapeAnalysis.NoEscape(), 1) + @test !can_elide_finalizer(EscapeAnalysis.ReturnEscape(1), 1) + @test can_elide_finalizer(EscapeAnalysis.ReturnEscape(1), 2) + @test !can_elide_finalizer(EscapeAnalysis.ArgumentReturnEscape(), 1) + @test can_elide_finalizer(EscapeAnalysis.ThrownEscape(1), 1) +end + +# # TODO implement a finalizer elision pass +# mutable struct WithFinalizer +# v +# function WithFinalizer(v) +# x = new(v) +# f(t) = @async println("Finalizing $t.") +# return finalizer(x, x) +# end +# end +# make_m(v = 10) = MyMutable(v) +# function simple(cond) +# m = make_m() +# if cond +# # println(m.v) +# return nothing # <= insert `finalize` call here +# end +# return m +# end + +@static EA_AS_PKG && @testset "code quality" begin + using JET + + # assert that our main routine are free from (unnecessary) runtime dispatches + + function function_filter(@nospecialize(ft)) + ft === typeof(Core.Compiler.widenconst) && return false # `widenconst` is very untyped, ignore + ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore + return true + end + target_modules = (EscapeAnalysis,) + test_opt(only(methods(EscapeAnalysis.analyze_escapes)).sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + for m in methods(EscapeAnalysis.escape_builtin!) + Base._methods_by_ftype(m.sig, 1, Base.get_world_counter()) === false && continue + test_opt(m.sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index c2510c7..5d9e54b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,1871 +1,5 @@ -include("setup.jl") - +const EA_AS_PKG = true +include(normpath(@__DIR__, "setup.jl")) @testset "EscapeAnalysis" begin - -@testset "basics" begin - let # simplest - result = code_escapes((Any,)) do a # return to caller - return nothing - end - @test has_return_escape(result.state[Argument(2)]) - end - let # return - result = code_escapes((Any,)) do a - return a - end - i = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(1)], 0) # self - @test !has_return_escape(result.state[Argument(1)], i) # self - @test has_return_escape(result.state[Argument(2)], 0) # a - @test has_return_escape(result.state[Argument(2)], i) # a - end - let # global store - result = code_escapes((Any,)) do a - global aa = a - nothing - end - @test has_all_escape(result.state[Argument(2)]) - end - let # global load - result = code_escapes() do - global gr - return gr - end - i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) - result = code_escapes((Any,)) do s - global v - v = s - return v - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - end - let # :gc_preserve_begin / :gc_preserve_end - result = code_escapes((String,)) do s - m = SafeRef(s) - GC.@preserve m begin - return nothing - end - end - i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement - @test !isnothing(i) - @test has_no_escape(result.state[SSAValue(i)]) - end - let # :isdefined - result = code_escapes((String, Bool, )) do a, b - if b - s = Ref(a) - end - return @isdefined(s) - end - i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement - @test !isnothing(i) - @test has_no_escape(result.state[SSAValue(i)]) - end - let # ϕ-node - result = code_escapes((Bool,Any,Any)) do cond, a, b - c = cond ? a : b # ϕ(a, b) - return c - end - @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) - i = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], i) # a - @test has_return_escape(result.state[Argument(4)], i) # b - end - let # π-node - result = code_escapes((Any,)) do a - if isa(a, Regex) # a::π(Regex) - return a - end - return nothing - end - @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) - @test any(findall(isreturn, result.ir.stmts.inst)) do i - has_return_escape(result.state[Argument(2)], i) - end - end - let # φᶜ-node / ϒ-node - result = code_escapes((Any,String)) do a, b - local x::String - try - x = a - catch err - x = b - end - return x - end - @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) - @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) - i = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], i) - @test has_return_escape(result.state[Argument(3)], i) - end - let # branching - result = code_escapes((Any,Bool,)) do a, c - if c - return nothing # a doesn't escape in this branch - else - return a # a escapes to a caller - end - end - @test has_return_escape(result.state[Argument(2)]) - end - let # loop - result = code_escapes((Int,)) do n - c = SafeRef{Bool}(false) - while n > 0 - rand(Bool) && return c - end - nothing - end - i = only(findall(isT(SafeRef{Bool}), result.ir.stmts.type)) - @test has_return_escape(result.state[SSAValue(i)]) - end - let # try/catch - result = code_escapes((Any,)) do a - try - nothing - catch err - return a # return escape - end - end - @test has_return_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do a - try - nothing - finally - return a # return escape - end - end - @test has_return_escape(result.state[Argument(2)]) - end + include(normpath(@__DIR__, "EscapeAnalysis.jl")) end - -let # simple allocation - result = code_escapes((Bool,)) do c - mm = SafeRef{Bool}(c) # just allocated, never escapes - return mm[] ? nothing : 1 - end - - i = only(findall(isT(SafeRef{Bool}), result.ir.stmts.type)) - @test has_no_escape(result.state[SSAValue(i)]) -end - -@testset "inter-procedural" begin - # FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` - # because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 - # and thus it leads to the following two broken tests - let result = @eval Module() begin - @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) - $code_escapes() do - broadcast_NoEscape(Ref("Hi")) - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test_broken has_no_escape(result.state[SSAValue(i)]) - end - let result = @eval Module() begin - @noinline broadcast_NoEscape2(b) = broadcast(identity, b) - $code_escapes() do - broadcast_NoEscape2(Ref("Hi")) - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test_broken has_no_escape(result.state[SSAValue(i)]) - end - let result = @eval Module() begin - @noinline f_GlobalEscape_a(a) = (global globala = a) # obvious escape - $code_escapes() do - f_GlobalEscape_a(Ref("Hi")) - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)]) && has_thrown_escape(result.state[SSAValue(i)]) - end - # if we can't determine the matching method statically, we should be conservative - let result = @eval Module() $code_escapes((Ref{Any},)) do a - may_exist(a) - end - @test has_all_escape(result.state[Argument(2)]) - end - let result = @eval Module() begin - @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) - $code_escapes((Ref{Any},)) do a - Base.@invokelatest broadcast_NoEscape(a) - end - end - @test has_all_escape(result.state[Argument(2)]) - end - - # handling of simple union-split (just exploit the inliner's effort) - let T = Union{Int,Nothing} - result = @eval Module() begin - @noinline unionsplit_NoEscape_a(a) = string(nothing) - @noinline unionsplit_NoEscape_a(a::Int) = a + 10 - $code_escapes(($T,)) do x - s = $SafeRef{$T}(x) - unionsplit_NoEscape_a(s[]) - return nothing - end - end - inds = findall(isT(SafeRef{T}), result.ir.stmts.type) # find allocation statement - @assert !isempty(inds) - for i in inds - @test has_no_escape(result.state[SSAValue(i)]) - end - end - - # appropriate conversion of inter-procedural context - # https://github.com/aviatesk/EscapeAnalysis.jl/issues/7 - let M = Module() - @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); Base.inferencebarrier(nothing)) - - result = @eval M $code_escapes() do - a = Ref("foo") # shouldn't be "return escape" - b = f_NoEscape_a(a) - nothing - end - i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) - @test has_no_escape(result.state[SSAValue(i)]) - - result = @eval M $code_escapes() do - a = Ref("foo") # still should be "return escape" - b = f_NoEscape_a(a) - return a - end - i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - end - - # should propagate escape information imposed on return value to the aliased call argument - let result = @eval Module() begin - @noinline f_ReturnEscape_a(a) = (println("prevent inlining"); a) - $code_escapes() do - obj = Ref("foo") # should be "return escape" - ret = f_ReturnEscape_a(obj) - return ret # alias of `obj` - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - end - let result = @eval Module() begin - @noinline f_NoReturnEscape_a(a) = (println("prevent inlining"); identity("hi")) - $code_escapes() do - obj = Ref("foo") # better to not be "return escape" - ret = f_NoReturnEscape_a(obj) - return ret # must not alias to `obj` - end - end - i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test !has_return_escape(result.state[SSAValue(i)], r) - end -end - -@testset "builtins" begin - let # throw - r = code_escapes((Any,)) do a - throw(a) - end - @test has_thrown_escape(r.state[Argument(2)]) - end - - let # implicit throws - r = code_escapes((Any,)) do a - getfield(a, :may_not_field) - end - @test has_thrown_escape(r.state[Argument(2)]) - - r = code_escapes((Any,)) do a - sizeof(a) - end - @test has_thrown_escape(r.state[Argument(2)]) - end - - let # :=== - result = code_escapes((Bool, String)) do cond, s - m = cond ? SafeRef(s) : nothing - c = m === nothing - return c - end - i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) - @test has_no_escape(result.state[SSAValue(i)]) - end - - let # sizeof - ary = [0,1,2] - result = @eval code_escapes() do - ary = $(QuoteNode(ary)) - sizeof(ary) - end - i = only(findall(isT(Core.Const(ary)), result.ir.stmts.type)) - @test has_no_escape(result.state[SSAValue(i)]) - end - - let # ifelse - result = code_escapes((Bool,)) do c - r = ifelse(c, Ref("yes"), Ref("no")) - return r - end - inds = findall(isT(Base.RefValue{String}), result.ir.stmts.type) - @assert !isempty(inds) - for i in inds - @test has_return_escape(result.state[SSAValue(i)]) - end - end - let # ifelse (with constant condition) - result = code_escapes() do - r = ifelse(true, Ref("yes"), Ref(nothing)) - return r - end - for i in 1:length(result.ir.stmts) - if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) - @test has_return_escape(result.state[SSAValue(i)]) - elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) - @test has_no_escape(result.state[SSAValue(i)]) - end - end - end - - let # typeassert - result = code_escapes((Any,)) do x - y = x::String - return y - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test !has_all_escape(result.state[Argument(2)]) - end - - let # isdefined - result = code_escapes((Any,)) do x - isdefined(x, :foo) ? x : throw("undefined") - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test !has_all_escape(result.state[Argument(2)]) - - result = code_escapes((Module,)) do m - isdefined(m, 10) # throws - end - @test has_thrown_escape(result.state[Argument(2)]) - end -end - -@testset "flow-sensitivity" begin - # ReturnEscape - let result = code_escapes((Bool,)) do cond - r = Ref("foo") - if cond - return cond - end - return r - end - i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) - rts = findall(isreturn, result.ir.stmts.inst) - @assert length(rts) == 2 - @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 - end - let result = code_escapes((Bool,)) do cond - r = Ref("foo") - cnt = 0 - while rand(Bool) - cnt += 1 - rand(Bool) && return r - end - rand(Bool) && return r - return cnt - end - i = only(findall(isT(Base.RefValue{String}), result.ir.stmts.type)) - rts = findall(isreturn, result.ir.stmts.inst) # return statement - @assert length(rts) == 3 - @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 - end -end - -@testset "escape through exceptions" begin - M = @eval Module() begin - unsafeget(x) = isassigned(x) ? x[] : throw(x) - @noinline function rethrow_escape!() - try - rethrow() - catch err - Gx[] = err - end - end - @noinline function current_exceptions_escape!() - excs = Base.current_exceptions() - Gx[] = excs - end - const Gx = Ref{Any}() - @__MODULE__ - end - - let # simple: return escape - result = @eval M $code_escapes() do - r = Ref{String}() - local ret - try - s = unsafeget(r) - ret = sizeof(s) - catch err - ret = err - end - return ret - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)]) - end - - let # simple: global escape - result = @eval M $code_escapes() do - r = Ref{String}() - local ret # prevent DCE - try - s = unsafeget(r) - ret = sizeof(s) - catch err - global g = err - end - nothing - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - - let # account for possible escapes via nested throws - result = @eval M $code_escapes() do - r = Ref{String}() - try - try - unsafeget(r) - catch err1 - throw(err1) - end - catch err2 - Gx[] = err2 - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # account for possible escapes via `rethrow` - result = @eval M $code_escapes() do - r = Ref{String}() - try - try - unsafeget(r) - catch err1 - rethrow(err1) - end - catch err2 - Gx[] = err2 - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # account for possible escapes via `rethrow` - result = @eval M $code_escapes() do - try - r = Ref{String}() - unsafeget(r) - catch - rethrow_escape!() - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # account for possible escapes via `rethrow` - result = @eval M $code_escapes() do - local t - try - r = Ref{String}() - t = unsafeget(r) - catch err - t = typeof(err) - rethrow_escape!() - end - return t - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # account for possible escapes via `Base.current_exceptions` - result = @eval M $code_escapes() do - try - r = Ref{String}() - unsafeget(r) - catch - Gx[] = Base.current_exceptions() - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - let # account for possible escapes via `Base.current_exceptions` - result = @eval M $code_escapes() do - try - r = Ref{String}() - unsafeget(r) - catch - current_exceptions_escape!() - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - end - - let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` - result = @eval M $code_escapes() do - r1 = Ref{String}() - r2 = Ref{String}() - local ret - try - s1 = unsafeget(r1) - ret = sizeof(s1) - catch err - global g = err - end - s2 = unsafeget(r2) - return s2, r2 - end - is = findall(isnew, result.ir.stmts.inst) - @test length(is) == 2 - i1, i2 = is - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i1)]) - @test !has_all_escape(result.state[SSAValue(i2)]) - @test has_return_escape(result.state[SSAValue(i2)], r) - end - - # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` - - let # limited propagation: exception is caught within a frame => doesn't escape to a caller - result = @eval M $code_escapes() do - r = Ref{String}() - local ret - try - s = unsafeget(r) - ret = sizeof(s) - catch - ret = nothing - end - return ret - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_return_escape(result.state[SSAValue(i)], r) - end - let # sequential: escape information imposed on `err1` and `err2 should propagate separately - result = @eval M $code_escapes() do - r1 = Ref{String}() - r2 = Ref{String}() - local ret - try - s1 = unsafeget(r1) - ret = sizeof(s1) - catch err1 - global g = err1 - end - try - s2 = unsafeget(r2) - ret = sizeof(s2) - catch err2 - ret = err2 - end - return ret - end - is = findall(isnew, result.ir.stmts.inst) - @test length(is) == 2 - i1, i2 = is - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i1)]) - @test has_return_escape(result.state[SSAValue(i2)], r) - @test_broken !has_all_escape(result.state[SSAValue(i2)]) - end - let # nested: escape information imposed on `inner` shouldn't propagate to `s` - result = @eval M $code_escapes() do - r = Ref{String}() - local ret - try - s = unsafeget(r) - try - ret = sizeof(s) - catch inner - return inner - end - catch outer - ret = nothing - end - return ret - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test_broken !has_return_escape(result.state[SSAValue(i)]) - end - let # merge: escape information imposed on `err1` and `err2 should be merged - result = @eval M $code_escapes() do - r = Ref{String}() - local ret - try - s = unsafeget(r) - ret = sizeof(s) - catch err1 - return err1 - end - try - s = unsafeget(r) - ret = sizeof(s) - catch err2 - return err2 - end - nothing - end - i = only(findall(isnew, result.ir.stmts.inst)) - rs = findall(isreturn, result.ir.stmts.inst) - @test_broken !has_all_escape(result.state[SSAValue(i)]) - for r in rs - @test has_return_escape(result.state[SSAValue(i)], r) - end - end - let # no exception handling: should keep propagating the escape - result = @eval M $code_escapes() do - r = Ref{String}() - local ret - try - s = unsafeget(r) - ret = sizeof(s) - finally - if !@isdefined(ret) - ret = 42 - end - end - return ret - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_return_escape(result.state[SSAValue(i)], r) - end -end - -@testset "field analysis / alias analysis" begin - # escaped allocations - # ------------------- - - # escaped object should escape its fields as well - let result = code_escapes((Any,)) do a - global g = SafeRef{Any}(a) - nothing - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - @test has_all_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do a - global g = (a,) - nothing - end - i = only(findall(issubT(Tuple), result.ir.stmts.type)) - @test has_all_escape(result.state[SSAValue(i)]) - @test has_all_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do a - o0 = SafeRef{Any}(a) - global g = SafeRef(o0) - nothing - end - i0 = only(findall(isT(SafeRef{Any}), result.ir.stmts.type)) - i1 = only(findall(isT(SafeRef{SafeRef{Any}}), result.ir.stmts.type)) - @test has_all_escape(result.state[SSAValue(i0)]) - @test has_all_escape(result.state[SSAValue(i1)]) - @test has_all_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do a - t0 = (a,) - global g = (t0,) - nothing - end - inds = findall(issubT(Tuple), result.ir.stmts.type) - @assert length(inds) == 2 - for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end - @test has_all_escape(result.state[Argument(2)]) - end - # global escape through `setfield!` - let result = code_escapes((Any,)) do a - r = SafeRef{Any}(:init) - global g = r - r[] = a - nothing - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - @test has_all_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,Any)) do a, b - r = SafeRef{Any}(a) - global g = r - r[] = b - nothing - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test has_all_escape(result.state[SSAValue(i)]) - @test has_all_escape(result.state[Argument(2)]) # a - @test has_all_escape(result.state[Argument(3)]) # b - end - let result = @eval EATModule() begin - const Rx = SafeRef{String}("Rx") - $code_escapes((String,)) do s - Rx[] = s - Core.sizeof(Rx[]) - end - end - @test has_all_escape(result.state[Argument(2)]) - end - let result = @eval EATModule() begin - const Rx = SafeRef{String}("Rx") - $code_escapes((String,)) do s - setfield!(Rx, :x, s) - Core.sizeof(Rx[]) - end - end - @test has_all_escape(result.state[Argument(2)]) - end - let M = EATModule() - @eval M module ___xxx___ - import ..SafeRef - const Rx = SafeRef("Rx") - end - result = @eval M begin - $code_escapes((String,)) do s - rx = getfield(___xxx___, :Rx) - rx[] = s - nothing - end - end - @test has_all_escape(result.state[Argument(2)]) - end - - # field escape - # ------------ - - # field escape should propagate to :new arguments - let result = code_escapes((String,)) do a - o = SafeRef(a) - f = o[] - return f - end - i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - let result = code_escapes((String,)) do a - t = (a,) - f = t[1] - return f - end - i = only(findall(t->t<:Tuple, result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - let result = code_escapes((String, String)) do a, b - obj = SafeRefs(a, b) - fld1 = obj[1] - fld2 = obj[2] - return (fld1, fld2) - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) - end - - # field escape should propagate to `setfield!` argument - let result = code_escapes((String,)) do a - o = SafeRef("foo") - o[] = a - f = o[] - return f - end - i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - # propagate escape information imposed on return value of `setfield!` call - let result = code_escapes((String,)) do a - obj = SafeRef("foo") - return (obj[] = a) - end - i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - - # nested allocations - let result = code_escapes((String,)) do a - o1 = SafeRef(a) - o2 = SafeRef(o1) - return o2[] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - for i in 1:length(result.ir.stmts) - if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) - @test has_return_escape(result.state[SSAValue(i)], r) - elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - end - let result = code_escapes((String,)) do a - o1 = (a,) - o2 = (o1,) - return o2[1] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - for i in 1:length(result.ir.stmts) - if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) - @test has_return_escape(result.state[SSAValue(i)], r) - elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - end - let result = code_escapes((String,)) do a - o1 = SafeRef(a) - o2 = SafeRef(o1) - o1′ = o2[] - a′ = o1′[] - return a′ - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - let result = code_escapes() do - o1 = SafeRef("foo") - o2 = SafeRef(o1) - return o2 - end - r = only(findall(isreturn, result.ir.stmts.inst)) - for i in findall(isnew, result.ir.stmts.inst) - @test has_return_escape(result.state[SSAValue(i)], r) - end - end - let result = code_escapes() do - o1 = SafeRef("foo") - o2′ = SafeRef(nothing) - o2 = SafeRef{SafeRef}(o2′) - o2[] = o1 - return o2 - end - r = only(findall(isreturn, result.ir.stmts.inst)) - findall(1:length(result.ir.stmts)) do i - if isnew(result.ir.stmts[i][:inst]) - t = result.ir.stmts[i][:type] - return t === SafeRef{String} || # o1 - t === SafeRef{SafeRef} # o2 - end - return false - end |> x->foreach(x) do i - @test has_return_escape(result.state[SSAValue(i)], r) - end - end - let result = code_escapes((String,)) do x - broadcast(identity, Ref(x)) - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - - # ϕ-node allocations - let result = code_escapes((Bool,Any,Any)) do cond, x, y - if cond - ϕ = SafeRef{Any}(x) - else - ϕ = SafeRef{Any}(y) - end - return ϕ[] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], r) # x - @test has_return_escape(result.state[Argument(4)], r) # y - i = only(findall(isϕ, result.ir.stmts.inst)) - @test is_load_forwardable(result.state[SSAValue(i)]) - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - let result = code_escapes((Bool,Any,Any)) do cond, x, y - if cond - ϕ2 = ϕ1 = SafeRef{Any}(x) - else - ϕ2 = ϕ1 = SafeRef{Any}(y) - end - return ϕ1[], ϕ2[] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], r) # x - @test has_return_escape(result.state[Argument(4)], r) # y - for i in findall(isϕ, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - # when ϕ-node merges values with different types - let result = code_escapes((Bool,String,String,String)) do cond, x, y, z - local out - if cond - ϕ = SafeRef(x) - out = ϕ[] - else - ϕ = SafeRefs(z, y) - end - return @isdefined(out) ? out : throw(ϕ) - end - r = only(findall(isreturn, result.ir.stmts.inst)) - t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) - ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) - @test has_return_escape(result.state[Argument(3)], r) # x - @test !has_return_escape(result.state[Argument(4)], r) # y - @test has_return_escape(result.state[Argument(5)], r) # z - @test has_thrown_escape(result.state[SSAValue(ϕ)], t) - end - - # alias analysis - # -------------- - - # alias via getfield & Expr(:new) - let result = @eval EATModule() begin - const Rx = SafeRef("Rx") - $code_escapes((String,)) do s - r = SafeRef(Rx) - rx = r[] # rx aliased to Rx - rx[] = s - nothing - end - end - i = findfirst(isnew, result.ir.stmts.inst) - @test has_all_escape(result.state[Argument(2)]) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - # alias via getfield & setfield! - let result = @eval EATModule() begin - const Rx = SafeRef("Rx") - $code_escapes((SafeRef{String}, String,)) do _rx, s - r = SafeRef(_rx) - r[] = Rx - rx = r[] # rx aliased to Rx - rx[] = s - nothing - end - end - i = findfirst(isnew, result.ir.stmts.inst) - @test has_all_escape(result.state[Argument(3)]) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - # alias via typeassert - let result = code_escapes((Any,)) do a - global g - (g::SafeRef{Any})[] = a - nothing - end - @test has_all_escape(result.state[Argument(2)]) - end - # alias via ifelse - let result = @eval EATModule() begin - const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") - $code_escapes((Bool,String,)) do c, a - r = ifelse(c, Lx, Rx) - r[] = a - nothing - end - end - @test has_all_escape(result.state[Argument(3)]) # a - end - # alias via ϕ-node - let result = code_escapes((Bool,String)) do cond, x - if cond - ϕ2 = ϕ1 = SafeRef("foo") - else - ϕ2 = ϕ1 = SafeRef("bar") - end - ϕ2[] = x - return ϕ1[] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], r) # x - for i in findall(isϕ, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x - if cond1 - ϕ2 = ϕ1 = SafeRef("foo") - else - ϕ2 = ϕ1 = SafeRef("bar") - end - cond2 && (ϕ2[] = x) - return ϕ1[] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(4)], r) # x - for i in findall(isϕ, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end - end - # alias via π-node - let result = code_escapes((String,)) do x - global g - l = g - if isa(l, SafeRef{String}) - l[] = x - end - nothing - end - @test has_all_escape(result.state[Argument(2)]) # x - end - - # dynamic semantics - # ----------------- - - # conservatively handle untyped objects - let result = @eval code_escapes((Any,Any,)) do T, x - obj = $(Expr(:new, :T, :x)) - end - t = only(findall(isnew, result.ir.stmts.inst)) - @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T - @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x - end - let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z - obj = $(Expr(:new, :T, :x, :y)) - return getfield(obj, :x) - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test #=x=# has_return_escape(result.state[Argument(3)], r) - @test #=y=# has_return_escape(result.state[Argument(4)], r) - @test #=z=# !has_return_escape(result.state[Argument(5)], r) - end - let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z - obj = $(Expr(:new, :T, :x)) - setfield!(obj, :x, y) - return getfield(obj, :x) - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test #=x=# has_return_escape(result.state[Argument(3)], r) - @test #=y=# has_return_escape(result.state[Argument(4)], r) - @test #=z=# !has_return_escape(result.state[Argument(5)], r) - end - - # conservatively handle unknown field: - # all fields should be escaped, but the allocation itself doesn't need to be escaped - let result = code_escapes((String, Symbol)) do a, fld - obj = SafeRef(a) - return getfield(obj, fld) - end - i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - let result = code_escapes((String, String, Symbol)) do a, b, fld - obj = SafeRefs(a, b) - return getfield(obj, fld) # should escape both `a` and `b` - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - let result = code_escapes((String, String, Int)) do a, b, idx - obj = SafeRefs(a, b) - return obj[idx] # should escape both `a` and `b` - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - let result = code_escapes((String, String, Symbol)) do a, b, fld - obj = SafeRefs("a", "b") - setfield!(obj, fld, a) - return obj[2] # should escape `a` - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test !has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - let result = code_escapes((String, Symbol)) do a, fld - obj = SafeRefs("a", "b") - setfield!(obj, fld, a) - return obj[1] # this should escape `a` - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - let result = code_escapes((String, String, Int)) do a, b, idx - obj = SafeRefs("a", "b") - obj[idx] = a - return obj[2] # should escape `a` - end - i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # a - @test !has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) # obj - end - - # interprocedural - # --------------- - - let result = @eval EATModule() begin - @noinline getx(obj) = obj[] - $code_escapes((String,)) do a - obj = SafeRef(a) - fld = getx(obj) - return fld - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) - # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it - @test_broken is_load_forwardable(result.state[SSAValue(i)]) - end - - # TODO interprocedural field analysis - let result = code_escapes((SafeRef{String},)) do s - s[] = "bar" - global g = s[] - nothing - end - @test_broken !has_all_escape(result.state[Argument(2)]) - end - - # TODO flow-sensitivity? - # ---------------------- - - let result = code_escapes((Any,Any)) do a, b - r = SafeRef{Any}(a) - r[] = b - return r[] - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_return_escape(result.state[Argument(2)], r) # a - @test has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) - end - let result = code_escapes((Any,Any)) do a, b - r = SafeRef{Any}(:init) - r[] = a - r[] = b - return r[] - end - i = only(findall(isnew, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_return_escape(result.state[Argument(2)], r) # a - @test has_return_escape(result.state[Argument(3)], r) # b - @test is_load_forwardable(result.state[SSAValue(i)]) - end - let result = code_escapes((Any,Any,Bool)) do a, b, cond - r = SafeRef{Any}(:init) - if cond - r[] = a - return r[] - else - r[] = b - return nothing - end - end - i = only(findall(isnew, result.ir.stmts.inst)) - @test is_load_forwardable(result.state[SSAValue(i)]) - r = only(findall(result.ir.stmts.inst) do @nospecialize x - isreturn(x) && isa(x.val, Core.SSAValue) - end) - @test has_return_escape(result.state[Argument(2)], r) # a - @test_broken !has_return_escape(result.state[Argument(3)], r) # b - end - - # handle conflicting field information correctly - let result = code_escapes((Bool,String,String,)) do cnd, baz, qux - if cnd - o = SafeRef("foo") - else - o = SafeRefs("bar", baz) - r = getfield(o, 2) - end - if cnd - o = o::SafeRef - setfield!(o, 1, qux) - r = getfield(o, 1) - end - r - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], r) # baz - @test has_return_escape(result.state[Argument(4)], r) # qux - for new in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(new)]) - end - end - let result = code_escapes((Bool,String,String,)) do cnd, baz, qux - if cnd - o = SafeRefs("foo", "bar") - r = setfield!(o, 2, baz) - else - o = SafeRef(qux) - end - if !cnd - o = o::SafeRef - r = getfield(o, 1) - end - r - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(3)], r) # baz - @test has_return_escape(result.state[Argument(4)], r) # qux - end -end - -# demonstrate the power of our field / alias analysis with a realistic end to end example -abstract type AbstractPoint{T} end -mutable struct MPoint{T} <: AbstractPoint{T} - x::T - y::T -end -add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) -function compute(T, ax, ay, bx, by) - a = T(ax, ay) - b = T(bx, by) - for i in 0:(100000000-1) - a = add(add(a, b), b) - end - a.x, a.y -end -function compute(a, b) - for i in 0:(100000000-1) - a = add(add(a, b), b) - end - a.x, a.y -end -function compute!(a, b) - for i in 0:(100000000-1) - a′ = add(add(a, b), b) - a.x = a′.x - a.y = a′.y - end -end -let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end -end -let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end -end -let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) - for i in findall(isnew, result.ir.stmts.inst) - @test is_load_forwardable(result.state[SSAValue(i)]) - end -end - -@testset "array primitives" begin - # arrayref - let result = code_escapes((Vector{String},Int)) do xs, i - s = Base.arrayref(true, xs, i) - return s - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # xs - @test_broken !has_thrown_escape(result.state[Argument(2)]) # xs - @test !has_return_escape(result.state[Argument(3)], r) # i - end - let result = code_escapes((Vector{String},Bool)) do xs, i - c = Base.arrayref(true, xs, i) # TypeError will happen here - return c - end - t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - let result = code_escapes((String,Int)) do xs, i - c = Base.arrayref(true, xs, i) # TypeError will happen here - return c - end - t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - let result = code_escapes((AbstractVector{String},Int)) do xs, i - c = Base.arrayref(true, xs, i) # TypeError may happen here - return c - end - t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - let result = code_escapes((Vector{String},Any)) do xs, i - c = Base.arrayref(true, xs, i) # TypeError may happen here - return c - end - t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - - # arrayset - let result = code_escapes((Vector{String},String,Int,)) do xs, x, i - Base.arrayset(true, xs, x, i) - return xs - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[Argument(2)], r) # xs - @test has_return_escape(result.state[Argument(3)], r) # x - end - let result = code_escapes((String,String,String,)) do s, t, u - xs = Vector{String}(undef, 3) - Base.arrayset(true, xs, s, 1) - Base.arrayset(true, xs, t, 2) - Base.arrayset(true, xs, u, 3) - return xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - for i in 2:result.state.nargs - @test has_return_escape(result.state[Argument(i)], r) - end - end - let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i - Base.arrayset(true, xs, x, i) # TypeError will happen here - return xs - end - t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - @test has_thrown_escape(result.state[Argument(3)], t) # x - end - let result = code_escapes((String,String,Int,)) do xs, x, i - Base.arrayset(true, xs, x, i) # TypeError will happen here - return xs - end - t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs::String - @test has_thrown_escape(result.state[Argument(3)], t) # x::String - end - let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i - Base.arrayset(true, xs, x, i) # TypeError may happen here - return xs - end - t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - @test has_thrown_escape(result.state[Argument(3)], t) # x - end - let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i - Base.arrayset(true, xs, x, i) # TypeError may happen here - return xs - end - t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - @test has_thrown_escape(result.state[Argument(3)], t) # x - end - let result = code_escapes((Vector{Any},AbstractString,Int,)) do xs, x, i - Base.arrayset(true, xs, x, i) # TypeError may happen here - return xs - end - t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) # xs - @test !has_thrown_escape(result.state[Argument(3)], t) # x - end - - # arrayref and arrayset - let result = code_escapes() do - a = Vector{Vector{Any}}(undef, 1) - b = Any[] - a[1] = b - return a[1] - end - r = only(findall(isreturn, result.ir.stmts.inst)) - ai = only(findall(result.ir.stmts.inst) do @nospecialize x - isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} - end) - bi = only(findall(result.ir.stmts.inst) do @nospecialize x - isarrayalloc(x) && x.args[2] === Vector{Any} - end) - @test !has_return_escape(result.state[SSAValue(ai)], r) - @test has_return_escape(result.state[SSAValue(bi)], r) - end - let result = code_escapes() do - a = Vector{Vector{Any}}(undef, 1) - b = Any[] - a[1] = b - return a - end - r = only(findall(isreturn, result.ir.stmts.inst)) - ai = only(findall(result.ir.stmts.inst) do @nospecialize x - isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} - end) - bi = only(findall(result.ir.stmts.inst) do @nospecialize x - isarrayalloc(x) && x.args[2] === Vector{Any} - end) - @test has_return_escape(result.state[SSAValue(ai)], r) - @test has_return_escape(result.state[SSAValue(bi)], r) - end - let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j - x = SafeRef(s) - xs[i] = x - xs[j] # potential error - end - i = only(findall(isnew, result.ir.stmts.inst)) - t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(3)], t) # s - @test has_thrown_escape(result.state[SSAValue(i)], t) # x - end - - # arraysize - let result = code_escapes((Vector{Any},)) do xs - Core.arraysize(xs, 1) - end - t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) - end - let result = code_escapes((Vector{Any},Int,)) do xs, dim - Core.arraysize(xs, dim) - end - t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) - end - let result = code_escapes((Any,)) do xs - Core.arraysize(xs, 1) - end - t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) - end - - # arraylen - let result = code_escapes((Vector{Any},)) do xs - Base.arraylen(xs) - end - t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) # xs - end - let result = code_escapes((String,)) do xs - Base.arraylen(xs) - end - t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - let result = code_escapes((Vector{Any},)) do xs - Base.arraylen(xs, 1) - end - t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) - @test has_thrown_escape(result.state[Argument(2)], t) # xs - end - - # array resizing - # without BoundsErrors - let result = code_escapes((Vector{Any},String)) do xs, x - @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid - xs[1] = x - xs - end - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) # xs - @test !has_thrown_escape(result.state[Argument(3)], t) # x - end - let result = code_escapes((Vector{Any},String)) do xs, x - @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid - xs[1] = x - xs - end - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[Argument(2)], t) # xs - @test !has_thrown_escape(result.state[Argument(3)], t) # x - end - # with possible BoundsErrors - let result = code_escapes((String,)) do x - xs = Any[1,2,3] - xs[3] = x - @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw - xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs - @test has_thrown_escape(result.state[Argument(2)], t) # x - end - let result = code_escapes((String,)) do x - xs = Any[1,2,3] - xs[1] = x - @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw - xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs - @test has_thrown_escape(result.state[Argument(2)], t) # x - end - let result = code_escapes((String,)) do x - xs = Any[x] - @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs - @test has_thrown_escape(result.state[Argument(2)], t) # x - end - let result = code_escapes((String,)) do x - xs = Any[x] - @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - t = only(findall(isarrayresize, result.ir.stmts.inst)) - @test !has_thrown_escape(result.state[SSAValue(i)], t) # xs - @test has_thrown_escape(result.state[Argument(2)], t) # x - end - - # array copy - let result = code_escapes((Vector{Any},)) do xs - return copy(xs) - end - i = only(findall(isarraycopy, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - @test_broken !has_return_escape(result.state[Argument(2)], r) - end - let result = code_escapes((String,)) do s - xs = String[s] - xs′ = copy(xs) - return xs′[1] - end - i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) - i2 = only(findall(isarraycopy, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test !has_return_escape(result.state[SSAValue(i1)]) - @test !has_return_escape(result.state[SSAValue(i2)]) - @test has_return_escape(result.state[Argument(2)], r) # s - end - let result = code_escapes((Vector{Any},)) do xs - xs′ = copy(xs) - return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) - end - i = only(findall(isarraycopy, result.ir.stmts.inst)) - ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) - ret = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) - @test_broken !has_return_escape(result.state[SSAValue(i)], ret) - @test has_thrown_escape(result.state[Argument(2)], ref) - @test has_return_escape(result.state[Argument(2)], ret) - end - let result = code_escapes((String,)) do s - xs = Vector{String}(undef, 1) - xs[1] = s - xs′ = copy(xs) - length(xs′) > 2 && throw(xs′) - return xs′ - end - i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) - i2 = only(findall(isarraycopy, result.ir.stmts.inst)) - t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) - @test_broken !has_return_escape(result.state[SSAValue(i1)], r) - @test has_thrown_escape(result.state[SSAValue(i2)], t) - @test has_return_escape(result.state[SSAValue(i2)], r) - @test has_thrown_escape(result.state[Argument(2)], t) - @test has_return_escape(result.state[Argument(2)], r) - end - - # isassigned - let result = code_escapes((Vector{Any},Int)) do xs, i - return isassigned(xs, i) - end - r = only(findall(isreturn, result.ir.stmts.inst)) - @test !has_return_escape(result.state[Argument(2)], r) - @test !has_thrown_escape(result.state[Argument(2)]) - end -end - -# demonstrate array primitive support with a realistic end to end example -let result = code_escapes((Int,String,)) do n,s - xs = String[] - for i in 1:n - push!(xs, s) - end - xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - @test !has_thrown_escape(result.state[SSAValue(i)]) - @test has_return_escape(result.state[Argument(3)], r) # s - @test !has_thrown_escape(result.state[Argument(3)]) # s -end -let result = code_escapes((Int,String,)) do n,s - xs = String[] - for i in 1:n - pushfirst!(xs, s) - end - xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - @test !has_thrown_escape(result.state[SSAValue(i)]) - @test has_return_escape(result.state[Argument(3)], r) # s - @test !has_thrown_escape(result.state[Argument(3)]) # s -end -let result = code_escapes((String,String,String)) do s, t, u - xs = String[] - resize!(xs, 3) - xs[1] = s - xs[1] = t - xs[1] = u - xs - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test has_return_escape(result.state[SSAValue(i)], r) - @test !has_thrown_escape(result.state[SSAValue(i)]) - @test has_return_escape(result.state[Argument(2)], r) # s - @test has_return_escape(result.state[Argument(3)], r) # t - @test has_return_escape(result.state[Argument(4)], r) # u -end - -@static if isdefined(Core, :ImmutableArray) - -import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw - -@testset "ImmutableArray" begin - # arrayfreeze - let result = code_escapes((Vector{Any},)) do xs - arrayfreeze(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Vector,)) do xs - arrayfreeze(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do xs - arrayfreeze(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((ImmutableArray{Any,1},)) do xs - arrayfreeze(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes() do - xs = Any[] - arrayfreeze(xs) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test has_no_escape(result.state[SSAValue(1)]) - end - - # mutating_arrayfreeze - let result = code_escapes((Vector{Any},)) do xs - mutating_arrayfreeze(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Vector,)) do xs - mutating_arrayfreeze(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do xs - mutating_arrayfreeze(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((ImmutableArray{Any,1},)) do xs - mutating_arrayfreeze(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes() do - xs = Any[] - mutating_arrayfreeze(xs) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test has_no_escape(result.state[SSAValue(1)]) - end - - # arraythaw - let result = code_escapes((ImmutableArray{Any,1},)) do xs - arraythaw(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((ImmutableArray,)) do xs - arraythaw(xs) - end - @test !has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Any,)) do xs - arraythaw(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes((Vector{Any},)) do xs - arraythaw(xs) - end - @test has_thrown_escape(result.state[Argument(2)]) - end - let result = code_escapes() do - xs = ImmutableArray(Any[]) - arraythaw(xs) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test has_no_escape(result.state[SSAValue(1)]) - end -end - -# demonstrate some arrayfreeze optimizations -# has_no_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization -let result = code_escapes((Int,)) do n - xs = collect(1:n) - ImmutableArray(xs) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test has_no_escape(result.state[SSAValue(i)]) -end -let result = code_escapes((Vector{Float64},)) do xs - ys = sin.(xs) - ImmutableArray(ys) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test_broken has_no_escape(result.state[SSAValue(i)]) -end -let result = code_escapes((Vector{Pair{Int,String}},)) do xs - n = maximum(first, xs) - ys = Vector{String}(undef, n) - for (i, s) in xs - ys[i] = s - end - ImmutableArray(xs) - end - i = only(findall(isarrayalloc, result.ir.stmts.inst)) - @test has_no_escape(result.state[SSAValue(i)]) -end - -end # @static if isdefined(Core, :ImmutableArray) - -# demonstrate a simple type level analysis can sometimes improve the analysis accuracy -# by compensating the lack of yet unimplemented analyses -@testset "special-casing bitstype" begin - let result = code_escapes((Nothing,)) do a - global bb = a - end - @test !(has_all_escape(result.state[Argument(2)])) - end - - let result = code_escapes((Int,)) do a - o = SafeRef(a) - f = o[] - return f - end - i = only(findall(isT(SafeRef{Int}), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test !has_return_escape(result.state[SSAValue(i)], r) - end - - # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) - let result = code_escapes((Int,Any,)) do a, b - t = tuple(a, b) - return t - end - i = only(findall(issubT(Tuple), result.ir.stmts.type)) - r = only(findall(isreturn, result.ir.stmts.inst)) - @test !has_return_escape(result.state[Argument(2)], r) - @test has_return_escape(result.state[Argument(3)], r) - end -end - -@testset "finalizer elision" begin - @test can_elide_finalizer(EscapeAnalysis.NoEscape(), 1) - @test !can_elide_finalizer(EscapeAnalysis.ReturnEscape(1), 1) - @test can_elide_finalizer(EscapeAnalysis.ReturnEscape(1), 2) - @test !can_elide_finalizer(EscapeAnalysis.ArgumentReturnEscape(), 1) - @test can_elide_finalizer(EscapeAnalysis.ThrownEscape(1), 1) -end - -# # TODO implement a finalizer elision pass -# mutable struct WithFinalizer -# v -# function WithFinalizer(v) -# x = new(v) -# f(t) = @async println("Finalizing $t.") -# return finalizer(x, x) -# end -# end -# make_m(v = 10) = MyMutable(v) -# function simple(cond) -# m = make_m() -# if cond -# # println(m.v) -# return nothing # <= insert `finalize` call here -# end -# return m -# end - -@testset "code quality" begin - # assert that our main routine are free from (unnecessary) runtime dispatches - - function function_filter(@nospecialize(ft)) - ft === typeof(Core.Compiler.widenconst) && return false # `widenconst` is very untyped, ignore - ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore - return true - end - target_modules = (EscapeAnalysis,) - test_opt(only(methods(EscapeAnalysis.analyze_escapes)).sig; - function_filter, - target_modules, - # skip_nonconcrete_calls=false, - ) - for m in methods(EscapeAnalysis.escape_builtin!) - Base._methods_by_ftype(m.sig, 1, Base.get_world_counter()) === false && continue - test_opt(m.sig; - function_filter, - target_modules, - # skip_nonconcrete_calls=false, - ) - end -end - -end # @testset "EscapeAnalysis" begin diff --git a/test/setup.jl b/test/setup.jl index ad7fe81..da26f73 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -1,4 +1,11 @@ -using EscapeAnalysis, Test, JET +using Test +if @isdefined(EA_AS_PKG) + using EscapeAnalysis +else + using Core.Compiler.EscapeAnalysis + import Base: code_escapes + import InteractiveUtils: @code_escapes +end import Core: Argument, SSAValue @static if isdefined(Core.Compiler, :alloc_array_ndims)