objax package

Modules

Module()

A module is a container to associate variables and functions.

ModuleList([iterable])

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

GradValues(f, variables[, input_argnums])

The GradValues module is used to compute the gradients of a function.

Jit(f[, vc, static_argnums])

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Parallel(f[, vc, reduce, axis_name, …])

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Vectorize(f[, vc, batch_axis])

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

class objax.Module[source]

A module is a container to associate variables and functions.

vars(scope='')[source]

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

class objax.ModuleList(iterable=(), /)[source]

Bases: objax.module.Module, list

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

Usage example:

import objax

ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))])
print(ml.vars())
# (ModuleList)[1]            20 (10, 2)
# +Total(1)                  20

ml.pop()
ml.append(objax.nn.Linear(2, 3))
print(ml.vars())
# (ModuleList)[1](Linear).b        3 (3,)
# (ModuleList)[1](Linear).w        6 (2, 3)
# +Total(2)                        9
vars(scope='')[source]

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.GradValues(f, variables, input_argnums=None)[source]

The GradValues module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])

def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_val_f = objax.GradValues(f, m.vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args)[source]

Returns the computed gradients for the first value returned by f and the values returned by f.

Returns

A tuple (gradients , values of f]), where gradients is a list containing

the input gradients, if any, followed by the variable gradients.

class objax.Jit(f, vc=None, static_argnums=None)[source]

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
jit_m = objax.Jit(m)                          # Jit a module
jit_f = objax.Jit(lambda x: m(x), m.vars())   # Jit a function: provide vars it uses

For more information, refer to JIT Compilation.

__init__(f, vc=None, static_argnums=None)[source]

Jit constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is equired for functions.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args)[source]

Call the compiled version of the function or module.

vars(scope='')

Collect all the variables (and their names) contained in the VarCollection.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Parallel(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
para_m = objax.Parallel(m)                         # Parallelize a module
para_f = objax.Parallel(lambda x: m(x), m.vars())  # Parallelize a function: provide vars it uses

When calling a parallelized module, one must replicate the variables it uses on all devices:

x = objax.random.normal((16, 2))
with m.vars().replicate():
    y = para_m(x)

For more information, refer to Parallelism.

__init__(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • reduce (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.

  • axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args)[source]

Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.

vars(scope='')

Collect all the variables (and their names) contained in the VarCollection.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Vectorize(f, vc=None, batch_axis=0)[source]

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vec_m = objax.Vectorize(m)                         # Vectorize a module
vec_f = objax.Vectorize(lambda x: m(x), m.vars())  # Vectorize a function: provide vars it uses

For more information and examples, refer to Vectorization.

__init__(f, vc=None, batch_axis=0)[source]

Vectorize constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.

__call__(*args)[source]

Call the vectorized version of the function or module.

vars(scope='')

Collect all the variables (and their names) contained in the VarCollection.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

Variables

BaseVar(reduce)

The abstract base class to represent objax variables.

TrainVar(tensor[, reduce])

A trainable variable.

BaseState(reduce)

The abstract base class used to represent objax state variables.

StateVar(tensor[, reduce])

StateVar are variables that get updated manually, and are not autmatically updated by optimizers.

TrainRef(ref)

A state variable that references a trainable variable for assignment.

RandomState(seed)

RandomState are variables that track the random generator state.

VarCollection

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

class objax.BaseVar(reduce)[source]

The abstract base class to represent objax variables.

__init__(reduce)[source]

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainVar(tensor, reduce=<function reduce_mean>)[source]

A trainable variable.

__init__(tensor, reduce=<function reduce_mean>)[source]

TrainVar constructor.

Parameters
  • tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.

  • reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.

assign(tensor)[source]

Sets the value of the variable.

Parameters

tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.BaseState(reduce)[source]

The abstract base class used to represent objax state variables. State variables are not trainable.

__init__(reduce)

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainRef(ref)[source]

A state variable that references a trainable variable for assignment.

TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.

__init__(ref)[source]

TrainRef constructor.

Parameters

ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.

property value

The value stored in the referenced TrainVar, it can be read or written.

assign(tensor)

Sets the value of the variable.

Parameters

tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.StateVar(tensor, reduce=<function reduce_mean>)[source]

StateVar are variables that get updated manually, and are not autmatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.

__init__(tensor, reduce=<function reduce_mean>)[source]

StateVar constructor.

Parameters
  • tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.

  • reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value stored in the StateVar, it can be read or written.

assign(tensor)

Sets the value of the variable.

Parameters

tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.RandomState(seed)[source]

RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.

__init__(seed)[source]

RandomState constructor.

Parameters

seed (int) – the initial seed of the random number generator.

seed(seed)[source]

Sets a new random seed.

Parameters

seed (int) – the new initial seed of the random number generator.

split(n)[source]

Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.

Parameters

n (int) – the number of seeds to generate.

Return type

List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]

assign(tensor)

Sets the value of the variable.

Parameters

tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –

property value

The value stored in the StateVar, it can be read or written.

class objax.VarCollection[source]

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vc = m.vars()  # This is a VarCollection

# It is a dictionary
print(repr(vc))
# {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>,
#  '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>}
print(vc.keys())  # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w'])
assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all()

# Convenience print
print(vc)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# Extra methods for manipulation of variables:
# For example, increment all variables by 1
vc.assign([x+1 for x in vc.tensors()])

# It's used by other modules.
# For example it's used to tell Jit what variables are used by a function.
jit_f = objax.Jit(lambda x: m(x), vc)

For more information and examples, refer to VarCollection.

update(other)[source]

Overload dict.update method to catch potential conflicts during assignment.

Parameters

other (Union[VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –

assign(tensors)[source]

Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.

Parameters

tensors (List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.

replicate()[source]

A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.

subset(is_a=None, is_not=None)[source]

Return a new VarCollection that is a filtered subset of the current collection.

Parameters
  • is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

  • is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.

Returns

A new VarCollection containing the subset of variables.

Return type

objax.variable.VarCollection

tensors(is_a=None)[source]

Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.

Parameters

is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

Returns

A new VarCollection containing the subset of variables.

Return type

List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]

Constants

class objax.ConvPadding(value)[source]

An Enum holding the possible padding values for convolution modules.

SAME = 'SAME'
VALID = 'VALID'