objax package

Modules

Module()

A module is a container to associate variables and functions.

ModuleList([iterable])

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

ForceArgs(module, **kwargs)

Forces override of arguments of given module.

Function(f, vc)

Turn a function into a Module by keeping the vars it uses.

Grad(f, variables[, input_argnums])

The Grad module is used to compute the gradients of a function.

GradValues(f, variables[, input_argnums])

The GradValues module is used to compute the gradients of a function.

Jit(f[, vc, static_argnums])

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Parallel(f[, vc, reduce, axis_name, …])

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Vectorize(f[, vc, batch_axis])

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

class objax.Module[source]

A module is a container to associate variables and functions.

vars(scope='')[source]

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

class objax.ModuleList(iterable=(), /)[source]

Bases: objax.module.Module, list

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

Usage example:

import objax

ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))])
print(ml.vars())
# (ModuleList)[1]            20 (10, 2)
# +Total(1)                  20

ml.pop()
ml.append(objax.nn.Linear(2, 3))
print(ml.vars())
# (ModuleList)[1](Linear).b        3 (3,)
# (ModuleList)[1](Linear).w        6 (2, 3)
# +Total(2)                        9
vars(scope='')[source]

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Function(f, vc)[source]

Turn a function into a Module by keeping the vars it uses.

Usage example:

import objax
import jax.numpy as jn

m = objax.nn.Linear(2, 3)

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

# Method 1: Create module by calling objax.Function to tell which variables are used.
m1 = objax.Function(f1, m.vars())

# Method 2: Use the decorator.
@objax.Function.with_vars(m.vars())
def f2(x, y):
    return ((m(x) - y) ** 2).mean()

# All behave like functions
x = jn.arange(10).reshape((5, 2))
y = jn.arange(15).reshape((5, 3))
print(type(f1), f1(x, y))  # <class 'function'> 237.01947
print(type(m1), m1(x, y))  # <class 'objax.module.Function'> 237.01947
print(type(f2), f2(x, y))  # <class 'objax.module.Function'> 237.01947

Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.

static with_vars(vc)[source]

Decorator which turns a function into a module using provided variable collection.

Parameters

vc (objax.variable.VarCollection) – the VarCollection of variables used by the function.

Usage example:

import objax

m = objax.nn.Linear(2, 3)

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

print(type(f))  # <class 'objax.module.Function'>
static auto_vars(f)[source]

Turns a function into a module by auto detecting used Objax variables. Could be used as a decorator.

WARNING: This is an experimental feature. It can detect variables used by function in many common cases, but not all cases. This feature may be removed in the future version of Objax if it appear to be too unreliable.

Parameters

f (Callable) – function which will be converted into a module.

Usage example:

import objax

m = objax.nn.Linear(2, 3)

@objax.Function.auto_vars
def f(x, y):
    return ((m(x) - y) ** 2).mean()

print(type(f))  # <class 'objax.module.Function'>
__init__(f, vc)[source]

Function constructor.

Parameters
  • f (Callable) – the function or the module to represent.

  • vc (objax.variable.VarCollection) – the VarCollection of variables used by the function.

__call__(*args, **kwargs)[source]

Call the the function.

vars(scope='')[source]

Return the VarCollection of the variables used by the function.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.ForceArgs(module, **kwargs)[source]

Forces override of arguments of given module.

One example of ForceArgs usage is to override training argument for batch normalization:

import objax
from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)

# Modify model to force training=False on first resnet block.
# First two ops in the resnet are convolution and padding,
# resnet blocks are starting at index 2.
model[2] = objax.ForceArgs(model[2], training=False)

# model(x, training=True) will be using `training=False` on model[2] due to ForceArgs
# ...

# Undo specific value of forced arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=True)

# Undo all values of specific argument in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY)

# Undo all values of all arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model)
class ANY

Token used in ForceArgs.undo to indicate undo of all values of specific argument.

static undo(module, **kwargs)[source]

Undo ForceArgs on each submodule of the module. Modifications are done in-place.

Parameters
  • module (objax.module.Module) – module for which to undo ForceArgs.

  • **kwargs – dictionary of argument overrides to undo. name=val remove override for value val of argument name. name=ForceArgs.ANY remove all overrides of argument name. If **kwargs is empty then all overrides will be undone.

__init__(module, **kwargs)[source]

Initializes ForceArgs by wrapping another module.

Parameters
  • module (objax.module.Module) – module which argument will be overridden.

  • kwargs – values of keyword arguments which will be forced to use.

vars(scope='')[source]

Returns the VarCollection of the wrapped module.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables of wrapped module.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Calls wrapped module using forced args to override wrapped module arguments.

class objax.Grad(f, variables, input_argnums=None)[source]

The Grad module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])

def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_f = objax.Grad(f, m.vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f.

Returns

A list of input gradients, if any, followed by the variable gradients.

class objax.GradValues(f, variables, input_argnums=None)[source]

The GradValues module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 2)])

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_val_f = objax.GradValues(f, m.vars())

# Create module to compute gradients of f for only some variables
grad_val_f_head = objax.GradValues(f, m[:1].vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f and the values returned by f.

Returns

A tuple (gradients , values of f]), where gradients is a list containing

the input gradients, if any, followed by the variable gradients.

vars(scope='')[source]

Return the VarCollection of the variables used.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.Jit(f, vc=None, static_argnums=None)[source]

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
jit_m = objax.Jit(m)                          # Jit a module

# For jitting functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

jit_f = objax.Jit(f)

For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, static_argnums=None)[source]

Jit constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args, **kwargs)[source]

Call the compiled version of the function or module.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Parallel(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
para_m = objax.Parallel(m)                         # Parallelize a module

# For parallelizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

para_f = objax.Parallel(f)

When calling a parallelized module, one must replicate the variables it uses on all devices:

x = objax.random.normal((16, 2))
with m.vars().replicate():
    y = para_m(x)

For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • reduce (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.

  • axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args)[source]

Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Vectorize(f, vc=None, batch_axis=(0))[source]

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vec_m = objax.Vectorize(m)                         # Vectorize a module

# For vectorizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

vec_f = objax.Parallel(f)

For more information and examples, refer to Vectorization.

__init__(f, vc=None, batch_axis=(0))[source]

Vectorize constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args)[source]

Call the vectorized version of the function or module.

Variables

BaseVar(reduce)

The abstract base class to represent objax variables.

TrainVar(tensor[, reduce])

A trainable variable.

BaseState(reduce)

The abstract base class used to represent objax state variables.

StateVar(tensor[, reduce])

StateVar are variables that get updated manually, and are not automatically updated by optimizers.

TrainRef(ref)

A state variable that references a trainable variable for assignment.

RandomState(seed)

RandomState are variables that track the random generator state.

VarCollection

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

class objax.BaseVar(reduce)[source]

The abstract base class to represent objax variables.

__init__(reduce)[source]

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainVar(tensor, reduce=<function reduce_mean>)[source]

A trainable variable.

__init__(tensor, reduce=<function reduce_mean>)[source]

TrainVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.

assign(tensor, check=True)[source]

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property dtype

Variable data type.

property ndim

Number of dimentions.

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property shape

Variable shape.

class objax.BaseState(reduce)[source]

The abstract base class used to represent objax state variables. State variables are not trainable.

__init__(reduce)

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainRef(ref)[source]

A state variable that references a trainable variable for assignment.

TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.

__init__(ref)[source]

TrainRef constructor.

Parameters

ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.

property value

The value stored in the referenced TrainVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property dtype

Variable data type.

property ndim

Number of dimentions.

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property shape

Variable shape.

class objax.StateVar(tensor, reduce=<function reduce_mean>)[source]

StateVar are variables that get updated manually, and are not automatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.

__init__(tensor, reduce=<function reduce_mean>)[source]

StateVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value stored in the StateVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property dtype

Variable data type.

property ndim

Number of dimentions.

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property shape

Variable shape.

class objax.RandomState(seed)[source]

RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.

__init__(seed)[source]

RandomState constructor.

Parameters

seed (int) – the initial seed of the random number generator.

seed(seed)[source]

Sets a new random seed.

Parameters

seed (int) – the new initial seed of the random number generator.

split(n)[source]

Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.

Parameters

n (int) – the number of seeds to generate.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property dtype

Variable data type.

property ndim

Number of dimentions.

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property shape

Variable shape.

property value

The value stored in the StateVar, it can be read or written.

class objax.VarCollection[source]

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vc = m.vars()  # This is a VarCollection

# It is a dictionary
print(repr(vc))
# {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>,
#  '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>}
print(vc.keys())  # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w'])
assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all()

# Convenience print
print(vc)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# Extra methods for manipulation of variables:
# For example, increment all variables by 1
vc.assign([x+1 for x in vc.tensors()])

# It's used by other modules.
# For example it's used to tell what variables are used by a function.

@objax.Function.with_vars(vc)
def my_function(x):
    return objax.functional.softmax(m(x))

For more information and examples, refer to VarCollection.

assign(tensors)[source]

Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.

Parameters

tensors (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.

rename(renamer)[source]

Rename the entries in the VarCollection.

Renaming entries in a VarCollection is a powerful tool that can be used for

  • mapping weights between models that differ slightly.

  • loading data checkpoints from foreign ML frameworks.

Usage example:

import re
import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
print(m.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# For example remove modules from the name
renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')])
print(m.vars().rename(renamer))
# [0].b                       3 (3,)
# [0].w                       6 (2, 3)
# +Total(2)                   9

# One can chain renamers, their syntax is flexible and it can use a string mapping:
renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer)
print(m.vars().rename(renamer_all))
# .0.b                        3 (3,)
# .0.w                        6 (2, 3)
# +Total(2)                   9
Parameters

renamer (objax.util.util.Renamer) –

replicate()[source]

A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.

subset(is_a=None, is_not=None)[source]

Return a new VarCollection that is a filtered subset of the current collection.

Parameters
  • is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

  • is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.

Returns

A new VarCollection containing the subset of variables.

Return type

objax.variable.VarCollection

tensors(is_a=None)[source]

Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.

Parameters

is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

Returns

A new VarCollection containing the subset of variables.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

update(other)[source]

Overload dict.update method to catch potential conflicts during assignment.

Parameters

other (Union[objax.variable.VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

Constants

class objax.ConvPadding(value)[source]

An Enum holding the possible padding values for convolution modules.

SAME = 'SAME'
VALID = 'VALID'