Skip to content

Commit edbb7a3

Browse files
committed
Make post_deq type-stable
1 parent 1cd9a6e commit edbb7a3

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/layers/chain.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@ function DEQChain(layers...)
2727
push!(encounter_deq ? post_deq : pre_deq, l)
2828
end
2929
@assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
30-
pre_deq = length(pre_deq) == 0 ? nothing : Chain(pre_deq...)
31-
post_deq = length(post_deq) == 0 ? nothing : Chain(post_deq...)
30+
pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...)
31+
post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...)
3232
return DEQChain(pre_deq, deq, post_deq)
3333
end
3434

35-
function (deq::DEQChain{P1,D,P2})(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) where {P1,D,P2}
36-
x1, st1 = if P1 == Nothing
37-
x, st.pre_deq
38-
else
39-
deq.pre_deq(x, ps.pre_deq, st.pre_deq)
40-
end
41-
(x2, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
42-
x3, st3 = if P2 == Nothing
43-
x2, st.post_deq
44-
else
45-
deq.post_deq(x2, ps.post_deq, st.post_deq)
46-
end
35+
function get_deq_return_type(
36+
deq::DEQChain{P1,<:Union{MultiScaleDeepEquilibriumNetwork,MultiScaleSkipDeepEquilibriumNetwork}}, ::T
37+
) where {P1,T}
38+
return NTuple{length(deq.deq.scales),T}
39+
end
40+
get_deq_return_type(::DEQChain, ::T) where {T} = T
41+
42+
function (deq::DEQChain)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple)
43+
T = get_deq_return_type(deq, x)
44+
x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq)
45+
(x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq)
46+
x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq)
4747
return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3)
4848
end

0 commit comments

Comments
 (0)