objax.random package

Generator([seed])

Random number generator module.

normal(shape, *[, mean, stddev, generator])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

randint(shape, low, high[, generator])

Returns a JaxAarray of shape shape with random integers in {low, ..., high-1}.

truncated_normal(shape, *[, stddev, lower, ...])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

uniform(shape[, generator])

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

class objax.random.Generator(seed=0)[source]

Random number generator module.

The default generator can be accessed through objax.random.DEFAULT_GENERATOR. Its seed is 0 by default, and can be set through objax.random.DEFAULT_GENERATOR.seed(s) where integer s is the desired seed.

Parameters:

seed (int) –

__init__(seed=0)[source]

Create a random key generator, seed is the random generator initial seed.

Parameters:

seed (int) –

property key

The random generator state (a tensor of 2 int32).

seed(seed=0)[source]

Sets a new random generator seed.

Parameters:

seed (int) –

__call__()[source]

Generate a new generator state.

objax.random.normal(shape, *, mean=0, stddev=1, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.

Parameters:
  • shape (Tuple[int, ...]) –

  • mean (float) –

  • stddev (float) –

  • generator (Generator) –

objax.random.randint(shape, low, high, generator=objax.random.Generator(seed=0))[source]

Returns a JaxAarray of shape shape with random integers in {low, …, high-1}.

NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.

Parameters:
  • shape (Tuple[int, ...]) –

  • low (int) –

  • high (int) –

  • generator (Generator) –

objax.random.truncated_normal(shape, *, stddev=1, lower=-2, upper=2, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.

Parameters:
  • shape (Tuple[int, ...]) –

  • stddev (float) –

  • lower (float) –

  • upper (float) –

  • generator (Generator) –

objax.random.uniform(shape, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

NOTE: if random numbers are generated inside a jitted, parallelized or vectorized function then generator variables (including DEFAULT_GENERATOR) have to be added to the variable collection.

Parameters:
  • shape (Tuple[int, ...]) –

  • generator (Generator) –