Skip to content

Commit

Permalink
locking and split() method
Browse files Browse the repository at this point in the history
  • Loading branch information
ntessore committed Dec 29, 2024
1 parent 68e968d commit 2223745
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# `rng-jax`JAX random number generation as a NumPy generator
# `rng-jax`NumPy random number generator API for JAX

**This is a proof of concept only.**

Expand All @@ -9,7 +9,6 @@ Wraps JAX's stateless random number generation in a class implementing the

```py
>>> import rng_jax
>>>
>>> rng = rng_jax.Generator(42) # same arguments as jax.random.key()
>>> rng.standard_normal(3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
Expand Down Expand Up @@ -38,23 +37,23 @@ package is to work in tandem with the Array API: array-agnostic code is not
usually compiled at low level. Conversely, native JAX code usually expects a
`key`, anyway, not a `rng_jax.Generator` instance.

To interface with a native JAX function expecting a `key`, use the `.key()`
To interface with a native JAX function expecting a `key`, use the `.split()`
method to obtain a new random key and advance the internal state of the
generator:

```py
>>> import jax
>>> rng = rng_jax.Generator(42)
>>> key = rng.key()
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
>>> key = rng.key()
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
```

The right way to compile array-agnostic code is usually to compile the "main"
function at the highest level of the code. Using the `rng_jax.Generator` class
fully _within_ a compiled function works without issue.
Using the `rng_jax.Generator` class fully _within_ a compiled JAX function
works without issue.

[array-api]: https://data-apis.org/array-api/latest/
[generator]: https://numpy.org/doc/stable/reference/random/generator.html
30 changes: 18 additions & 12 deletions rng_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import math
from threading import Lock
from typing import Literal, Self, TypeAlias

from jax import Array
Expand Down Expand Up @@ -52,8 +53,9 @@ class Generator:
Wrapper class for JAX random number generation.
"""

__slots__ = ("_key",)
_key: Array
__slots__ = ("key", "lock")
key: Array
lock: Lock

@classmethod
def from_key(cls, key: Array) -> Self:
Expand All @@ -63,38 +65,43 @@ def from_key(cls, key: Array) -> Self:
if not isinstance(key, Array) or not issubdtype(key.dtype, prng_key):
raise ValueError("not a random key")
rng = object.__new__(cls)
rng._key = key
rng.key = key
rng.lock = Lock()
return rng

def __init__(self, seed: int | ArrayLike, *, impl: str | None = None) -> None:
"""
Create a wrapper instance with a new key.
"""
self._key = key(seed, impl=impl)
self.key = key(seed, impl=impl)
self.lock = Lock()

@property
def __key(self) -> Array:
"""
Return next key for sampling while updating internal state.
"""
self._key, key = split(self._key)
with self.lock:
self.key, key = split(self.key)
return key

def key(self, size: Size = None) -> Array:
def split(self, size: Size = None) -> Array:
"""
Return random key, advancing internal state.
Split random key.
"""
shape = _s(size)
keys = split(self._key, 1 + math.prod(shape))
self._key = keys[0]
with self.lock:
keys = split(self.key, 1 + math.prod(shape))
self.key = keys[0]
return keys[1:].reshape(shape)

def spawn(self, n_children: int) -> list[Self]:
"""
Create new independent child generators.
"""
self._key, *subkeys = split(self._key, num=n_children + 1)
return list(map(self.from_key, subkeys))
with self.lock:
self.key, *keys = split(self.key, num=n_children + 1)
return list(map(self.from_key, keys))

def integers(
self,
Expand All @@ -119,7 +126,6 @@ def random(self, size: Size = None, dtype: DTypeLike = float) -> Array:
"""
Return random floats in the half-open interval [0.0, 1.0).
"""
self._key, key = split(self._key)
return uniform(self.__key, _s(size), dtype)

def choice(
Expand Down

0 comments on commit 2223745

Please sign in to comment.