Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name = "TensorNetworkQuantumSimulator"
uuid = "4de3b72a-362e-43dd-83ff-3f381eda9f9c"
license = "MIT"
version = "0.3.3"
authors = ["JoeyT1994 <jtindall@flatironinstitute.org>", "MSRudolph <manuel.rudolph@web.de>", "and contributors"]
description = "A Julia package for quantum simulation with tensor networks of near-arbitrary topology."
version = "0.3.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
7 changes: 5 additions & 2 deletions src/Apply/apply_gates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ function apply_gate!(
envs = length(v⃗) == 1 ? nothing : incoming_messages(ψ_bpc, v⃗)

updated_tensors, s_values, err = simple_update(gate, network(ψ_bpc), v⃗; envs, apply_kwargs...)

invalidate_sequences = false
if length(v⃗) == 2
old_bond_dim = dim(commonind(network(ψ_bpc)[v⃗[1]], network(ψ_bpc)[v⃗[2]]))
new_dim = commonind(updated_tensors[1], updated_tensors[2])
v1, v2 = v⃗
e = NamedEdge(v1 => v2)
ind2 = commonind(s_values, first(updated_tensors))
Expand All @@ -118,10 +120,11 @@ function apply_gate!(
s_values = denseblocks(s_values) * denseblocks(δuv)
setmessage!(ψ_bpc, e, dag(s_values))
setmessage!(ψ_bpc, reverse(e), s_values)
invalidate_sequences = dim(new_dim) != old_bond_dim
end

for (i, v) in enumerate(v⃗)
setindex_preserve!(ψ_bpc, updated_tensors[i], v)
setindex_preserve!(ψ_bpc, updated_tensors[i], v; invalidate_sequences)
end

return ψ_bpc, err
Expand Down
40 changes: 34 additions & 6 deletions src/MessagePassing/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ abstract type AbstractBeliefPropagationCache{V} <: AbstractNamedGraph{V} end

#Interface
messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
contraction_sequences(bp_cache::AbstractBeliefPropagationCache) = not_implemented()
default_messages() = Dictionary{NamedEdge, Union{ITensor, Vector{ITensor}}}()

function rescale_messages!(
Expand Down Expand Up @@ -45,7 +46,6 @@ for f in [
:(maxvirtualdim),
:(default_message),
:(siteinds),
:(setindex_preserve!),
]
@eval begin
function $f(bp_cache::AbstractBeliefPropagationCache, args...; kwargs...)
Expand All @@ -54,6 +54,31 @@ for f in [
end
end

function invalidate_contraction_sequences!(bp_cache::AbstractBeliefPropagationCache, vertex)
seq_cache = contraction_sequences(bp_cache)
isnothing(seq_cache) && return bp_cache
for key in collect(keys(seq_cache))
if first(key) == vertex
delete!(seq_cache, key)
end
end
return bp_cache
end

function invalidate_contraction_sequences!(bp_cache::AbstractBeliefPropagationCache)
seq_cache = contraction_sequences(bp_cache)
!isnothing(seq_cache) && empty!(seq_cache)
return bp_cache
end

function setindex_preserve!(bp_cache::AbstractBeliefPropagationCache, value::ITensor, vertex; invalidate_sequences = true)
if invalidate_sequences
invalidate_contraction_sequences!(bp_cache, vertex)
end
setindex_preserve!(network(bp_cache), value, vertex)
return bp_cache
end

#Forward onto the graph
for f in [
:(NamedGraphs.edgetype),
Expand Down Expand Up @@ -157,12 +182,15 @@ function updated_message(
)
state = bp_factors(bp_cache, vertex)
contract_list = ITensor[incoming_ms; state]
sequence = contraction_sequence(contract_list; alg = alg.kwargs.sequence_alg)
updated_message = contract(contract_list; sequence)

if alg.kwargs.enforce_hermiticity
updated_message = make_hermitian(updated_message)
cache_key = vertex => edge
seq_cache = contraction_sequences(bp_cache)
if haskey(seq_cache, cache_key)
sequence = seq_cache[cache_key]
else
sequence = contraction_sequence(contract_list; alg = alg.kwargs.sequence_alg)
set!(seq_cache, cache_key, sequence)
end
updated_message = contract(contract_list; sequence)

if alg.kwargs.normalize
message_norm = sum(updated_message)
Expand Down
25 changes: 16 additions & 9 deletions src/MessagePassing/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ struct BeliefPropagationCache{V, N <: AbstractTensorNetwork{V}, M <: Union{ITens
AbstractBeliefPropagationCache{V}
network::N
messages::Dictionary{NamedEdge, M}
contraction_sequences::Dictionary{Pair, Vector}
edge_sequence::Vector
end

function message_diff(message_a::ITensor, message_b::ITensor)
Expand All @@ -22,17 +24,24 @@ messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
network(bp_cache::BeliefPropagationCache) = bp_cache.network
graph(bp_cache::BeliefPropagationCache) = graph(network(bp_cache))

function BeliefPropagationCache(network, messages, contraction_sequences)
return BeliefPropagationCache(network, messages, contraction_sequences, forest_cover_edge_sequence(graph(network)))
end
BeliefPropagationCache(network, messages) = BeliefPropagationCache(network, messages, Dictionary{Pair, Vector}())
BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages())

contraction_sequences(bp_cache::BeliefPropagationCache) = bp_cache.contraction_sequences

function Base.copy(bp_cache::BeliefPropagationCache)
return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache)))
return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache)), copy(contraction_sequences(bp_cache)), copy(edge_sequence(bp_cache)))
end

default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : _default_bp_update_maxiter

#TODO: Get subgraph working on an TensorNetwork to overload this directly
function default_bp_edge_sequence(bp_cache::BeliefPropagationCache)
return forest_cover_edge_sequence(graph(bp_cache))
edge_sequence(bp_cache::BeliefPropagationCache) = bp_cache.edge_sequence

function set_edge_sequence(bp_cache::BeliefPropagationCache, edge_sequence::Vector)
return BeliefPropagationCache(network(bp_cache), messages(bp_cache), contraction_sequences(bp_cache), edge_sequence)
end

function edge_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge)
Expand All @@ -44,24 +53,22 @@ default_update_alg(bp_cache::BeliefPropagationCache) = "bp"
default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract"
default_normalize(::Algorithm"contract") = true
default_sequence_alg(::Algorithm"contract") = "optimal"
default_enforce_hermicity(::Algorithm"contract", bp_cache::AbstractBeliefPropagationCache) = false
function set_default_kwargs(alg::Algorithm"contract", bp_cache::AbstractBeliefPropagationCache)
normalize = get(alg.kwargs, :normalize, default_normalize(alg))
sequence_alg = get(alg.kwargs, :sequence_alg, default_sequence_alg(alg))
enforce_hermiticity = get(alg.kwargs, :enforce_hermiticity, default_enforce_hermicity(alg, bp_cache))
return Algorithm("contract"; normalize, sequence_alg, enforce_hermiticity)
return Algorithm("contract"; normalize, sequence_alg)
end
default_verbose(::Algorithm"bp") = false
default_tolerance(::Algorithm"bp") = nothing
function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
verbose = get(alg.kwargs, :verbose, default_verbose(alg))
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bp_cache))
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bp_cache))
_edge_sequence = get(alg.kwargs, :edge_sequence, edge_sequence(bp_cache))
tolerance = get(alg.kwargs, :tolerance, default_tolerance(alg))
message_update_alg = set_default_kwargs(
get(alg.kwargs, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))), bp_cache
)
return Algorithm("bp"; verbose, maxiter, edge_sequence, tolerance, message_update_alg)
return Algorithm("bp"; verbose, maxiter, edge_sequence = _edge_sequence, tolerance, message_update_alg)
end

function update_message!(
Expand Down
14 changes: 9 additions & 5 deletions src/MessagePassing/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ struct BoundaryMPSCache{V, N <: AbstractTensorNetwork{V}, M <: Union{ITensor, Ve
supergraph::PartitionedGraph
sorted_edges::Dictionary{PartitionEdge, Vector{NamedEdge}}
mps_bond_dimension::Integer
contraction_sequences::Dictionary{Pair, Vector}
end

default_update_alg(bmps_cache::BoundaryMPSCache) = "bp"
function set_default_kwargs(alg::Algorithm"bp", bmps_cache::BoundaryMPSCache)
maxiter = get(alg.kwargs, :maxiter, default_bp_maxiter(bmps_cache))
edge_sequence = get(alg.kwargs, :edge_sequence, default_bp_edge_sequence(bmps_cache))
edge_sequence = get(alg.kwargs, :edge_sequence, bp_edge_sequence(bmps_cache))
message_update_alg = set_default_kwargs(
get(alg.kwargs, :message_update_alg, Algorithm(default_message_update_alg(bmps_cache))), bmps_cache
)
return Algorithm("bp"; maxiter, edge_sequence, message_update_alg, tolerance = nothing)
end

function default_bp_edge_sequence(bmps_cache::BoundaryMPSCache)
function bp_edge_sequence(bmps_cache::BoundaryMPSCache)
return PartitionEdge.(forest_cover_edge_sequence(partitions_graph(supergraph(bmps_cache))))
end
default_bp_maxiter(bmps_cache::BoundaryMPSCache) = is_tree(partitions_graph(supergraph(bmps_cache))) ? 1 : 5
Expand Down Expand Up @@ -99,13 +100,16 @@ for f in [
end
end

contraction_sequences(bmps_cache::BoundaryMPSCache) = bmps_cache.contraction_sequences

function Base.copy(bmps_cache::BoundaryMPSCache)
return BoundaryMPSCache(
copy(network(bmps_cache)),
copy(messages(bmps_cache)),
copy(supergraph(bmps_cache)),
copy(sorted_edges(bmps_cache)),
mps_bond_dimension(bmps_cache),
copy(contraction_sequences(bmps_cache)),
)
end

Expand Down Expand Up @@ -159,7 +163,7 @@ function BoundaryMPSCache(
sorted_es = Dictionary{PartitionEdge, Vector{NamedEdge}}(pes, Vector{NamedEdge}[sorted_edges(supergraph, pe) for pe in pes])

messages = default_messages()
bmps_cache = BoundaryMPSCache(tn, messages, supergraph, sorted_es, mps_bond_dimension)
bmps_cache = BoundaryMPSCache(tn, messages, supergraph, sorted_es, mps_bond_dimension, Dictionary{Pair, Vector}())
@assert is_correct_format(bmps_cache)
set_messages && set_interpartition_messages!(bmps_cache, pes)

Expand Down Expand Up @@ -226,7 +230,7 @@ end

function update_partition!(bmps_cache::BoundaryMPSCache, seq::Vector)
isempty(seq) && return bmps_cache
alg = set_default_kwargs(Algorithm("contract", normalize = false, enforce_hermiticity = false), bmps_cache)
alg = set_default_kwargs(Algorithm("contract", normalize = false), bmps_cache)
for e in seq
m = updated_message(alg, bmps_cache, e)
setmessage!(bmps_cache, e, m)
Expand Down Expand Up @@ -304,7 +308,7 @@ function extracter(
bmps_cache::BoundaryMPSCache,
update_e::NamedEdge
)
message_update_alg = set_default_kwargs(Algorithm("contract"; normalize = false, enforce_hermiticity = false), bmps_cache)
message_update_alg = set_default_kwargs(Algorithm("contract"; normalize = false), bmps_cache)
m = updated_message(message_update_alg, bmps_cache, update_e)
return m
end
Expand Down
1 change: 1 addition & 0 deletions src/contraction_sequences.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ function contraction_sequence(::Algorithm"optimal", tensors::Vector{<:ITensor};
#Converting dims to Float64 to minimize overflow issues
inds_to_dims = Dict(i => Float64(dim(i)) for i in unique(Iterators.flatten(network)))
seq, _ = optimaltree(network, inds_to_dims)
seq = typeof(seq) <: Int ? [seq] : seq
return seq
end

Expand Down
3 changes: 2 additions & 1 deletion src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ function get_one_sample(
kwargs...,
)
norm_bmps_cache = copy(norm_bmps_cache)
invalidate_contraction_sequences!(norm_bmps_cache)
cutoff, maxdim = 1.0e-10, projected_mps_bond_dimension

bit_string = Dictionary{keytype(vertices(network(norm_bmps_cache))), Int}()
Expand Down Expand Up @@ -248,7 +249,7 @@ function sample_partition!(
q = ρ_diag[config]
logq += log(q)
Pψv = copy(network(norm_bmps_cache)[v]) * inv(sqrt(q)) * P
setindex_preserve!(norm_bmps_cache, Pψv, v)
setindex_preserve!(norm_bmps_cache, Pψv, v; invalidate_sequences = false)
prev_v = v
end

Expand Down
Loading