Skip to content

Commit 2bd398d

Browse files
authored
Support rescaled array in sampling (#103)
* support rescaled array in sampling * fix docs
1 parent 86aa5fd commit 2bd398d

File tree

6 files changed

+204
-13
lines changed

6 files changed

+204
-13
lines changed

docs/make.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using Documenter, Literate
55
using Pkg
66

77
# Literate Examples
8+
const DRAFT = get(ENV, "DRAFT", "false") == "true"
9+
@show DRAFT
810
const EXAMPLE_DIR = pkgdir(TensorInference, "examples")
911
const LITERATE_GENERATED_DIR = pkgdir(TensorInference, "docs", "src", "generated")
1012
mkpath(LITERATE_GENERATED_DIR)
@@ -19,7 +21,7 @@ for each in readdir(EXAMPLE_DIR)
1921
# build
2022
input_file = joinpath(workdir, "main.jl")
2123
@info "building" input_file
22-
Literate.markdown(input_file, workdir; execute=true)
24+
Literate.markdown(input_file, workdir; execute=!DRAFT)
2325
# restore environment
2426
# Pkg.activate(Pkg.PREV_ENV_PATH[])
2527
end
@@ -30,7 +32,7 @@ for each in EXTRA_JL
3032
cp(joinpath(SRC_DIR, each), joinpath(LITERATE_GENERATED_DIR, each); force=true)
3133
input_file = joinpath(LITERATE_GENERATED_DIR, each)
3234
@info "building" input_file
33-
Literate.markdown(input_file, LITERATE_GENERATED_DIR; execute=true)
35+
Literate.markdown(input_file, LITERATE_GENERATED_DIR; execute=!DRAFT)
3436
end
3537

3638
DocMeta.setdocmeta!(TensorInference, :DocTestSetup, :(using TensorInference); recursive=true)
@@ -68,6 +70,7 @@ makedocs(;
6870
],
6971
doctest = false,
7072
warnonly = :missing_docs,
73+
draft = DRAFT,
7174
)
7275

7376
deploydocs(;

src/RescaledArray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ Base.show(io::IO, ::MIME"text/plain", c::RescaledArray) = Base.show(io, c)
1616
Base.Array(c::RescaledArray) = rmul!(Array(c.normalized_value), exp(c.log_factor))
1717
Base.copy(c::RescaledArray) = RescaledArray(c.log_factor, copy(c.normalized_value))
1818
Base.getindex(r::RescaledArray, indices...) = map(x->x * exp(r.log_factor), getindex(r.normalized_value, indices...))
19+
Base.similar(r::RescaledArray, ::Type{T}, dims::Dims) where {T} = RescaledArray(r.log_factor, similar(r.normalized_value, T, dims))
20+
Base.selectdim(r::RescaledArray, d::Int, i::Int) = RescaledArray(r.log_factor, selectdim(r.normalized_value, d, i))
21+
function Base.copyto!(dest::RescaledArray, src::RescaledArray)
22+
dest.normalized_value .= exp(src.log_factor - dest.log_factor) .* src.normalized_value
23+
return dest
24+
end
1925

2026
"""
2127
$(TYPEDSIGNATURES)

src/sampling.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
3232
@assert length(ix) == N
3333
return x[eliminated_selector(size(x), ix, el.first, el.second)...]
3434
end
35+
function eliminate_dimensions(x::RescaledArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}, <:AbstractVector}) where {T, N, L}
36+
return RescaledArray(x.log_factor, eliminate_dimensions(x.normalized_value, ix, el))
37+
end
38+
3539
function eliminated_size(size0, ix, labels)
3640
@assert length(size0) == length(ix)
3741
return ntuple(length(ix)) do i
@@ -53,7 +57,7 @@ function eliminate_dimensions_addbatch!(x::AbstractArray{T, N}, ix::AbstractVect
5357
@assert length(ix) == N
5458
res = similar(x, (eliminated_size(size(x), ix, el.first)..., nbatch))
5559
for ibatch in 1:nbatch
56-
selectdim(res, N+1, ibatch) .= eliminate_dimensions(x, ix, el.first=>view(el.second, :, ibatch))
60+
copyto!(selectdim(res, N+1, ibatch), eliminate_dimensions(x, ix, el.first=>view(el.second, :, ibatch)))
5761
end
5862
push!(ix, batch_label)
5963
return res
@@ -63,7 +67,7 @@ function eliminate_dimensions_withbatch(x::AbstractArray{T, N}, ix::AbstractVect
6367
@assert length(ix) == N && size(x, N) == nbatch
6468
res = similar(x, (eliminated_size(size(x), ix, el.first)))
6569
for ibatch in 1:nbatch
66-
selectdim(res, N, ibatch) .= eliminate_dimensions(selectdim(x, N, ibatch), ix[1:end-1], el.first=>view(el.second, :, ibatch))
70+
copyto!(selectdim(res, N, ibatch), eliminate_dimensions(selectdim(x, N, ibatch), ix[1:end-1], el.first=>view(el.second, :, ibatch)))
6771
end
6872
return res
6973
end
@@ -79,28 +83,28 @@ Returns a vector of vector, each element being a configurations defined on `get_
7983
* `n` is the number of samples to be returned.
8084
8185
### Keyword Arguments
86+
* `rescale` is a boolean flag to indicate whether to rescale the tensors during contraction.
8287
* `usecuda` is a boolean flag to indicate whether to use CUDA for tensor computation.
8388
* `queryvars` is the variables to be sampled, default is `get_vars(tn)`.
8489
"""
85-
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn))::Samples
90+
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn), rescale::Bool = false)::Samples
8691
# generate tropical tensors with its elements being log(p).
87-
xs = adapt_tensors(tn; usecuda, rescale = false)
92+
xs = adapt_tensors(tn; usecuda, rescale)
8893
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
8994
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
9095
# forward compute and cache intermediate results.
9196
cache = cached_einsum(tn.code, xs, size_dict)
9297
# initialize `y̅` as the initial batch of samples.
9398
iy = getiyv(tn.code)
9499
idx = map(l->findfirst(==(l), queryvars), iy queryvars)
95-
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _Weights(vec(cache.content)), n)
100+
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _weight(cache.content), n)
96101
configs = zeros(Int, length(queryvars), n)
97102
for i=1:n
98103
configs[idx, i] .= indices[i].I .- 1
99104
end
100105
samples = Samples(configs, queryvars)
101106
# back-propagate
102-
env = similar(cache.content, (size(cache.content)..., n)) # batched env
103-
fill!(env, one(eltype(env)))
107+
env = ones_like(cache.content, n)
104108
batch_label = _newindex(OMEinsum.uniquelabels(tn.code))
105109
code = deepcopy(tn.code)
106110
iy_env = [OMEinsum.getiyv(code)..., batch_label]
@@ -115,10 +119,22 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
115119
end
116120
_newindex(labels::AbstractVector{<:Union{Int, Char}}) = maximum(labels) + 1
117121
_newindex(::AbstractVector{Symbol}) = gensym(:batch)
118-
_Weights(x::AbstractVector{<:Real}) = Weights(x)
119-
function _Weights(x::AbstractArray{<:Complex})
122+
_weight(x::AbstractArray{<:Real}) = Weights(_normvec(x))
123+
function _weight(_x::AbstractArray{<:Complex})
124+
x = _normvec(_x)
120125
@assert all(e->abs(imag(e)) < max(100*eps(abs(e)), 1e-8), x) "Complex probability encountered: $x"
121-
return Weights(real.(x))
126+
return _weight(real.(x))
127+
end
128+
_normvec(x::AbstractArray) = vec(x)
129+
_normvec(x::RescaledArray) = vec(x.normalized_value)
130+
131+
function ones_like(x::AbstractArray{T}, n::Int) where {T}
132+
res = similar(x, (size(x)..., n))
133+
fill!(res, one(eltype(res)))
134+
return res
135+
end
136+
function ones_like(x::RescaledArray, n::Int)
137+
return RescaledArray(zero(x.log_factor), ones_like(x.normalized_value, n))
122138
end
123139

124140
function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, iy_env::Vector{Int}, env::AbstractArray{T}, samples::Samples{L}, pool, batch_label::L, size_dict::Dict{L}) where {T, L}
@@ -177,7 +193,7 @@ function update_samples!(labels, sample, vars::AbstractVector{L}, probabilities:
177193
@assert length(vars) == N
178194
totalset = CartesianIndices(probabilities)
179195
eliminated_locs = idx4labels(labels, vars)
180-
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
196+
config = StatsBase.sample(totalset, _weight(probabilities))
181197
sample[eliminated_locs] .= config.I .- 1
182198
end
183199

test/RescaledArray.jl

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using Test
2+
using TensorInference
3+
using OMEinsum
4+
5+
@testset "RescaledArray" begin
6+
# Test basic construction
7+
@testset "Construction" begin
8+
α = 2.0
9+
T = [1.0 2.0; 3.0 4.0]
10+
r = RescaledArray(α, T)
11+
12+
@test r.log_factor == α
13+
@test r.normalized_value == T
14+
@test size(r) == (2, 2)
15+
@test size(r, 1) == 2
16+
@test size(r, 2) == 2
17+
end
18+
19+
# Test rescale_array function
20+
@testset "rescale_array" begin
21+
T = [1.0 2.0; 3.0 4.0]
22+
r = TensorInference.rescale_array(T)
23+
24+
# Maximum absolute value should be 1 in normalized_value
25+
@test maximum(abs, r.normalized_value) 1.0
26+
27+
# Original array should be recoverable
28+
@test Array(r) T
29+
30+
# Test with zero array
31+
zero_T = zeros(2, 2)
32+
r_zero = TensorInference.rescale_array(zero_T)
33+
@test r_zero.log_factor == 0.0
34+
@test r_zero.normalized_value == zero_T
35+
end
36+
37+
# Test Array conversion
38+
@testset "Array conversion" begin
39+
α = 1.5
40+
T = [0.5 1.0; 0.25 0.75]
41+
r = RescaledArray(α, T)
42+
43+
expected = exp(α) * T
44+
@test Array(r) expected
45+
end
46+
47+
# Test indexing
48+
@testset "Indexing" begin
49+
α = 0.5
50+
T = [1.0 2.0; 3.0 4.0]
51+
r = RescaledArray(α, T)
52+
53+
@test r[1, 1] T[1, 1] * exp(α)
54+
@test r[2, 2] T[2, 2] * exp(α)
55+
@test r[1:2, 1] T[1:2, 1] * exp(α)
56+
end
57+
58+
# Test copy
59+
@testset "Copy" begin
60+
α = 1.0
61+
T = [1.0 2.0; 3.0 4.0]
62+
r = RescaledArray(α, T)
63+
r_copy = copy(r)
64+
65+
@test r_copy.log_factor == r.log_factor
66+
@test r_copy.normalized_value == r.normalized_value
67+
@test r_copy.normalized_value !== r.normalized_value # Different objects
68+
end
69+
70+
# Test selectdim
71+
@testset "selectdim" begin
72+
T = reshape(Float64.(1:8), 2, 2, 2) # Convert to Float64 to match log factor type
73+
α = 0.5
74+
r = RescaledArray(α, T)
75+
76+
r_slice = selectdim(r, 3, 1)
77+
@test r_slice.log_factor == α
78+
@test r_slice.normalized_value == selectdim(T, 3, 1)
79+
end
80+
81+
# Test einsum operations
82+
@testset "Einsum operations" begin
83+
# Create two rescaled arrays
84+
α1, α2 = 1.0, 1.5
85+
T1 = [1.0 0.5; 0.25 1.0]
86+
T2 = [0.5 1.0; 1.0 0.5]
87+
88+
r1 = RescaledArray(α1, T1)
89+
r2 = RescaledArray(α2, T2)
90+
91+
# Test matrix multiplication via einsum
92+
code = ein"ij,jk->ik"
93+
result = einsum(code, (r1, r2))
94+
95+
# Compare with regular array multiplication
96+
expected_array = Array(r1) * Array(r2)
97+
@test Array(result) expected_array
98+
99+
# The log factor should be the sum of input log factors plus rescaling
100+
@test result isa RescaledArray
101+
end
102+
103+
# Test fill! and conj
104+
@testset "fill! and conj" begin
105+
α = 0.5
106+
T = [1.0 2.0; 3.0 4.0]
107+
r = RescaledArray(α, copy(T))
108+
109+
# Test fill!
110+
fill!(r, 2.0)
111+
expected_fill_value = 2.0 / exp(α)
112+
@test all(x -> x expected_fill_value, r.normalized_value)
113+
114+
# Test conj with complex numbers
115+
α_complex = 1.0 + 0.5im
116+
T_complex = [1.0+1.0im 2.0+2.0im; 3.0+3.0im 4.0+4.0im]
117+
r_complex = RescaledArray(α_complex, T_complex)
118+
r_conj = conj(r_complex)
119+
120+
@test r_conj.log_factor == conj(α_complex)
121+
@test r_conj.normalized_value == conj(T_complex)
122+
end
123+
124+
# Test show methods
125+
@testset "Display" begin
126+
α = 1.0
127+
T = [1.0 2.0]
128+
r = RescaledArray(α, T)
129+
130+
# Test that show methods don't error
131+
@test sprint(show, r) isa String
132+
@test sprint(show, "text/plain", r) isa String
133+
end
134+
135+
# Test copyto!
136+
@testset "copyto!" begin
137+
α = 2.0
138+
T = [1.0 2.0; 3.0 4.0]
139+
r = RescaledArray(α, T)
140+
r_copy = similar(r)
141+
copyto!(r_copy, r)
142+
@test Array(r_copy) Array(r)
143+
144+
α = 2.0
145+
T = [1.0 2.0; 3.0 4.0]
146+
r = RescaledArray(α, T)
147+
r_copy = similar(r)
148+
copyto!(selectdim(r_copy, 1, 1), selectdim(r, 1, 1))
149+
@test Array(r_copy)[1, :] Array(r)[1, :]
150+
@test !(Array(r_copy)[2, :] Array(r)[2, :])
151+
end
152+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ end
3636
include("fileio.jl")
3737
end
3838

39+
@testset "RescaledArray" begin
40+
include("RescaledArray.jl")
41+
end
42+
3943
using CUDA
4044
if CUDA.functional()
4145
include("cuda.jl")

test/sampling.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,14 @@ end
8686
entropy(probs) = -sum(probs .* log.(probs))
8787
@show negative_loglikelyhood(probs, indices), entropy(probs)
8888
@test negative_loglikelyhood(probs, indices) entropy(probs) atol=1e-1
89+
end
90+
91+
@testset "issue 102 - support using rescaled array in sampling" begin
92+
n = 100
93+
chi = 10
94+
Random.seed!(140)
95+
mps = random_matrix_product_state(Float64, n, chi)
96+
mps.tensors[setdiff(1:length(mps.tensors), mps.unity_tensors_idx)] .*= 100
97+
samples = sample(mps, 1; rescale = true)
98+
@test samples isa TensorInference.Samples
8999
end

0 commit comments

Comments
 (0)