Skip to content

Commit

Permalink
pytree interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Dec 16, 2024
1 parent e3477f9 commit 283d2c6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 8 deletions.
90 changes: 82 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,84 @@
# NumPy random number generator API for JAX

Small wrapper to use JAX random number generators via the NumPy random API.

>>> from jrng import JRNG
>>> rng = JRNG(42)
>>> rng.standard_normal(3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> rng.standard_normal(3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
**This is a proof of concept only.**

Wraps stateless JAX random number generation in the
[`numpy.random.Generator`](generator) API.

```py
from jrng import JRNG

rng = JRNG(42)

rng.standard_normal(3)
# Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
rng.standard_normal(3)
# Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
```

The goal of this experiment is to investigate ways in which there can be a
random number generation API that works in tandem with the Python Array API.

## How it works

The `JRNG` class works in the obvious way: it keeps track of the JAX `key` and
calls `jax.random.split()` before every random operation.

## JIT

The problem with a stateful RNG is that it cannot easily be passed into a
compiled function. However, the `JRNG` class is only "stateful" in that it
keeps track of the current `key`.

When a `JRNG` pytree is flattened, the resulting child node contains an
_independent_ random key, while the internal state of the existing `JRNG` is
advanced at the same time. This allows passing `JRNG` instances into compiled
functions and still obtaining independent random outputs:

```py
import jax
from jrng import JRNG

def print_key(key): print(jax.random.key_data(key))

@jax.jit
def f(x, rng):
return x + rng.standard_normal(x.shape)

x = jax.numpy.array([1, 2, 3])
rng = JRNG(42)

print_key(rng.key) # [ 0 42]

print(f(x, rng)) # [0.047065 1.6797752 3.9650078]

print_key(rng.key) # [4249898905 2425127087]

print(f(x, rng)) # [0.60631436 1.0040649 2.4605024 ]

print_key(rng.key) # [ 499334550 3925197703]
```

However, this mechanism means flattening the `JRNG` pytree changes internal
state (due to the internal details of JAX, it actually advances the random
number generator multiple times).

```py
# same initial state as above
key = jax.random.key(42)
print_key(key) # [ 0 42]

# pytree is flattened 4 times per invocation
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
key, _ = jax.random.split(key)
print_key(key) # [4249898905 2425127087]
```

While this is not an ideal solution, it may be an acceptable one: the goal of
this API is to work in tandem with the Array API. Array-agnostic code is not
usually compiled at low level. Using the `JRNG` class _inside_ a compiled
function works without issue.

[generator]: https://numpy.org/doc/stable/reference/random/generator.html
16 changes: 16 additions & 0 deletions jrng.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
split,
uniform,
)
from jax.tree_util import register_pytree_node_class
from jax.typing import ArrayLike, DTypeLike


Expand All @@ -46,6 +47,7 @@ def _s(size: Size, *bcast: ArrayLike) -> tuple[int, ...]:
return size


@register_pytree_node_class
class JRNG:
"""
Wrapper class for JAX random number generation.
Expand Down Expand Up @@ -76,6 +78,20 @@ def __key(self) -> Array:
self.key, key = split(self.key)
return key

def tree_flatten(self) -> tuple[tuple[Array], None]:
"""
Return pytree representation of JRNG instance.
"""
return (self.__key,), None

@classmethod
def tree_unflatten(cls, aux_data: None, children: tuple[Array]) -> Self:
"""
Construct JRNG instance from pytree representation.
"""
(key,) = children
return cls.from_key(key)

def spawn(self, n: int) -> list[Self]:
"""
Create new independent child generators.
Expand Down

0 comments on commit 283d2c6

Please sign in to comment.