|
14 | 14 | import warnings
|
15 | 15 |
|
16 | 16 | from collections.abc import Callable, Generator, Iterable, Sequence
|
17 |
| -from typing import cast |
| 17 | +from typing import cast, overload |
18 | 18 |
|
19 | 19 | import numpy as np
|
20 | 20 | import pandas as pd
|
21 | 21 | import pytensor
|
22 | 22 | import pytensor.tensor as pt
|
23 | 23 | import scipy.sparse as sps
|
24 | 24 |
|
| 25 | +from pytensor import shared |
25 | 26 | from pytensor.compile import Function, Mode, get_mode
|
26 | 27 | from pytensor.compile.builders import OpFromGraph
|
27 | 28 | from pytensor.gradient import grad
|
|
42 | 43 | from pytensor.tensor.basic import _as_tensor_variable
|
43 | 44 | from pytensor.tensor.elemwise import Elemwise
|
44 | 45 | from pytensor.tensor.random.op import RandomVariable
|
45 |
| -from pytensor.tensor.random.type import RandomType |
| 46 | +from pytensor.tensor.random.type import RandomGeneratorType, RandomType |
46 | 47 | from pytensor.tensor.random.var import RandomGeneratorSharedVariable
|
47 | 48 | from pytensor.tensor.rewriting.basic import topo_unconditional_constant_folding
|
48 | 49 | from pytensor.tensor.rewriting.shape import ShapeFeature
|
|
51 | 52 | from pytensor.tensor.variable import TensorVariable
|
52 | 53 |
|
53 | 54 | from pymc.exceptions import NotConstantValueError
|
54 |
| -from pymc.util import makeiter |
| 55 | +from pymc.util import RandomGeneratorState, makeiter, random_generator_from_state |
55 | 56 | from pymc.vartypes import continuous_types, isgenerator, typefilter
|
56 | 57 |
|
57 | 58 | PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable
|
@@ -1163,3 +1164,55 @@ def normalize_rng_param(rng: None | Variable) -> Variable:
|
1163 | 1164 | "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
|
1164 | 1165 | )
|
1165 | 1166 | return rng
|
| 1167 | + |
| 1168 | + |
| 1169 | +@overload |
| 1170 | +def copy_function_with_new_rngs( |
| 1171 | + fn: PointFunc, rng: np.random.Generator | RandomGeneratorState |
| 1172 | +) -> PointFunc: ... |
| 1173 | + |
| 1174 | + |
| 1175 | +@overload |
| 1176 | +def copy_function_with_new_rngs( |
| 1177 | + fn: Function, rng: np.random.Generator | RandomGeneratorState |
| 1178 | +) -> Function: ... |
| 1179 | + |
| 1180 | + |
| 1181 | +def copy_function_with_new_rngs( |
| 1182 | + fn: Function, rng: np.random.Generator | RandomGeneratorState |
| 1183 | +) -> Function: |
| 1184 | + """Copy a compiled pytensor function and replace the random Generators with spawns. |
| 1185 | +
|
| 1186 | + Parameters |
| 1187 | + ---------- |
| 1188 | + fn : pytensor.compile.function.types.Function | pymc.util.PointFunc |
| 1189 | + The compiled function |
| 1190 | + rng : numpy.random.Generator | RandomGeneratorState |
| 1191 | + The random generator or its state |
| 1192 | +
|
| 1193 | + Returns |
| 1194 | + ------- |
| 1195 | + fn_out : pytensor.compile.function.types.Function | pymc.pytensorf.PointFunc |
| 1196 | + A copy of the input function with the shared random generator states set to |
| 1197 | + spawns of the supplied ``rng``. If the function has no shared random generators |
| 1198 | + in it, the input ``fn`` is returned without any changes. |
| 1199 | + If ``fn`` is a :clas:`~pymc.pytensorf.PointFunc` instance, and the inner |
| 1200 | + pytensor function has random variables, then the inner pytensor function is |
| 1201 | + copied, setting new random generators, and a new ``PointFunc`` instance is |
| 1202 | + returned. |
| 1203 | + """ |
| 1204 | + # Copy the function and replace any shared RNGs |
| 1205 | + # This is needed so that it can work correctly with multiple traces |
| 1206 | + # This will be costly if set_rng is called too often! |
| 1207 | + rng_gen = rng if isinstance(rng, np.random.Generator) else random_generator_from_state(rng) |
| 1208 | + fn_ = fn.f if isinstance(fn, PointFunc) else fn |
| 1209 | + shared_rngs = [var for var in fn_.get_shared() if isinstance(var.type, RandomGeneratorType)] |
| 1210 | + n_shared_rngs = len(shared_rngs) |
| 1211 | + swap = { |
| 1212 | + old_shared_rng: shared(rng, borrow=True) |
| 1213 | + for old_shared_rng, rng in zip(shared_rngs, rng_gen.spawn(n_shared_rngs), strict=True) |
| 1214 | + } |
| 1215 | + if isinstance(fn, PointFunc): |
| 1216 | + return PointFunc(fn.f.copy(swap=swap)) if n_shared_rngs > 0 else fn |
| 1217 | + else: |
| 1218 | + return fn.copy(swap=swap) if n_shared_rngs > 0 else fn |
0 commit comments