diff --git a/src/LightSumTypes.jl b/src/LightSumTypes.jl index 3ea8ab1..c2f3503 100644 --- a/src/LightSumTypes.jl +++ b/src/LightSumTypes.jl @@ -3,7 +3,7 @@ module LightSumTypes using MacroTools: namify -export @sumtype, sumtype_expr, variant, variantof, allvariants, is_sumtype +export @sumtype, sumtype_expr, variant, variantof, allvariants, is_sumtype, apply unwrap(sumt) = getfield(sumt, :variants) @@ -194,6 +194,55 @@ is_sumtype(T::Type) = false function variant_idx end +function _is_sumtype_structurally(T) + return T isa DataType && fieldcount(T) == 1 && fieldname(T, 1) === :variants && fieldtype(T, 1) isa Union +end + +function _get_variant_types(T_sum) + field_T = fieldtype(T_sum, 1) + types = [] + curr = field_T + while curr isa Union + push!(types, curr.a) + curr = curr.b + end + push!(types, curr) + return types +end + +@generated function apply(f::F, args::Tuple) where {F} + + args = fieldtypes(args) + sumtype_args = [(i, T) for (i, T) in enumerate(args) if _is_sumtype_structurally(T)] + + final_args = Any[:(args[$i]) for i in 1:length(args)] + for (idx, T) in sumtype_args + final_args[idx] = Symbol("v_", idx) + end + + body = :(f($(final_args...))) + + for (idx, T) in reverse(sumtype_args) + unwrapped_var = Symbol("v_", idx) + + variant_types = _get_variant_types(T) + + branch_expr = :(error("THIS_SHOULD_BE_UNREACHABLE")) + for V_type in reverse(variant_types) + condition = :($unwrapped_var isa $V_type) + branch_expr = Expr(:elseif, condition, body, branch_expr) + end + branch_expr = Expr(:if, branch_expr.args...) + + body = quote + let $(unwrapped_var) = $LightSumTypes.unwrap(args[$idx]) + $branch_expr + end + end + end + return body +end + include("precompile.jl") end