objax.nn package¶
objax.nn¶
|
Applies a batch normalization on different ranks of an input tensor. |
|
Applies a 0D batch normalization on a 2D-input batch of shape (N,C). |
|
Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L). |
|
Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W). |
|
Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W). |
|
Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W). |
|
In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep. |
|
Applies a linear transformation on an input batch. |
|
Computes moving average of an input batch. |
|
computes exponential moving average (also called EMA or EWMA) of an input batch. |
|
Executes modules in the order they were passed to the constructor. |
|
Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs). |
|
Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C). |
|
Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L). |
|
Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W). |
-
class
objax.nn.
BatchNorm
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Applies a batch normalization on different ranks of an input tensor.
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per specified dimensions and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape dims. The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm module instance.
- Parameters
dims (Iterable[int]) – shape of the batch normalization state variables.
redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)[source]¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm0D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 0D batch normalization on a 2D-input batch of shape (N,C).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm0D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm1D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm1D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm2D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm2D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Conv2D
(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).
In the simplest case (strides = 1, padding = VALID), the output tensor \((N,C_{out},H_{out},W_{out})\) is computed from an input tensor \((N,C_{in},H,W)\) with kernel weight \((k,k,C_{in},C_{out})\) and bias \((C_{out})\) as follows:
\[\mathrm{out}[n,c,h,w] = \mathrm{b}[c] + \sum_{t=0}^{C_{in}-1}\sum_{i=0}^{k-1}\sum_{j=0}^{k-1} \mathrm{in}[n,c,i+h,j+w] \times \mathrm{w}[i,j,t,c]\]where \(H_{out}=H-k+1\), \(W_{out}=W-k+1\). Note that the implementation follows the definition of cross-correlation. When padding = SAME, the input tensor is zero-padded by \(\lfloor\frac{k-1}{2}\rfloor\) for left and up sides and \(\lfloor\frac{k}{2}\rfloor\) for right and down sides.
-
__init__
(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Creates a Conv2D module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.
strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.
dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.
groups (int) – number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.
use_bias (bool) – if True then convolution will have bias term.
w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the convolution to input x.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
ConvTranspose2D
(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).
This module can be seen as a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution. Note that ConvTranspose2D is consistent with Conv2DTranspose, of Tensorflow but is not consistent with ConvTranspose2D of PyTorch due to kernel transpose and padding.
-
__init__
(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Creates a ConvTranspose2D module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.
strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.
dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.
use_bias (bool) – if True then convolution will have bias term.
w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the transposed convolution to input x.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Dropout
(keep, generator=objax.random.Generator(seed=0))[source]¶ In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.
During the evaluation, the module does not modify the input tensor. Dropout (Improving neural networks by preventing co-adaptation of feature detectors) is an effective regularization technique which reduces the overfitting and increases the overall utility.
-
__init__
(keep, generator=objax.random.Generator(seed=0))[source]¶ Creates Dropout module instance.
- Parameters
keep (float) – probability to keep element of the tensor.
generator – optional argument with instance of ObJAX random generator.
-
__call__
(x, training, dropout_keep=None)[source]¶ Performs dropout of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True then apply dropout to the input, otherwise keep input tensor unchanged.
dropout_keep (Optional[float]) – optional argument, when set overrides dropout keep probability.
- Returns
Tensor with applied dropout.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Linear
(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]¶ Applies a linear transformation on an input batch.
The output tensor \((N,C_{out})\) is computed from an input tensor \((N,C_{in})\) with kernel weight \((C_{in},C_{out})\) and bias \((C_{out})\) as follows:
\[\mathrm{out}[n,c] = \mathrm{b}[c] + \sum_{t=1}^{C_{in}} \mathrm{in}[n,t] \times \mathrm{w}[t,c]\]-
__init__
(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]¶ Creates a Linear module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
use_bias (bool) – if True then linear layer will have bias term.
w_init (Callable) – weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the linear transformation to input x.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
MovingAverage
(shape, buffer_size, init_value=0)[source]¶ Computes moving average of an input batch.
-
__init__
(shape, buffer_size, init_value=0)[source]¶ Creates a MovingAverage module instance.
- Parameters
shape (Tuple[int, ..]) – shape of the input tensor.
buffer_size (int) – buffer size for moving average.
init_value (float) – initial value for moving average buffer.
-
__call__
(x)[source]¶ Update the statistics using x and return the moving average.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
ExponentialMovingAverage
(shape, momentum=0.999, init_value=0)[source]¶ computes exponential moving average (also called EMA or EWMA) of an input batch.
\[x_{\mathrm{EMA}} \leftarrow \mathrm{momentum} \times x_{\mathrm{EMA}} + (1-\mathrm{momentum}) \times x\]-
__init__
(shape, momentum=0.999, init_value=0)[source]¶ Creates a ExponentialMovingAverage module instance.
- Parameters
shape (Tuple[int, ..]) – shape of the input tensor.
momentum (float) – momentum for exponential decrease of accumulated value.
init_value (float) – initial value for exponential moving average.
-
__call__
(x)[source]¶ Update the statistics using x and return the exponential moving average.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Sequential
(iterable=(), /)[source]¶ Executes modules in the order they were passed to the constructor.
Usage example:
import objax ml = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)]) x = objax.random.normal((10, 2)) y = ml(x) # Runs all the operations (Linear -> ReLU -> Linear). print(y.shape) # (10, 4) # objax.nn.Sequential is really a list. ml.insert(2, objax.nn.BatchNorm0D(3)) # Add a batch norm layer after ReLU ml.append(objax.nn.Dropout(keep=0.5)) # Add a dropout layer at the end y = ml(x, training=False) # Both batch norm and dropout expect a training argument. # Sequential automatically pass arguments to the modules using them. # You can run a subset of operations since it is a list. y1 = ml[:2](x) # Run first two layers (Linear -> ReLU) y2 = ml[2:](y1, training=False) # Run all layers starting from third (BatchNorm0D -> Dropout) print(ml(x, training=False) - y2) # [[0. 0. ...]] - results are the same. print(ml.vars()) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # (Sequential)[2](BatchNorm0D).running_mean 3 (1, 3) # (Sequential)[2](BatchNorm0D).running_var 3 (1, 3) # (Sequential)[2](BatchNorm0D).beta 3 (1, 3) # (Sequential)[2](BatchNorm0D).gamma 3 (1, 3) # (Sequential)[3](BatchNorm0D).running_mean 3 (1, 3) # (Sequential)[3](BatchNorm0D).running_var 3 (1, 3) # (Sequential)[3](BatchNorm0D).beta 3 (1, 3) # (Sequential)[3](BatchNorm0D).gamma 3 (1, 3) # (Sequential)[4](Linear).b 4 (4,) # (Sequential)[4](Linear).w 12 (3, 4) # (Sequential)[5](Dropout).keygen(Generator)._key 2 (2,) # +Total(13) 51
-
__call__
(*args, **kwargs)[source]¶ Execute the sequence of operations contained on
*args
and**kwargs
and return result.- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]
-
__init__
(*args, **kwargs)¶ Initialize self. See help(type(self)) for accurate signature.
-
append
(object, /)¶ Append object to the end of the list.
-
clear
()¶ Remove all items from list.
-
copy
()¶ Return a shallow copy of the list.
-
count
(value, /)¶ Return number of occurrences of value.
-
extend
(iterable, /)¶ Extend list by appending elements from the iterable.
-
index
(value, start=0, stop=9223372036854775807, /)¶ Return first index of value.
Raises ValueError if the value is not present.
-
insert
(index, object, /)¶ Insert object before index.
-
pop
(index=-1, /)¶ Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
-
remove
(value, /)¶ Remove first occurrence of value.
Raises ValueError if the value is not present.
-
reverse
()¶ Reverse IN PLACE.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the list and its submodules.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.nn.
SyncedBatchNorm
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).
-
__call__
(x, training, batch_norm_update=True)[source]¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
__init__
(dims, redux, momentum=0.999, eps=1e-06)¶ Creates a BatchNorm module instance.
- Parameters
dims (Iterable[int]) – shape of the batch normalization state variables.
redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
-
class
objax.nn.
SyncedBatchNorm0D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm0D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
SyncedBatchNorm1D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm1D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
SyncedBatchNorm2D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm2D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.nn.init¶
|
The recommended gain value for leaky_relu. |
|
Returns the identity matrix. |
|
Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. |
|
Returns a tensor with values assigned using truncated normal initialization. |
|
Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks. |
|
Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks. |
|
Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks. |
-
class
objax.nn.init.
gain_leaky_relu
(relu_slope=0.1)[source]¶ The recommended gain value for leaky_relu.
- Parameters
relu_slope – negative slope of leaky_relu.
- Returns
The recommended gain value for leaky_relu.
The returned gain value is
\[\sqrt{\frac{2}{1 + \text{relu_slope}^2}}.\]
-
class
objax.nn.init.
identity
(shape, gain=1)[source]¶ Returns the identity matrix. This initializer was proposed in A Simple Way to Initialize Recurrent Networks of Rectified Linear Units.
- Parameters
shape – Shape of the tensor. It should have exactly rank 2.
gain – optional scaling factor.
- Returns
Tensor initialized to the identity matrix.
-
class
objax.nn.init.
kaiming_normal_gain
(shape)[source]¶ Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
- Returns
Scalar, the standard deviation gain.
The returned gain value is
\[\sqrt{\frac{1}{\text{fan_in}}}.\]
-
class
objax.nn.init.
kaiming_normal
(shape, gain=1)[source]¶ Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
gain – optional scaling factor.
- Returns
Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).
-
class
objax.nn.init.
kaiming_truncated_normal
(shape, lower=- 2, upper=2, gain=1)[source]¶ Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
gain – optional scaling factor.
- Returns
Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper].
-
class
objax.nn.init.
orthogonal
(shape, gain=1, axis=- 1)[source]¶ Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.
- Parameters
shape – shape of the output tensor.
gain – optional scaling factor.
axis – the orthogonalizarion axis
- Returns
An orthogonally initialized tensor. These tensors will be row-orthonormal along the access specified by
axis
. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if theaxis
dimension is larger, otherwise the tensor will be transposed (equivalently, it will be column orthonormal instead of row orthonormal). If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
-
class
objax.nn.init.
truncated_normal
(shape, lower=- 2, upper=2, stddev=1)[source]¶ Returns a tensor with values assigned using truncated normal initialization.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
stddev – expected standard deviation.
- Returns
Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper].
-
class
objax.nn.init.
xavier_normal_gain
(shape)[source]¶ Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
- Returns
Scalar, the standard deviation gain.
The returned gain value is
\[\sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}.\]
-
class
objax.nn.init.
xavier_normal
(shape, gain=1)[source]¶ Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
gain – optional scaling factor.
- Returns
Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).
-
class
objax.nn.init.
xavier_truncated_normal
(shape, lower=- 2, upper=2, gain=1)[source]¶ Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
gain – optional scaling factor.
- Returns
Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper].