@@ -27,22 +27,22 @@ function DEQChain(layers...)
2727 push! (encounter_deq ? post_deq : pre_deq, l)
2828 end
2929 @assert encounter_deq " No DEQ Layer in the Chain!!! Maybe you wanted to use Chain"
30- pre_deq = length (pre_deq) == 0 ? nothing : Chain (pre_deq... )
31- post_deq = length (post_deq) == 0 ? nothing : Chain (post_deq... )
30+ pre_deq = length (pre_deq) == 0 ? NoOpLayer () : Chain (pre_deq... )
31+ post_deq = length (post_deq) == 0 ? NoOpLayer () : Chain (post_deq... )
3232 return DEQChain (pre_deq, deq, post_deq)
3333end
3434
35- function (deq :: DEQChain{P1,D,P2} )(x, ps :: Union{ComponentArray,NamedTuple} , st :: NamedTuple ) where {P1,D,P2}
36- x1, st1 = if P1 == Nothing
37- x, st . pre_deq
38- else
39- deq . pre_deq (x, ps . pre_deq, st . pre_deq)
40- end
41- (x2, deq_soln), st2 = deq . deq (x1, ps . deq, st . deq)
42- x3, st3 = if P2 == Nothing
43- x2, st . post_deq
44- else
45- deq. post_deq (x2 , ps. post_deq , st. post_deq )
46- end
35+ function get_deq_return_type (
36+ deq :: DEQChain{P1,<:Union{MultiScaleDeepEquilibriumNetwork,MultiScaleSkipDeepEquilibriumNetwork}} , :: T
37+ ) where {P1,T}
38+ return NTuple{ length (deq . deq . scales),T}
39+ end
40+ get_deq_return_type ( :: DEQChain , :: T ) where {T} = T
41+
42+ function (deq :: DEQChain )(x, ps :: Union{ComponentArray,NamedTuple} , st :: NamedTuple )
43+ T = get_deq_return_type (deq, x)
44+ x1, st1 = deq . pre_deq (x, ps . pre_deq, st . pre_deq)
45+ (x2 :: T , deq_soln), st2 = deq. deq (x1 , ps. deq , st. deq )
46+ x3, st3 = deq . post_deq (x2, ps . post_deq, st . post_deq)
4747 return (x3, deq_soln), (pre_deq= st1, deq= st2, post_deq= st3)
4848end
0 commit comments