@@ -32,6 +32,10 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
32
32
@assert length (ix) == N
33
33
return x[eliminated_selector (size (x), ix, el. first, el. second)... ]
34
34
end
35
+ function eliminate_dimensions (x:: RescaledArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}, <:AbstractVector} ) where {T, N, L}
36
+ return RescaledArray (x. log_factor, eliminate_dimensions (x. normalized_value, ix, el))
37
+ end
38
+
35
39
function eliminated_size (size0, ix, labels)
36
40
@assert length (size0) == length (ix)
37
41
return ntuple (length (ix)) do i
@@ -53,7 +57,7 @@ function eliminate_dimensions_addbatch!(x::AbstractArray{T, N}, ix::AbstractVect
53
57
@assert length (ix) == N
54
58
res = similar (x, (eliminated_size (size (x), ix, el. first)... , nbatch))
55
59
for ibatch in 1 : nbatch
56
- selectdim (res, N+ 1 , ibatch) . = eliminate_dimensions (x, ix, el. first=> view (el. second, :, ibatch))
60
+ copyto! ( selectdim (res, N+ 1 , ibatch), eliminate_dimensions (x, ix, el. first=> view (el. second, :, ibatch) ))
57
61
end
58
62
push! (ix, batch_label)
59
63
return res
@@ -63,7 +67,7 @@ function eliminate_dimensions_withbatch(x::AbstractArray{T, N}, ix::AbstractVect
63
67
@assert length (ix) == N && size (x, N) == nbatch
64
68
res = similar (x, (eliminated_size (size (x), ix, el. first)))
65
69
for ibatch in 1 : nbatch
66
- selectdim (res, N, ibatch) . = eliminate_dimensions (selectdim (x, N, ibatch), ix[1 : end - 1 ], el. first=> view (el. second, :, ibatch))
70
+ copyto! ( selectdim (res, N, ibatch), eliminate_dimensions (selectdim (x, N, ibatch), ix[1 : end - 1 ], el. first=> view (el. second, :, ibatch) ))
67
71
end
68
72
return res
69
73
end
@@ -79,28 +83,28 @@ Returns a vector of vector, each element being a configurations defined on `get_
79
83
* `n` is the number of samples to be returned.
80
84
81
85
### Keyword Arguments
86
+ * `rescale` is a boolean flag to indicate whether to rescale the tensors during contraction.
82
87
* `usecuda` is a boolean flag to indicate whether to use CUDA for tensor computation.
83
88
* `queryvars` is the variables to be sampled, default is `get_vars(tn)`.
84
89
"""
85
- function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false , queryvars = get_vars (tn)):: Samples
90
+ function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false , queryvars = get_vars (tn), rescale :: Bool = false ):: Samples
86
91
# generate tropical tensors with its elements being log(p).
87
- xs = adapt_tensors (tn; usecuda, rescale = false )
92
+ xs = adapt_tensors (tn; usecuda, rescale)
88
93
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
89
94
size_dict = OMEinsum. get_size_dict! (getixsv (tn. code), xs, Dict {Int, Int} ())
90
95
# forward compute and cache intermediate results.
91
96
cache = cached_einsum (tn. code, xs, size_dict)
92
97
# initialize `y̅` as the initial batch of samples.
93
98
iy = getiyv (tn. code)
94
99
idx = map (l-> findfirst (== (l), queryvars), iy ∩ queryvars)
95
- indices = StatsBase. sample (CartesianIndices (size (cache. content)), _Weights ( vec ( cache. content) ), n)
100
+ indices = StatsBase. sample (CartesianIndices (size (cache. content)), _weight ( cache. content), n)
96
101
configs = zeros (Int, length (queryvars), n)
97
102
for i= 1 : n
98
103
configs[idx, i] .= indices[i]. I .- 1
99
104
end
100
105
samples = Samples (configs, queryvars)
101
106
# back-propagate
102
- env = similar (cache. content, (size (cache. content)... , n)) # batched env
103
- fill! (env, one (eltype (env)))
107
+ env = ones_like (cache. content, n)
104
108
batch_label = _newindex (OMEinsum. uniquelabels (tn. code))
105
109
code = deepcopy (tn. code)
106
110
iy_env = [OMEinsum. getiyv (code)... , batch_label]
@@ -115,10 +119,22 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
115
119
end
116
120
_newindex (labels:: AbstractVector{<:Union{Int, Char}} ) = maximum (labels) + 1
117
121
_newindex (:: AbstractVector{Symbol} ) = gensym (:batch )
118
- _Weights (x:: AbstractVector{<:Real} ) = Weights (x)
119
- function _Weights (x:: AbstractArray{<:Complex} )
122
+ _weight (x:: AbstractArray{<:Real} ) = Weights (_normvec (x))
123
+ function _weight (_x:: AbstractArray{<:Complex} )
124
+ x = _normvec (_x)
120
125
@assert all (e-> abs (imag (e)) < max (100 * eps (abs (e)), 1e-8 ), x) " Complex probability encountered: $x "
121
- return Weights (real .(x))
126
+ return _weight (real .(x))
127
+ end
128
+ _normvec (x:: AbstractArray ) = vec (x)
129
+ _normvec (x:: RescaledArray ) = vec (x. normalized_value)
130
+
131
+ function ones_like (x:: AbstractArray{T} , n:: Int ) where {T}
132
+ res = similar (x, (size (x)... , n))
133
+ fill! (res, one (eltype (res)))
134
+ return res
135
+ end
136
+ function ones_like (x:: RescaledArray , n:: Int )
137
+ return RescaledArray (zero (x. log_factor), ones_like (x. normalized_value, n))
122
138
end
123
139
124
140
function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , iy_env:: Vector{Int} , env:: AbstractArray{T} , samples:: Samples{L} , pool, batch_label:: L , size_dict:: Dict{L} ) where {T, L}
@@ -177,7 +193,7 @@ function update_samples!(labels, sample, vars::AbstractVector{L}, probabilities:
177
193
@assert length (vars) == N
178
194
totalset = CartesianIndices (probabilities)
179
195
eliminated_locs = idx4labels (labels, vars)
180
- config = StatsBase. sample (totalset, _Weights ( vec ( probabilities) ))
196
+ config = StatsBase. sample (totalset, _weight ( probabilities))
181
197
sample[eliminated_locs] .= config. I .- 1
182
198
end
183
199
0 commit comments