Skip to content

Commit 88308f2

Browse files
Set rng state for trace fn mapping draws to posterior samples
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 09def3b commit 88308f2

File tree

7 files changed

+129
-12
lines changed

7 files changed

+129
-12
lines changed

pymc/backends/__init__.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from pymc.blocking import PointType
7777
from pymc.model import Model
7878
from pymc.step_methods.compound import BlockedStep, CompoundStep
79+
from pymc.util import get_random_generator
7980

8081
HAS_MCB = False
8182
try:
@@ -103,11 +104,13 @@ def _init_trace(
103104
model: Model,
104105
trace_vars: list[TensorVariable] | None = None,
105106
initial_point: PointType | None = None,
107+
rng: np.random.Generator | None = None,
106108
) -> BaseTrace:
107109
"""Initialize a trace backend for a chain."""
110+
rng_ = get_random_generator(rng)
108111
strace: BaseTrace
109112
if trace is None:
110-
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
113+
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point, rng=rng_)
111114
elif isinstance(trace, BaseTrace):
112115
if len(trace) > 0:
113116
raise ValueError("Continuation of traces is no longer supported.")
@@ -129,6 +132,7 @@ def init_traces(
129132
model: Model,
130133
trace_vars: list[TensorVariable] | None = None,
131134
tune: int = 0,
135+
rng: np.random.Generator | None = None,
132136
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
133137
"""Initialize a trace recorder for each chain."""
134138
if isinstance(backend, ZarrTrace):
@@ -140,6 +144,7 @@ def init_traces(
140144
model=model,
141145
vars=trace_vars,
142146
test_point=initial_point,
147+
rng=rng,
143148
)
144149
return None, backend.straces
145150
if HAS_MCB and isinstance(backend, Backend):
@@ -149,6 +154,7 @@ def init_traces(
149154
initial_point=initial_point,
150155
step=step,
151156
model=model,
157+
rng=rng,
152158
)
153159

154160
assert backend is None or isinstance(backend, BaseTrace)
@@ -161,7 +167,8 @@ def init_traces(
161167
model=model,
162168
trace_vars=trace_vars,
163169
initial_point=initial_point,
170+
rng=rng_,
164171
)
165-
for chain_number in range(chains)
172+
for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
166173
]
167174
return None, traces

pymc/backends/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from pymc.backends.report import SamplerReport
3636
from pymc.model import modelcontext
37-
from pymc.pytensorf import compile
37+
from pymc.pytensorf import compile, copy_function_with_new_rngs
3838
from pymc.util import get_var_name
3939

4040
logger = logging.getLogger(__name__)
@@ -159,6 +159,7 @@ def __init__(
159159
fn=None,
160160
var_shapes=None,
161161
var_dtypes=None,
162+
rng=None,
162163
):
163164
model = modelcontext(model)
164165

@@ -177,6 +178,8 @@ def __init__(
177178
on_unused_input="ignore",
178179
)
179180
fn.trust_input = True
181+
if rng is not None:
182+
fn = copy_function_with_new_rngs(fn=fn, rng=rng)
180183

181184
# Get variable shapes. Most backends will need this
182185
# information.

pymc/backends/mcbackend.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from pymc.backends.base import IBaseTrace
3131
from pymc.model import Model
32-
from pymc.pytensorf import PointFunc
32+
from pymc.pytensorf import PointFunc, copy_function_with_new_rngs
3333
from pymc.step_methods.compound import (
3434
BlockedStep,
3535
CompoundStep,
@@ -38,6 +38,7 @@
3838
flat_statname,
3939
flatten_steps,
4040
)
41+
from pymc.util import get_random_generator
4142

4243
_log = logging.getLogger(__name__)
4344

@@ -96,7 +97,11 @@ class ChainRecordAdapter(IBaseTrace):
9697
"""Wraps an McBackend ``Chain`` as an ``IBaseTrace``."""
9798

9899
def __init__(
99-
self, chain: mcb.Chain, point_fn: PointFunc, stats_bijection: StatsBijection
100+
self,
101+
chain: mcb.Chain,
102+
point_fn: PointFunc,
103+
stats_bijection: StatsBijection,
104+
rng: np.random.Generator | None = None,
100105
) -> None:
101106
# Assign attributes required by IBaseTrace
102107
self.chain = chain.cmeta.chain_number
@@ -107,8 +112,11 @@ def __init__(
107112
for sstats in stats_bijection._stat_groups
108113
]
109114

115+
self._rng = rng
110116
self._chain = chain
111117
self._point_fn = point_fn
118+
if rng is not None:
119+
self._point_fn = copy_function_with_new_rngs(self._point_fn, rng)
112120
self._statsbj = stats_bijection
113121
super().__init__()
114122

@@ -257,6 +265,7 @@ def init_chain_adapters(
257265
initial_point: Mapping[str, np.ndarray],
258266
step: CompoundStep | BlockedStep,
259267
model: Model,
268+
rng: np.random.Generator | None,
260269
) -> tuple[mcb.Run, list[ChainRecordAdapter]]:
261270
"""Create an McBackend metadata description for the MCMC run.
262271
@@ -286,7 +295,8 @@ def init_chain_adapters(
286295
chain=run.init_chain(chain_number=chain_number),
287296
point_fn=point_fn,
288297
stats_bijection=statsbj,
298+
rng=rng_,
289299
)
290-
for chain_number in range(chains)
300+
for chain_number, rng_ in enumerate(get_random_generator(rng).spawn(chains))
291301
]
292302
return run, adapters

pymc/backends/zarr.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@
3535
from pymc.backends.base import BaseTrace
3636
from pymc.blocking import StatDtype, StatShape
3737
from pymc.model.core import Model, modelcontext
38+
from pymc.pytensorf import copy_function_with_new_rngs
3839
from pymc.step_methods.compound import (
3940
BlockedStep,
4041
CompoundStep,
4142
StatsBijection,
4243
get_stats_dtypes_shapes_from_steps,
4344
)
44-
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
45+
from pymc.util import (
46+
UNSET,
47+
_UnsetType,
48+
get_default_varnames,
49+
get_random_generator,
50+
is_transformed_name,
51+
)
4552

4653
try:
4754
from zarr.storage import BaseStore, default_compressor
@@ -398,6 +405,7 @@ def init_trace(
398405
model: Model | None = None,
399406
vars: Sequence[TensorVariable] | None = None,
400407
test_point: dict[str, np.ndarray] | None = None,
408+
rng: np.random.Generator | None = None,
401409
):
402410
"""Initialize the trace groups and arrays.
403411
@@ -437,6 +445,12 @@ def init_trace(
437445
This is not used and is a product of the inheritance of :class:`ZarrChain`
438446
from :class:`~.BaseTrace`, which uses it to determine the shape and dtype
439447
of `vars`.
448+
rng : numpy.random.Generator | None
449+
A random generator to use to seed the shared random generators that are
450+
present in the pytensor function that maps samples drawn by step methods
451+
onto samples in the posterior trace. Note that this only does anything
452+
if there are deterministic variables that are generated by raw pytensor
453+
random variables.
440454
"""
441455
if self._is_base_setup:
442456
raise RuntimeError("The ZarrTrace has already been initialized") # pragma: no cover
@@ -534,9 +548,9 @@ def init_trace(
534548
test_point=test_point,
535549
stats_bijection=StatsBijection(step.stats_dtypes),
536550
draws_per_chunk=self.draws_per_chunk,
537-
fn=self.fn,
551+
fn=copy_function_with_new_rngs(self.fn, rng_),
538552
)
539-
for _ in range(chains)
553+
for rng_ in get_random_generator(rng).spawn(chains)
540554
]
541555
for chain, strace in enumerate(self.straces):
542556
strace.setup(draws=tune + draws, chain=chain, sampler_vars=None)

pymc/pytensorf.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
import warnings
1515

1616
from collections.abc import Callable, Generator, Iterable, Sequence
17-
from typing import cast
17+
from typing import cast, overload
1818

1919
import numpy as np
2020
import pandas as pd
2121
import pytensor
2222
import pytensor.tensor as pt
2323
import scipy.sparse as sps
2424

25+
from pytensor import shared
2526
from pytensor.compile import Function, Mode, get_mode
2627
from pytensor.compile.builders import OpFromGraph
2728
from pytensor.gradient import grad
@@ -42,7 +43,7 @@
4243
from pytensor.tensor.basic import _as_tensor_variable
4344
from pytensor.tensor.elemwise import Elemwise
4445
from pytensor.tensor.random.op import RandomVariable
45-
from pytensor.tensor.random.type import RandomType
46+
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
4647
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
4748
from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
4849
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -51,7 +52,7 @@
5152
from pytensor.tensor.variable import TensorVariable
5253

5354
from pymc.exceptions import NotConstantValueError
54-
from pymc.util import makeiter
55+
from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state
5556
from pymc.vartypes import continuous_types, isgenerator, typefilter
5657

5758
PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable
@@ -1163,3 +1164,55 @@ def normalize_rng_param(rng: None | Variable) -> Variable:
11631164
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
11641165
)
11651166
return rng
1167+
1168+
1169+
@overload
1170+
def copy_function_with_new_rngs(
1171+
fn: PointFunc, rng: np.random.Generator | RandomGeneratorState
1172+
) -> PointFunc: ...
1173+
1174+
1175+
@overload
1176+
def copy_function_with_new_rngs(
1177+
fn: Function, rng: np.random.Generator | RandomGeneratorState
1178+
) -> Function: ...
1179+
1180+
1181+
def copy_function_with_new_rngs(
1182+
fn: Function, rng: np.random.Generator | RandomGeneratorState
1183+
) -> Function:
1184+
"""Copy a compiled pytensor function and replace the random Generators with spawns.
1185+
1186+
Parameters
1187+
----------
1188+
fn : pytensor.compile.function.types.Function | pymc.util.PointFunc
1189+
The compiled function
1190+
rng : numpy.random.Generator | RandomGeneratorState
1191+
The random generator or its state
1192+
1193+
Returns
1194+
-------
1195+
fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc
1196+
A copy of the input function with the shared random generator states set to
1197+
spawns of the supplied ``rng``. If the function has no shared random generators
1198+
in it, the input ``fn`` is returned without any changes.
1199+
If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner
1200+
pytensor function has random variables, then the inner pytensor function is
1201+
copied, setting new random generators, and a new ``PointFunc`` instance is
1202+
returned.
1203+
"""
1204+
# Copy the function and replace any shared RNGs
1205+
# This is needed so that it can work correctly with multiple traces
1206+
# This will be costly if set_rng is called too often!
1207+
rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng)
1208+
fn_ = fn.f if isinstance(fn, PointFunc) else fn
1209+
shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)]
1210+
n_shared_rngs = len(shared_rngs)
1211+
swap = {
1212+
old_shared_rng: shared(rng, borrow=True)
1213+
for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True)
1214+
}
1215+
if isinstance(fn, PointFunc):
1216+
return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn
1217+
else:
1218+
return fn.copy(swap=swap) if n_shared_rngs > 0 else fn

pymc/sampling/mcmc.py

+1
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,7 @@ def joined_blas_limiter():
866866
initial_point=initial_points[0],
867867
model=model,
868868
tune=tune,
869+
rng=rngs[0].spawn(1)[0],
869870
)
870871

871872
sample_args = {

tests/sampling/test_mcmc.py

+29
Original file line numberDiff line numberDiff line change
@@ -909,3 +909,32 @@ def test_sample(self, seeded_test):
909909
np.testing.assert_allclose(
910910
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
911911
)
912+
913+
914+
@pytest.fixture(scope="function", params=[None, "mcbackend", "zarr"])
915+
def trace_backend(request):
916+
if request.param is None:
917+
return None
918+
elif request.param == "mcbackend":
919+
try:
920+
import mcbackend as mcb
921+
except ImportError:
922+
pytest.skip("Requires McBackend to be installed.")
923+
return mcb.NumPyBackend()
924+
elif request.param == "zarr":
925+
try:
926+
trace = pm.backends.zarr.ZarrTrace()
927+
except RuntimeError:
928+
pytest.skip("Requires zarr to be installed")
929+
return trace
930+
931+
932+
def test_random_deterministics(trace_backend):
933+
with pm.Model() as m:
934+
x = pm.Bernoulli("x", p=0.5) * 0 # Force it to be zero
935+
pm.Deterministic("y", x + pm.Normal.dist())
936+
937+
idata1 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
938+
idata2 = pm.sample(tune=0, draws=1, random_seed=1, trace=trace_backend)
939+
940+
assert idata1.posterior.equals(idata2.posterior)

0 commit comments

Comments
 (0)