Skip to content

Commit 1e88935

Browse files
roporteodow
authored andcommitted
Add backward pass sampling scheme (#224)
Conceived and implemented by @roporte. Closes #177
1 parent 1386f61 commit 1e88935

File tree

6 files changed

+147
-6
lines changed

6 files changed

+147
-6
lines changed

src/SDDP.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include("plugins/risk_measures.jl")
4040
include("plugins/sampling_schemes.jl")
4141
include("plugins/bellman_functions.jl")
4242
include("plugins/stopping_rules.jl")
43+
include("plugins/backward_sampling_schemes.jl")
4344

4445
# Visualization related code.
4546
include("visualization/publication_plot.jl")

src/algorithm.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct Options{T}
8282
initial_state::Dict{Symbol, Float64}
8383
# The sampling scheme to use on the forward pass.
8484
sampling_scheme::AbstractSamplingScheme
85+
backward_sampling_scheme::AbstractBackwardSamplingScheme
8586
# Storage for the set of possible sampling states at each node. We only use
8687
# this if there is a cycle in the policy graph.
8788
starting_states::Dict{T, Vector{Dict{Symbol, Float64}}}
@@ -100,12 +101,14 @@ struct Options{T}
100101
function Options(model::PolicyGraph{T},
101102
initial_state::Dict{Symbol, Float64},
102103
sampling_scheme::AbstractSamplingScheme,
104+
backward_sampling_scheme::AbstractBackwardSamplingScheme,
103105
risk_measures,
104106
cycle_discretization_delta::Float64,
105107
refine_at_similar_nodes::Bool) where {T, S}
106108
return new{T}(
107109
initial_state,
108110
sampling_scheme,
111+
backward_sampling_scheme,
109112
to_nodal_form(model, x -> Dict{Symbol, Float64}[]),
110113
to_nodal_form(model, risk_measures),
111114
cycle_discretization_delta,
@@ -477,7 +480,9 @@ function backward_pass(
477480
belief == 0.0 && continue
478481
solve_all_children(
479482
model, model[node_index], items, belief, belief_state,
480-
objective_state, outgoing_state, scenario_path[1:index])
483+
objective_state, outgoing_state,
484+
options.backward_sampling_scheme,
485+
scenario_path[1:index])
481486
end
482487
# We need to refine our estimate at all nodes in the partition.
483488
for node_index in model.belief_partition[partition_index]
@@ -501,7 +506,9 @@ function backward_pass(
501506
end
502507
solve_all_children(
503508
model, node, items, 1.0, belief_state, objective_state,
504-
outgoing_state, scenario_path[1:index])
509+
outgoing_state, options.backward_sampling_scheme,
510+
scenario_path[1:index]
511+
)
505512
refine_bellman_function(
506513
model, node, node.bellman_function,
507514
options.risk_measures[node_index], outgoing_state,
@@ -545,13 +552,19 @@ struct BackwardPassItems{T, U}
545552
end
546553

547554
function solve_all_children(
548-
model::PolicyGraph{T}, node::Node{T}, items::BackwardPassItems,
549-
belief::Float64, belief_state, objective_state,
550-
outgoing_state::Dict{Symbol, Float64}, scenario_path) where {T}
555+
model::PolicyGraph{T}, node::Node{T}, items::BackwardPassItems,
556+
belief::Float64, belief_state, objective_state,
557+
outgoing_state::Dict{Symbol, Float64},
558+
backward_sampling_scheme::AbstractBackwardSamplingScheme,
559+
scenario_path
560+
) where {T}
551561
length_scenario_path = length(scenario_path)
552562
for child in node.children
563+
if isapprox(child.probability, 0.0, atol=1e-6)
564+
continue
565+
end
553566
child_node = model[child.term]
554-
for noise in child_node.noise_terms
567+
for noise in sample_backward_noise_terms(backward_sampling_scheme, child_node)
555568
if length(scenario_path) == length_scenario_path
556569
push!(scenario_path, (child.term, noise.term))
557570
else
@@ -620,6 +633,9 @@ function calculate_bound(model::PolicyGraph{T},
620633

621634
# Solve all problems that are children of the root node.
622635
for child in model.root_children
636+
if isapprox(child.probability, 0.0, atol=1e-6)
637+
continue
638+
end
623639
node = model[child.term]
624640
for noise in node.noise_terms
625641
if node.objective_state !== nothing
@@ -745,6 +761,9 @@ Train the policy for `model`. Keyword arguments:
745761
- `sampling_scheme`: a sampling scheme to use on the forward pass of the
746762
algorithm. Defaults to [`InSampleMonteCarlo`](@ref).
747763
764+
- `backward_sampling_scheme`: a backward pass sampling scheme to use on the
765+
backward pass of the algorithm. Defaults to `CompleteSampler`.
766+
748767
- `cut_type`: choose between `SDDP.SINGLE_CUT` and `SDDP.MULTI_CUT` versions of SDDP.
749768
750769
- `dashboard::Bool`: open a visualization of the training over time. Defaults
@@ -770,6 +789,7 @@ function train(
770789
cycle_discretization_delta::Float64 = 0.0,
771790
refine_at_similar_nodes::Bool = true,
772791
cut_deletion_minimum::Int = 1,
792+
backward_sampling_scheme::AbstractBackwardSamplingScheme = SDDP.CompleteSampler(),
773793
dashboard::Bool = false
774794
)
775795
# Reset the TimerOutput.
@@ -812,6 +832,7 @@ function train(
812832
model,
813833
model.initial_root_state,
814834
sampling_scheme,
835+
backward_sampling_scheme,
815836
risk_measure,
816837
cycle_discretization_delta,
817838
refine_at_similar_nodes
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2017-19, Oscar Dowson and contributors.
2+
# This Source Code Form is subject to the terms of the Mozilla Public
3+
# License, v. 2.0. If a copy of the MPL was not distributed with this
4+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
5+
6+
"""
7+
CompleteSampler()
8+
9+
Backward sampler that returns all noises of the corresponding node.
10+
"""
11+
struct CompleteSampler <: AbstractBackwardSamplingScheme end
12+
13+
sample_backward_noise_terms(::CompleteSampler, node) = node.noise_terms
14+
15+
16+
"""
17+
MonteCarloSampler(number_of_samples::Int)
18+
19+
Backward sampler that returns `number_of_samples` noises sampled with
20+
replacement from noises on the corresponding node.
21+
"""
22+
struct MonteCarloSampler <: AbstractBackwardSamplingScheme
23+
number_of_samples::Int
24+
end
25+
26+
function sample_backward_noise_terms(sampler::MonteCarloSampler, node::Node)
27+
prob = 1 / sampler.number_of_samples
28+
return [
29+
Noise(sample_noise(InSampleMonteCarlo(), node.noise_terms), prob)
30+
for _ in 1:sampler.number_of_samples
31+
]
32+
end

src/plugins/headers.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,26 @@ function convergence_test(graph::PolicyGraph, log::Vector{Log},
133133
end
134134
return false, :not_solved
135135
end
136+
137+
# ============================== backward_samplers =========================== #
138+
139+
"""
140+
AbstractBackwardSamplingScheme
141+
142+
The abstract type for backward sampling scheme interface.
143+
144+
You need to define the following methods:
145+
- [`SDDP.sample_backward_noise_terms`](@ref)
146+
"""
147+
abstract type AbstractBackwardSamplingScheme end
148+
149+
"""
150+
sample_backward_noise_terms(
151+
backward_sampling_scheme::AbstractBackwardSamplingScheme,
152+
node::Node{T}
153+
)::Vector{Noise}
154+
155+
Returns a `Vector{Noise}` of noises sampled from `node.noise_terms` using
156+
`backward_sampling_scheme`
157+
"""
158+
function sample_backward_noise_terms end

test/algorithm.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using SDDP, Test, GLPK
2323
model,
2424
Dict(:x => 1.0),
2525
SDDP.InSampleMonteCarlo(),
26+
SDDP.CompleteSampler(),
2627
SDDP.Expectation(),
2728
0.0,
2829
true
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2017-19, Oscar Dowson.
2+
# This Source Code Form is subject to the terms of the Mozilla Public
3+
# License, v. 2.0. If a copy of the MPL was not distributed with this
4+
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
5+
6+
using SDDP, Test
7+
8+
@testset "CompleteSampler" begin
9+
model = SDDP.LinearPolicyGraph(
10+
stages = 2, lower_bound = 0.0, direct_mode = false
11+
) do node, stage
12+
@variable(node, 0 <= x <= 1)
13+
SDDP.parameterize(node, stage * [1, 3], [0.5, 0.5]) do ω
14+
JuMP.set_upper_bound(x, ω)
15+
end
16+
end
17+
terms = SDDP.sample_backward_noise_terms(SDDP.CompleteSampler(), model[1])
18+
@test terms == model[1].noise_terms
19+
end
20+
21+
@testset "MonteCarloSampler(1)" begin
22+
model = SDDP.LinearPolicyGraph(
23+
stages = 1, lower_bound = 0.0, direct_mode = false
24+
) do node, stage
25+
@variable(node, 0 <= x <= 1)
26+
SDDP.parameterize(node, [1, 3], [0.9, 0.1]) do ω
27+
JuMP.set_upper_bound(x, ω)
28+
end
29+
end
30+
term_count = 0
31+
for i in 1:100
32+
terms = SDDP.sample_backward_noise_terms(SDDP.MonteCarloSampler(1), model[1])
33+
@test terms[1].probability == 1.0
34+
if terms[1].term == model[1].noise_terms[1].term
35+
term_count += 1
36+
else
37+
term_count -= 1
38+
end
39+
end
40+
@test term_count > 20
41+
end
42+
43+
@testset "MonteCarloSampler(100)" begin
44+
model = SDDP.LinearPolicyGraph(
45+
stages = 1, lower_bound = 0.0, direct_mode = false
46+
) do node, stage
47+
@variable(node, 0 <= x <= 1)
48+
SDDP.parameterize(node, [1, 3], [0.9, 0.1]) do ω
49+
JuMP.set_upper_bound(x, ω)
50+
end
51+
end
52+
terms = SDDP.sample_backward_noise_terms(SDDP.MonteCarloSampler(100), model[1])
53+
term_count = 0
54+
for term in terms
55+
@test term.probability == 0.01
56+
if term.term == model[1].noise_terms[1].term
57+
term_count += 1
58+
else
59+
term_count -= 1
60+
end
61+
end
62+
@test term_count > 20
63+
end

0 commit comments

Comments
 (0)