Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update type annotations for Pyro 1.9 #527

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9cbf026
fix types in indexed module
eb8680 Feb 20, 2024
179d851
fix observational
eb8680 Feb 22, 2024
8f94794
fix cond
eb8680 Feb 22, 2024
3a9508a
indexed
eb8680 Feb 22, 2024
d2e3e73
fix counterfactual module
eb8680 Feb 22, 2024
12d00cf
fix part of explainable module
eb8680 Feb 22, 2024
152c49f
nit
eb8680 Feb 22, 2024
501c25f
fix dynamical
eb8680 Feb 22, 2024
bb08e6e
fix robust
eb8680 Feb 23, 2024
3e5b96d
fix explanation
eb8680 Feb 23, 2024
4f90106
nit
eb8680 Feb 23, 2024
ee9665d
revert docstring
eb8680 Feb 23, 2024
daa2678
Merge branch 'master' into eb-fix-types-pyro
eb8680 Feb 23, 2024
6837696
bump pyro version
eb8680 Feb 23, 2024
c39160d
remove comment
eb8680 Feb 23, 2024
00d85b3
format
eb8680 Feb 23, 2024
bbe6307
fix tests
eb8680 Feb 23, 2024
15f3e89
dont render params
eb8680 Feb 23, 2024
02b350e
remove duplicate boilerplate
eb8680 Feb 23, 2024
b36c3f9
fix
eb8680 Feb 23, 2024
a887100
disable line
eb8680 Feb 23, 2024
75731e4
patch rendering
eb8680 Feb 23, 2024
5bb6a6e
add comment explaining patch
eb8680 Feb 23, 2024
9e82051
lint
eb8680 Feb 23, 2024
033ee1e
float version
eb8680 Feb 23, 2024
cc2cc67
move patches to a separate file
eb8680 Feb 23, 2024
567c767
import just
eb8680 Feb 23, 2024
e55fc3a
just
eb8680 Feb 27, 2024
fef6fb2
aim for backward compatibility
eb8680 Feb 27, 2024
400f6ea
add pyro 1.8.6 lint stage
eb8680 Feb 27, 2024
6fdec8d
add pyro 1.8.6 lint stage
eb8680 Feb 27, 2024
59f79fe
unpin version
eb8680 Feb 27, 2024
8afbe07
ignores
eb8680 Feb 27, 2024
40359f3
tweaking
eb8680 Feb 27, 2024
9c78abb
repin pyro version
eb8680 Feb 27, 2024
bef38fd
remove backward compatibility attempt
eb8680 Feb 27, 2024
f3b33c1
Merge branch 'master' into eb-fix-types-pyro
eb8680 Apr 25, 2024
f46b3dc
Update ops.py
eb8680 Apr 25, 2024
89a4041
Update soft_conditioning.py
eb8680 Apr 25, 2024
c01c4dd
reenable both linting stages for backwards compatibility
eb8680 Apr 25, 2024
e3b0254
unpin pyro version in setup.py
eb8680 Apr 25, 2024
3b27c71
overload just
eb8680 Apr 25, 2024
ef84f1f
Merge branch 'master' into eb-fix-types-pyro
eb8680 Apr 25, 2024
c0cb512
remove Message for now
eb8680 Apr 25, 2024
b8de4f4
remove inferdict
eb8680 Apr 26, 2024
fa20481
type ignores
eb8680 Apr 26, 2024
6885df8
Merge branch 'master' into eb-fix-types-pyro
eb8680 Apr 26, 2024
614ce97
revert patches
eb8680 May 7, 2024
5d9d454
remove backward compatibility
eb8680 May 7, 2024
de91c58
Merge branch 'master' into eb-fix-types-pyro
eb8680 May 7, 2024
285710c
Merge branch 'master' into eb-fix-types-pyro
eb8680 Jul 12, 2024
ae27552
Update components.py
eb8680 Jul 12, 2024
cb7eb19
Update components.py
eb8680 Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
strategy:
matrix:
python-version: [3.8, 3.9, '3.10']
pyro-version: ['1.8.6', 'latest']

steps:
- uses: actions/checkout@v2
Expand All @@ -32,6 +33,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ "${{ matrix.pyro-version }}" = "latest" ]; then
pip install pyro-ppl
else
pip install pyro-ppl==${{ matrix.pyro-version }}
fi
pip install .[test]

- name: Lint
Expand Down
15 changes: 10 additions & 5 deletions chirho/counterfactual/handlers/ambiguity.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
import typing
from typing import TypeVar

import pyro
import pyro.distributions as dist
import pyro.poutine.indep_messenger
import torch

from chirho.counterfactual.handlers.selection import (
Expand All @@ -23,21 +25,21 @@ class FactualConditioningMessenger(pyro.poutine.messenger.Messenger):
counterfactual semantics handlers such as :class:`MultiWorldCounterfactual` .
"""

def _pyro_post_sample(self, msg: dict) -> None:
def _pyro_post_sample(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below comment about removal of dict

# expand latent values to include all index plates
if not msg["is_observed"] and not pyro.poutine.util.site_is_subsample(msg):
rv, value, event_dim = msg["fn"], msg["value"], len(msg["fn"].event_shape)
index_plates = get_index_plates()

new_shape = list(value.shape)
for k in set(indices_of(rv)) - set(indices_of(value, event_dim=event_dim)):
dim = index_plates[k].dim
dim = typing.cast(int, index_plates[k].dim)
new_shape = [1] * ((event_dim - dim) - len(new_shape)) + new_shape
new_shape[dim - event_dim] = rv.batch_shape[dim]

msg["value"] = value.expand(tuple(new_shape))

def _pyro_observe(self, msg: dict) -> None:
def _pyro_observe(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is msg now, if not a dict?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see later it's typed as Mapping[str, Any]. Would that work here?

if "name" not in msg["kwargs"]:
msg["kwargs"]["name"] = msg["name"]

Expand All @@ -55,7 +57,7 @@ def _dispatched_observe(self, rv, obs: torch.Tensor, name: str) -> torch.Tensor:
@_dispatched_observe.register(dist.FoldedDistribution)
@_dispatched_observe.register(dist.Distribution)
def _observe_dist(
self, rv: dist.Distribution, obs: torch.Tensor, name: str
self, rv: dist.TorchDistribution, obs: torch.Tensor, name: str
) -> torch.Tensor:
with pyro.poutine.infer_config(config_fn=no_ambiguity):
with SelectFactual():
Expand All @@ -74,7 +76,10 @@ def _observe_dist(

@_dispatched_observe.register
def _observe_tfmdist(
self, rv: dist.TransformedDistribution, value: torch.Tensor, name: str
self,
rv: dist.TransformedDistribution,
value: torch.Tensor,
name: str,
) -> torch.Tensor:
tfm = (
rv.transforms[-1]
Expand Down
20 changes: 15 additions & 5 deletions chirho/counterfactual/internals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Dict
import typing
from typing import Any, Mapping, TypedDict

import pyro

from chirho.indexed.ops import indices_of, union


def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
def site_is_ambiguous(msg: Mapping[str, Any]) -> bool:
"""
Helper function used with :func:`observe` to determine
whether a site is observed or ambiguous.
Expand All @@ -13,17 +16,24 @@ def site_is_ambiguous(msg: Dict[str, Any]) -> bool:
are fixed/observed (as opposed to random/unobserved).
"""
rv, obs = msg["args"][:2]
infer_dict = msg["infer"]
if typing.TYPE_CHECKING:
assert infer_dict is not None
value_indices = indices_of(obs, event_dim=len(rv.event_shape))
dist_indices = indices_of(rv)
return (
bool(union(value_indices, dist_indices)) and value_indices != dist_indices
) or not msg["infer"].get("_specified_conditioning", True)
) or not infer_dict.get("_specified_conditioning", True)


class AmbiguityInferDict(pyro.poutine.runtime.InferDict):
_specified_conditioning: bool


def no_ambiguity(msg: Dict[str, Any]) -> Dict[str, Any]:
def no_ambiguity(msg: Mapping[str, Any]) -> AmbiguityInferDict:
"""
Helper function used with :func:`pyro.poutine.infer_config` to inform
:class:`FactualConditioningMessenger` that all ambiguity in the current
context has been resolved.
"""
return {"_specified_conditioning": True}
return AmbiguityInferDict(_specified_conditioning=True)
12 changes: 9 additions & 3 deletions chirho/dynamical/handlers/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import Generic, TypeVar

import pyro
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(self, times: torch.Tensor, is_traced: bool = False):

super().__init__()

def _pyro_post_simulate(self, msg: dict) -> None:
def _pyro_post_simulate(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]?

initial_state: State[T] = msg["args"][1]
start_time = msg["args"][2]

Expand All @@ -74,7 +75,10 @@ def _pyro_post_simulate(self, msg: dict) -> None:

if self.is_traced:
# This adds the trajectory to the trace so that it can be accessed later.
[pyro.deterministic(name, value) for name, value in self.trajectory.items()]
[
pyro.deterministic(name, typing.cast(torch.Tensor, value))
for name, value in self.trajectory.items()
]

def _pyro_simulate_point(self, msg) -> None:
# Turn a simulate that returns a state into a simulate that returns a trajectory at each of the logging_times
Expand All @@ -100,7 +104,9 @@ def _pyro_simulate_point(self, msg) -> None:

# TODO support dim != -1
idx_name = "__time"
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim = {
k: typing.cast(int, f.dim) - 1 for k, f in get_index_plates().items()
}
name_to_dim[idx_name] = -1

if len(timespan) > 2:
Expand Down
4 changes: 2 additions & 2 deletions chirho/dynamical/internals/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

@typing.final
def _process_message(self, msg: Dict[str, Any]) -> None:
def _process_message(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]?

if not self.used and hasattr(self, f"_pyro_{msg['type']}"):
self.used = True
super()._process_message(msg)
Expand All @@ -119,7 +119,7 @@ def cont(msg: Dict[str, Any]) -> None:
msg["continuation"] = cont

@typing.final
def _postprocess_message(self, msg: Dict[str, Any]) -> None:
def _postprocess_message(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above comment.

if hasattr(self, f"_pyro_post_{msg['type']}"):
raise NotImplementedError("ShallowHandler does not support postprocessing")

Expand Down
7 changes: 5 additions & 2 deletions chirho/dynamical/internals/backends/torchdiffeq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import typing
from typing import Callable, List, Optional, Tuple, TypeVar

import pyro
Expand Down Expand Up @@ -151,7 +152,9 @@ def torchdiffeq_simulate_point(

# TODO support dim != -1
idx_name = "__time"
name_to_dim = {k: f.dim - 1 for k, f in get_index_plates().items()}
name_to_dim = {
k: typing.cast(int, f.dim) - 1 for k, f in get_index_plates().items()
}
name_to_dim[idx_name] = -1

final_idx = IndexSet(**{idx_name: {len(timespan) - 1}})
Expand Down Expand Up @@ -249,7 +252,7 @@ def torchdiffeq_simulate_to_interruption(
dynamics, initial_state, start_time, interruptions, **kwargs
)

value = simulate_point(
value: State[torch.Tensor] = simulate_point(
dynamics, initial_state, start_time, interruption_time, **kwargs
)
return value, interruption_time, next_interruption
Expand Down
14 changes: 8 additions & 6 deletions chirho/dynamical/internals/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.predicate = predicate
self.callback = callback

def _pyro_get_new_interruptions(self, msg: dict) -> None:
def _pyro_get_new_interruptions(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]?

if msg["value"] is None:
msg["value"] = []
assert isinstance(msg["value"], list)
Expand All @@ -42,7 +42,7 @@ def get_new_interruptions() -> List[Interruption]:
"""
Install the active interruptions into the context.
"""
return []
return typing.cast(List[Interruption], [])


class Solver(Generic[T], pyro.poutine.messenger.Messenger):
Expand All @@ -58,11 +58,11 @@ def _prioritize_interruption(h: Interruption[T]) -> Prioritized[Interruption[T]]
raise NotImplementedError(f"cannot install interruption {h}")

@typing.final
def _pyro_simulate(self, msg: dict) -> None:
def _pyro_simulate(self, msg) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in ambiguity.py about dropping the msg type. Could this be replaced with Mapping[str, Any]?

from chirho.dynamical.handlers.interruption import StaticEvent

dynamics: Dynamics[T] = msg["args"][0]
state: State[T] = msg["args"][1]
dynamics: Dynamics[T] = typing.cast(Dynamics[T], msg["args"][0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the type specification Dynamics[T] in dynamics: Dynamics[T] still necessary here now that you're using typing.cast?

state: State[T] = typing.cast(State[T], msg["args"][1])
start_time: R = msg["args"][2]
end_time: R = msg["args"][3]

Expand Down Expand Up @@ -106,6 +106,7 @@ def _pyro_simulate(self, msg: dict) -> None:
if ph.priority > start_time:
break

next_interruption: Optional[Interruption[T]]
state, start_time, next_interruption = simulate_to_interruption(
possible_interruptions,
dynamics,
Expand All @@ -116,7 +117,8 @@ def _pyro_simulate(self, msg: dict) -> None:
)

if next_interruption is not None:
dynamics, state = next_interruption.callback(dynamics, state)
# TODO reenable this check once we figure out why mypy is inferring "Never"s
dynamics, state = next_interruption.callback(dynamics, state) # type: ignore

for h in possible_interruptions:
if h is not next_interruption:
Expand Down
7 changes: 4 additions & 3 deletions chirho/explainable/handlers/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import typing
from typing import Callable, Iterable, MutableMapping, Optional, TypeVar

import pyro
Expand Down Expand Up @@ -91,7 +92,7 @@ def _random_intervention(value: T) -> T:
support,
event_shape=event_shape,
)
return pyro.sample(name, proposal_dist)
return typing.cast(T, pyro.sample(name, proposal_dist))

return _random_intervention

Expand Down Expand Up @@ -309,7 +310,7 @@ def _consequent_eq_neq(consequent: T) -> torch.Tensor:
},
}

new_value = scatter_n(
new_value: torch.Tensor = scatter_n(
nec_suff_log_probs_partitioned,
event_dim=0,
)
Expand Down Expand Up @@ -343,6 +344,6 @@ def __init__(self):

self.supports = {}

def _pyro_post_sample(self, msg: dict) -> None:
def _pyro_post_sample(self, msg) -> None:
if not pyro.poutine.util.site_is_subsample(msg):
self.supports[msg["name"]] = msg["fn"].support
1 change: 1 addition & 0 deletions chirho/explainable/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import typing
from typing import Callable, Mapping, Optional, TypeVar, Union

import pyro.distributions.constraints as constraints
Expand Down
9 changes: 6 additions & 3 deletions chirho/explainable/internals/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import TypeVar

import pyro
import pyro.distributions
import pyro.distributions.constraints
import pyro.distributions.torch
import torch

S = TypeVar("S")
Expand All @@ -12,7 +15,7 @@
def uniform_proposal(
support: pyro.distributions.constraints.Constraint,
**kwargs,
) -> pyro.distributions.Distribution:
) -> pyro.distributions.TorchDistribution:
"""
This function heuristically constructs a probability distribution over a specified
support. The choice of distribution depends on the type of support provided.
Expand Down Expand Up @@ -44,7 +47,7 @@ def _uniform_proposal_indep(
*,
event_shape: torch.Size = torch.Size([]),
**kwargs,
) -> pyro.distributions.Distribution:
) -> pyro.distributions.TorchDistribution:
d = uniform_proposal(support.base_constraint, event_shape=event_shape, **kwargs)
return d.expand(event_shape).to_event(support.reinterpreted_batch_ndims)

Expand All @@ -53,7 +56,7 @@ def _uniform_proposal_indep(
def _uniform_proposal_integer(
support: pyro.distributions.constraints.integer_interval,
**kwargs,
) -> pyro.distributions.Distribution:
) -> pyro.distributions.TorchDistribution:
if support.lower_bound != 0:
raise NotImplementedError(
"integer_interval with lower_bound > 0 not yet supported"
Expand Down
2 changes: 1 addition & 1 deletion chirho/indexed/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_mask(
) -> torch.Tensor:
raise NotImplementedError

def _pyro_sample(self, msg: Dict[str, Any]) -> None:
def _pyro_sample(self, msg) -> None:
if pyro.poutine.util.site_is_subsample(msg):
return

Expand Down
Loading
Loading