@@ -82,6 +82,7 @@ struct Options{T}
8282 initial_state:: Dict{Symbol, Float64}
8383 # The sampling scheme to use on the forward pass.
8484 sampling_scheme:: AbstractSamplingScheme
85+ backward_sampling_scheme:: AbstractBackwardSamplingScheme
8586 # Storage for the set of possible sampling states at each node. We only use
8687 # this if there is a cycle in the policy graph.
8788 starting_states:: Dict{T, Vector{Dict{Symbol, Float64}}}
@@ -100,12 +101,14 @@ struct Options{T}
100101 function Options (model:: PolicyGraph{T} ,
101102 initial_state:: Dict{Symbol, Float64} ,
102103 sampling_scheme:: AbstractSamplingScheme ,
104+ backward_sampling_scheme:: AbstractBackwardSamplingScheme ,
103105 risk_measures,
104106 cycle_discretization_delta:: Float64 ,
105107 refine_at_similar_nodes:: Bool ) where {T, S}
106108 return new {T} (
107109 initial_state,
108110 sampling_scheme,
111+ backward_sampling_scheme,
109112 to_nodal_form (model, x -> Dict{Symbol, Float64}[]),
110113 to_nodal_form (model, risk_measures),
111114 cycle_discretization_delta,
@@ -477,7 +480,9 @@ function backward_pass(
477480 belief == 0.0 && continue
478481 solve_all_children (
479482 model, model[node_index], items, belief, belief_state,
480- objective_state, outgoing_state, scenario_path[1 : index])
483+ objective_state, outgoing_state,
484+ options. backward_sampling_scheme,
485+ scenario_path[1 : index])
481486 end
482487 # We need to refine our estimate at all nodes in the partition.
483488 for node_index in model. belief_partition[partition_index]
@@ -501,7 +506,9 @@ function backward_pass(
501506 end
502507 solve_all_children (
503508 model, node, items, 1.0 , belief_state, objective_state,
504- outgoing_state, scenario_path[1 : index])
509+ outgoing_state, options. backward_sampling_scheme,
510+ scenario_path[1 : index]
511+ )
505512 refine_bellman_function (
506513 model, node, node. bellman_function,
507514 options. risk_measures[node_index], outgoing_state,
@@ -545,13 +552,19 @@ struct BackwardPassItems{T, U}
545552end
546553
547554function solve_all_children (
548- model:: PolicyGraph{T} , node:: Node{T} , items:: BackwardPassItems ,
549- belief:: Float64 , belief_state, objective_state,
550- outgoing_state:: Dict{Symbol, Float64} , scenario_path) where {T}
555+ model:: PolicyGraph{T} , node:: Node{T} , items:: BackwardPassItems ,
556+ belief:: Float64 , belief_state, objective_state,
557+ outgoing_state:: Dict{Symbol, Float64} ,
558+ backward_sampling_scheme:: AbstractBackwardSamplingScheme ,
559+ scenario_path
560+ ) where {T}
551561 length_scenario_path = length (scenario_path)
552562 for child in node. children
563+ if isapprox (child. probability, 0.0 , atol= 1e-6 )
564+ continue
565+ end
553566 child_node = model[child. term]
554- for noise in child_node. noise_terms
567+ for noise in sample_backward_noise_terms (backward_sampling_scheme, child_node)
555568 if length (scenario_path) == length_scenario_path
556569 push! (scenario_path, (child. term, noise. term))
557570 else
@@ -620,6 +633,9 @@ function calculate_bound(model::PolicyGraph{T},
620633
621634 # Solve all problems that are children of the root node.
622635 for child in model. root_children
636+ if isapprox (child. probability, 0.0 , atol= 1e-6 )
637+ continue
638+ end
623639 node = model[child. term]
624640 for noise in node. noise_terms
625641 if node. objective_state != = nothing
@@ -745,6 +761,9 @@ Train the policy for `model`. Keyword arguments:
745761 - `sampling_scheme`: a sampling scheme to use on the forward pass of the
746762 algorithm. Defaults to [`InSampleMonteCarlo`](@ref).
747763
764+ - `backward_sampling_scheme`: a backward pass sampling scheme to use on the
765+ backward pass of the algorithm. Defaults to `CompleteSampler`.
766+
748767 - `cut_type`: choose between `SDDP.SINGLE_CUT` and `SDDP.MULTI_CUT` versions of SDDP.
749768
750769 - `dashboard::Bool`: open a visualization of the training over time. Defaults
@@ -770,6 +789,7 @@ function train(
770789 cycle_discretization_delta:: Float64 = 0.0 ,
771790 refine_at_similar_nodes:: Bool = true ,
772791 cut_deletion_minimum:: Int = 1 ,
792+ backward_sampling_scheme:: AbstractBackwardSamplingScheme = SDDP. CompleteSampler (),
773793 dashboard:: Bool = false
774794)
775795 # Reset the TimerOutput.
@@ -812,6 +832,7 @@ function train(
812832 model,
813833 model. initial_root_state,
814834 sampling_scheme,
835+ backward_sampling_scheme,
815836 risk_measure,
816837 cycle_discretization_delta,
817838 refine_at_similar_nodes
0 commit comments