diff --git a/chirho/contrib/__init__.py b/chirho/contrib/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/chirho/contrib/compexp/__init__.py b/chirho/contrib/compexp/__init__.py
new file mode 100644
index 00000000..11237f5e
--- /dev/null
+++ b/chirho/contrib/compexp/__init__.py
@@ -0,0 +1,8 @@
+from .composeable_expectation.composed_expectation import ComposedExpectation
+from .composeable_expectation.expectation_atom import ExpectationAtom
+from .handlers.expectation_handler import ExpectationHandler
+from .handlers.importance_sampling_expectation_handler import ImportanceSamplingExpectationHandler
+from .handlers.montecarlo_expectation_handler import MonteCarloExpectationHandler
+from .typedecs import StochasticFunction, ExpectationFunction
+from .composeable_expectation import grad
+E = ExpectationAtom
diff --git a/chirho/contrib/compexp/composeable_expectation/__init__.py b/chirho/contrib/compexp/composeable_expectation/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/chirho/contrib/compexp/composeable_expectation/composed_expectation.py b/chirho/contrib/compexp/composeable_expectation/composed_expectation.py
new file mode 100644
index 00000000..726b8a55
--- /dev/null
+++ b/chirho/contrib/compexp/composeable_expectation/composed_expectation.py
@@ -0,0 +1,55 @@
+from typing import Callable, List, TYPE_CHECKING
+from ..typedecs import ModelType
+import torch
+from torch import Tensor as TT
+if TYPE_CHECKING:
+ from .expectation_atom import ExpectationAtom
+
+
+class ComposedExpectation:
+ def __init__(
+ self,
+ children: List["ComposedExpectation"],
+ op: Callable[[TT, ...], TT],
+ parts: List["ExpectationAtom"]
+ ):
+ self.op = op
+ self.children = children
+ self.parents: List[ComposedExpectation] = []
+ self.parts = parts
+
+ self._normalization_constant_cancels = False
+
+ def recursively_refresh_parts(self):
+ self.parts = None
+
+ for child in self.children:
+ child.recursively_refresh_parts()
+
+ self.parts = []
+ for child in self.children:
+ self.parts.extend(child.parts)
+
+ def __op_other(self, other: "ComposedExpectation", op) -> "ComposedExpectation":
+ ce = ComposedExpectation(
+ children=[self, other],
+ op=op,
+ parts=self.parts + other.parts)
+ self.parents.append(ce)
+ other.parents.append(ce)
+ return ce
+
+ def __truediv__(self, other: "ComposedExpectation") -> "ComposedExpectation":
+ return self.__op_other(other, torch.divide)
+
+ def __add__(self, other: "ComposedExpectation") -> "ComposedExpectation":
+ return self.__op_other(other, torch.add)
+
+ def __mul__(self, other: "ComposedExpectation") -> "ComposedExpectation":
+ return self.__op_other(other, torch.multiply)
+
+ def __sub__(self, other: "ComposedExpectation") -> "ComposedExpectation":
+ return self.__op_other(other, torch.subtract)
+
+ def __call__(self, p: ModelType) -> TT:
+ return self.op(*[child(p) for child in self.children])
diff --git a/chirho/contrib/compexp/composeable_expectation/expectation_atom.py b/chirho/contrib/compexp/composeable_expectation/expectation_atom.py
new file mode 100644
index 00000000..6cc8dc8e
--- /dev/null
+++ b/chirho/contrib/compexp/composeable_expectation/expectation_atom.py
@@ -0,0 +1,93 @@
+import pyro
+from typing import Optional
+import torch
+from torch import tensor as tt
+from torch import Tensor as TT
+from ..typedecs import ModelType, KWType
+from .composed_expectation import ComposedExpectation
+from ..typedecs import StochasticFunction
+from ..ops import _compute_expectation_atom
+
+
+class ExpectationAtom(ComposedExpectation):
+
+ def __init__(
+ self,
+ f: StochasticFunction, # TODO say in docstring that this has to have scalar valued output.
+ name: str,
+ log_fac_eps: float = 1e-45,
+ guide: Optional[ModelType] = None):
+
+ self.f = f
+ self.name = name
+ self.log_fac_eps = log_fac_eps
+ self.guide = guide
+
+ self._is_positive_everywhere = False
+
+ super().__init__(children=[], op=lambda v: v, parts=[self])
+
+ def recursively_refresh_parts(self):
+ self.parts = [self]
+
+ def __call__(self, p: ModelType) -> TT:
+ """
+ Overrides the non-atomic call to actually estimate the value of this expectation atom.
+ """
+ ret = _compute_expectation_atom(self, p)
+ if ret.ndim != 0:
+ raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {self.name} must return a scalar,"
+ f" but got {ret} instead.")
+ return ret
+
+ def build_pseudo_density(self, p: ModelType) -> ModelType:
+
+ if not self._is_positive_everywhere:
+ raise NotImplementedError("Non positive pseudo-density construction is not supported. "
+ f"Convert atom named {self.name} by using output of "
+ "CompoundedExpectation.split_into_positive_components().")
+
+ # This defines a new density that is the product of the density defined by p and this all-positive function
+ # we want to take the expectation wrt.
+ def pseudo_density() -> KWType:
+ stochastics = p()
+ factor_name = self.name + "_factor"
+ pyro.factor(factor_name, torch.log(self.log_fac_eps + self.f(stochastics)))
+ return stochastics
+
+ return pseudo_density
+
+ # TODO maybe rename this to get_tabi_decomposition.
+ def split_into_positive_components(
+ self,
+ # TODO bdt18dosjk maybe don't allow for guide specification, but rather handler specification that
+ # may specify a particular guide arrangement?
+ pos_guide: Optional[ModelType] = None,
+ neg_guide: Optional[ModelType] = None,
+ den_guide: Optional[ModelType] = None) -> "ComposedExpectation":
+
+ pos_part = ExpectationAtom(
+ f=lambda s: torch.relu(self.f(s)), name=self.name + "_split_pos", guide=pos_guide)
+ pos_part._is_positive_everywhere = True
+ neg_part = ExpectationAtom(
+ f=lambda s: torch.relu(-self.f(s)), name=self.name + "_split_neg", guide=neg_guide)
+ neg_part._is_positive_everywhere = True
+ den_part = ExpectationAtom(
+ lambda s: tt(1.), name=self.name + "_split_den", guide=den_guide)
+ den_part._is_positive_everywhere = True
+
+ ret: ComposedExpectation = (pos_part - neg_part) / den_part
+ ret._normalization_constant_cancels = True
+
+ return ret
+
+ def swap_self_for_other_child(self, other):
+ for parent in self.parents:
+ positions_as_child = [i for i, child in enumerate(parent.children) if child is self]
+ assert len(positions_as_child) >= 1, "This shouldn't be possible." \
+ " There's a reference mismatch with parents."
+ # Now, swap out the old atom with the new composite.
+ for pac in positions_as_child:
+ parent.children[pac] = other
+ other.parents.append(parent)
+ self.parents = []
diff --git a/chirho/contrib/compexp/composeable_expectation/grad.py b/chirho/contrib/compexp/composeable_expectation/grad.py
new file mode 100644
index 00000000..8b5311ad
--- /dev/null
+++ b/chirho/contrib/compexp/composeable_expectation/grad.py
@@ -0,0 +1,113 @@
+from typing import Callable
+from ..typedecs import KWType
+import torch
+from torch import Tensor as TT
+import warnings
+from .expectation_atom import ExpectationAtom
+from .composed_expectation import ComposedExpectation
+
+
+def _build_df_dd(dparams: TT, di: int, part: ExpectationAtom) -> Callable[[KWType], TT]:
+ def df_dd(stochastics: KWType) -> TT:
+ y: TT = part.f(stochastics)
+
+ if y.ndim != 0:
+ raise ValueError(f"Argument f to {ExpectationAtom.__name__} with name {part.name} must return a scalar,"
+ f" but got {y} instead.")
+
+ assert dparams[di].ndim == 0, "This shouldn't be possible due to the outer check of 1 dimension."
+
+ try:
+ df_ddparam, = torch.autograd.grad(
+ outputs=(y,),
+ # FIXME HACK Have to grad wrt the whole tensor apparently, and then index after.
+ inputs=(dparams,),
+ create_graph=True
+ )
+ df_ddparam = df_ddparam[di]
+ except RuntimeError as e:
+ if "does not require grad and does not have a grad_fn" in str(e):
+ # FIXME FIXME kf2801dgi1 this is only correct when this particular atom is a mul or div of one
+ # that does require grad. It would be nice to not have to repro autodiff here but
+ # somehow this needs to how the parent incorporates this. Maybe we could autodiff
+ # parent's op and see how this atom relates, then use that to determine what
+ # should be returned here?
+ warnings.warn(f"The gradient of atom named {part.name} with respect to dparam {di}"
+ f" is 0.0, but returning the original atom's value for now because"
+ f" it's probably a scaling factor. This is a hack and should be fixed."
+ f" See FIXME tagged kf2801dgi1.")
+ return y
+ else:
+ raise
+
+ assert df_ddparam.ndim == 0, "This shouldn't be possible due to out and in being 0 dimensional."
+
+ return df_ddparam
+
+ return df_dd
+
+
+# FIXME 7301ykd0sk See below. Want to conver to proper in place operation with no return value.
+def gradify_in_place_but_need_return(
+ output: ComposedExpectation, dparams: TT, split_atoms=False) -> ComposedExpectation:
+
+ if dparams.ndim != 1:
+ raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must be a 1d tensor, "
+ f"but got ndim {dparams.ndim} instead.")
+
+ assert len(output.parts) >= 1, "This shouldn't be possible due to composites always having at least one " \
+ "part (themselves)."
+
+ if not len(dparams) >= 1:
+ raise ValueError(f"Argument dparams to {gradify_in_place_but_need_return.__name__} must have at least one "
+ f"element, but got {len(dparams)} instead.")
+
+ # Only relevant if output is an atom. Just defining outside of loop so type checking is happy below.
+ sub_atom_composite = None
+
+ for part in output.parts:
+
+ sub_atoms = []
+
+ # Create a new atom for each of the old atoms.
+ for di, _ in enumerate(dparams):
+
+ # Create a new atom just for just this element of the gradient vector.
+ ea = ExpectationAtom(
+ f=_build_df_dd(dparams, di, part),
+ name=f"d{part.name}_dd{di}",
+ log_fac_eps=part.log_fac_eps
+ # TODO maybe seed a new guide with the original guide (if present)?
+ )
+
+ if split_atoms:
+ ea = ea.split_into_positive_components()
+
+ sub_atoms.append(ea)
+
+ # Create a composite that simply concatenates the new atoms into one tensor.
+ sub_atom_composite = ComposedExpectation(
+ children=sub_atoms,
+ op=lambda *v: torch.stack(v, dim=0),
+ # Note bm72gdi1: This will be updated before the return of this function.
+ parts=[]
+ )
+
+ for parent in part.parents:
+ positions_as_child = [i for i, child in enumerate(parent.children) if child is part]
+ assert len(positions_as_child) >= 1, "This shouldn't be possible." \
+ " There's a reference mismatch with parents."
+ # Now, swap out the old atom with the new composite.
+ for pac in positions_as_child:
+ parent.children[pac] = sub_atom_composite
+ sub_atom_composite.parents.append(parent)
+
+ if isinstance(output, ExpectationAtom):
+ assert output.parts[0] and len(output.parts) == 1, "This shouldn't be possible: atom code broken?"
+ # FIXME 7301ykd0sk this is why you have to take the return value here...
+ output = sub_atom_composite
+
+ # Note bm72gdi1 this ensures that the cached part list is up-to-date.
+ output.recursively_refresh_parts()
+
+ return output
diff --git a/chirho/contrib/compexp/handlers/__init__.py b/chirho/contrib/compexp/handlers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/chirho/contrib/compexp/handlers/expectation_handler.py b/chirho/contrib/compexp/handlers/expectation_handler.py
new file mode 100644
index 00000000..2737ce52
--- /dev/null
+++ b/chirho/contrib/compexp/handlers/expectation_handler.py
@@ -0,0 +1,12 @@
+import pyro
+
+
+class ExpectationHandler(pyro.poutine.messenger.Messenger):
+
+ # noinspection PyMethodMayBeStatic
+ def _pyro_compute_expectation_atom(self, msg) -> None:
+ if msg["done"]:
+ # TODO bdt18dosjk Do something similar to the demo1 setup where we define an OOP interface as well.
+ # Then handler can be passed to expectation atoms so that users can specify different handlers
+ # for different atoms.
+ raise RuntimeError("Only one default expectation handler can be in effect at a time.")
diff --git a/chirho/contrib/compexp/handlers/guide_registration_mixin.py b/chirho/contrib/compexp/handlers/guide_registration_mixin.py
new file mode 100644
index 00000000..18fd7b62
--- /dev/null
+++ b/chirho/contrib/compexp/handlers/guide_registration_mixin.py
@@ -0,0 +1,173 @@
+import pyro
+from typing import Callable, List, Tuple, Type, Optional
+import torch
+from pyro.infer.autoguide import AutoGuide
+from ..typedecs import ModelType, KWType
+from ..composeable_expectation.composed_expectation import ComposedExpectation
+from ..composeable_expectation.expectation_atom import ExpectationAtom
+
+
+class _GuideRegistrationMixin:
+ def __init__(self):
+ self.pseudo_densities = dict()
+ self.guides = dict()
+ self.registered_model = None
+
+ def optimize_guides(self, lr: float, n_steps: int,
+ adjust_grads_: Callable[[torch.nn.Parameter, ...], None] = None,
+ callback: Optional[Callable[[str, int], None]] = None,
+ start_scale=1.0):
+ if not len(self.keys()):
+ raise ValueError("No guides registered. Did you call "
+ f"{_GuideRegistrationMixin.__name__}.register_guides?")
+
+ for k in self.keys():
+ pseudo_density = self.pseudo_densities[k]
+ guide = self.guides[k]
+
+ current_scale = start_scale
+
+ def scaling_pseudo_density() -> KWType:
+ return pyro.poutine.scale(pseudo_density, scale=current_scale)()
+
+ # noinspection PyTypeChecker
+ elbo = pyro.infer.Trace_ELBO()(scaling_pseudo_density, guide)
+ elbo() # Call to surface parameters for optimizer.
+ optim = torch.optim.ASGD(elbo.parameters(), lr=lr)
+
+ for i in range(n_steps):
+
+ # Scale from start_scale up to 1.
+ current_scale = start_scale + (1. - start_scale) * (i / n_steps)
+
+ for param in elbo.parameters():
+ param.grad = None
+ optim.zero_grad()
+
+ loss = elbo()
+ loss.backward()
+
+ if adjust_grads_ is not None:
+ adjust_grads_(*tuple(elbo.parameters()))
+
+ if callback is not None:
+ callback(k, i)
+
+ optim.step()
+
+ def register_guides(self, ce: ComposedExpectation, model: ModelType,
+ auto_guide: Optional[Type[AutoGuide]], auto_guide_kwargs=None,
+ allow_repeated_names=False):
+ self.clear_guides()
+
+ if auto_guide_kwargs is None:
+ auto_guide_kwargs = dict()
+ else:
+ if auto_guide is None:
+ raise ValueError("auto_guide_kwargs provided but no auto_guide class provided. Did you mean to "
+ "provide an auto_guide class?")
+
+ for part in ce.parts:
+ pseudo_density = part.build_pseudo_density(model)
+ if part.guide is not None:
+ guide = part.guide
+
+ if tuple(guide().keys()) != tuple(model().keys()):
+ raise ValueError("A preset guide must return the same variables as the model, but got "
+ f"{tuple(guide().keys())} and {tuple(model().keys())} instead.")
+
+ else:
+ if auto_guide is None:
+ raise ValueError("No guide preregistered and no no auto guide class provided.")
+ guide = auto_guide(model, **auto_guide_kwargs)
+
+ if not allow_repeated_names:
+ if part.name in self.pseudo_densities:
+ raise ValueError(f"Repeated part name {part.name}.")
+ if part.name in self.guides:
+ raise ValueError(f"Repeated part name {part.name}.")
+
+ if part.name not in self.pseudo_densities:
+ self.pseudo_densities[part.name] = pseudo_density
+
+ if part.name not in self.guides:
+ self.guides[part.name] = guide
+
+ self.registered_model = model
+
+ def keys(self) -> frozenset:
+ assert tuple(self.pseudo_densities.keys()) == tuple(self.guides.keys()), "Should not be possible b/c these" \
+ " are added to at the same time."
+ return frozenset(self.pseudo_densities.keys())
+
+ def clear_guides(self):
+ self.pseudo_densities = dict()
+ self.guides = dict()
+ self.registered_model = None
+
+ def _get_pq(self, ea: "ExpectationAtom", p: ModelType) -> Tuple[ModelType, ModelType]:
+ try:
+ q: ModelType = self.guides[ea.name]
+ except KeyError:
+ raise KeyError(f"No guide registered for {ea.name}. "
+ f"Did you call {_GuideRegistrationMixin.__name__}.register_guides?")
+
+ if p is not self.registered_model:
+ raise ValueError("The probability distribution registered with the guides does not match the "
+ "probability distribution called to compute the expectation. In other words,"
+ f"the same p must be used in {_GuideRegistrationMixin.__name__}"
+ f".register_guides and {ComposedExpectation.__name__}.__call__.")
+
+ return p, q
+
+ def plot_guide_pseudo_likelihood(
+ self, rv_name: str, guide_kde_kwargs, pseudo_density_plot_kwargs, keys: List[str] = None):
+ # TODO move this to a separate class and inherit or something, just so plotting code doesn't clutter
+ # up functional code.
+ import seaborn as sns
+ import matplotlib.pyplot as plt
+
+ if not len(self.keys()):
+ raise ValueError("No guides registered. Did you call "
+ f"{_GuideRegistrationMixin.__name__}.register_guides?")
+
+ figs = []
+
+ if keys is None:
+ keys = self.keys()
+
+ for k in keys:
+ pseudo_density = self.pseudo_densities[k]
+ guide = self.guides[k]
+
+ if self.registered_model()[rv_name].ndim != 0:
+ raise ValueError("Can only plot pseudo likelihood/guide comparisons for univariates.")
+
+ fig, ax = plt.subplots(1, 1)
+ sns.kdeplot([guide()[rv_name].item() for _ in range(1000)], label="guide", **guide_kde_kwargs)
+
+ tax = ax.twinx()
+
+ model_samples = torch.tensor([self.registered_model()[rv_name] for _ in range(1000)])
+ xx = torch.linspace(model_samples.min(), model_samples.max(), 1000).detach()
+
+ lps = []
+ for x in xx:
+ cm = pyro.poutine.condition(pseudo_density, data={rv_name: x})
+ lp = pyro.poutine.trace(cm).get_trace().log_prob_sum()
+ lps.append(lp)
+ lps = torch.tensor(lps).exp().detach().numpy()
+
+ # This will be squiggly if there are other latents. TODO smooth?
+ tax.plot(xx, lps, label="pseudo-density", **pseudo_density_plot_kwargs)
+
+ ax.set_title(f"q and pseudo-p for {rv_name} and part {k}")
+
+ # Add single legend for both.
+ lines, labels = ax.get_legend_handles_labels()
+ lines2, labels2 = tax.get_legend_handles_labels()
+ ax.legend(lines + lines2, labels + labels2, loc=0)
+
+ figs.append(fig)
+
+ return figs
diff --git a/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py b/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py
new file mode 100644
index 00000000..b7114e5f
--- /dev/null
+++ b/chirho/contrib/compexp/handlers/importance_sampling_expectation_handler.py
@@ -0,0 +1,35 @@
+from ..composeable_expectation.expectation_atom import ExpectationAtom
+from .expectation_handler import ExpectationHandler
+from .guide_registration_mixin import _GuideRegistrationMixin
+import pyro
+import torch
+from ..utils import msg_args_kwargs_to_kwargs, kft
+from ..typedecs import ModelType
+
+
+class ImportanceSamplingExpectationHandler(ExpectationHandler, _GuideRegistrationMixin):
+
+ def __init__(self, num_samples: int):
+ super().__init__()
+ self.num_samples = num_samples
+
+ def _pyro__compute_expectation_atom(self, msg) -> None:
+ super()._pyro_compute_expectation_atom(msg)
+
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+
+ ea: ExpectationAtom = kwargs.pop("ea")
+ p: ModelType = kwargs["p"]
+ p, q = self._get_pq(ea, p)
+
+ fpqvals = []
+
+ for _ in range(self.num_samples):
+ qtr = pyro.poutine.trace(q).get_trace()
+ ptr = pyro.poutine.trace(pyro.poutine.replay(p, trace=qtr)).get_trace()
+ s = kft(qtr)
+
+ fpqval = ptr.log_prob_sum() + torch.log(ea.log_fac_eps + ea.f(s)) - qtr.log_prob_sum()
+ fpqvals.append(torch.exp(fpqval))
+
+ msg["value"] = torch.mean(torch.stack(fpqvals), dim=0)
diff --git a/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py b/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py
new file mode 100644
index 00000000..47b7ff53
--- /dev/null
+++ b/chirho/contrib/compexp/handlers/montecarlo_expectation_handler.py
@@ -0,0 +1,25 @@
+from .expectation_handler import ExpectationHandler
+import torch
+from ..composeable_expectation.expectation_atom import ExpectationAtom
+from ..typedecs import ModelType
+from ..utils import msg_args_kwargs_to_kwargs
+
+
+class MonteCarloExpectationHandler(ExpectationHandler):
+
+ def __init__(self, num_samples: int):
+ self.num_samples = num_samples
+
+ def _pyro__compute_expectation_atom(self, msg) -> None:
+ super()._pyro_compute_expectation_atom(msg)
+
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ ea: ExpectationAtom = kwargs.pop("ea")
+ p: ModelType = kwargs["p"]
+
+ fvals = []
+ for _ in range(self.num_samples):
+ s = p()
+ fvals.append(ea.f(s))
+
+ msg["value"] = torch.mean(torch.stack(fvals), dim=0)
diff --git a/chirho/contrib/compexp/ops.py b/chirho/contrib/compexp/ops.py
new file mode 100644
index 00000000..95010ced
--- /dev/null
+++ b/chirho/contrib/compexp/ops.py
@@ -0,0 +1,11 @@
+import pyro
+from .handlers.expectation_handler import ExpectationHandler
+from .typedecs import ModelType
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from .composeable_expectation.expectation_atom import ExpectationAtom
+
+
+@pyro.poutine.runtime.effectful(type="_compute_expectation_atom")
+def _compute_expectation_atom(ea: "ExpectationAtom", p: ModelType):
+ raise NotImplementedError(f"Must be called in the context of an {ExpectationHandler.__name__}.")
diff --git a/chirho/contrib/compexp/typedecs.py b/chirho/contrib/compexp/typedecs.py
new file mode 100644
index 00000000..0fa6245b
--- /dev/null
+++ b/chirho/contrib/compexp/typedecs.py
@@ -0,0 +1,11 @@
+from typing import Callable
+from torch import Tensor
+from torch.nn import Parameter
+from collections import OrderedDict
+
+KWType = OrderedDict[str, Tensor]
+KWTypeNNParams = OrderedDict[str, Parameter]
+ModelType = Callable[[], KWType]
+
+StochasticFunction = Callable[[KWType], Tensor]
+ExpectationFunction = Callable[[ModelType], Tensor]
diff --git a/chirho/contrib/compexp/utils.py b/chirho/contrib/compexp/utils.py
new file mode 100644
index 00000000..6bab1b1a
--- /dev/null
+++ b/chirho/contrib/compexp/utils.py
@@ -0,0 +1,41 @@
+import torch
+from torch import tensor as tt, Tensor as TT
+from collections import OrderedDict
+import inspect
+from .typedecs import KWType, KWTypeNNParams
+
+
+def flatten_dparams(dparams: KWTypeNNParams) -> TT:
+ return torch.cat([torch.flatten(dparams[k]) for k in dparams.keys()])
+
+
+def unflatten_df_dparams(cat_df_dparams: TT, dparams: KWTypeNNParams) -> KWType:
+ df_dparams = OrderedDict()
+
+ last_fidx = 0
+ for k in dparams.keys():
+ slice_len = dparams[k].numel()
+ cat_df_dparams_slice = cat_df_dparams[last_fidx:last_fidx + slice_len]
+ last_fidx += slice_len
+
+ df_dparams[k] = torch.unflatten(cat_df_dparams_slice, dim=0, sizes=dparams[k].shape)
+
+ return df_dparams
+
+
+def msg_args_kwargs_to_kwargs(msg):
+
+ ba = inspect.signature(msg["fn"]).bind(*msg["args"], **msg["kwargs"])
+ ba.apply_defaults()
+
+ return ba.arguments
+
+
+# Keyword arguments From Trace
+def kft(trace) -> KWType:
+ # Copy the ordereddict.
+ new_trace = OrderedDict(trace.nodes.items())
+ # Remove the _INPUT and _RETURN nodes.
+ del new_trace["_INPUT"]
+ del new_trace["_RETURN"]
+ return OrderedDict(zip(new_trace.keys(), [v["value"] for v in new_trace.values()]))
diff --git a/docs/source/expectation_programming/README.md b/docs/source/expectation_programming/README.md
new file mode 100644
index 00000000..dc35cd6a
--- /dev/null
+++ b/docs/source/expectation_programming/README.md
@@ -0,0 +1,2 @@
+The demo scripts in this folder need to be converted and reduced into some notebooks.
+It contains two still-functional old demos that don't use the compositional expectation machinery.
diff --git a/docs/source/expectation_programming/demo_compexp.py b/docs/source/expectation_programming/demo_compexp.py
new file mode 100644
index 00000000..615ffa2e
--- /dev/null
+++ b/docs/source/expectation_programming/demo_compexp.py
@@ -0,0 +1,336 @@
+# TODO Needs supporting stuff ported here. Also break out into notebook and tests.
+
+
+import pyro
+import torch
+from torch import tensor as tt
+from pyro.infer.autoguide import AutoGuide
+
+from toy_tabi_problem import (
+ cost,
+ model as model_ttp,
+ q_optimal_normal_guide_mean_var,
+ MODEL_DIST as MODEL_TTP_DIST
+ )
+import pyro.distributions as dist
+from collections import OrderedDict
+import matplotlib.pyplot as plt
+import numpy as np
+import chirho.contrib.compexp as compexp
+
+import old_ep_demo_scratch as stor
+
+pyro.settings.set(module_local_params=True)
+
+
+def main():
+ # noinspection PyPep8Naming
+ D = tt(0.5)
+ # noinspection PyPep8Naming
+ C = tt(1.0)
+ # noinspection PyPep8Naming
+ GT = -1.1337
+
+ # Make sure we match a ground truth value.
+ toy_opt_guidep = dist.Normal(*q_optimal_normal_guide_mean_var(d=D, c=C, z=False))
+ toy_opt_guiden = dist.Normal(*q_optimal_normal_guide_mean_var(d=D, c=C, z=True))
+
+ # decomposed toy expectation
+ # noinspection PyPep8Naming
+ ttE: compexp.ComposedExpectation = compexp.E(
+ f=lambda s: 1e1*cost(d=tt(0.5), c=tt(1.), **s), name='ttE'
+ ).split_into_positive_components(
+ # TODO bdt18dosjk
+ pos_guide=lambda: OrderedDict(x=pyro.sample('x', toy_opt_guidep)),
+ neg_guide=lambda: OrderedDict(x=pyro.sample('x', toy_opt_guiden)),
+ den_guide=lambda: OrderedDict(x=pyro.sample('x', MODEL_TTP_DIST))
+ )
+
+ # Just to make sure everything still works after this is called.
+ ttE.recursively_refresh_parts()
+
+ #
+ mcestimates = []
+ for _ in range(1000):
+ with compexp.MonteCarloExpectationHandler(num_samples=100):
+ mcestimates.append(ttE(model_ttp))
+
+ plt.suptitle("MC Estimate")
+ plt.hist(mcestimates, bins=30)
+ mcestimate = torch.mean(torch.stack(mcestimates))
+ plt.axvline(x=mcestimate, color='r')
+ plt.axvline(x=GT, color='black', linestyle='--')
+ plt.suptitle(f"MC Estimate: {mcestimate}\n GT = {GT}")
+ plt.show()
+ #
+
+ #
+ # When using the decomposition above with per-atom guides and importance sampling, we get TABI. Because
+ # we've preset the optimal guides above, we will get an exact estimate.
+ iseh = compexp.ImportanceSamplingExpectationHandler(num_samples=1)
+ iseh.register_guides(ce=ttE, model=model_ttp, auto_guide=None, auto_guide_kwargs=None)
+ with iseh:
+ tabiestimate = ttE(model_ttp)
+ print(f"TABI Estimate: {tabiestimate}", f"GT = {GT}")
+ assert torch.isclose(tabiestimate, tt(GT), atol=1e-4)
+ #
+
+ #
+
+ ttE = compexp.E(
+ f=lambda s: 1e1 * cost(d=tt(0.5), c=tt(1.), **s), name='ttE'
+ ).split_into_positive_components()
+
+ iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=300)
+ iseh2.register_guides(
+ ce=ttE,
+ model=model_ttp,
+ auto_guide=pyro.infer.autoguide.AutoNormal,
+ auto_guide_kwargs=dict(init_scale=2.))
+
+ def plot_callback_(k, i):
+ if i % 5000 == 0:
+ figs = iseh2.plot_guide_pseudo_likelihood(
+ rv_name='x',
+ guide_kde_kwargs=dict(bw_method=0.1, color='orange'),
+ pseudo_density_plot_kwargs=dict(color='purple'),
+ keys=[k] if k is not None else None
+ )
+ plt.show()
+ for f in figs:
+ plt.close(f)
+
+ iseh2.optimize_guides(
+ lr=5e-3, n_steps=10001,
+ callback=plot_callback_
+ )
+
+ tabi_learned_estimates = []
+ for _ in range(100):
+ with iseh2:
+ tabi_learned_estimates.append(ttE(model_ttp))
+ tabi_learned_estimates = torch.stack(tabi_learned_estimates).detach().numpy()
+
+ plt.suptitle("TABI Gradient Estimate")
+ plt.hist(tabi_learned_estimates, bins=30)
+ plt.axvline(x=np.mean(tabi_learned_estimates), color='r')
+ plt.axvline(x=GT, color='black', linestyle='--')
+ plt.suptitle(f"TABI Gradient Estimate: {np.mean(tabi_learned_estimates)}\n GT = {GT}")
+ plt.show()
+
+ #
+
+ #
+ dps = torch.nn.Parameter(tt([0.5]))
+
+ # noinspection PyPep8Naming
+ ttgradE1: compexp.ComposedExpectation = compexp.E(
+ # f=lambda s: 1e1*cost(d=dps[0] * dps[1], c=tt(1.), **s), name='ttgradE'
+ f=lambda s: 1e1*cost(d=dps[0], c=tt(1.), **s), name='ttgradE'
+ ).split_into_positive_components()
+ ttgradE1 = compexp.grad.gradify_in_place_but_need_return(ttgradE1, dparams=dps)
+
+ mc_grad_estimates = []
+ for _ in range(1000):
+ with compexp.MonteCarloExpectationHandler(num_samples=100):
+ mc_grad_estimates.append(ttgradE1(model_ttp))
+ mc_grad_estimates = torch.stack(mc_grad_estimates).detach().numpy().T[0]
+
+ plt.suptitle("MC Gradient Estimate")
+ plt.hist(mc_grad_estimates, bins=30)
+ plt.axvline(x=np.mean(mc_grad_estimates), color='r')
+ plt.axvline(x=2.358, color='black', linestyle='--')
+ plt.suptitle(f"MC Gradient Estimate: {np.mean(mc_grad_estimates)}\n GT = {2.358}")
+ plt.show()
+ #
+
+ #
+ dps = torch.nn.Parameter(tt([0.25, 2.]))
+
+ # noinspection PyPep8Naming
+ ttgradE2: compexp.ComposedExpectation = compexp.E(
+ f=lambda s: 1e1*cost(d=dps[0] * dps[1], c=tt(1.), **s), name='ttgradE'
+ ).split_into_positive_components()
+ ttgradE1 = compexp.grad.gradify_in_place_but_need_return(ttgradE2, dparams=dps)
+
+ with compexp.MonteCarloExpectationHandler(num_samples=10):
+ print(ttgradE2(model_ttp), "GT unknown but it runs")
+ #
+
+ #
+ dps = torch.nn.Parameter(tt([0.5]))
+
+ ttgrad_tabi_unfit = compexp.E(
+ f=lambda s: 1e1*cost(d=dps[0], c=tt(1.), **s), name='ttgradE'
+ )
+ ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True)
+
+ iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=50)
+ iseh2.register_guides(
+ ce=ttgrad_tabi_unfit,
+ model=model_ttp,
+ auto_guide=pyro.infer.autoguide.AutoNormal,
+ auto_guide_kwargs=dict(init_scale=1.5))
+ iseh2.plot_guide_pseudo_likelihood(
+ rv_name='x',
+ guide_kde_kwargs=dict(bw_method=0.1, color='orange'),
+ pseudo_density_plot_kwargs=dict(color='purple')
+ )
+ plt.show()
+
+ tabi_unlearned_grad_estimates = []
+ for _ in range(1000):
+ with iseh2:
+ tabi_unlearned_grad_estimates.append(ttgrad_tabi_unfit(model_ttp))
+ tabi_unlearned_grad_estimates = torch.stack(tabi_unlearned_grad_estimates).detach().numpy()
+
+ plt.suptitle("TABI Gradient Estimate")
+ plt.hist(tabi_unlearned_grad_estimates, bins=30)
+ plt.axvline(x=np.mean(tabi_unlearned_grad_estimates), color='r')
+ plt.axvline(x=2.358, color='black', linestyle='--')
+ plt.suptitle(f"TABI Gradient Estimate: {np.mean(tabi_unlearned_grad_estimates)}\n GT = {2.358}")
+ plt.show()
+ #
+
+ #
+ dps = torch.nn.Parameter(tt([0.5]))
+
+ ttgrad_tabi_unfit = compexp.E(
+ f=lambda s: 1e1 * cost(d=dps[0], c=tt(1.), **s), name='ttgradE'
+ )
+ ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True)
+
+ iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=300)
+ iseh2.register_guides(
+ ce=ttgrad_tabi_unfit,
+ model=model_ttp,
+ auto_guide=pyro.infer.autoguide.AutoNormal,
+ auto_guide_kwargs=dict(init_scale=1.))
+
+ # Make guides that are roughly in the correct positions.
+ pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=False)
+ neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=True)
+ pos_guide_grad = stor.MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item(), neg_istd_grad.item()]
+ )
+ neg_guide_grad = stor.MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item(), neg_istd_grad.item()]
+ )
+
+ # TODO is this a good pattern for specifying guides? If so should add some more runtime checking to make sure
+ # e.g. names exist and that the guides spit out the right stochastics.
+ iseh2.guides['dttgradE_dd0_split_pos'] = pos_guide_grad
+ iseh2.guides['dttgradE_dd0_split_neg'] = neg_guide_grad
+ # Leave 'dttgradE_dd0_split_den' as default AutoNormal.
+
+ iseh2.plot_guide_pseudo_likelihood(
+ rv_name='x',
+ guide_kde_kwargs=dict(bw_method=0.1, color='orange'),
+ pseudo_density_plot_kwargs=dict(color='purple')
+ )
+ plt.show()
+
+ tabi_unlearned_good_grad_estimates = []
+ for _ in range(100):
+ with iseh2:
+ tabi_unlearned_good_grad_estimates.append(ttgrad_tabi_unfit(model_ttp))
+ tabi_unlearned_good_grad_estimates = torch.stack(tabi_unlearned_good_grad_estimates).detach().numpy()
+
+ plt.suptitle("TABI Unlearned Good Gradient Estimate")
+ plt.hist(tabi_unlearned_good_grad_estimates, bins=30)
+ plt.axvline(x=np.mean(tabi_unlearned_good_grad_estimates), color='r')
+ plt.axvline(x=2.358, color='black', linestyle='--')
+ plt.suptitle(f"TABI Unlearned Good Gradient Estimate: {np.mean(tabi_unlearned_good_grad_estimates)}\n GT = {2.358}")
+ plt.show()
+ #
+
+ #
+ dps = torch.nn.Parameter(tt([0.5]))
+
+ ttgrad_tabi_unfit = compexp.E(
+ f=lambda s: 1e1 * cost(d=dps[0], c=tt(1.), **s), name='ttgradE'
+ ) # .split_into_positive_components()
+ ttgrad_tabi_unfit = compexp.grad.gradify_in_place_but_need_return(ttgrad_tabi_unfit, dparams=dps, split_atoms=True)
+
+ # # HACK any atom with the word "den" in it is actually the same atom.
+ # broadcasting_den = None
+ # for part in ttgrad_tabi_unfit.parts:
+ # if 'den' in part.name:
+ # if broadcasting_den is None:
+ # broadcasting_den = part
+ # else:
+ # part.swap_self_for_other_child(broadcasting_den)
+ # ttgrad_tabi_unfit.recursively_refresh_parts()
+
+ iseh2 = compexp.ImportanceSamplingExpectationHandler(num_samples=50)
+ iseh2.register_guides(
+ ce=ttgrad_tabi_unfit,
+ model=model_ttp,
+ auto_guide=pyro.infer.autoguide.AutoNormal,
+ auto_guide_kwargs=dict(init_scale=2.))
+
+ # Start the guides roughly in the correct positions.
+ pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=False)
+ neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(d=tt(0.5), c=tt(1.0), z=True)
+ pos_guide_grad = stor.MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item() * 2., neg_istd_grad.item() * 2.],
+ studentt=True
+ )
+ neg_guide_grad = stor.MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item() * 2., neg_istd_grad.item() * 2.],
+ studentt=True
+ )
+
+ # TODO is this a good pattern for specifying guides? If so should add some more runtime checking to make sure
+ # e.g. names exist and that the guides spit out the right stochastics.
+ iseh2.guides['dttgradE_dd0_split_pos'] = pos_guide_grad
+ iseh2.guides['dttgradE_dd0_split_neg'] = neg_guide_grad
+ # Leave 'dttgradE_dd0_split_den' as default AutoNormal.
+
+ NSTEPS = 10000
+
+ def plot_callback_(k, i):
+ if i % NSTEPS == 0:
+ figs = iseh2.plot_guide_pseudo_likelihood(
+ rv_name='x',
+ guide_kde_kwargs=dict(bw_method=0.1, color='orange'),
+ pseudo_density_plot_kwargs=dict(color='purple'),
+ keys=[k] if k is not None else None
+ )
+ plt.show()
+ for f in figs:
+ plt.close(f)
+
+ # plot_callback_(None, 0)
+
+ iseh2.optimize_guides(
+ lr=1e-3, n_steps=NSTEPS + 1,
+ adjust_grads_=stor.abort_guide_grads_,
+ callback=plot_callback_
+ )
+
+ tabi_learned_grad_estimates = []
+ for _ in range(100):
+ with iseh2:
+ tabi_learned_grad_estimates.append(ttgrad_tabi_unfit(model_ttp))
+ tabi_learned_grad_estimates = torch.stack(tabi_learned_grad_estimates).detach().numpy()
+
+ plt.suptitle("TABI Learned Gradient Estimate")
+ plt.hist(tabi_learned_grad_estimates, bins=30)
+ plt.axvline(x=np.mean(tabi_learned_grad_estimates), color='r')
+ plt.axvline(x=2.358, color='black', linestyle='--')
+ plt.suptitle(f"TABI Learned Gradient Estimate: {np.mean(tabi_learned_grad_estimates)}\n GT = {2.358}")
+ plt.show()
+ #
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/expectation_programming/old_ep_demo_scratch.py b/docs/source/expectation_programming/old_ep_demo_scratch.py
new file mode 100644
index 00000000..d21a8acd
--- /dev/null
+++ b/docs/source/expectation_programming/old_ep_demo_scratch.py
@@ -0,0 +1,1191 @@
+import pyro
+import torch
+from pyro.infer.autoguide import AutoMultivariateNormal
+import pyro.distributions as dist
+import matplotlib.pyplot as plt
+import seaborn as sns
+import inspect
+from torch import tensor as tt
+
+from collections import OrderedDict
+
+from typing import (
+ Callable,
+ TypeVar,
+ Optional,
+ Dict,
+ List,
+ Generic,
+ Union
+)
+
+# from typing_extensions import Unpack
+
+import numpy as np
+import functools
+
+from pyro.poutine.replay_messenger import ReplayMessenger
+
+pyro.settings.set(module_local_params=True)
+
+
+# FIXME Use the actual type from pyro. Gotta find ref. Just using this to force type checking.
+class TraceType:
+ @property
+ def nodes(self) -> Dict[str, Dict[str, torch.Tensor]]:
+ raise NotImplementedError()
+
+ def log_prob_sum(self) -> torch.Tensor:
+ raise NotImplementedError()
+
+
+KWType = OrderedDict[str, torch.Tensor]
+KWTypeNNParams = OrderedDict[str, torch.nn.Parameter]
+KWTypeWithTrace = OrderedDict[str, Union[torch.Tensor, TraceType]]
+
+# Model should involv epyro primitives.
+ModelType = Callable[[], KWType]
+
+# This should not involve pyro primitives
+# This tuple return is primarily to support gradients of expectations.
+ExpectigrandType = Callable[[KWType], KWType]
+
+# This function takes both decision parameters and stochastics, and returns a tuple of things.
+# This tuple return is primarily to support gradients of expectations.
+DecisionExpectigrandType = Callable[[KWTypeNNParams, KWType], KWType]
+
+
+# noinspection PyUnusedLocal
+@pyro.poutine.runtime.effectful(type="expectation")
+def expectation(
+ f: ExpectigrandType,
+ p: ModelType) -> KWType:
+ raise NotImplementedError("Requires an expectation handler to function.")
+
+
+@pyro.poutine.runtime.effectful(type="build_expectigrand_gradient")
+def build_expectigrand_gradient(
+ dparams: KWTypeNNParams,
+ f: DecisionExpectigrandType
+) -> ExpectigrandType:
+
+ # Making this an effectful operation allows constraints to 1) add their gradients to the decision
+ # parameters and 2) add any auxiliary parameters to dparams that are required to represent themselves
+ # as an unconstrained problem.
+
+ def df_dd(stochastics: KWType) -> KWType:
+
+ y: KWType = f(dparams, stochastics)
+
+ if len(y) != 1:
+ # TODO eventually support multiple outputs? Will probably want to do this
+ # once constraints get properly involved.
+ raise ValueError("Decision function must return a single tensor value.")
+
+ ddparams = torch.autograd.grad(
+ outputs=list(y.values()),
+ inputs=list(dparams.values()),
+ create_graph=True)
+
+ return OrderedDict(zip((_gradify_dparam_name(k) for k in dparams.keys()), ddparams))
+
+ return df_dd
+
+
+def _gradify_dparam_name(name: str) -> str:
+ return f"df_d{name}"
+
+
+# noinspection PyUnusedLocal
+@pyro.poutine.runtime.effectful(type="optimize_decision")
+def optimize_decision(
+ f: DecisionExpectigrandType,
+ p: ModelType,
+ terminal_condition: Callable[[KWTypeNNParams, int], bool],
+ # Note the difference between this signature and the one in optimize_proposal. TODO unify.
+ adjust_grads: Optional[Callable[[KWType], KWType]] = None,
+ callback: Optional[Callable[[], None]] = None
+) -> KWType:
+ raise NotImplementedError("Requires both an expectation and optimizer handler.")
+
+
+def msg_args_kwargs_to_kwargs(msg):
+
+ ba = inspect.signature(msg["fn"]).bind(*msg["args"], **msg["kwargs"])
+ ba.apply_defaults()
+
+ return ba.arguments
+
+
+class ExpectationHandler(pyro.poutine.messenger.Messenger):
+
+ def _pyro_expectation(self, msg) -> None:
+ if msg["done"]:
+ raise ValueError("You may be operating in a context with more than one expectation handler. In these"
+ " cases, you must explicitly specify which expectation handler to use by using the OOP"
+ " style call with the desired handler (e.g. `ExpectationHandler(...).expectation(...)`).")
+
+ def expectation(self, *args, **kwargs):
+ # Calling this method blocks all other expectation handlers and uses only this one.
+ with self:
+ with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ExpectationHandler) and m is not self):
+ return expectation(*args, **kwargs)
+
+ def optimize_proposal(self, *args, **kwargs):
+ # Calling this method blocks all other expectation handlers and uses only this one.
+ with self:
+ with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ExpectationHandler) and m is not self):
+ return optimize_proposal(*args, **kwargs)
+
+
+# Keyword arguments From Trace
+def kft(trace: TraceType) -> KWType:
+ # Copy the ordereddict.
+ new_trace = OrderedDict(trace.nodes.items())
+ # Remove the _INPUT and _RETURN nodes.
+ del new_trace["_INPUT"]
+ del new_trace["_RETURN"]
+ return OrderedDict(zip(new_trace.keys(), [v["value"] for v in new_trace.values()]))
+
+
+DT = TypeVar("DT")
+DTr = TypeVar("DTr")
+
+
+class OppableDictParent: # FIXME 28dj10dlk
+ pass
+
+
+# TODO inherit from dict to just extend relevant methods, but couldn't figure out proper
+# inheritance of generics.
+class OppableDict(OppableDictParent, Generic[DT]):
+ """
+ Helper class for executing operations on all the values of a dictionary.
+ Heavily typed to make sure everything is lining up.
+ """
+
+ def __init__(self, d: Optional[OrderedDict[str, DT]] = None):
+ if d is None:
+ d: OrderedDict[str, DT] = OrderedDict()
+
+ self._d = d
+
+ def __getitem__(self, item: str) -> DT:
+ return self._d[item]
+
+ def __contains__(self, item: str) -> bool:
+ return item in self._d
+
+ def __setitem__(self, key: str, value: DT) -> None:
+ self._d[key] = value
+
+ def items(self):
+ return self._d.items()
+
+ def op(self, f: Callable[[DT], DTr]) -> 'OppableDict[DTr]':
+ ret = OppableDict()
+
+ for k, v in self.items():
+ ret[k] = f(v)
+
+ return ret
+
+ def op_other(
+ self,
+ f: Callable[[DT, ...], DTr],
+ *others: 'OppableDict[DT]') -> 'OppableDict[DTr]':
+ ret = OppableDict()
+
+ for k, v in self.items():
+ ret[k] = f(v, *tuple(o[k] for o in others))
+
+ return ret
+
+ @functools.singledispatchmethod
+ def __sub__(self, other):
+ raise NotImplementedError()
+
+ @__sub__.register(Union[float, torch.Tensor])
+ def _(self, other: Union[float, torch.Tensor]) -> "OppableDict":
+ return self.op(lambda v: v - other)
+
+ # FIXME 28dj10dlk Once https://github.com/python/cpython/issues/86153 drops.
+
+ # @__add__.register(OppableDict) # FIXME 28dj10dlk
+ @__sub__.register(OppableDictParent)
+ # def _(self, other: "OppableDict") -> "OppableDict": # FIXME 28dj10dlk desired
+ def _(self, other: OppableDictParent) -> "OppableDict": # FIXME 28dj10dlk
+ assert isinstance(other, OppableDict) # FIXME 28dj10dlk runtime check instead.
+ return self.op_other(lambda v, o: v - o, other)
+
+ @functools.singledispatchmethod
+ def __add__(self, other):
+ raise NotImplementedError()
+
+ @__add__.register(Union[float, torch.Tensor])
+ def _(self, other: Union[float, torch.Tensor]) -> "OppableDict":
+ return self.op(lambda v: v + other)
+
+ # FIXME 28dj10dlk
+ @__add__.register(OppableDictParent)
+ def _(self, other: OppableDictParent) -> "OppableDict":
+ assert isinstance(other, OppableDict)
+ return self.op_other(lambda v, o: v + o, other)
+
+ @functools.singledispatchmethod
+ def __truediv__(self, other):
+ raise NotImplementedError()
+
+ @__truediv__.register(Union[float, torch.Tensor])
+ def _(self, other: Union[float, torch.Tensor]) -> "OppableDict":
+ return self.op(lambda v: v / other)
+
+ # @__truediv__.register(OppableDict) # FIXME 28dj10dlk
+ @__truediv__.register(OppableDictParent)
+ # def _(self, other: "OppableDict") -> "OppableDict": # FIXME 28dj10dlk desired
+ def _(self, other: OppableDictParent) -> "OppableDict": # FIXME 28dj10dlk
+ assert isinstance(other, OppableDict) # FIXME 28dj10dlk runtime check instead.
+ return self.op_other(lambda v, o: v / o, other)
+
+ @functools.singledispatchmethod
+ def __mul__(self, other):
+ raise NotImplementedError()
+
+ @__mul__.register(Union[float, torch.Tensor])
+ def _(self, other: Union[float, torch.Tensor]) -> "OppableDict":
+ return self.op(lambda v: v * other)
+
+ # FIXME 28dj10dlk desired
+ @__mul__.register(OppableDictParent)
+ def _(self, other: OppableDictParent) -> "OppableDict":
+ assert isinstance(other, OppableDict)
+ return self.op_other(lambda v, o: v * o, other)
+
+ @property
+ def wrapped(self) -> OrderedDict[str, DT]:
+ return self._d
+
+
+class DictOLists(OppableDict[List[torch.Tensor]]):
+ def append(self, value: Dict[str, torch.Tensor]) -> None:
+
+ for k, v in value.items():
+
+ if k not in self:
+ self[k] = []
+
+ self[k].append(v)
+
+
+class MonteCarloExpectation(ExpectationHandler):
+ # Adapted from Rafal's "query library" code.
+
+ def __init__(self, num_samples: int):
+ self.num_samples = num_samples
+
+ super().__init__()
+
+ def _pyro_expectation(self, msg) -> None:
+ super()._pyro_expectation(msg)
+
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ p: ModelType = kwargs["p"]
+ f: ExpectigrandType = kwargs["f"]
+
+ fvs = DictOLists()
+
+ for i in range(self.num_samples):
+ trace = pyro.poutine.trace(p).get_trace()
+ ret = trace.nodes["_RETURN"]["value"] # type: KWType
+ fv = f(ret)
+ fvs.append(fv)
+
+ msg_value = fvs.op(lambda v: torch.sum(torch.tensor(v)) / self.num_samples)
+
+ msg["value"] = msg_value
+ msg["done"] = True
+
+
+class SNISExpectation(ExpectationHandler):
+
+ def __init__(self, q: ModelType, num_samples: int):
+ self.q = q
+ self.num_samples = num_samples
+
+ super().__init__()
+
+ def _pyro_expectation(self, msg) -> None:
+ super()._pyro_expectation(msg)
+
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ q = self.q
+ p = kwargs["p"] # type: ModelType
+ f = kwargs["f"] # type: ExpectigrandType
+
+ fvs = DictOLists()
+ plps = []
+ qlps = []
+
+ for i in range(self.num_samples):
+ # Sample stochastics from the proposal distribution.
+ q_trace = pyro.poutine.trace(q).get_trace()
+ qlp = q_trace.log_prob_sum()
+
+ # Trace the full model with the proposed stochastics.
+ with ReplayMessenger(trace=q_trace):
+ trace = pyro.poutine.trace(lambda: f(p())).get_trace()
+
+ # Record the return value and the log probability with respect to the model.
+ fv = trace.nodes["_RETURN"]["value"] # type: KWType
+ plp = trace.log_prob_sum()
+
+ fvs.append(fv)
+ plps.append(plp)
+ qlps.append(qlp)
+
+ plps = torch.tensor(plps)
+ qlps = torch.tensor(qlps)
+
+ unw = plps - qlps # unnormalized weights
+ w = torch.exp(unw - torch.logsumexp(unw, dim=0)) # normalized weights
+
+ msg_value = fvs.op(lambda v: torch.sum(torch.tensor(v) * w))
+
+ msg["value"] = msg_value
+ msg["done"] = True
+
+
+# noinspection PyUnusedLocal
+@pyro.poutine.runtime.effectful(type="optimize_proposal")
+def optimize_proposal(p: ModelType, f: ExpectigrandType, n_steps=1, lr=0.01,
+ # Note the difference in signature here vs in the optimize_decision. TODO unify.
+ adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = None,
+ callback: Optional[Callable[[], None]] = None):
+ raise NotImplementedError()
+
+
+class LazySVIStuff:
+ elbo: pyro.infer.Trace_ELBO = None
+ optim = None
+
+
+class TABIExpectation(ExpectationHandler):
+ QTRACE_KEY = "TABIExpectation_q_trace_for_log_prob_and_replay"
+ FAC_KEY = "TABIExpectation_expectation_targeting_log_factor"
+
+ def __init__(self, q_plus: ModelType, q_den: ModelType, num_samples: int,
+ grad_clip: Optional[Callable] = None,
+ q_minus: Optional[ModelType] = None):
+ self.q_plus = q_plus
+ self.q_minus = q_minus
+ self.q_den = q_den
+ self.num_samples = num_samples
+ self.grad_clip = grad_clip
+
+ self._lazy_q_plus_svi = LazySVIStuff()
+ self._lazy_q_minus_svi = LazySVIStuff()
+ self._lazy_q_den_svi = LazySVIStuff()
+
+ super().__init__()
+
+ def __enter__(self):
+ self._lazy_q_plus_svi = LazySVIStuff()
+ self._lazy_q_minus_svi = LazySVIStuff()
+ self._lazy_q_den_svi = LazySVIStuff()
+
+ return super().__enter__()
+
+ def _optimize_proposal_part(self, n_steps, elbo, optim, adjust_grads_=None,
+ callback: Optional[Callable[[], None]] = None):
+ for step in range(0, n_steps + 1):
+ for param in elbo.parameters():
+ param.grad = None
+
+ optim.zero_grad()
+ loss = elbo()
+ loss.backward()
+
+ if adjust_grads_ is not None:
+ adjust_grads_(*tuple(elbo.parameters()))
+
+ if callback is not None:
+ callback()
+
+ optim.step()
+
+ def get_part(self, sign: float, f: ExpectigrandType, p: ModelType):
+ def factor_augmented_p():
+ stochastics: KWType = p()
+ fv: KWType = f(stochastics)
+ fv_od = OppableDict(fv)
+
+ # FIXME HACK 1e-6 is a hack to avoid log(0), make passable argument?
+ facval = fv_od.op(lambda v: torch.log(1e-6 + torch.relu(sign * v)))
+
+ for k, v in facval.items():
+ aug_k = f'{self.FAC_KEY}_{sign}_{k}'
+ fac = pyro.factor(aug_k, v)
+
+ assert aug_k not in stochastics
+ stochastics[aug_k] = fac
+
+ return stochastics
+ return factor_augmented_p
+
+ @staticmethod
+ def get_svi_stuff(p, q, svi_stuff, lr):
+ if svi_stuff.elbo is None:
+ svi_stuff.elbo = pyro.infer.Trace_ELBO()(p, q)
+ svi_stuff.elbo()
+ if svi_stuff.optim is None:
+ svi_stuff.optim = torch.optim.ASGD(svi_stuff.elbo.parameters(), lr=lr)
+ return svi_stuff
+
+ def _pyro_optimize_proposal(self, msg) -> None:
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ p: ModelType = kwargs["p"]
+ f: ExpectigrandType = kwargs["f"]
+ n_steps: int = kwargs["n_steps"]
+ lr: float = kwargs["lr"]
+ adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = kwargs["adjust_grads_"]
+ callback: Optional[Callable[[], None]] = kwargs["callback"]
+
+ self.get_svi_stuff(self.get_part(1., f, p), self.q_plus, self._lazy_q_plus_svi, lr)
+ self._optimize_proposal_part(n_steps, self._lazy_q_plus_svi.elbo, self._lazy_q_plus_svi.optim,
+ adjust_grads_, callback=callback)
+
+ self.get_svi_stuff(p, self.q_den, self._lazy_q_den_svi, lr)
+ self._optimize_proposal_part(n_steps, self._lazy_q_den_svi.elbo, self._lazy_q_den_svi.optim,
+ adjust_grads_, callback=callback)
+
+ if self.q_minus is not None:
+ self.get_svi_stuff(self.get_part(-1., f, p), self.q_minus, self._lazy_q_minus_svi, lr)
+ self._optimize_proposal_part(n_steps, self._lazy_q_minus_svi.elbo, self._lazy_q_minus_svi.optim,
+ adjust_grads_, callback=callback)
+
+ msg["value"] = None
+ msg["done"] = True
+
+ def _pyro_expectation(self, msg) -> None:
+ super()._pyro_expectation(msg)
+
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ q_plus = self.q_plus
+ q_minus = self.q_minus
+ q_den = self.q_den
+ p = kwargs["p"] # type: ModelType
+ f = kwargs["f"] # type: ExpectigrandType
+
+ # TODO All this funkiness is to reuse logic between the component proposals, but its super opaque
+ # and annoying. It's also set up to get the actual proposal trace object to where it needs to go
+ # so that importance weights can be computed inside the expectigrand. Would like to clean up/simplify.
+ # This would allow us to get rid of KWTypeWithTrace also. Also, now that I've added the typing to sort
+ # out the redirection mess, it's like equally as verbose as not sharing code.
+
+ def expectigrand(lf_: Optional[OppableDict[torch.Tensor]], q_trace: TraceType) -> KWType:
+ with ReplayMessenger(trace=q_trace):
+ unnorm_log_p = pyro.poutine.trace(p).get_trace().log_prob_sum()
+
+ unnorm_log_q = q_trace.log_prob_sum()
+
+ ret: KWType
+ if lf_ is not None:
+ ret = lf_.op(lambda v: torch.exp(v + unnorm_log_p - unnorm_log_q)).wrapped
+ else:
+ ret = OrderedDict(expected_importance_weight=torch.exp(unnorm_log_p - unnorm_log_q))
+ return ret
+
+ def get_signed_lf(kwstochastics: KWType, s: float) -> OppableDict[torch.Tensor]:
+ fv = OppableDict(f(kwstochastics))
+ return fv.op(lambda v: torch.log(torch.relu(s * v)))
+
+ def expectigrand_plus(kwstochastics: KWTypeWithTrace) -> KWType:
+ q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY)
+ kwstochastics: KWType
+ return expectigrand(lf_=get_signed_lf(kwstochastics, 1.), q_trace=q_trace)
+
+ def expectigrand_minus(kwstochastics: KWTypeWithTrace) -> KWType:
+ q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY)
+ kwstochastics: KWType
+ return expectigrand(lf_=get_signed_lf(kwstochastics, -1.), q_trace=q_trace)
+
+ def expectigrand_den(kwstochastics: KWTypeWithTrace) -> KWType:
+ q_trace: TraceType = kwstochastics.pop(self.QTRACE_KEY)
+ kwstochastics: KWType
+ return expectigrand(lf_=None, q_trace=q_trace)
+
+ def get_get_qkwstochastics(q: ModelType) -> ModelType:
+ def get_qkwstochastics() -> KWTypeWithTrace:
+ qtr: TraceType = pyro.poutine.trace(q).get_trace()
+ qkwstochastics = kft(qtr)
+ # Add the log prob of the proposal for use down the line.
+ assert self.QTRACE_KEY not in qkwstochastics
+ qkwstochastics: KWTypeWithTrace
+ qkwstochastics[self.QTRACE_KEY] = qtr
+ return qkwstochastics
+
+ return get_qkwstochastics
+
+ with pyro.poutine.messenger.block_messengers(lambda m: m is self):
+
+ with MonteCarloExpectation(self.num_samples):
+ e_plus = expectation(
+ f=expectigrand_plus,
+ p=get_get_qkwstochastics(q_plus)
+ )
+
+ if q_minus is not None:
+ with MonteCarloExpectation(self.num_samples):
+ e_minus = expectation(
+ f=expectigrand_minus,
+ p=get_get_qkwstochastics(q_minus)
+ )
+ else:
+ e_minus = 0.0
+
+ with MonteCarloExpectation(self.num_samples):
+ e_den: torch.Tensor = expectation(
+ f=expectigrand_den,
+ p=get_get_qkwstochastics(q_den)
+ )["expected_importance_weight"]
+
+ # Set this up to "broadcast" between the named tensors of the return dictionaries.
+ e_plus_od = OppableDict(e_plus)
+ e_minus_od: Union[OppableDict[torch.Tensor], float] = \
+ OppableDict(e_minus) if isinstance(e_minus, OrderedDict) else e_minus
+
+ msg["value"] = ((e_plus_od - e_minus_od) / e_den).wrapped
+ msg["done"] = True
+
+
+class ConstraintHandler(pyro.poutine.messenger.Messenger):
+ def __init__(self, g: DecisionExpectigrandType, tau: float, threshold: float):
+ """
+ :param g: The function of decision parameters and stochastics that feeds into the constraint.
+ :param tau: The scaling factor used to add the constraint gradient to the decision parameter gradient. This
+ is required to convert constrained problems into unconstrained problems.
+ :param threshold: The threshold value for the constraint.
+ """
+
+ self.g = g
+ self.tau = tt(tau)
+ self.threshold = tt(threshold)
+
+# TODO Proposal Optimization for Constraints
+# So I think these constraints need to have expectation handlers passed to them. They can add them to the stack
+# in the enter/exit, and then use the OOP strategy to call a specific optimize_proposal? I guess the constraints
+# should be able to preempt the optimize_proposal call and do their own optimization — they have everything they need
+# cz they have the constraint expectigrand passed directly to them. So no oop for optimize_proposal, the constraints
+# just preempt the call, run optimize_proposal in the context of whatever internal constraint expectation estimators
+# they happen to be using — they just need to block other expectation handlers during that execution.
+# Remember that the DecisionOptimizationHandler automatically converts any optimize_proposal calls to operate over the
+# gradient, so proposal optimization here just needs to use the raw, non-differentiated function.
+
+
+class MeanConstraintHandler(ConstraintHandler):
+
+ def __init__(self, *args, **kwargs):
+
+ super().__init__(*args, **kwargs)
+
+ def _pyro_post_build_expectigrand_gradient(self, msg) -> None:
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ dparams: KWTypeNNParams = kwargs["dparams"]
+ df_dd: ExpectigrandType = msg["value"]
+
+ # Block all constraint handlers, because we just want the raw gradients here so we can add them.
+ with pyro.poutine.messenger.block_messengers(lambda m: isinstance(m, ConstraintHandler)):
+ dg_dd: ExpectigrandType = build_expectigrand_gradient(dparams, self.g)
+
+ def rewrapped_df_dd(stochastics: KWType) -> KWType:
+
+ # FIXME so the trick here is that these same stochastics aren't supposed to be sent into all
+ # three of these functions (df_dd, dg_dd, and g). Instead they're supposed to be three separate
+ # expectation operations. So what we need here is to actually make grad_expectation an
+ # operation, and then in post we need to call self.grad_expectation_handler.expectation and
+ # self.expectation_handler.expectation, and add everything up as needed.
+
+ df_dd_ret = OppableDict(df_dd(stochastics))
+
+ cret = self.g(dparams, stochastics)
+ if len(cret) > 1:
+ raise ValueError(f"{self.__class__.__name__} only supports scalar constraints,"
+ f" but got return values named {tuple(cret.keys())}")
+
+ cval = tuple(cret.values())[0]
+ constraint_violated = cval > self.threshold
+
+ if constraint_violated:
+ print("Constraint Violated!")
+ dg_dd_ret = OppableDict(dg_dd(stochastics))
+ return (df_dd_ret + dg_dd_ret * self.tau).wrapped
+ else:
+ return df_dd_ret.wrapped
+
+ msg["value"] = rewrapped_df_dd
+
+
+class DecisionOptimizerHandler(pyro.poutine.messenger.Messenger):
+
+ def __init__(self, dparams: KWTypeNNParams, lr: float,
+ proposal_update_steps: int = 0, proposal_update_lr: float = None,
+ proposal_adjust_grads_: Optional[Callable[[torch.nn.Parameter, ...], None]] = None):
+ self.dparams = dparams
+ self.lr = lr
+ self.proposal_update_steps = proposal_update_steps
+ self.proposal_update_lr = proposal_update_lr
+ self.proposal_adjust_grads_ = proposal_adjust_grads_
+
+ if self.proposal_update_steps > 0:
+ assert self.proposal_update_lr is not None, "Learning rate for proposal update must be specified if " \
+ "proposal update steps is greater than 0."
+
+ def _swap_f_for_df(self, msg) -> None:
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ f: DecisionExpectigrandType = kwargs["f"]
+ fp: ExpectigrandType = build_expectigrand_gradient(self.dparams, f)
+ kwargs["f"] = fp
+ msg["kwargs"] = kwargs
+
+ def _pyro_optimize_proposal(self, msg) -> None:
+ # This handler replaces the function of interest with the gradient of that function with respect to the
+ # decision parameters. This supports any proposal optimization by giving it the correct target function.
+ self._swap_f_for_df(msg)
+
+ def _swap_f_for_f_of_d(self, msg) -> None:
+ # This converts the decision function to a function just of stochastics so it can be used with
+ # standard expectation operations.
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ f: DecisionExpectigrandType = kwargs["f"]
+
+ # Only actually swap this out if it's required. This allows users to pass in the results of
+ # build_expectigrand_gradient, which already returns a non-decision expectigrand, without
+ # having to block this handler's effect.
+ if len(inspect.signature(f).parameters) == 1:
+ return
+
+ fp: ExpectigrandType = lambda stochastics: f(self.dparams, stochastics)
+ kwargs["f"] = fp
+ msg["kwargs"] = kwargs
+
+ def _pyro_expectation(self, msg) -> None:
+ # Expectation calls within this handler need to include the decision parameters as arguments,
+ # so this converts things to the default case not involving decision parameters.
+ self._swap_f_for_f_of_d(msg)
+
+ def _pyro_optimize_decision(self, msg) -> None:
+ kwargs = msg_args_kwargs_to_kwargs(msg)
+ f: DecisionExpectigrandType = kwargs["f"]
+ p: ModelType = kwargs["p"]
+ adjust_grads: Optional[Callable[[KWType], KWType]] = kwargs["adjust_grads"]
+ terminal_condition: Callable[[KWTypeNNParams, int], bool] = kwargs["terminal_condition"]
+ callback: Optional[Callable[[], None]] = kwargs["callback"]
+
+ optim = torch.optim.SGD(list(self.dparams.values()), lr=self.lr)
+
+ i = 0
+ while not terminal_condition(self.dparams, i):
+
+ optim.zero_grad()
+ # Not sure if this necessary.
+ for d in self.dparams.values():
+ d.grad = None
+
+ # Block self here because the gradient function already handles the partial evaluation. I.e.
+ # we don't want _swap_f_for_f_of_d called, because build_expectation_gradient already returns a valid
+ # ExpectigrandType.
+ with pyro.poutine.messenger.block_messengers(lambda m: m is self):
+ grad_estimate = expectation(f=build_expectigrand_gradient(self.dparams, f), p=p)
+
+ res = OppableDict(grad_estimate).op(lambda v: torch.isfinite(v))
+ if not torch.tensor(tuple(res.wrapped.values())).all():
+ print("Warning: Non-finite gradient estimate encountered. Skipping update.") # TODO warnings module
+ continue
+
+ if adjust_grads is not None:
+ grad_estimate: KWType = adjust_grads(grad_estimate)
+
+ # Sanity check.
+ assert tuple(grad_estimate.keys()) == tuple(_gradify_dparam_name(k) for k in self.dparams.keys())
+
+ # Assign gradients to parameters.
+ for k, dp in self.dparams.items():
+ dp.grad = grad_estimate[_gradify_dparam_name(k)]
+
+ # And update those parameters with the specified optimzier.
+ optim.step()
+
+ if self.proposal_update_steps > 0:
+ optimize_proposal(p=p, f=f, n_steps=self.proposal_update_steps, lr=self.proposal_update_lr,
+ adjust_grads_=self.proposal_adjust_grads_)
+
+ if callback is not None:
+ callback()
+
+ i += 1
+
+ # msg["value"] = tuple(d.item() for d in self.dparams)
+ msg["value"] = tuple(v.item() for v in self.dparams.values())
+ msg["done"] = True
+
+
+class MultiModalGuide1D(pyro.nn.PyroModule):
+
+ @pyro.nn.PyroParam(constraint=dist.constraints.simplex)
+ def pi(self):
+ return torch.ones(self.num_components) / self.num_components
+
+ @pyro.nn.PyroParam(constraint=dist.constraints.positive)
+ def scale(self):
+ return self.init_scale
+
+ @pyro.nn.PyroParam(constraint=dist.constraints.real)
+ def loc(self):
+ return self.init_loc
+
+ def __init__(self, *args, num_components: int, init_loc, init_scale, studentt=False, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.num_components = num_components
+ self.init_loc = torch.tensor(init_loc)
+ self.init_scale = torch.tensor(init_scale)
+ self.studentt = studentt
+
+ if self.init_loc.shape != (self.num_components,):
+ raise ValueError("init_loc must be a tensor of shape (num_components,)")
+ if self.init_scale.shape != (self.num_components,):
+ raise ValueError("init_scale must be a tensor of shape (num_components,)")
+
+ def forward(self):
+ component_idx = pyro.sample("component_idx", dist.Categorical(self.pi), infer={'is_auxiliary': True})
+ scale = torch.clip(self.scale[component_idx], torch.tensor(0.05), torch.tensor(1.5))
+
+ if self.studentt:
+ x = pyro.sample("x", dist.StudentT(10, self.loc[component_idx], scale))
+ else:
+ x = pyro.sample("x", dist.Normal(self.loc[component_idx], scale))
+ return OrderedDict(x=x)
+
+ def __call__(self, *args, **kwargs) -> KWType:
+ return super().__call__(*args, **kwargs)
+
+
+def clip_decision_grads(grads: KWType):
+ # These blow up sometimes, so just clip them (they are 1d so no need to worry about norms).
+ ret_grads = OrderedDict()
+ for k in grads:
+ ret_grads[k] = torch.clip(grads[k], -2e-1, 2e-1)
+ return ret_grads
+
+
+def abort_guide_grads_(*parameters: torch.nn.Parameter, lim=50.):
+ # These gradients also blow up, but clipping them causes weird non-convergence. Just aborting
+ # the gradient update seems to work.
+ if torch.any(torch.tensor([torch.any(param.grad > lim) for param in parameters])):
+ for param in parameters:
+ param.grad = torch.zeros_like(param.grad)
+
+
+def main():
+
+ # See https://www.desmos.com/calculator/ixtzpb4l75 for analytical computation of the expectation.
+ # This all amounts to a functional on C and D, which can be set in the desmos graph.
+ C = tt(1.)
+ D = OrderedDict(d=torch.nn.Parameter(tt(0.5)))
+ NSnaive = 100000
+ GT = -1.1337
+ print(f"Ground Truth: {GT}")
+
+ # Freeze this decision parameter, as we're using it as a non-optimizeable constant.
+ for d in D.values():
+ d.requires_grad = False
+
+ from toy_tabi_problem import (
+ model as model_ttp,
+ cost as _cost_ttp,
+ q_optimal_normal_guide_mean_var,
+ MODEL_DIST as MODEL_TTP_DIST
+ )
+
+ # A scaled up cost function, just because the original one is small.
+ def dparam_scost_ttp(dparams: KWTypeNNParams, stochastics: KWType, c: torch.Tensor) -> KWType:
+ return OrderedDict(cost=1e1 * _cost_ttp(**dparams, **stochastics, c=c))
+
+ def get_scost_ttp(dparams: KWTypeNNParams, c: torch.Tensor) -> Callable[[KWType], KWType]:
+ return lambda stochastics: dparam_scost_ttp(dparams, stochastics, c=c)
+
+ def get_dparam_scost_ttp(c: torch.Tensor) -> Callable[[KWTypeNNParams, KWType], KWType]:
+ return lambda dparams, stochastics: dparam_scost_ttp(dparams, stochastics, c=c)
+
+ with MonteCarloExpectation(num_samples=NSnaive):
+ print(f"MCE TABI Toy (N={NSnaive})",
+ expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"])
+
+ def subopt_guide() -> KWType:
+ return OrderedDict(x=pyro.sample('x', dist.Normal(0.5, 2.)))
+
+ with SNISExpectation(
+ q=subopt_guide,
+ num_samples=50000
+ ):
+ print(f"SNIS SubOptGuide TABI Toy (N={NSnaive})",
+ expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"])
+
+ # Do SNIS again but with an optimized proposal. We use studentt here because SNIS requires the division of the
+ # original probability by the proposal probability, so if something is sampled in the tails of the proposal
+ # but not in the tails of the original, this will blow up and cause problems. It does this before it ever
+ # sees the cost function in order self-normalize the weights.
+ # noinspection PyUnresolvedReferences
+ snis_opt_guidep = dist.StudentT(1, *q_optimal_normal_guide_mean_var(**D, c=C, z=False))
+ # noinspection PyUnresolvedReferences
+ snis_opt_guiden = dist.StudentT(1, *q_optimal_normal_guide_mean_var(**D, c=C, z=True))
+
+ def opt_guide_mix() -> KWType:
+ # noinspection PyUnresolvedReferences
+ if pyro.sample('z', dist.Bernoulli(probs=torch.tensor([0.5]))):
+ x = pyro.sample('x', snis_opt_guidep)
+ else:
+ x = pyro.sample('x', snis_opt_guiden)
+
+ return OrderedDict(x=x)
+
+ NS_snis_opt = 10000
+
+ with SNISExpectation(
+ # The bi-modal optimal proposal covers the product of the model and the absval of the cost function.
+ q=opt_guide_mix,
+ num_samples=NS_snis_opt
+ ):
+ print(f"SNIS OptGuide TABI Toy (N={NS_snis_opt})",
+ expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"])
+
+ tabi_opt_guidep = dist.Normal(*q_optimal_normal_guide_mean_var(**D, c=C, z=False))
+ tabi_opt_guiden = dist.Normal(*q_optimal_normal_guide_mean_var(**D, c=C, z=True))
+
+ with TABIExpectation(
+ q_plus=lambda: OrderedDict(x=pyro.sample('x', tabi_opt_guidep)),
+ q_minus=lambda: OrderedDict(x=pyro.sample('x', tabi_opt_guiden)),
+ q_den=lambda: OrderedDict(x=pyro.sample('x', MODEL_TTP_DIST)),
+ num_samples=1
+ ):
+ print(f"TABI Toy Exact (N=1)", expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"])
+
+ #
+
+ # noinspection DuplicatedCode
+ def pos_comp():
+ xp = model_ttp()
+ pos_fac = pyro.factor(
+ 'pos_fac', torch.log(1e-6 + torch.relu(dparam_scost_ttp(dparams=D, stochastics=xp, c=C)["cost"])))
+
+ return OrderedDict(x=xp, pos_fac=pos_fac)
+
+ # noinspection DuplicatedCode
+ def neg_comp():
+ xn = model_ttp()
+ neg_fac = pyro.factor(
+ 'neg_fac', torch.log(1e-6 + torch.relu(-dparam_scost_ttp(dparams=D, stochastics=xn, c=C)["cost"])))
+
+ return OrderedDict(x=xn, neg_fac=neg_fac)
+
+ N_STEPS = 10000
+ LR = 1e-3
+ NS = 100
+
+ # noinspection DuplicatedCode
+ def run_svi_inference(model_, guide, n_steps=100, verbose=True):
+ elbo = pyro.infer.Trace_ELBO()(model_, guide)
+ elbo()
+ optim = torch.optim.SGD(elbo.parameters(), lr=LR)
+ for step in range(0, n_steps):
+ optim.zero_grad()
+ loss = elbo()
+ loss.backward()
+ optim.step()
+ if (step % 100 == 0) & verbose:
+ print("[iteration %04d] loss: %.4f" % (step, loss))
+ return guide
+
+ def get_num_guide_init(d, c, z: bool):
+ iloc, istd = q_optimal_normal_guide_mean_var(d=d, c=c, z=z)
+ return pyro.infer.autoguide.AutoNormal(
+ model=pos_comp,
+ init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(values={'x': iloc}),
+ init_scale=istd.item() * 3., # scale up the std to give SVI something to do.
+ ), iloc, istd
+ pos_guide, pos_iloc, pos_istd = get_num_guide_init(**D, c=C, z=False)
+ neg_guide, neg_iloc, neg_istd = get_num_guide_init(**D, c=C, z=True)
+
+ def get_den_guide_init():
+ return pyro.infer.autoguide.AutoNormal(
+ model=model_ttp,
+ init_loc_fn=pyro.infer.autoguide.initialization.init_to_value(
+ # Very important to get a COPY of this tensor and not pass the model parameter itself. Otherwise
+ # when the guide updates the model will also change, which naturally leads to insanity.
+ values={'x': torch.tensor(MODEL_TTP_DIST.loc.item())}),
+ init_scale=MODEL_TTP_DIST.scale.item() * 3., # scale up the std to give SVI something to do.
+ )
+ den_guide = get_den_guide_init()
+
+ def plot_tabi_guides(tabi_handler: TABIExpectation, og_pos, og_neg, og_den):
+ plt.figure()
+ xx_ = torch.linspace(-10, 10, 1000)
+ sns.kdeplot([tabi_handler.q_plus.forward()['x'].item() for _ in range(10000)],
+ label='pos', linestyle='--', color='red')
+ sns.kdeplot([tabi_handler.q_minus.forward()['x'].item() for _ in range(10000)],
+ label='neg', linestyle='--', color='blue')
+ sns.kdeplot([tabi_handler.q_den.forward()['x'].item() for _ in range(10000)],
+ label='den', linestyle='--', color='green')
+
+ plt.plot(xx_, og_den.log_prob(xx_).exp(), color='green', alpha=0.5)
+ plt.plot(xx_, og_pos.log_prob(xx_).exp(), color='red', alpha=0.5)
+ plt.plot(xx_, og_neg.log_prob(xx_).exp(), color='blue', alpha=0.5)
+
+ plt.show()
+
+ # # <-----Manual execution>
+ #
+ # run_svi_inference(pos_comp, pos_guide, n_steps=N_STEPS, verbose=False)
+ # run_svi_inference(neg_comp, neg_guide, n_steps=N_STEPS, verbose=False)
+ # run_svi_inference(model_ttp, den_guide, n_steps=N_STEPS, verbose=False)
+ # with TABIExpectation(
+ # q_plus=pos_guide,
+ # q_minus=neg_guide,
+ # q_den=den_guide,
+ # num_samples=NS
+ # ) as te:
+ # print(f"TABI Toy Learned Guide (Manual) (N={NS})",
+ # expectation(f=get_scost_ttp(D, c=C), p=model_ttp))
+ #
+ # plot_tabi_guides(te, dist.Normal(pos_iloc, pos_istd), dist.Normal(neg_iloc, neg_istd), MODEL_TTP_DIST)
+ #
+ # # -----Manual execution>
+
+ #
+
+ #
+
+ with TABIExpectation(
+ q_plus=get_num_guide_init(**D, c=C, z=False)[0],
+ q_minus=get_num_guide_init(**D, c=C, z=True)[0],
+ q_den=get_den_guide_init(),
+ num_samples=NS
+ ) as te:
+ optimize_proposal(p=model_ttp, f=get_scost_ttp(dparams=D, c=C), n_steps=N_STEPS, lr=LR)
+
+ tabi_ress = []
+ for _ in range(100):
+ tabi_ress.append(expectation(f=get_scost_ttp(dparams=D, c=C), p=model_ttp)["cost"])
+
+ plt.figure()
+ plt.suptitle(f"TABI Toy Learned Guide (Integrated) (N={NS})")
+ plt.hist(tabi_ress, bins=20)
+ plt.axvline(x=GT, color='black', linestyle='--')
+
+ plot_tabi_guides(te, dist.Normal(pos_iloc, pos_istd), dist.Normal(neg_iloc, neg_istd), MODEL_TTP_DIST)
+
+ #
+
+ #
+
+ # Wrap in a tensor so we can optimize it.
+ dparams = OrderedDict(d=torch.nn.Parameter(torch.tensor(-1.)))
+ cval = tt(2.) # A value of 2 makes this a bit more difficult than the above.
+
+ dprogression = [dparams['d'].item()]
+ dgrads = []
+
+ # The guides for the gradients have to be multi-modal to track with the multi-modal cost function.
+ # While the positive and negative components of the non-differential are themselves unimodal, when working
+ # directly with the gradients, each pos/neg component has a positive and negative component themselves.
+ # This means each positive/negative guide component has to be multimodal.
+
+ # Plot the d/dd expectigrand as a function of x.
+ plt.figure()
+ xx = torch.linspace(-5., 5., 100)
+ ddf = build_expectigrand_gradient(dparams, get_dparam_scost_ttp(c=cval))
+ plt.plot(xx, [
+ (OppableDict(ddf(OrderedDict(x=x)))*MODEL_TTP_DIST.log_prob(x).exp())[_gradify_dparam_name("d")].detach().item()
+ for x in xx])
+ plt.show()
+
+ # We initialize guides with components at both the positive and negative components of the cost function. They
+ # will then adjust to capture the two components of each side of the gradient.
+ pos_iloc_grad, pos_istd_grad = q_optimal_normal_guide_mean_var(**D, c=cval, z=False)
+ neg_iloc_grad, neg_istd_grad = q_optimal_normal_guide_mean_var(**D, c=cval, z=True)
+ pos_guide_grad = MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item(), neg_istd_grad.item()]
+ )
+ neg_guide_grad = MultiModalGuide1D(
+ num_components=2,
+ init_loc=[pos_iloc_grad.item(), neg_iloc_grad.item()],
+ init_scale=[pos_istd_grad.item(), neg_istd_grad.item()]
+ )
+ # The denominator doesn't involve the cost function, so it stays the same.
+
+ def plotss_(title, te):
+ # Plot the d/dd expectigrand as a function of x.
+ plt.figure()
+ plt.suptitle(title)
+
+ def plot_part_(sign, color):
+ xy = []
+ with pyro.poutine.trace() as og_tr:
+ te.get_part(sign, f=ddf, p=model_ttp)()
+ for x in torch.linspace(-4., 4., 1000):
+ og_tr.get_trace().nodes['x']['value'] = x
+ with ReplayMessenger(og_tr.trace):
+ with pyro.poutine.trace() as tr:
+ te.get_part(sign, f=ddf, p=model_ttp)()
+ xy.append((
+ tr.get_trace().nodes['x']['value'].item(),
+ tr.get_trace().log_prob_sum().exp().item()
+ ))
+ xy = np.array(xy)
+ plt.plot(*xy.T, color=color, linestyle='--')
+
+ plot_part_(1., 'blue')
+ plot_part_(-1., 'red')
+
+ # Plot an sns density plot of the guides. Use the same figure.
+ sns.kdeplot([te.q_plus()['x'].item() for _ in range(1000)], label="q_plus", bw_method=0.05, color='blue')
+ sns.kdeplot([te.q_minus()['x'].item() for _ in range(1000)], label="q_minus", bw_method=0.05, color='red')
+ sns.kdeplot([te.q_den()['x'].item() for _ in range(1000)], label="q_den", bw_method=0.25, color='green')
+
+ # Also plot the true model density.
+ sns.kdeplot([model_ttp()['x'].item() for _ in range(1000)],
+ label="model", bw_method=0.25, color='green', linestyle='--')
+
+ plt.legend()
+ plt.xlim(-5., 5.)
+ plt.show()
+ plt.close()
+
+ # # <----Manual Execution>
+ # with TABIExpectation(
+ # q_plus=pos_guide_grad,
+ # q_minus=neg_guide_grad,
+ # q_den=get_den_guide_init(),
+ # num_samples=1
+ # ) as te_:
+ #
+ # plotss_("Before", te_)
+ #
+ # grad_estimate = expectation(f=ddf, p=model_ttp)
+ # print(f"TABI Toy Learned Grads (Truth: -0.0575) (Before): {grad_estimate}")
+ #
+ # for _ in range(5):
+ # optimize_proposal(p=model_ttp, f=ddf, n_steps=3000, lr=1e-4, adjust_grads_=abort_guide_grads_)
+ #
+ # plotss_("After", te_)
+ #
+ # grad_estimate = expectation(f=ddf, p=model_ttp)
+ # print(f"TABI Toy Learned Grads (Truth: -0.0575) (After): {grad_estimate}")
+ #
+ # # Iteratively optimize the decision variable and the proposals.
+ # optim_ = torch.optim.SGD(tuple(dparams.values()), lr=1e-1)
+ #
+ # ii = 1
+ # cii = 0
+ # while cii < 30:
+ #
+ # optim_.zero_grad()
+ # for dp in dparams.values():
+ # dp.grad = None
+ #
+ # # Move the decision variable.
+ # grad_estimate = expectation(f=ddf, p=model_ttp)
+ # print("Gradient Estimate", grad_estimate["d"].item())
+ #
+ # grad_estimate = clip_decision_grads(grad_estimate)
+ #
+ # for k, dp in dparams.items():
+ #
+ # dp.grad = grad_estimate[k]
+ #
+ # optim_.step()
+ #
+ # # Then update the proposals.
+ # optimize_proposal(p=model_ttp, f=ddf, n_steps=300, lr=3e-5, adjust_grads_=abort_guide_grads_)
+ #
+ # if ii % 30 == 0:
+ # plotss_(f"Decision {round(dparams['d'].item(), 3)}", te_)
+ #
+ # # And append the state of dten.
+ # dprogression.append(dparams["d"].item())
+ # dgrads.append(grad_estimate["d"].item())
+ #
+ # ii += 1
+ #
+ # if abs(dparams["d"].item()) < 1e-3:
+ # cii += 1
+ # else:
+ # cii = 0
+ #
+ # plt.figure()
+ # plt.plot(dprogression)
+ # plt.title("Decision Progression")
+ #
+ # plt.figure()
+ # plt.plot(dgrads)
+ # plt.title("Decision Gradients")
+ #
+ # plt.show()
+ # # ----Manual Execution>
+
+ #
+
+ #
+
+ te_ = TABIExpectation(
+ q_plus=pos_guide_grad,
+ q_minus=neg_guide_grad,
+ q_den=get_den_guide_init(),
+ num_samples=1)
+ dh_ = DecisionOptimizerHandler(dparams=dparams, lr=1e-1, proposal_update_lr=3e-5, proposal_update_steps=300,
+ proposal_adjust_grads_=abort_guide_grads_)
+
+ with te_, dh_:
+
+ plotss_("Before", te_)
+
+ grad_estimate = expectation(f=ddf, p=model_ttp)
+ print(f"TABI Toy Learned Grads (Before): {grad_estimate}")
+
+ # Perform initial optimization of proposals.
+ for _ in range(2):
+ # In the dh_ handler, the function here will be converted to the gradient with respect to d.
+ # TODO I don't know if I like this being implicit...?
+ optimize_proposal(p=model_ttp, f=get_dparam_scost_ttp(c=cval), n_steps=3000, lr=1e-4,
+ adjust_grads_=abort_guide_grads_)
+
+ plotss_("After", te_)
+
+ grad_estimate = expectation(f=ddf, p=model_ttp)
+ print(f"TABI Toy Learned Grads (After): {grad_estimate}")
+
+ def terminal_condition_(dparams_, i):
+
+ if i % 10 == 0:
+ plotss_(f"Decision {round(dparams_['d'].item(), 3)}", te_)
+
+ ret = abs(dparams_['d'].item()) < 1e-2
+
+ if ret:
+ plotss_(f"Decision {round(dparams_['d'].item(), 3)}", te_)
+
+ return ret
+
+ optimize_decision(f=get_dparam_scost_ttp(c=cval), p=model_ttp,
+ terminal_condition=terminal_condition_,
+ adjust_grads=clip_decision_grads)
+
+ #
+
+ exit()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/expectation_programming/old_ep_demo_scratch_sri.py b/docs/source/expectation_programming/old_ep_demo_scratch_sri.py
new file mode 100644
index 00000000..5b648a45
--- /dev/null
+++ b/docs/source/expectation_programming/old_ep_demo_scratch_sri.py
@@ -0,0 +1,621 @@
+from sri_tabi_problem import ReuseableSimulation, LOCKDOWN_TYPE, LockdownType
+from torch import tensor as tt
+
+import pyro
+import torch
+from pyro.infer.autoguide import AutoMultivariateNormal
+import pyro.distributions as dist
+import matplotlib.pyplot as plt
+
+from enum import Enum
+
+from chirho.dynamical.ops import State, Trajectory
+
+from collections import OrderedDict
+
+from typing import (
+ Optional,
+ Union,
+)
+
+import numpy as np
+
+import old_ep_demo_scratch as stor
+
+from sklearn.neighbors import KernelDensity
+
+PT = torch.nn.Parameter
+TT = torch.Tensor
+UPTT = Union[PT, TT]
+
+if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE:
+ DEFAULT_DPARAMS = DDP = OrderedDict(
+ # # Optimized.
+ # lockdown_trigger=tt(0.176),
+ # lockdown_lift_trigger=tt(0.457),
+ # lockdown_strength=tt(0.656)
+
+ # lockdown_trigger=tt(0.23),
+ # lockdown_lift_trigger=tt(0.45),
+ # lockdown_strength=tt(0.69)
+
+ lockdown_trigger=torch.nn.Parameter(tt(0.1)),
+ lockdown_lift_trigger=torch.nn.Parameter(tt(0.32)),
+ lockdown_strength=torch.nn.Parameter(tt(0.7))
+ )
+elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU:
+ DEFAULT_DPARAMS = DDP = OrderedDict(
+ # Optimal-ish for continouous plateau setup.
+ lockdown_trigger=tt(0.03),
+ lockdown_lift_trigger=tt(0.49),
+ lockdown_strength=tt(0.61)
+
+ # # Sub-optimal but decent init for continuous plateau setup.
+ # lockdown_trigger=tt(0.08),
+ # lockdown_lift_trigger=tt(0.6),
+ # lockdown_strength=tt(0.5)
+ )
+elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME:
+ DEFAULT_DPARAMS = DDP = OrderedDict(
+ lockdown_trigger=tt(0.8),
+ lockdown_lift_trigger=tt(8.0),
+ lockdown_strength=tt(0.5)
+ )
+
+DEFAULT_INIT_STATE = DIS = State(S=tt(0.99), I=tt(0.01), R=tt(0.0), L=tt(0.0), l=tt(0.0), O=tt(0.0))
+
+DEFAULT_STOCHASTICS = DST = OrderedDict(
+ beta=tt(2.),
+ gamma=tt(.4),
+ capacity=tt(0.01),
+ hospitalization_rate=tt(0.05)
+)
+
+DEFAULT_TIMES = DT = torch.linspace(0., 20., 100)
+
+if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE:
+ OEXPO = 1.
+ OSCALING = 2e2
+elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME:
+ OEXPO = 1.
+ OSCALING = 2e2
+elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU:
+ OSCALING = 1.3e2
+ OEXPO = 1.5
+
+
+def o_transform(o: torch.Tensor) -> torch.Tensor:
+ return ((1. + o) ** OEXPO - 1.) * OSCALING
+
+
+def copy_odict(odict: OrderedDict[str, torch.Tensor]) -> OrderedDict[str, torch.Tensor]:
+ return OrderedDict((k, tt(v.item())) for k, v in odict.items())
+
+
+def f_traj(
+ dparams: stor.KWTypeNNParams,
+ stochastics: stor.KWType,
+ rs: Optional[ReuseableSimulation] = None) -> Trajectory[torch.Tensor]:
+
+ # If dparams or stochastics are missing keys defined in the defaults, fill them in with copies of the default
+ # values. Also wrap lockdown triggers in their appropriate state objects.
+ dparams = copy_odict(DEFAULT_DPARAMS) | dparams
+ stochastics = copy_odict(DEFAULT_STOCHASTICS) | stochastics
+
+ if rs is None:
+ rs = ReuseableSimulation()
+
+ for k, v in stochastics.items():
+ # All stochastics need to be positive, so just do it here and not worry about it in the guide.
+ stochastics[k] = torch.abs(v)
+
+ for k, v in dparams.items():
+ # All dparams need to be positive and non-zero, so put them through a relu with eps added.
+ dparams[k] = torch.relu(v) + 1e-3
+
+ traj = rs(
+ **dparams,
+ init_state=DEFAULT_INIT_STATE,
+ **stochastics,
+ times=DEFAULT_TIMES)
+
+ return traj
+
+
+def f_combined(dparams: stor.KWTypeNNParams, stochastics: stor.KWType) -> stor.KWType:
+ traj = f_traj(dparams, stochastics)
+
+ total_hospital_overrun = traj[-1].O
+ total_lockdown_unpleasantness = traj[-1].L
+
+ # For now, just combine these into a single cost. In future we want to minimize lockdown with constraints
+ # on total hospital overrun.
+ return OrderedDict(cost=o_transform(total_hospital_overrun) + total_lockdown_unpleasantness)
+
+
+def f_o_only(
+ dparams: stor.KWTypeNNParams,
+ stochastics: stor.KWType,
+ # TODO lo1dop6k This avoids redundancy in the cost and constraint calls, but isn't used right now.
+ rs: Optional[ReuseableSimulation] = None) -> stor.KWType:
+ traj = f_traj(dparams, stochastics, rs=rs)
+
+ total_hospital_overrun = traj[-1].O
+
+ return OrderedDict(cost=o_transform(total_hospital_overrun))
+
+
+def f_l_only(
+ dparams: stor.KWTypeNNParams,
+ stochastics: stor.KWType,
+ # TODO lo1dop6k See tag above.
+ rs: Optional[ReuseableSimulation] = None) -> stor.KWType:
+ traj = f_traj(dparams, stochastics, rs=rs)
+
+ total_lockdown_unpleasantness = traj[-1].L
+
+ return OrderedDict(cost=total_lockdown_unpleasantness)
+
+
+def plot_basic(dparams=None, stochastics=None):
+
+ if dparams is None:
+ dparams = OrderedDict()
+ if stochastics is None:
+ stochastics = OrderedDict()
+
+ traj = f_traj(dparams, stochastics)
+
+ fig, (ax1, ax3, ax2) = plt.subplots(3, 1, figsize=(7, 10))
+ tax3 = ax3.twinx()
+ ax2.axhline(DST['capacity'] * (1. / DST['hospitalization_rate']), color='k', linestyle='--')
+ ax3.axhline(DST['capacity'] * (1. / DST['hospitalization_rate']), color='k', linestyle='--',
+ label='Healthcare Capacity')
+ ax2.plot(DT, traj.S, label='S', color='blue')
+ ax2.plot(DT, traj.I, label='I', color='red')
+ ax3.plot(DT, traj.I, label='I', color='red')
+ ax2.plot(DT, traj.R, label='R', color='green')
+ ax1.plot(DT, traj.L, label='Aggregate Lockdown', color='orange')
+ ax1.plot(DT, traj.l, label='Lockdown', color='orange', linestyle='--')
+ tax3.plot(DT, traj.O, label='Aggregate Overrun', color='k')
+ ax1.legend()
+ ax2.legend()
+ ax3.legend()
+ tax3.legend()
+
+ ax2.set_xlabel('Time')
+ ax1.set_ylabel('Lockdown Strength')
+ ax2.set_ylabel('Proportion of Population')
+ ax3.set_ylabel('Proportion of Population')
+ tax3.set_ylabel('Aggregate Overrun')
+
+ plt.tight_layout()
+
+ plt.show()
+
+ return
+
+
+def plot_cost_vs_parameter(parameter_name, center: bool = False):
+
+ if center:
+ c = DDP[parameter_name]
+ parameter_values = torch.linspace(torch.relu(c - 0.05), c + 0.05, 25)
+ else:
+ # Define the range of parameter values to consider
+ parameter_values = torch.linspace(0.01, 1.0, 100)
+
+ # Initialize empty lists to store the corresponding cost values
+ cost_values = []
+ o_only_values = []
+ l_only_values = []
+
+ # Loop over the parameter values
+ for parameter_value in parameter_values:
+ # Calculate the cost for this parameter value
+ o_only = f_o_only(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost']
+ l_only = f_l_only(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost']
+ cost = f_combined(OrderedDict({parameter_name: parameter_value}), OrderedDict())['cost']
+
+ o_only_values.append(o_only.item())
+ l_only_values.append(l_only.item())
+ cost_values.append(cost.item())
+
+ # Create the plot
+ fig, ax = plt.subplots(1, 1)
+ ax.plot(parameter_values, cost_values)
+ ax.plot(parameter_values, o_only_values)
+ ax.plot(parameter_values, l_only_values)
+ ax.set_xlabel(parameter_name.capitalize())
+ ax.set_ylabel('Cost')
+ ax.legend(['Combined', 'Overrun Only', 'Lockdown Only'])
+ ax.grid(True)
+
+ plt.show()
+
+
+def plot_cost_vs_parameters(parameter_name1, parameter_name2):
+ # Define the range of parameter values to consider
+ parameter_values1 = torch.linspace(0.01, 1.0, 10) # Adjust the number of points as needed
+ parameter_values2 = torch.linspace(0.01, 1.0, 10) # Adjust the number of points as needed
+
+ # Create a meshgrid of parameter values
+ parameter_grid1, parameter_grid2 = torch.meshgrid(parameter_values1, parameter_values2)
+
+ # Initialize empty tensors to store the corresponding cost values
+ cost_values = torch.zeros_like(parameter_grid1)
+ o_only_values = torch.zeros_like(parameter_grid1)
+ l_only_values = torch.zeros_like(parameter_grid1)
+
+ # Loop over the parameter values
+ for i in range(parameter_values1.shape[0]):
+ for j in range(parameter_values2.shape[0]):
+ # Calculate the cost for this pair of parameter values
+ o_only = f_o_only(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost']
+ l_only = f_l_only(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost']
+ cost = f_combined(OrderedDict({parameter_name1: parameter_grid1[i, j], parameter_name2: parameter_grid2[i, j]}), OrderedDict())['cost']
+
+ o_only_values[i, j] = o_only.item()
+ l_only_values[i, j] = l_only.item()
+ cost_values[i, j] = cost.item()
+
+ # Create the contour plots
+ fig, axs = plt.subplots(3, 1, figsize=(10, 12))
+
+ axs[0].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), o_only_values.numpy(), cmap='viridis')
+ axs[0].set_xlabel(parameter_name1.capitalize())
+ axs[0].set_ylabel(parameter_name2.capitalize())
+ axs[0].set_title('Overrun Cost')
+ fig.colorbar(ax=axs[0], mappable=axs[0].collections[0], label='Cost')
+
+ axs[1].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), l_only_values.numpy(), cmap='viridis')
+ axs[1].set_xlabel(parameter_name1.capitalize())
+ axs[1].set_ylabel(parameter_name2.capitalize())
+ axs[1].set_title('Lockdown Cost')
+ fig.colorbar(ax=axs[1], mappable=axs[1].collections[0], label='Cost')
+
+ axs[2].contourf(parameter_grid1.numpy(), parameter_grid2.numpy(), cost_values.numpy(), cmap='viridis')
+ axs[2].set_xlabel(parameter_name1.capitalize())
+ axs[2].set_ylabel(parameter_name2.capitalize())
+ axs[2].set_title('Combined Cost')
+ fig.colorbar(ax=axs[2], mappable=axs[2].collections[0], label='Cost')
+
+ plt.tight_layout()
+ plt.show()
+
+
+def _NNM_vectorized(f_, N, M, X, Y):
+ out = np.empty((N, N, M))
+ it = np.nditer(out, flags=['multi_index'])
+ while not it.finished:
+ out[*it.multi_index[:-1], :] = f_(X[it.multi_index[:-1]], Y[it.multi_index[:-1]])
+ it.iternext()
+
+ return out
+
+
+def plot_cost_likelihood_convolution_for_stochastics(
+ stochastic_name1: str, stochastic_name2: str,
+ p: stor.ModelType, f_: stor.ExpectigrandType, n=1000):
+
+ samples = []
+ for _ in range(n): # Generate 1000 samples
+ sample = p()
+ samples.append((sample[stochastic_name1].item(), sample[stochastic_name2].item()))
+
+ samples = np.array(samples)
+
+ kde: KernelDensity = KernelDensity(kernel='gaussian', bandwidth=0.02).fit(samples)
+
+ def density(s1, s2) -> np.ndarray:
+ return np.array([np.exp(kde.score_samples([[s1, s2]]))])
+
+ def cost(s1, s2) -> np.ndarray:
+ vals = f_(OrderedDict({stochastic_name1: tt(s1), stochastic_name2: tt(s2)})).values()
+ return np.array(tuple(v.item() for v in vals))
+
+ def cust_colorbar(ax, arr):
+ ticks = np.linspace(0., 1.0, 10)
+ cbar = fig.colorbar(ax=ax, mappable=ax.collections[0], ticks=ticks)
+ arrmax = arr.max()
+ arrmin = arr.min()
+ tick_labels = ticks * (arrmax - arrmin) + arrmin
+ cbar.ax.set_yticklabels(['{:.1f}'.format(tick) for tick in tick_labels])
+
+ resolution = 15
+
+ s1ls = np.linspace(0.00, 1.0, resolution)
+ s2ls = np.linspace(0.00, 1.0, resolution)
+
+ s1ls = s1ls * (samples[:, 0].max() - samples[:, 0].min()) + samples[:, 0].min()
+ s2ls = s2ls * (samples[:, 1].max() - samples[:, 1].min()) + samples[:, 1].min()
+
+ X, Y = np.meshgrid(s1ls, s2ls)
+
+ # Make a subplot for the density and one for each component of the cost.
+ cost_component_names = f_(OrderedDict({stochastic_name1: tt(s1ls[0]), stochastic_name2: tt(s2ls[0])})).keys()
+ num_cost_components = len(cost_component_names)
+
+ fig, axs = plt.subplots(num_cost_components + 1, 3, figsize=(6*num_cost_components, 18))
+
+ # Plot the density in the top row.
+ density_array = _NNM_vectorized(f_=density, N=resolution, M=1, X=X, Y=Y)
+ for col in range(3):
+ axs[0][col].contourf(X, Y, density_array[..., 0], cmap='viridis')
+ cust_colorbar(axs[0][col], density_array[..., 0])
+
+ # Compute the cost array. We have to do this manually because there
+ cost_array = _NNM_vectorized(f_=cost, N=resolution, M=num_cost_components, X=X, Y=Y)
+
+ assert cost_array.shape == (resolution, resolution, num_cost_components)
+
+ # Plot the cost components in the left column.
+ for i, cost_component_name in enumerate(cost_component_names):
+ axs[i+1][0].contourf(X, Y, cost_array[:, :, i], cmap='viridis')
+ axs[i+1][0].set_title(cost_component_name.capitalize())
+ cust_colorbar(axs[i+1][0], cost_array[:, :, i])
+
+ # This maybe doesn't give the right shape.
+ cost_density_convolution_array = cost_array * density_array
+ assert cost_density_convolution_array.shape == (resolution, resolution, num_cost_components)
+
+ # Plot the positive part of the convolved cost components in the middle column.
+
+ def plot_part(arr, col):
+ for i, cost_component_name in enumerate(cost_component_names):
+
+ axs[i+1][col].contourf(X, Y, arr[:, :, i], cmap='viridis')
+ axs[i+1][col].set_title(cost_component_name + (' +' if col == 1 else ' -'))
+ cust_colorbar(axs[i+1][col], arr[:, :, i])
+
+ # And draw crosshairs at the original density mean, for comparison.
+ axs[i+1][col].axvline(x=samples[:, 0].mean(), color='white', linestyle='--')
+ axs[i+1][col].axhline(y=samples[:, 1].mean(), color='white', linestyle='--')
+
+ plot_part(np.maximum(cost_density_convolution_array, 0.0), 1)
+ plot_part(-np.minimum(cost_density_convolution_array, 0.0), 2)
+
+ plt.tight_layout()
+
+ plt.show()
+
+
+def pyro_prior_over_sirlo_params():
+ beta = pyro.sample(
+ "beta", dist.Normal(2., 0.3))
+ gamma = pyro.sample(
+ "gamma", dist.Normal(.4, .06))
+ capacity = pyro.sample(
+ "capacity", dist.Normal(0.01, 0.003))
+ hospitalization_rate = pyro.sample(
+ "hospitalization_rate", dist.Normal(0.05, 0.015))
+ return OrderedDict(
+ beta=beta,
+ gamma=gamma,
+ capacity=capacity,
+ hospitalization_rate=hospitalization_rate
+ )
+
+
+def pyro_prior_over_sirlo_params_2d():
+ beta = pyro.sample(
+ "beta", dist.Normal(2., 0.1))
+ gamma = pyro.sample(
+ "gamma", dist.Normal(.4, .02))
+ return OrderedDict(
+ beta=beta,
+ gamma=gamma
+ )
+
+
+def _grad_debugging():
+ dparams = OrderedDict(
+ lockdown_trigger=torch.nn.Parameter(tt(DDP['lockdown_trigger'].item())),
+ lockdown_lift_trigger=torch.nn.Parameter(tt(DDP['lockdown_lift_trigger'].item())),
+ lockdown_strength=torch.nn.Parameter(tt(DDP['lockdown_strength'].item()))
+ )
+
+ overrun = f_o_only(dparams, OrderedDict())
+ lockdown = f_l_only(dparams, OrderedDict())
+ combined = f_combined(dparams, OrderedDict())
+
+ traj = f_traj(dparams, OrderedDict())
+
+ # Make sure the gradient of the lockdown trigger end state wrt the param is one.
+ dltdlt = torch.autograd.grad(
+ outputs=(traj.lockdown_trigger[-1],),
+ inputs=tuple(dparams.values()),
+ create_graph=True)
+ assert torch.isclose(dltdlt[0], tt(1.0))
+ assert torch.isclose(dltdlt[1], tt(0.0))
+ assert torch.isclose(dltdlt[2], tt(0.0))
+
+ # Make sure the gradient of the lockdown lift trigger end state wrt the param is zero.
+ dlftdlft = torch.autograd.grad(
+ outputs=(traj.lockdown_lift_trigger[-1],),
+ inputs=tuple(dparams.values()),
+ create_graph=True)
+ assert torch.isclose(dlftdlft[0], tt(0.0), atol=1e-4)
+ assert torch.isclose(dlftdlft[1], tt(1.0))
+ assert torch.isclose(dlftdlft[2], tt(0.0))
+
+ if LOCKDOWN_TYPE != LockdownType.CONT_PLATEAU:
+ # Make sure the gradient of the lockdown strength wrt itself is one.
+ dlstdlst = torch.autograd.grad(
+ outputs=(traj.l[30],), # get strength when lockdown is active.
+ inputs=tuple(dparams.values()),
+ create_graph=True)
+ assert torch.isclose(dlstdlst[0], tt(0.0))
+ assert torch.isclose(dlstdlst[1], tt(0.0))
+ assert torch.isclose(dlstdlst[2], tt(1.0))
+
+ dCd = torch.autograd.grad(
+ outputs=(combined["cost"],),
+ inputs=tuple(dparams.values()),
+ # inputs=(dparams['lockdown_strength'],),
+ # inputs=(dparams['lockdown_trigger'],),
+ # inputs=(dparams['lockdown_lift_trigger'],),
+ create_graph=True)
+
+ assert not torch.isclose(dCd[0], tt(0.0), atol=1e-4)
+ assert not torch.isclose(dCd[1], tt(0.0), atol=1e-4)
+ assert not torch.isclose(dCd[2], tt(0.0), atol=1e-4)
+
+
+class ConstraintType(Enum):
+ JOINT = 1
+ MEAN = 2
+
+
+def optimize_decision_2d_latent(constraint_type: ConstraintType):
+
+ # Beta and gamma are the only latents.
+ p_ = pyro_prior_over_sirlo_params_2d
+
+ dparams = OrderedDict(
+ lockdown_trigger=torch.nn.Parameter(tt(0.1)),
+ lockdown_lift_trigger=torch.nn.Parameter(tt(0.3)),
+ lockdown_strength=torch.nn.Parameter(tt(0.7))
+ )
+
+ q_plus_guide = AutoMultivariateNormal(
+ model=p_,
+ )
+ q_minus_guide = AutoMultivariateNormal(
+ model=p_
+ )
+ q_den_guide = AutoMultivariateNormal(
+ model=p_
+ )
+
+ te = stor.TABIExpectation(
+ q_plus=q_plus_guide,
+ q_minus=q_minus_guide,
+ q_den=q_den_guide,
+ num_samples=1
+ )
+
+ def abort_guide_grads_(*parameters: torch.nn.Parameter):
+ # These gradients also blow up, but clipping them causes weird non-convergence. Just aborting
+ # the gradient update seems to work.
+ if torch.any(torch.tensor([torch.any(torch.abs(param.grad) > 400.) for param in parameters])):
+ for param in parameters:
+ param.grad = torch.zeros_like(param.grad)
+
+ dh = stor.DecisionOptimizerHandler(
+ dparams=dparams,
+ lr=1e-1,
+ proposal_update_lr=1e-4,
+ # TODO restore this once the constraints can get dedicated proposals.
+ proposal_update_steps=10 if constraint_type == ConstraintType.JOINT else 0,
+ proposal_adjust_grads_=abort_guide_grads_
+ )
+
+ mc = stor.MeanConstraintHandler(
+ g=f_o_only,
+ tau=1.,
+ threshold=0.001
+ ) if constraint_type == ConstraintType.MEAN else pyro.poutine.messenger.Messenger()
+
+ def terminal_condition(_, i: int) -> bool:
+ return i > 1000 # TODO more sophisticated convergence criterion.
+
+ betagammas_q_plus_progression = []
+ betagammas_q_minus_progression = []
+ betagammas_q_den_progression = []
+ lockdown_strength_progression = []
+ lockdown_trigger_progression = []
+ lockdown_lift_trigger_progression = []
+
+ def save_progressions():
+ np.savez(
+ f"/Users/azane/Desktop/sirlo_logs/sirlo_opt.npz",
+ lockdown_strength_progression=np.array(lockdown_strength_progression),
+ lockdown_trigger_progression=np.array(lockdown_trigger_progression),
+ lockdown_lift_trigger_progression=np.array(lockdown_lift_trigger_progression),
+ beta_q_plus_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_plus_progression]),
+ gamma_q_plus_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_plus_progression]),
+ beta_q_minus_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_minus_progression]),
+ gamma_q_minus_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_minus_progression]),
+ beta_q_den_progression=np.array([bg['beta'].detach().item() for bg in betagammas_q_den_progression]),
+ gamma_q_den_progression=np.array([bg['gamma'].detach().item() for bg in betagammas_q_den_progression]),
+ )
+
+ class OptimizeProposalCallback:
+
+ init_optimize_proposal_iterations = 0
+
+ def __call__(self):
+ self.init_optimize_proposal_iterations += 1
+ print(f"Optimizing proposal {self.init_optimize_proposal_iterations}.")
+
+ betagammas_q_plus_progression.append(q_plus_guide.forward())
+ betagammas_q_minus_progression.append(q_minus_guide.forward())
+ betagammas_q_den_progression.append(q_den_guide.forward())
+
+ save_progressions()
+
+ class OptimizeDecisionCallback:
+ optimize_decision_iterations = 0
+
+ def __call__(self):
+ self.optimize_decision_iterations += 1
+ print(f"Optimizing decision {self.optimize_decision_iterations}.")
+
+ lockdown_strength_progression.append(dh.dparams['lockdown_strength'].detach().item())
+ lockdown_trigger_progression.append(dh.dparams['lockdown_trigger'].detach().item())
+ lockdown_lift_trigger_progression.append(dh.dparams['lockdown_lift_trigger'].detach().item())
+
+ betagammas_q_plus_progression.append(q_plus_guide.forward())
+ betagammas_q_minus_progression.append(q_minus_guide.forward())
+ betagammas_q_den_progression.append(q_den_guide.forward())
+
+ save_progressions()
+
+ with te, dh, mc:
+
+ # # This changes what optimize_decision sees as the stochastic program to be the product of a conditioned
+ # # inference procedure.
+ # with pyro.condition(sir_noisy_data):
+
+ # TODO enable for explicit constraints can get dedicated proposals.
+ if constraint_type == ConstraintType.JOINT:
+ # Initial optimization of proposals.
+ stor.optimize_proposal(
+ p=p_,
+ f=f_combined,
+ n_steps=10,
+ lr=1e-4,
+ adjust_grads_=dh.proposal_adjust_grads_,
+ callback=OptimizeProposalCallback()
+ )
+
+ optimal_decision = stor.optimize_decision(
+ p=p_,
+ f=f_combined if constraint_type == ConstraintType.JOINT else f_l_only,
+ terminal_condition=terminal_condition,
+ adjust_grads=lambda g: OrderedDict([(k, torch.clip(g[k], -1./100., 1./100.)) for k in g.keys()]),
+ callback=OptimizeDecisionCallback(),
+ )
+
+ return # just here to put a breakpoint on
+
+
+if __name__ == '__main__':
+
+ # plot_basic()
+
+ # plot_cost_vs_parameter('lockdown_strength')
+ # plot_cost_vs_parameter('lockdown_trigger')
+ # plot_cost_vs_parameter('lockdown_lift_trigger')
+ # plot_cost_vs_parameters('lockdown_strength', 'lockdown_trigger')
+ # plot_cost_vs_parameters('lockdown_strength', 'lockdown_lift_trigger')
+ # plot_cost_vs_parameters('lockdown_trigger', 'lockdown_lift_trigger')
+
+ optimize_decision_2d_latent(constraint_type=ConstraintType.MEAN)
+ # plot_basic(OrderedDict(lockdown_strength=tt(.663)))
+
+ # _grad_debugging()
+
+ # plot_cost_likelihood_convolution_for_stochastics(
+ # 'beta', 'gamma', pyro_prior_over_sirlo_params_2d, f_=stor.build_expectigrand_gradient(DDP, f_combined))
diff --git a/docs/source/expectation_programming/sri_tabi_problem.py b/docs/source/expectation_programming/sri_tabi_problem.py
new file mode 100644
index 00000000..9be5be17
--- /dev/null
+++ b/docs/source/expectation_programming/sri_tabi_problem.py
@@ -0,0 +1,197 @@
+import pyro
+import torch
+import pyro.distributions as dist
+
+from chirho.dynamical.handlers import (
+ DynamicIntervention,
+ SimulatorEventLoop,
+ simulate,
+ ODEDynamics,
+)
+
+from chirho.dynamical.ops import State
+from torch import tensor as tt
+from enum import Enum
+
+
+# Largely for debugging...
+class LockdownType(Enum):
+ CONT_PLATEAU = 1
+ NONCONT_STATE = 2
+ NONCONT_TIME = 3
+
+
+LOCKDOWN_TYPE = LockdownType.NONCONT_STATE
+
+
+class SimpleSIRDynamics(ODEDynamics):
+ def __init__(self, beta, gamma):
+ super().__init__()
+ self.beta = beta
+ self.gamma = gamma
+
+ def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
+ dX.S = -self.beta * X.S * X.I
+ dX.I = self.beta * X.S * X.I - self.gamma * X.I
+ dX.R = self.gamma * X.I
+
+ def observation(self, X: State[torch.Tensor]):
+
+ I_obs = pyro.sample(f"I_obs", dist.Poisson(X.I)) # noisy number of infected actually observed
+ R_obs = pyro.sample(f"R_obs", dist.Poisson(X.R)) # noisy number of recovered actually observed
+
+ return {
+ f"I_obs": I_obs,
+ f"R_obs": R_obs,
+ }
+
+
+class SimpleSIRDynamicsLockdown(SimpleSIRDynamics):
+ def __init__(self, beta0, gamma):
+ super().__init__(torch.zeros_like(gamma), gamma)
+ self.beta0 = beta0
+
+ def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
+
+ # Lockdown strength is a piecewise constant function affected by dynamic interventions, so dX is 0.
+ dX.l = torch.tensor(0.0)
+
+ if LOCKDOWN_TYPE == LOCKDOWN_TYPE.CONT_PLATEAU:
+ # Pretend the trigger is also in terms of the number recovered.
+ recovered_trigger = X.lockdown_trigger
+ # Pretend the lift trigger is in terms of how many additional recovered need to trigger.
+ recovered_lift_trigger = recovered_trigger + X.lockdown_lift_trigger
+ # Definep dX.l as a continuous plateau.
+ plateau_u = (recovered_lift_trigger + recovered_trigger) / tt(2.)
+ plateau_s = (recovered_lift_trigger - recovered_trigger) / tt(2.)
+ dXL_override = X.static_lockdown_strength * torch.exp(-((X.R - plateau_u) / plateau_s)**tt(10))
+
+ dX.L = dXL_override
+ self.beta = (1 - dXL_override) * self.beta0
+ else:
+ # Time-varing beta parametrized by lockdown strength l_t
+ self.beta = (1 - X.l) * self.beta0
+
+ # Accrual of lockdown time.
+ dX.L = X.l
+
+ # Constant event parameters have to be in the state in order for torchdiffeq to give derivs.
+ dX.lockdown_trigger = tt(0.0)
+ dX.lockdown_lift_trigger = tt(0.0)
+ dX.static_lockdown_strength = tt(0.0)
+
+ # Call the base SIR class diff method
+ super().diff(dX, X)
+
+
+class SimpleSIRDynamicsLockdownCapacityOverrun(SimpleSIRDynamicsLockdown):
+ def __init__(self, beta0, gamma, capacity, hospitalization_rate):
+ super().__init__(beta0, gamma)
+
+ self.capacity = capacity
+ self.hospitalization_rate = hospitalization_rate
+
+ def diff(self, dX: State[torch.Tensor], X: State[torch.Tensor]):
+ # If the number of infected individuals needing hospitalization exceeds the capacity, accrue that difference
+ # in the overrun factor.
+ dX.O = torch.relu(X.I * self.hospitalization_rate - self.capacity)
+
+ super().diff(dX, X)
+
+
+def initiate_lockdown(t: torch.tensor, state: State[torch.tensor]):
+ target_recovered = state.lockdown_lift_trigger
+ target_infected = state.lockdown_trigger
+
+ if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE:
+ # To enact the policy, require that the lift trigger hasn't effectively already fired.
+ # Do this by saying that the infected count is treated as zero if the lift trigger had fired.
+ # Let's say the lockdown is approaching its firing state (infected count is increasing), then this value
+ # will start going to zero, but if the lift trigger had fired, then it will jump up to the non-zero
+ # target state again and the lockdown will never fire.
+ return target_infected - state.I * torch.greater_equal(target_recovered, state.R).type(state.R.dtype)
+ elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME:
+ return target_infected - t
+ elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU:
+ # Disabled entirely for continuous plateau.
+ return tt(1.)
+
+
+def lift_lockdown(t: torch.tensor, state: State[torch.tensor]):
+ target_recovered = state.lockdown_lift_trigger
+
+ if LOCKDOWN_TYPE == LockdownType.NONCONT_STATE:
+ # To lift the policy, require that the recovered count exceeds a certain level, and that a lockdown of
+ # non-zero strength is in place.
+ return target_recovered - state.R + torch.isclose(state.l, torch.tensor(0.0)).type(state.l.dtype)
+ elif LOCKDOWN_TYPE == LockdownType.NONCONT_TIME:
+ return target_recovered - t
+ elif LOCKDOWN_TYPE == LockdownType.CONT_PLATEAU:
+ # Disabled entirely for continuous plateau.
+ return tt(1.)
+
+
+class ReuseableSimulation:
+ # TODO can something like this be accomplished with functools partial?
+ def __init__(self):
+ self.result = None
+
+ @staticmethod
+ def _inner_call(lockdown_trigger, lockdown_lift_trigger, lockdown_strength,
+ init_state, beta, gamma, capacity, hospitalization_rate, times, **kwargs):
+
+ # Make a new state object so we can add the lockdown trigger constants to the state without modifying the
+ # original. This is required because torchdiffeq requires that even constant event parameters must be in the
+ # state in order to take gradients with respect to them.
+ new_init_state = State()
+ for k in init_state.keys:
+ setattr(new_init_state, k, getattr(init_state, k))
+ setattr(new_init_state, "lockdown_trigger", lockdown_trigger)
+ setattr(new_init_state, "lockdown_lift_trigger", lockdown_lift_trigger)
+ setattr(new_init_state, "static_lockdown_strength", lockdown_strength)
+
+ if torch.isclose(lockdown_trigger, torch.tensor(0.0)) or torch.less(lockdown_trigger, torch.tensor(0.0)):
+ raise ValueError("Lockdown trigger must be greater than zero.")
+
+ if torch.isclose(lockdown_lift_trigger, torch.tensor(0.0)) or torch.less(lockdown_lift_trigger, torch.tensor(0.0)):
+ raise ValueError("Lockdown lift trigger must be greater than zero.")
+
+ sir = SimpleSIRDynamicsLockdownCapacityOverrun(beta, gamma, capacity, hospitalization_rate)
+ with SimulatorEventLoop():
+ with DynamicIntervention(event_f=initiate_lockdown,
+ intervention=State(l=lockdown_strength),
+ var_order=new_init_state.var_order, max_applications=1, ):
+ with DynamicIntervention(event_f=lift_lockdown,
+ intervention=State(l=torch.tensor(0.0)), var_order=new_init_state.var_order,
+ max_applications=1):
+ tspan = times
+ if not torch.isclose(tspan[0], torch.tensor(0.0)):
+ tspan = torch.cat([torch.tensor([0.]), tspan])
+
+ soln = simulate(sir, new_init_state, tspan)
+
+ return soln
+
+ def __call__(self, *args, **kwargs):
+ if self.result is None:
+ self.result = self._inner_call(*args, **kwargs)
+ else:
+ # TODO assert that the arguments are the same?
+ pass
+
+ return self.result
+
+
+def cost(dparams, stochastics, end_time, reuseable_sim) -> torch.Tensor:
+ sim_results = reuseable_sim(dparams, stochastics, end_time)
+
+ # The cost of the lockdown policy is the total lockdown time times the severity. See lockdown diff.
+ return sim_results[-1].L
+
+
+def failure_magnitude(dparams, stochastics, end_time, reuseable_sim) -> torch.Tensor:
+ sim_results = reuseable_sim(dparams, stochastics, end_time)
+
+ # The failure magnitude of lockdown policy is the total time spent by infected individuals who needed
+ # hospitalization but could not get it.
+ return sim_results[-1].O
diff --git a/docs/source/expectation_programming/toy_tabi_problem.py b/docs/source/expectation_programming/toy_tabi_problem.py
new file mode 100644
index 00000000..2f989efa
--- /dev/null
+++ b/docs/source/expectation_programming/toy_tabi_problem.py
@@ -0,0 +1,129 @@
+import torch
+import pyro
+import pyro.distributions as dist
+from collections import OrderedDict
+from torch import tensor as tt
+
+
+def h(d, c):
+ # Convert the decision parameter to a convex space with defined maximum (1+c) and minimum (c).
+ # This is the root form of a student-t distribution, giving us longer tails and a better ramp up for
+ # the decision parameter gradients.
+ return (1. + d**2.)**-1. + c
+
+
+def cost_part(d, x, c, z: bool):
+ # This builds out a part of the cost function. It's a normal positioned in the tails to the degree defined by
+ # c. If z is true, this is the negative part of the cost function in the left tail, otherwise it's the positive
+ # part of the cost function in the right tail. Each part is normal-distribution shaped with a mean of h(d, c)
+ # and h(d, c) - 3 * c, respectively. The variance is fixed at .2**2.
+ # numerator = -(x - h(d, c) + 3 * float(z) * c)**2.
+ # denominator = torch.tensor(.2**2.)
+ # return .5 * torch.exp(numerator / denominator)
+
+ mean = h(d, c) - 3 * float(z) * c
+ std = .2
+ return torch.exp(dist.Normal(mean, std).log_prob(x))
+
+
+def cost(d: torch.Tensor, x: torch.Tensor, c: torch.Tensor, **kwargs) -> torch.Tensor:
+ # The additional kwargs just takes stochastics and decision parameters that don't actually play into this
+ # cost function. This is required because all decision parameters and all stochastics are unpacked here
+ # as keyword arguments to the cost function, and there are auxiliary stochastics we don't care about.
+ return cost_part(d, x, c, False) - cost_part(d, x, c, True)
+
+
+def q_optimal_normal_guide_mean_var(d, c, z: bool):
+ # This is the product of the standard normal and the cost normal pdfs.
+ # See e.g. https://www.johndcook.com/blog/2012/10/29/product-of-normal-pdfs/ for more details.
+ m = h(d, c) - 3 * float(z) * c
+ vp = tt(.2**2.)
+ vn = tt(.2**-2)
+
+ new_mean = (vn * m) / (vn + 1)
+ new_var = vp / (1 + vp)
+
+ return new_mean, torch.sqrt(new_var)
+
+
+MODEL_DIST = dist.Normal(0., 1.)
+
+
+def model():
+ # The model is simply a unit normal.
+ return OrderedDict(x=pyro.sample("x", MODEL_DIST))
+
+
+if __name__ == "__main__":
+
+ import matplotlib.pyplot as plt
+
+ xx = torch.linspace(-5., 5., 300)
+
+ c_ = tt(1.5)
+ d1_ = tt(2.)
+ d2_ = tt(0.)
+ d3_ = tt(0.5)
+
+ # Plot the decision parameter re-mapping.
+ plt.figure()
+ plt.suptitle('Decision Parameter Re-Mapping')
+ plt.plot(xx, h(xx, c_))
+
+ # Plot the cost function for the two decision parameters.
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
+ plt.suptitle('Cost by Model on Stochastics')
+
+ # Plot the cost function for the two decision parameters. Put in the different subplots.
+ ax1.plot(xx, cost(d1_, xx, c_), color='red', label=f"d={d1_}")
+ ax2.plot(xx, cost(d2_, xx, c_), color='red', label=f"d={d2_}")
+ ax3.plot(xx, cost(d3_, xx, c_), color='red', label=f"d={d3_}")
+
+ # Plot the model on top of the cost function.
+ ax1.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue')
+ ax2.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue')
+ ax3.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue')
+
+ # Plot the optimal guides for the two parts of the cost function. These are normal distributions with means and
+ # variances calculated from the product of the model pdf and the cost function "pdf".
+ for d_, ax in zip([d1_, d2_, d3_], [ax1, ax2, ax3]):
+ gpm, gms = q_optimal_normal_guide_mean_var(d_, c_, False)
+ gnm, gns = q_optimal_normal_guide_mean_var(d_, c_, True)
+ ax.plot(xx, torch.exp(-((xx - gpm)/gms)**tt(2.))*.4, color='orange', label='Optimal Guide')
+ ax.plot(xx, torch.exp(-((xx - gnm)/gns)**tt(2.))*.4, color='orange')
+
+ ax1.legend()
+ ax2.legend()
+ ax3.legend()
+
+ # In a different plot, show the unnormalized positive component (using torch.relu) of the cost function
+ # multiplied by the model pdf. Also show the negative component (this can be achieved by using torch.relu(-x)).
+ # Just do this for d3.
+ fig1, ax = plt.subplots()
+ tax = ax.twinx()
+ tax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)), color='blue', label=f"d={d3_}", linestyle='--')
+ ax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(1e2*cost(d3_, xx, c_)), color='green', label=f"d={d3_}")
+ ax.plot(xx, torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(-1e2*cost(d3_, xx, c_)), color='red', label=f"d={d3_}")
+
+ # Plot just the cost function with a thin line.
+ ax.plot(xx, cost(d3_, xx, c_), color='black', linestyle='--', linewidth=0.4)
+
+ # Now plot the properly scaled normal distributions that map to the normalizations of the curves plotted above.
+ gpm, gms = q_optimal_normal_guide_mean_var(d3_, c_, False)
+ gnm, gns = q_optimal_normal_guide_mean_var(d3_, c_, True)
+
+ ax.plot(xx, torch.exp(dist.Normal(gpm, gms).log_prob(xx)), color='orange', linestyle='--')
+ ax.plot(xx, torch.exp(dist.Normal(gnm, gns).log_prob(xx)), color='orange', linestyle='--')
+
+ # In a different figure, show the ratio of the model-scaled positive and negative components of the cost function
+ # with respect to the properly normalized optimal guides. Show this across the same xx.
+ fig2, ax = plt.subplots()
+ num_pos = torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(1e2 * cost(d3_, xx, c_))
+ den_pos = torch.exp(dist.Normal(gpm, gms).log_prob(xx))
+ plt.plot(xx, num_pos / den_pos, color='green', label='Positive Component')
+
+ num_neg = torch.exp(MODEL_DIST.log_prob(xx)) * torch.relu(-1e2 * cost(d3_, xx, c_))
+ den_neg = torch.exp(dist.Normal(gnm, gns).log_prob(xx))
+ plt.plot(xx, num_neg / den_neg, color='red', label='Negative Component')
+
+ plt.show()