Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions pytensor/compile/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copyreg
import time
import warnings
from functools import singledispatch
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -37,6 +38,18 @@ class AliasedMemoryError(Exception):
DUPLICATE = object()


@singledispatch
def reseeded_rng_value(current, generator: np.random.Generator):
"""Return ``generator`` in the storage representation of ``current``.

Used by :meth:`Function.reseed_rngs`. Most backends store a NumPy ``Generator``
directly, so the default returns it unchanged. Backends that store RNGs in another
representation register a conversion keyed on that representation (e.g. the JAX backend
stores a state ``dict``).
"""
return generator


class Function:
r"""A class that wraps the execution of a `VM` making it easier for use as a "function".

Expand Down Expand Up @@ -810,6 +823,31 @@ def get_shared(self):
"""
return [i.variable for i in self.maker.inputs if i.implicit]

def reseed_rngs(self, seed=None) -> None:
"""Reseed the random generators used by this function.

Each random input is set to a fresh stream spawned from ``seed`` (an ``int``,
sequence of ints, or ``SeedSequence``; ``None`` draws fresh entropy). This works
for every backend, including JAX, whose compiled functions copy their RNGs at
compile time and so cannot be reseeded through the original shared variables.
"""
from pytensor.tensor.random.type import RandomType

rng_containers = [
container
for inp, container in zip(
self.maker.expanded_inputs, self.input_storage, strict=True
)
if isinstance(inp.variable.type, RandomType)
]
if not rng_containers:
return

seed_seqs = np.random.SeedSequence(seed).spawn(len(rng_containers))
for container, seed_seq in zip(rng_containers, seed_seqs, strict=True):
generator = np.random.Generator(np.random.PCG64(seed_seq))
container.storage[0] = reseeded_rng_value(container.storage[0], generator)

def dprint(self, **kwargs):
"""Debug print itself

Expand Down
7 changes: 7 additions & 0 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)

import pytensor.tensor.random.basic as ptr
from pytensor.compile.executor import reseeded_rng_value
from pytensor.graph import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
Expand Down Expand Up @@ -80,6 +81,12 @@ def jax_typify_Generator(rng, **kwargs):
return state


@reseeded_rng_value.register(dict)
def reseeded_rng_value_jax(current, generator):
# JAX stores RNGs as the typified state dict produced by `jax_typify_Generator`.
return jax_typify_Generator(generator)


@jax_funcify.register(ptr.RandomVariable)
def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
"""JAX implementation of random variables."""
Expand Down
17 changes: 17 additions & 0 deletions tests/compile/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,20 @@ def test_pickle_class_with_functions(self):

blah.f2(5, 1)
assert blah.f1._finder[blah.s].value != blah2.f1._finder[blah2.s].value


def test_reseed_rngs():
rng = shared(np.random.default_rng(0))
rv = pt.random.normal(0, 1, size=3, rng=rng)
f = function([], rv, updates={rng: rv.owner.outputs[0]})

f.reseed_rngs(123)
draw = f()
f.reseed_rngs(123)
np.testing.assert_array_equal(draw, f()) # same seed -> same draw
f.reseed_rngs(456)
assert not np.array_equal(draw, f()) # different seed -> different draw

# A function without random inputs is a no-op.
x = scalar("x")
function([x], x * 2).reseed_rngs(0)
15 changes: 15 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,18 @@ def test_constant_shape_after_graph_rewriting(self):
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)


def test_reseed_rngs():
# JAX copies RNG shared variables at compile time, so the originals can't be reseeded
# via set_value; Function.reseed_rngs reseeds the function's own (typified) RNG storage.
rng = shared(np.random.default_rng(0))
rv = pt.random.normal(0, 1, size=3, rng=rng)
f = function([], rv, updates={rng: rv.owner.outputs[0]}, mode=jax_mode)

f.reseed_rngs(123)
draw = np.asarray(f())
f.reseed_rngs(123)
np.testing.assert_array_equal(draw, np.asarray(f())) # same seed -> same draw
f.reseed_rngs(456)
assert not np.array_equal(draw, np.asarray(f())) # different seed -> different draw
Loading