objax package¶
Modules¶
|
A module is a container to associate variables and functions. |
|
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it. |
|
Forces override of arguments of given module. |
|
Turn a function into a Module by keeping the vars it uses. |
|
The Grad module is used to compute the gradients of a function. |
|
The GradValues module is used to compute the gradients of a function. |
|
JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution. |
|
Parallel module takes a function or a module and compiles it for running on multiple devices in parallel. |
|
Vectorize module takes a function or a module and compiles it for running in parallel on a single device. |
-
class
objax.
Module
[source]¶ A module is a container to associate variables and functions.
-
vars
(scope='')[source]¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
ModuleList
(iterable=(), /)[source]¶ Bases:
objax.module.Module
,list
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.
See also
Usage example:
import objax ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))]) print(ml.vars()) # (ModuleList)[1] 20 (10, 2) # +Total(1) 20 ml.pop() ml.append(objax.nn.Linear(2, 3)) print(ml.vars()) # (ModuleList)[1](Linear).b 3 (3,) # (ModuleList)[1](Linear).w 6 (2, 3) # +Total(2) 9
-
class
objax.
Function
(f, vc)[source]¶ Turn a function into a Module by keeping the vars it uses.
Usage example:
import objax import jax.numpy as jn m = objax.nn.Linear(2, 3) def f1(x, y): return ((m(x) - y) ** 2).mean() # Method 1: Create module by calling objax.Function to tell which variables are used. m1 = objax.Function(f1, m.vars()) # Method 2: Use the decorator. @objax.Function.with_vars(m.vars()) def f2(x, y): return ((m(x) - y) ** 2).mean() # All behave like functions x = jn.arange(10).reshape((5, 2)) y = jn.arange(15).reshape((5, 3)) print(type(f1), f1(x, y)) # <class 'function'> 237.01947 print(type(m1), m1(x, y)) # <class 'objax.module.Function'> 237.01947 print(type(f2), f2(x, y)) # <class 'objax.module.Function'> 237.01947
Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.
-
static
with_vars
(vc)[source]¶ Decorator which turns a function into a module using provided variable collection.
- Parameters
vc (objax.variable.VarCollection) – the VarCollection of variables used by the function.
Usage example:
import objax m = objax.nn.Linear(2, 3) @objax.Function.with_vars(m.vars()) def f(x, y): return ((m(x) - y) ** 2).mean() print(type(f)) # <class 'objax.module.Function'>
-
static
auto_vars
(f)[source]¶ Turns a function into a module by auto detecting used Objax variables. Could be used as a decorator.
WARNING: This is an experimental feature. It can detect variables used by function in many common cases, but not all cases. This feature may be removed in the future version of Objax if it appear to be too unreliable.
- Parameters
f (Callable) – function which will be converted into a module.
Usage example:
import objax m = objax.nn.Linear(2, 3) @objax.Function.auto_vars def f(x, y): return ((m(x) - y) ** 2).mean() print(type(f)) # <class 'objax.module.Function'>
-
static
-
class
objax.
ForceArgs
(module, **kwargs)[source]¶ Forces override of arguments of given module.
One example of ForceArgs usage is to override training argument for batch normalization:
import objax from objax.zoo.resnet_v2 import ResNet50 model = ResNet50(in_channels=3, num_classes=1000) # Modify model to force training=False on first resnet block. # First two ops in the resnet are convolution and padding, # resnet blocks are starting at index 2. model[2] = objax.ForceArgs(model[2], training=False) # model(x, training=True) will be using `training=False` on model[2] due to ForceArgs # ... # Undo specific value of forced arguments in `model` and all submodules of `model` objax.ForceArgs.undo(model, training=True) # Undo all values of specific argument in `model` and all submodules of `model` objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY) # Undo all values of all arguments in `model` and all submodules of `model` objax.ForceArgs.undo(model)
-
class
ANY
¶ Token used in ForceArgs.undo to indicate undo of all values of specific argument.
-
static
undo
(module, **kwargs)[source]¶ Undo ForceArgs on each submodule of the module. Modifications are done in-place.
- Parameters
module (objax.module.Module) – module for which to undo ForceArgs.
**kwargs – dictionary of argument overrides to undo. name=val remove override for value val of argument name. name=ForceArgs.ANY remove all overrides of argument name. If **kwargs is empty then all overrides will be undone.
-
__init__
(module, **kwargs)[source]¶ Initializes ForceArgs by wrapping another module.
- Parameters
module (objax.module.Module) – module which argument will be overridden.
kwargs – values of keyword arguments which will be forced to use.
-
class
-
class
objax.
Grad
(f, variables, input_argnums=None)[source]¶ The Grad module is used to compute the gradients of a function.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) def f(x, y): return ((m(x) - y) ** 2).mean() # Create module to compute gradients of f for m.vars() grad_f = objax.Grad(f, m.vars()) # Create module to compute gradients of f for input 0 (x) and m.vars() grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))
For more information and examples, see Understanding Gradients.
-
__init__
(f, variables, input_argnums=None)[source]¶ Constructs an instance to compute the gradient of f w.r.t. variables.
- Parameters
f (Callable) – the function for which to compute gradients.
variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.
input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.
-
-
class
objax.
GradValues
(f, variables, input_argnums=None)[source]¶ The GradValues module is used to compute the gradients of a function.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 2)]) @objax.Function.with_vars(m.vars()) def f(x, y): return ((m(x) - y) ** 2).mean() # Create module to compute gradients of f for m.vars() grad_val_f = objax.GradValues(f, m.vars()) # Create module to compute gradients of f for only some variables grad_val_f_head = objax.GradValues(f, m[:1].vars()) # Create module to compute gradients of f for input 0 (x) and m.vars() grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))
For more information and examples, see Understanding Gradients.
-
__init__
(f, variables, input_argnums=None)[source]¶ Constructs an instance to compute the gradient of f w.r.t. variables.
- Parameters
f (Union[objax.module.Module, Callable]) – the function for which to compute gradients.
variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.
input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.
-
-
class
objax.
Jit
(f, vc=None, static_argnums=None)[source]¶ JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) jit_m = objax.Jit(m) # Jit a module # For jitting functions, use objax.Function.with_vars @objax.Function.with_vars(m.vars()) def f(x): return m(x) + 1 jit_f = objax.Jit(f)
For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.
-
__init__
(f, vc=None, static_argnums=None)[source]¶ Jit constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Parallel
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) para_m = objax.Parallel(m) # Parallelize a module # For parallelizing functions, use objax.Function.with_vars @objax.Function.with_vars(m.vars()) def f(x): return m(x) + 1 para_f = objax.Parallel(f)
When calling a parallelized module, one must replicate the variables it uses on all devices:
x = objax.random.normal((16, 2)) with m.vars().replicate(): y = para_m(x)
For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.
-
__init__
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
reduce (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.
axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
__call__
(*args)[source]¶ Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Vectorize
(f, vc=None, batch_axis=(0))[source]¶ Vectorize module takes a function or a module and compiles it for running in parallel on a single device.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vec_m = objax.Vectorize(m) # Vectorize a module # For vectorizing functions, use objax.Function.with_vars @objax.Function.with_vars(m.vars()) def f(x): return m(x) + 1 vec_f = objax.Parallel(f)
For more information and examples, refer to Vectorization.
-
__init__
(f, vc=None, batch_axis=(0))[source]¶ Vectorize constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
Variables¶
|
The abstract base class to represent objax variables. |
|
A trainable variable. |
|
The abstract base class used to represent objax state variables. |
|
StateVar are variables that get updated manually, and are not automatically updated by optimizers. |
|
A state variable that references a trainable variable for assignment. |
|
RandomState are variables that track the random generator state. |
A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
-
class
objax.
BaseVar
(reduce)[source]¶ The abstract base class to represent objax variables.
-
__init__
(reduce)[source]¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainVar
(tensor, reduce=<function reduce_mean>)[source]¶ A trainable variable.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ TrainVar constructor.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.
reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.
-
assign
(tensor, check=True)[source]¶ Sets the value of the variable.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
dtype
¶ Variable data type.
-
property
ndim
¶ Number of dimentions.
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
shape
¶ Variable shape.
-
-
class
objax.
BaseState
(reduce)[source]¶ The abstract base class used to represent objax state variables. State variables are not trainable.
-
__init__
(reduce)¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainRef
(ref)[source]¶ A state variable that references a trainable variable for assignment.
TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.
-
__init__
(ref)[source]¶ TrainRef constructor.
- Parameters
ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.
-
property
value
¶ The value stored in the referenced TrainVar, it can be read or written.
-
assign
(tensor, check=True)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
dtype
¶ Variable data type.
-
property
ndim
¶ Number of dimentions.
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
shape
¶ Variable shape.
-
-
class
objax.
StateVar
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar are variables that get updated manually, and are not automatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar constructor.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.
reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
assign
(tensor, check=True)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
dtype
¶ Variable data type.
-
property
ndim
¶ Number of dimentions.
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
shape
¶ Variable shape.
-
-
class
objax.
RandomState
(seed)[source]¶ RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.
-
__init__
(seed)[source]¶ RandomState constructor.
- Parameters
seed (int) – the initial seed of the random number generator.
-
seed
(seed)[source]¶ Sets a new random seed.
- Parameters
seed (int) – the new initial seed of the random number generator.
-
split
(n)[source]¶ Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.
- Parameters
n (int) – the number of seeds to generate.
- Return type
List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]
-
assign
(tensor, check=True)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
dtype
¶ Variable data type.
-
property
ndim
¶ Number of dimentions.
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
shape
¶ Variable shape.
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
-
class
objax.
VarCollection
[source]¶ A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vc = m.vars() # This is a VarCollection # It is a dictionary print(repr(vc)) # {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>, # '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>} print(vc.keys()) # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w']) assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all() # Convenience print print(vc) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # +Total(2) 9 # Extra methods for manipulation of variables: # For example, increment all variables by 1 vc.assign([x+1 for x in vc.tensors()]) # It's used by other modules. # For example it's used to tell what variables are used by a function. @objax.Function.with_vars(vc) def my_function(x): return objax.functional.softmax(m(x))
For more information and examples, refer to VarCollection.
-
assign
(tensors)[source]¶ Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.
- Parameters
tensors (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.
-
rename
(renamer)[source]¶ Rename the entries in the VarCollection.
Renaming entries in a VarCollection is a powerful tool that can be used for
mapping weights between models that differ slightly.
loading data checkpoints from foreign ML frameworks.
Usage example:
import re import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) print(m.vars()) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # +Total(2) 9 # For example remove modules from the name renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')]) print(m.vars().rename(renamer)) # [0].b 3 (3,) # [0].w 6 (2, 3) # +Total(2) 9 # One can chain renamers, their syntax is flexible and it can use a string mapping: renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer) print(m.vars().rename(renamer_all)) # .0.b 3 (3,) # .0.w 6 (2, 3) # +Total(2) 9
- Parameters
renamer (objax.util.util.Renamer) –
-
replicate
()[source]¶ A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.
-
subset
(is_a=None, is_not=None)[source]¶ Return a new VarCollection that is a filtered subset of the current collection.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.
- Returns
A new VarCollection containing the subset of variables.
- Return type
objax.variable.VarCollection
-
tensors
(is_a=None)[source]¶ Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
- Returns
A new VarCollection containing the subset of variables.
- Return type
List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]
-
update
(other)[source]¶ Overload dict.update method to catch potential conflicts during assignment.
- Parameters
other (Union[objax.variable.VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –
-
__init__
(*args, **kwargs)¶ Initialize self. See help(type(self)) for accurate signature.
-