-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update type annotations for Pyro 1.9 #527
base: master
Are you sure you want to change the base?
Changes from all commits
9cbf026
179d851
8f94794
3a9508a
d2e3e73
12d00cf
152c49f
501c25f
bb08e6e
3e5b96d
4f90106
ee9665d
daa2678
6837696
c39160d
00d85b3
bbe6307
15f3e89
02b350e
b36c3f9
a887100
75731e4
5bb6a6e
9e82051
033ee1e
cc2cc67
567c767
e55fc3a
fef6fb2
400f6ea
6fdec8d
59f79fe
8afbe07
40359f3
9c78abb
bef38fd
f3b33c1
f46b3dc
89a4041
c01c4dd
e3b0254
3b27c71
ef84f1f
c0cb512
b8de4f4
fa20481
6885df8
614ce97
5d9d454
de91c58
285710c
ae27552
cb7eb19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
import functools | ||
import typing | ||
from typing import TypeVar | ||
|
||
import pyro | ||
import pyro.distributions as dist | ||
import pyro.poutine.indep_messenger | ||
import torch | ||
|
||
from chirho.counterfactual.handlers.selection import ( | ||
|
@@ -23,21 +25,21 @@ class FactualConditioningMessenger(pyro.poutine.messenger.Messenger): | |
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` . | ||
""" | ||
|
||
def _pyro_post_sample(self, msg: dict) -> None: | ||
def _pyro_post_sample(self, msg) -> None: | ||
# expand latent values to include all index plates | ||
if not msg["is_observed"] and not pyro.poutine.util.site_is_subsample(msg): | ||
rv, value, event_dim = msg["fn"], msg["value"], len(msg["fn"].event_shape) | ||
index_plates = get_index_plates() | ||
|
||
new_shape = list(value.shape) | ||
for k in set(indices_of(rv)) - set(indices_of(value, event_dim=event_dim)): | ||
dim = index_plates[k].dim | ||
dim = typing.cast(int, index_plates[k].dim) | ||
new_shape = [1] * ((event_dim - dim) - len(new_shape)) + new_shape | ||
new_shape[dim - event_dim] = rv.batch_shape[dim] | ||
|
||
msg["value"] = value.expand(tuple(new_shape)) | ||
|
||
def _pyro_observe(self, msg: dict) -> None: | ||
def _pyro_observe(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see later it's typed as |
||
if "name" not in msg["kwargs"]: | ||
msg["kwargs"]["name"] = msg["name"] | ||
|
||
|
@@ -55,7 +57,7 @@ def _dispatched_observe(self, rv, obs: torch.Tensor, name: str) -> torch.Tensor: | |
@_dispatched_observe.register(dist.FoldedDistribution) | ||
@_dispatched_observe.register(dist.Distribution) | ||
def _observe_dist( | ||
self, rv: dist.Distribution, obs: torch.Tensor, name: str | ||
self, rv: dist.TorchDistribution, obs: torch.Tensor, name: str | ||
) -> torch.Tensor: | ||
with pyro.poutine.infer_config(config_fn=no_ambiguity): | ||
with SelectFactual(): | ||
|
@@ -74,7 +76,10 @@ def _observe_dist( | |
|
||
@_dispatched_observe.register | ||
def _observe_tfmdist( | ||
self, rv: dist.TransformedDistribution, value: torch.Tensor, name: str | ||
self, | ||
rv: dist.TransformedDistribution, | ||
value: torch.Tensor, | ||
name: str, | ||
) -> torch.Tensor: | ||
tfm = ( | ||
rv.transforms[-1] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import typing | ||
from typing import Generic, TypeVar | ||
|
||
import pyro | ||
|
@@ -55,7 +56,7 @@ def __init__(self, times: torch.Tensor, is_traced: bool = False): | |
|
||
super().__init__() | ||
|
||
def _pyro_post_simulate(self, msg: dict) -> None: | ||
def _pyro_post_simulate(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment in |
||
initial_state: State[T] = msg["args"][1] | ||
start_time = msg["args"][2] | ||
|
||
|
@@ -74,7 +75,10 @@ def _pyro_post_simulate(self, msg: dict) -> None: | |
|
||
if self.is_traced: | ||
# This adds the trajectory to the trace so that it can be accessed later. | ||
[pyro.deterministic(name, value) for name, value in self.trajectory.items()] | ||
[ | ||
pyro.deterministic(name, typing.cast(torch.Tensor, value)) | ||
for name, value in self.trajectory.items() | ||
] | ||
|
||
def _pyro_simulate_point(self, msg) -> None: | ||
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times | ||
|
@@ -100,7 +104,9 @@ def _pyro_simulate_point(self, msg) -> None: | |
|
||
# TODO support dim != -1 | ||
idx_name = "__time" | ||
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()} | ||
name_to_dim = { | ||
k: typing.cast(int, f.dim) - 1 for k, f in get_index_plates().items() | ||
} | ||
name_to_dim[idx_name] = -1 | ||
|
||
if len(timespan) > 2: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,7 +103,7 @@ def __exit__(self, exc_type, exc_value, traceback): | |
super().__exit__(exc_type, exc_value, traceback) | ||
|
||
@typing.final | ||
def _process_message(self, msg: Dict[str, Any]) -> None: | ||
def _process_message(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]? |
||
if not self.used and hasattr(self, f"_pyro_{msg['type']}"): | ||
self.used = True | ||
super()._process_message(msg) | ||
|
@@ -119,7 +119,7 @@ def cont(msg: Dict[str, Any]) -> None: | |
msg["continuation"] = cont | ||
|
||
@typing.final | ||
def _postprocess_message(self, msg: Dict[str, Any]) -> None: | ||
def _postprocess_message(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above comment. |
||
if hasattr(self, f"_pyro_post_{msg['type']}"): | ||
raise NotImplementedError("ShallowHandler does not support postprocessing") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ def __init__( | |
self.predicate = predicate | ||
self.callback = callback | ||
|
||
def _pyro_get_new_interruptions(self, msg: dict) -> None: | ||
def _pyro_get_new_interruptions(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]? |
||
if msg["value"] is None: | ||
msg["value"] = [] | ||
assert isinstance(msg["value"], list) | ||
|
@@ -42,7 +42,7 @@ def get_new_interruptions() -> List[Interruption]: | |
""" | ||
Install the active interruptions into the context. | ||
""" | ||
return [] | ||
return typing.cast(List[Interruption], []) | ||
|
||
|
||
class Solver(Generic[T], pyro.poutine.messenger.Messenger): | ||
|
@@ -58,11 +58,11 @@ def _prioritize_interruption(h: Interruption[T]) -> Prioritized[Interruption[T]] | |
raise NotImplementedError(f"cannot install interruption {h}") | ||
|
||
@typing.final | ||
def _pyro_simulate(self, msg: dict) -> None: | ||
def _pyro_simulate(self, msg) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]? |
||
from chirho.dynamical.handlers.interruption import StaticEvent | ||
|
||
dynamics: Dynamics[T] = msg["args"][0] | ||
state: State[T] = msg["args"][1] | ||
dynamics: Dynamics[T] = typing.cast(Dynamics[T], msg["args"][0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the type specification |
||
state: State[T] = typing.cast(State[T], msg["args"][1]) | ||
start_time: R = msg["args"][2] | ||
end_time: R = msg["args"][3] | ||
|
||
|
@@ -106,6 +106,7 @@ def _pyro_simulate(self, msg: dict) -> None: | |
if ph.priority > start_time: | ||
break | ||
|
||
next_interruption: Optional[Interruption[T]] | ||
state, start_time, next_interruption = simulate_to_interruption( | ||
possible_interruptions, | ||
dynamics, | ||
|
@@ -116,7 +117,8 @@ def _pyro_simulate(self, msg: dict) -> None: | |
) | ||
|
||
if next_interruption is not None: | ||
dynamics, state = next_interruption.callback(dynamics, state) | ||
# TODO reenable this check once we figure out why mypy is inferring "Never"s | ||
dynamics, state = next_interruption.callback(dynamics, state) # type: ignore | ||
|
||
for h in possible_interruptions: | ||
if h is not next_interruption: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below comment about removal of
dict