diff --git a/src/obstransform.jl b/src/obstransform.jl index e5c6168..4952668 100644 --- a/src/obstransform.jl +++ b/src/obstransform.jl @@ -159,12 +159,12 @@ end # joinumobs -struct JoinedData{T,N} <: AbstractDataContainer - datas::NTuple{N,T} +struct JoinedData{T<:Tuple,N} <: AbstractDataContainer + datas::T ns::NTuple{N,Int} end -JoinedData(datas) = JoinedData(datas, numobs.(datas)) +JoinedData(datas::Tuple) = JoinedData(datas, numobs.(datas)) Base.length(data::JoinedData) = sum(data.ns) @@ -194,7 +194,12 @@ jdata = joinumobs(data1, data2) getobs(jdata, 15) == 15 ``` """ -joinobs(datas...) = JoinedData(datas) +joinobs(datas...) = JoinedData(cleanjoin(datas...)) + +cleanjoin(x::JoinedData, ys...) = (x.datas..., cleanjoin(ys...)...) +cleanjoin(x, ys...) = (x, cleanjoin(ys...)...) +cleanjoin() = () + """ shuffleobs([rng], data) diff --git a/test/obstransform.jl b/test/obstransform.jl index 5b6ec7a..f94ea32 100644 --- a/test/obstransform.jl +++ b/test/obstransform.jl @@ -94,6 +94,27 @@ end @test data[5:6] == [5, 6] data = joinobs(ones(2, 3), zeros(2, 3)) @test data[3:4] == [[1.0, 1.0], [0.0, 0.0]] + + @testset "joins of joins" begin + data1, data2 = 1:10, 11:20 + data12 = joinobs(data1, data2) + data3 = 21:30 + data123 = joinobs(data12, data3) + @test getobs(data123, 15) == 15 + @test getobs(data123, 25) == 25 + @test length(data123) == 30 + @test data123.datas[1] == data1 + @test data123.datas[2] == data2 + @test data123.datas[3] == data3 + end + + @testset "join different types" begin + data1 = 1:5 + data2 = ones(2, 3) + data12 = joinobs(data1, data2) + @test data12[3] == 3 + @test data12[6] == [1.0, 1.0] + end end @testset "shuffleobs" begin