objax.random package
|
Random number generator module. |
|
Returns a |
|
Returns a |
|
Returns a |
|
Returns a |
- class objax.random.Generator(seed=0)[source]
Random number generator module.
The default generator can be accessed through
objax.random.DEFAULT_GENERATOR
. Its seed is 0 by default, and can be set throughobjax.random.DEFAULT_GENERATOR.seed(s)
where integer s is the desired seed.- Parameters:
seed (int) –
- __init__(seed=0)[source]
Create a random key generator, seed is the random generator initial seed.
- Parameters:
seed (int) –
- property key
The random generator state (a tensor of 2 int32).
- objax.random.normal(shape, *, mean=0, stddev=1, generator=objax.random.Generator(seed=0))[source]
Returns a
JaxArray
of shapeshape
with random numbers from a normal distribution with meanmean
and standard deviationstddev
.NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.
- Parameters:
shape (Tuple[int, ...]) –
mean (float) –
stddev (float) –
generator (Generator) –
- objax.random.randint(shape, low, high, generator=objax.random.Generator(seed=0))[source]
Returns a
JaxAarray
of shapeshape
with random integers in {low, …, high-1}.NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.
- Parameters:
shape (Tuple[int, ...]) –
low (int) –
high (int) –
generator (Generator) –
- objax.random.truncated_normal(shape, *, stddev=1, lower=-2, upper=2, generator=objax.random.Generator(seed=0))[source]
Returns a
JaxArray
of shapeshape
with random numbers from a normal distribution with mean 0 and standard deviationstddev
truncated by (lower
,upper
).NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.
- Parameters:
shape (Tuple[int, ...]) –
stddev (float) –
lower (float) –
upper (float) –
generator (Generator) –
- objax.random.uniform(shape, generator=objax.random.Generator(seed=0))[source]
Returns a
JaxArray
of shapeshape
with random numbers from a uniform distribution [0, 1].NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.
- Parameters:
shape (Tuple[int, ...]) –
generator (Generator) –