objax.random package

Generator([seed])

Random number generator module.

normal(shape, *[, mean, stddev, generator])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

randint(shape, low, high[, generator])

Returns a JaxAarray of shape shape with random integers in {low, …, high-1}.

truncated_normal(shape, *[, stddev, lower, …])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

uniform(shape[, generator])

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

class objax.random.Generator(seed=0)[source]

Random number generator module.

The default generator can be accessed through objax.random.DEFAULT_GENERATOR. Its seed is 0 by default, and can be set through objax.random.DEFAULT_GENERATOR.seed(s) where integer s is the desired seed.

__init__(seed=0)[source]

Create a random key generator, seed is the random generator initial seed.

Parameters

seed (int) –

property key

The random generator state (a tensor of 2 int32).

seed(seed=0)[source]

Sets a new random generator seed.

Parameters

seed (int) –

__call__()[source]

Generate a new generator state.

objax.random.normal(shape, *, mean=0, stddev=1, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

Parameters
  • shape (Tuple[int, ..]) –

  • mean (float) –

  • stddev (float) –

  • generator (objax.random.random.Generator) –

objax.random.randint(shape, low, high, generator=objax.random.Generator(seed=0))[source]

Returns a JaxAarray of shape shape with random integers in {low, …, high-1}.

Parameters
  • shape (Tuple[int, ..]) –

  • low (int) –

  • high (int) –

  • generator (objax.random.random.Generator) –

objax.random.truncated_normal(shape, *, stddev=1, lower=- 2, upper=2, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

Parameters
  • shape (Tuple[int, ..]) –

  • stddev (float) –

  • lower (float) –

  • upper (float) –

  • generator (objax.random.random.Generator) –

objax.random.uniform(shape, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

Parameters
  • shape (Tuple[int, ..]) –

  • generator (objax.random.random.Generator) –