objax package

Modules

Module()

A module is a container to associate variables and functions.

ModuleList([iterable])

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

ForceArgs(module, **kwargs)

Forces override of arguments of given module.

Function(f, vc)

Turn a function into a Module by keeping the vars it uses.

Grad(f, variables[, input_argnums])

The Grad module is used to compute the gradients of a function.

GradValues(f, variables[, input_argnums])

The GradValues module is used to compute the gradients of a function.

Jit(f[, vc, static_argnums])

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Parallel(f[, vc, reduce, axis_name, …])

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Vectorize(f[, vc, batch_axis])

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

class objax.Module[source]

A module is a container to associate variables and functions.

vars(scope='')[source]

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

class objax.ModuleList(iterable=(), /)[source]

Bases: objax.module.Module, list

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

Usage example:

import objax

ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))])
print(ml.vars())
# (ModuleList)[1]            20 (10, 2)
# +Total(1)                  20

ml.pop()
ml.append(objax.nn.Linear(2, 3))
print(ml.vars())
# (ModuleList)[1](Linear).b        3 (3,)
# (ModuleList)[1](Linear).w        6 (2, 3)
# +Total(2)                        9
vars(scope='')[source]

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Function(f, vc)[source]

Turn a function into a Module by keeping the vars it uses.

Usage example:

import objax
import jax.numpy as jn

m = objax.nn.Linear(2, 3)

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

# Method 1: Create module by calling objax.Function to tell which variables are used.
m1 = objax.Function(f1, m.vars())

# Method 2: Use the decorator.
@objax.Function.with_vars(m.vars())
def f2(x, y):
    return ((m(x) - y) ** 2).mean()

# All behave like functions
x = jn.arange(10).reshape((5, 2))
y = jn.arange(15).reshape((5, 3))
print(type(f1), f1(x, y))  # <class 'function'> 237.01947
print(type(m1), m1(x, y))  # <class 'objax.module.Function'> 237.01947
print(type(f2), f2(x, y))  # <class 'objax.module.Function'> 237.01947

Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.

__init__(f, vc)[source]

Function constructor.

Parameters
  • f (Callable) – the function or the module to represent.

  • vc (objax.variable.VarCollection) – the VarCollection of variables used by the function.

__call__(*args, **kwargs)[source]

Call the the function.

vars(scope='')[source]

Return the VarCollection of the variables used by the function.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.ForceArgs(module, **kwargs)[source]

Forces override of arguments of given module.

One example of ForceArgs usage is to override training argument for batch normalization:

import objax
from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)

# Modify model to force training=False on first resnet block.
# First two ops in the resnet are convolution and padding,
# resnet blocks are starting at index 2.
model[2] = objax.ForceArgs(model[2], training=False)

# model(x, training=True) will be using `training=False` on model[2] due to ForceArgs
# ...

# Undo specific value of forced arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=True)

# Undo all values of specific argument in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY)

# Undo all values of all arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model)
class ANY

Token used in ForceArgs.undo to indicate undo of all values of specific argument.

static undo(module, **kwargs)[source]

Undo ForceArgs on each submodule of the module. Modifications are done in-place.

Parameters
  • module (objax.module.Module) – module for which to undo ForceArgs.

  • **kwargs – dictionary of argument overrides to undo. name=val remove override for value val of argument name. name=ForceArgs.ANY remove all overrides of argument name. If **kwargs is empty then all overrides will be undone.

__init__(module, **kwargs)[source]

Initializes ForceArgs by wrapping another module.

Parameters
  • module (objax.module.Module) – module which argument will be overridden.

  • kwargs – values of keyword arguments which will be forced to use.

vars(scope='')[source]

Returns the VarCollection of the wrapped module.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables of wrapped module.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Calls wrapped module using forced args to override wrapped module arguments.

class objax.Grad(f, variables, input_argnums=None)[source]

The Grad module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])

def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_f = objax.Grad(f, m.vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f.

Returns

A list of input gradients, if any, followed by the variable gradients.

class objax.GradValues(f, variables, input_argnums=None)[source]

The GradValues module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 2)])

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_val_f = objax.GradValues(f, m.vars())

# Create module to compute gradients of f for only some variables
grad_val_f_head = objax.GradValues(f, m[:1].vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f and the values returned by f.

Returns

A tuple (gradients , values of f]), where gradients is a list containing

the input gradients, if any, followed by the variable gradients.

vars(scope='')[source]

Return the VarCollection of the variables used.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.Jit(f, vc=None, static_argnums=None)[source]

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
jit_m = objax.Jit(m)                          # Jit a module

# For jitting functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

jit_f = objax.Jit(f)

For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, static_argnums=None)[source]

Jit constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args, **kwargs)[source]

Call the compiled version of the function or module.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Parallel(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
para_m = objax.Parallel(m)                         # Parallelize a module

# For parallelizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

para_f = objax.Parallel(f)

When calling a parallelized module, one must replicate the variables it uses on all devices:

x = objax.random.normal((16, 2))
with m.vars().replicate():
    y = para_m(x)

For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • reduce (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.

  • axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args)[source]

Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Vectorize(f, vc=None, batch_axis=(0))[source]

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vec_m = objax.Vectorize(m)                         # Vectorize a module

# For vectorizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

vec_f = objax.Parallel(f)

For more information and examples, refer to Vectorization.

__init__(f, vc=None, batch_axis=(0))[source]

Vectorize constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args)[source]

Call the vectorized version of the function or module.

Variables

BaseVar(reduce)

The abstract base class to represent objax variables.

TrainVar(tensor[, reduce])

A trainable variable.

BaseState(reduce)

The abstract base class used to represent objax state variables.

StateVar(tensor[, reduce])

StateVar are variables that get updated manually, and are not automatically updated by optimizers.

TrainRef(ref)

A state variable that references a trainable variable for assignment.

RandomState(seed)

RandomState are variables that track the random generator state.

VarCollection

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

class objax.BaseVar(reduce)[source]

The abstract base class to represent objax variables.

__init__(reduce)[source]

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainVar(tensor, reduce=<function reduce_mean>)[source]

A trainable variable.

__init__(tensor, reduce=<function reduce_mean>)[source]

TrainVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.

assign(tensor, check=True)[source]

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.BaseState(reduce)[source]

The abstract base class used to represent objax state variables. State variables are not trainable.

__init__(reduce)

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainRef(ref)[source]

A state variable that references a trainable variable for assignment.

TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.

__init__(ref)[source]

TrainRef constructor.

Parameters

ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.

property value

The value stored in the referenced TrainVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.StateVar(tensor, reduce=<function reduce_mean>)[source]

StateVar are variables that get updated manually, and are not automatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.

__init__(tensor, reduce=<function reduce_mean>)[source]

StateVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value stored in the StateVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.RandomState(seed)[source]

RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.

__init__(seed)[source]

RandomState constructor.

Parameters

seed (int) – the initial seed of the random number generator.

seed(seed)[source]

Sets a new random seed.

Parameters

seed (int) – the new initial seed of the random number generator.

split(n)[source]

Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.

Parameters

n (int) – the number of seeds to generate.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property value

The value stored in the StateVar, it can be read or written.

class objax.VarCollection[source]

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vc = m.vars()  # This is a VarCollection

# It is a dictionary
print(repr(vc))
# {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>,
#  '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>}
print(vc.keys())  # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w'])
assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all()

# Convenience print
print(vc)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# Extra methods for manipulation of variables:
# For example, increment all variables by 1
vc.assign([x+1 for x in vc.tensors()])

# It's used by other modules.
# For example it's used to tell what variables are used by a function.

@objax.Function.with_vars(vc)
def my_function(x):
    return objax.functional.softmax(m(x))

For more information and examples, refer to VarCollection.

assign(tensors)[source]

Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.

Parameters

tensors (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.

rename(renamer)[source]

Rename the entries in the VarCollection.

Renaming entries in a VarCollection is a powerful tool that can be used for

  • mapping weights between models that differ slightly.

  • loading data checkpoints from foreign ML frameworks.

Usage example:

import re
import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
print(m.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# For example remove modules from the name
renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')])
print(m.vars().rename(renamer))
# [0].b                       3 (3,)
# [0].w                       6 (2, 3)
# +Total(2)                   9

# One can chain renamers, their syntax is flexible and it can use a string mapping:
renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer)
print(m.vars().rename(renamer_all))
# .0.b                        3 (3,)
# .0.w                        6 (2, 3)
# +Total(2)                   9
Parameters

renamer (objax.util.util.Renamer) –

replicate()[source]

A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.

subset(is_a=None, is_not=None)[source]

Return a new VarCollection that is a filtered subset of the current collection.

Parameters
  • is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

  • is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.

Returns

A new VarCollection containing the subset of variables.

Return type

objax.variable.VarCollection

tensors(is_a=None)[source]

Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.

Parameters

is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

Returns

A new VarCollection containing the subset of variables.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

update(other)[source]

Overload dict.update method to catch potential conflicts during assignment.

Parameters

other (Union[objax.variable.VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

Constants

class objax.ConvPadding(value)[source]

An Enum holding the possible padding values for convolution modules.

SAME = 'SAME'
VALID = 'VALID'