objax package¶
Modules¶
|
A module is a container to associate variables and functions. |
|
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it. |
|
Forces override of arguments of given module. |
|
Turn a function into a Module by keeping the vars it uses. |
|
The Grad module is used to compute the gradients of a function. |
|
The GradValues module is used to compute the gradients of a function. |
|
JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution. |
|
Parallel module takes a function or a module and compiles it for running on multiple devices in parallel. |
|
Vectorize module takes a function or a module and compiles it for running in parallel on a single device. |
-
class
objax.
Module
[source]¶ A module is a container to associate variables and functions.
-
vars
(scope='')[source]¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
ModuleList
(iterable=(), /)[source]¶ Bases:
objax.module.Module
,list
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.
See also
Usage example:
import objax ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))]) print(ml.vars()) # (ModuleList)[1] 20 (10, 2) # +Total(1) 20 ml.pop() ml.append(objax.nn.Linear(2, 3)) print(ml.vars()) # (ModuleList)[1](Linear).b 3 (3,) # (ModuleList)[1](Linear).w 6 (2, 3) # +Total(2) 9
-
class
objax.
Function
(f, vc)[source]¶ Turn a function into a Module by keeping the vars it uses.
Usage example:
import objax import jax.numpy as jn m = objax.nn.Linear(2, 3) def f1(x, y): return ((m(x) - y) ** 2).mean() # Method 1: Create module by calling objax.Function to tell which variables are used. m1 = objax.Function(f1, m.vars()) # Method 2: Use the decorator. @objax.Function.with_vars(m.vars()) def f2(x, y): return ((m(x) - y) ** 2).mean() # All behave like functions x = jn.arange(10).reshape((5, 2)) y = jn.arange(15).reshape((5, 3)) print(type(f1), f1(x, y)) # <class 'function'> 237.01947 print(type(m1), m1(x, y)) # <class 'objax.module.Function'> 237.01947 print(type(f2), f2(x, y)) # <class 'objax.module.Function'> 237.01947
Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.
-
class
objax.
ForceArgs
(module, **kwargs)[source]¶ Forces override of arguments of given module.
One example of ForceArgs usage is to override training argument for batch normalization:
import objax from objax.zoo.resnet_v2 import ResNet50 model = ResNet50(in_channels=3, num_classes=1000) # Modify model to force training=False on first resnet block. # First two ops in the resnet are convolution and padding, # resnet blocks are starting at index 2. model[2] = objax.ForceArgs(model[2], training=False) # model(x, training=True) will be using `training=False` on model[2] due to ForceArgs # ... # Undo specific value of forced arguments in `model` and all submodules of `model` objax.ForceArgs.undo(model, training=True) # Undo all values of specific argument in `model` and all submodules of `model` objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY) # Undo all values of all arguments in `model` and all submodules of `model` objax.ForceArgs.undo(model)
-
class
ANY
¶ Token used in ForceArgs.undo to indicate undo of all values of specific argument.
-
static
undo
(module, **kwargs)[source]¶ Undo ForceArgs on each submodule of the module. Modifications are done in-place.
- Parameters
module (objax.module.Module) – module for which to undo ForceArgs.
**kwargs – dictionary of argument overrides to undo. name=val remove override for value val of argument name. name=ForceArgs.ANY remove all overrides of argument name. If **kwargs is empty then all overrides will be undone.
-
__init__
(module, **kwargs)[source]¶ Initializes ForceArgs by wrapping another module.
- Parameters
module (objax.module.Module) – module which argument will be overridden.
kwargs – values of keyword arguments which will be forced to use.
-
class
-
class
objax.
Grad
(f, variables, input_argnums=None)[source]¶ The Grad module is used to compute the gradients of a function.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) def f(x, y): return ((m(x) - y) ** 2).mean() # Create module to compute gradients of f for m.vars() grad_f = objax.Grad(f, m.vars()) # Create module to compute gradients of f for input 0 (x) and m.vars() grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))
For more information and examples, see Understanding Gradients.
-
__init__
(f, variables, input_argnums=None)[source]¶ Constructs an instance to compute the gradient of f w.r.t. variables.
- Parameters
f (Callable) – the function for which to compute gradients.
variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.
input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.
-
-
class
objax.
GradValues
(f, variables, input_argnums=None)[source]¶ The GradValues module is used to compute the gradients of a function.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) def f(x, y): return ((m(x) - y) ** 2).mean() # Create module to compute gradients of f for m.vars() grad_val_f = objax.GradValues(f, m.vars()) # Create module to compute gradients of f for input 0 (x) and m.vars() grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))
For more information and examples, see Understanding Gradients.
-
__init__
(f, variables, input_argnums=None)[source]¶ Constructs an instance to compute the gradient of f w.r.t. variables.
- Parameters
f (Union[objax.module.Module, Callable]) – the function for which to compute gradients.
variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.
input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.
-
-
class
objax.
Jit
(f, vc=None, static_argnums=None)[source]¶ JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) jit_m = objax.Jit(m) # Jit a module jit_f = objax.Jit(lambda x: m(x), m.vars()) # Jit a function: provide vars it uses
For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.
-
__init__
(f, vc=None, static_argnums=None)[source]¶ Jit constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Parallel
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) para_m = objax.Parallel(m) # Parallelize a module para_f = objax.Parallel(lambda x: m(x), m.vars()) # Parallelize a function: provide vars it uses
When calling a parallelized module, one must replicate the variables it uses on all devices:
x = objax.random.normal((16, 2)) with m.vars().replicate(): y = para_m(x)
For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.
-
__init__
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
reduce (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.
axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
__call__
(*args)[source]¶ Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Vectorize
(f, vc=None, batch_axis=0)[source]¶ Vectorize module takes a function or a module and compiles it for running in parallel on a single device.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vec_m = objax.Vectorize(m) # Vectorize a module vec_f = objax.Vectorize(lambda x: m(x), m.vars()) # Vectorize a function: provide vars it uses
For more information and examples, refer to Vectorization.
-
__init__
(f, vc=None, batch_axis=0)[source]¶ Vectorize constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
Variables¶
|
The abstract base class to represent objax variables. |
|
A trainable variable. |
|
The abstract base class used to represent objax state variables. |
|
StateVar are variables that get updated manually, and are not automatically updated by optimizers. |
|
A state variable that references a trainable variable for assignment. |
|
RandomState are variables that track the random generator state. |
A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
-
class
objax.
BaseVar
(reduce)[source]¶ The abstract base class to represent objax variables.
-
__init__
(reduce)[source]¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainVar
(tensor, reduce=<function reduce_mean>)[source]¶ A trainable variable.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ TrainVar constructor.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.
-
assign
(tensor)[source]¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
BaseState
(reduce)[source]¶ The abstract base class used to represent objax state variables. State variables are not trainable.
-
__init__
(reduce)¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainRef
(ref)[source]¶ A state variable that references a trainable variable for assignment.
TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.
-
__init__
(ref)[source]¶ TrainRef constructor.
- Parameters
ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.
-
property
value
¶ The value stored in the referenced TrainVar, it can be read or written.
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
StateVar
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar are variables that get updated manually, and are not automatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar constructor.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
RandomState
(seed)[source]¶ RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.
-
__init__
(seed)[source]¶ RandomState constructor.
- Parameters
seed (int) – the initial seed of the random number generator.
-
seed
(seed)[source]¶ Sets a new random seed.
- Parameters
seed (int) – the new initial seed of the random number generator.
-
split
(n)[source]¶ Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.
- Parameters
n (int) – the number of seeds to generate.
- Return type
List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
-
class
objax.
VarCollection
[source]¶ A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vc = m.vars() # This is a VarCollection # It is a dictionary print(repr(vc)) # {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>, # '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>} print(vc.keys()) # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w']) assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all() # Convenience print print(vc) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # +Total(2) 9 # Extra methods for manipulation of variables: # For example, increment all variables by 1 vc.assign([x+1 for x in vc.tensors()]) # It's used by other modules. # For example it's used to tell Jit what variables are used by a function. jit_f = objax.Jit(lambda x: m(x), vc)
For more information and examples, refer to VarCollection.
-
assign
(tensors)[source]¶ Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.
- Parameters
tensors (List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.
-
rename
(renamer)[source]¶ Rename the entries in the VarCollection.
Renaming entries in a VarCollection is a powerful tool that can be used for
mapping weights between models that differ slightly.
loading data checkpoints from foreign ML frameworks.
Usage example:
import re import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) print(m.vars()) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # +Total(2) 9 # For example remove modules from the name renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')]) print(m.vars().rename(renamer)) # [0].b 3 (3,) # [0].w 6 (2, 3) # +Total(2) 9 # One can chain renamers, their syntax is flexible and it can use a string mapping: renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer) print(m.vars().rename(renamer_all)) # .0.b 3 (3,) # .0.w 6 (2, 3) # +Total(2) 9
- Parameters
renamer (objax.util.util.Renamer) –
-
replicate
()[source]¶ A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.
-
subset
(is_a=None, is_not=None)[source]¶ Return a new VarCollection that is a filtered subset of the current collection.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.
- Returns
A new VarCollection containing the subset of variables.
- Return type
objax.variable.VarCollection
-
tensors
(is_a=None)[source]¶ Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
- Returns
A new VarCollection containing the subset of variables.
- Return type
List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]
-
update
(other)[source]¶ Overload dict.update method to catch potential conflicts during assignment.
- Parameters
other (Union[VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –
-
__init__
(*args, **kwargs)¶ Initialize self. See help(type(self)) for accurate signature.
-