diff --git a/src/aggregation.jl b/src/aggregation.jl index 79b690b..a2b768c 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -59,7 +59,7 @@ function transform_with(flag::LogJacFlag, t::ArrayTransform, x::RealVector) d = dimension(transformation) I = reshape(range(firstindex(x); length = prod(dims), step = d), dims) yℓ = map(i -> transform_with(flag, transformation, view_into(x, i, d)), I) - ℓz = logjac_zero(flag, eltype(x)) + ℓz = logjac_zero(flag, extended_eltype(x)) first.(yℓ), isempty(yℓ) ? ℓz : ℓz + sum(last, yℓ) end @@ -67,7 +67,7 @@ function transform_with(flag::LogJacFlag, t::ArrayTransform{Identity}, x::RealVe # TODO use version below when https://github.com/FluxML/Flux.jl/issues/416 is fixed # y = reshape(copy(x), t.dims) y = reshape(map(identity, x), t.dims) - y, logjac_zero(flag, eltype(x)) + y, logjac_zero(flag, extended_eltype(x)) end inverse_eltype(t::ArrayTransform, x::AbstractArray) = @@ -152,7 +152,8 @@ $(SIGNATURES) Helper function for transforming tuples. Used internally, to help type inference. Use via `transfom_tuple`. """ -_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) = (), logjac_zero(flag, eltype(x)) +_transform_tuple(flag::LogJacFlag, x::RealVector, index, ::Tuple{}) = + (), logjac_zero(flag, extended_eltype(x)) function _transform_tuple(flag::LogJacFlag, x::RealVector, index, ts) tfirst = first(ts) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 275da1f..653d2a0 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -41,8 +41,9 @@ end dimension(t::UnitVector) = t.n - 1 -function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector{T}) where T +function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector) @unpack n = t + T = extended_eltype(x) r = one(T) y = Vector{T}(undef, n) ℓ = logjac_zero(flag, T) @@ -57,7 +58,7 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::RealVector{T}) where y, ℓ end -inverse_eltype(t::UnitVector, y::RealVector) = float(eltype(y)) +inverse_eltype(t::UnitVector, y::RealVector) = extended_eltype(y) function inverse!(x::RealVector, t::UnitVector, y::RealVector) @unpack n = t @@ -93,11 +94,11 @@ end dimension(t::CorrCholeskyFactor) = unit_triangular_dimension(t.n) -function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, - x::RealVector{T}) where T +function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::RealVector) @unpack n = t + T = extended_eltype(x) ℓ = logjac_zero(flag, T) - U = zeros(typeof(√one(T)), n, n) + U = Matrix{T}(undef, n, n) index = firstindex(x) @inbounds for col in 1:n r = one(T) @@ -112,7 +113,7 @@ function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, UpperTriangular(U), ℓ end -inverse_eltype(t::CorrCholeskyFactor, U::UpperTriangular) = float(eltype(U)) +inverse_eltype(t::CorrCholeskyFactor, U::UpperTriangular) = extended_eltype(U) function inverse!(x::RealVector, t::CorrCholeskyFactor, U::UpperTriangular) @unpack n = t diff --git a/src/utilities.jl b/src/utilities.jl index 72e1c39..78a2d2f 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -1,5 +1,6 @@ - -# logistic and logit +### +### logistic and logit +### logistic(x::Real) = inv(one(x) + exp(-x)) @@ -12,8 +13,9 @@ logit(x::Real) = log(x / (one(x) - x)) logit_logjac(y) = -log(y) - log1p(-y) - -# calculations +### +### calculations +### """ $SIGNATURES @@ -22,8 +24,27 @@ Number of elements (strictly) above the diagonal in an ``n×n`` matrix. """ unit_triangular_dimension(n::Int) = n * (n-1) ÷ 2 - -# view management +### +### type calculations +### + +""" + $(SIGNATURES) + +Extend element type of argument so that it is closed under the algebra used by this package. + +Pessimistic default for non-real types. +""" +function extended_eltype(::Type{S}) where S + T = eltype(S) + T <: Real ? typeof(√(one(T))) : Any +end + +extended_eltype(x::T) where T = extended_eltype(T) + +### +### view management +### """ $SIGNATURES @@ -32,8 +53,9 @@ A view of `v` starting from `i` for `len` elements, no bounds checking. """ view_into(v::AbstractVector, i, len) = @inbounds view(v, i:(i+len-1)) - -# macros +### +### macros +### """ $(SIGNATURES) @@ -57,6 +79,10 @@ macro calltrans(ex) end end +#### +#### random values +#### + "Shared part of docstrings for keyword arguments of or passed to [`random_reals`](@ref)." const _RANDOM_REALS_KWARGS_DOC = """ A standard multivaritate normal or Cauchy is used, depending on `cauchy`, then scaled with diff --git a/test/runtests.jl b/test/runtests.jl index 5e4aa51..ddac45b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -288,6 +288,13 @@ end @test g2.value == v.value @test g2.gradient ≈ g1.gradient + # test element type calculations for Flux + t2 = CorrCholeskyFactor(4) + @test t2(Flux.param(ones(dimension(t2)))) isa UpperTriangular + + t3 = UnitVector(3) + @test sum(abs2, t3(Flux.param(ones(dimension(t3))))) ≈ Flux.param(1.0) + # ReverseDiff P3 = ADgradient(:ReverseDiff, P) g3 = @inferred logdensity(ValueGradient, P3, x)