objax.nn package

objax.nn

BatchNorm(dims, redux[, momentum, eps])

Applies a batch normalization on different ranks of an input tensor.

BatchNorm0D(nin[, momentum, eps])

Applies a 0D batch normalization on a 2D-input batch of shape (N,C).

BatchNorm1D(nin[, momentum, eps])

Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).

BatchNorm2D(nin[, momentum, eps])

Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).

Conv2D(nin, nout, k[, strides, dilations, …])

Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).

ConvTranspose2D(nin, nout, k[, strides, …])

Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).

Dropout(keep[, generator])

In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.

Linear(nin, nout[, use_bias, w_init])

Applies a linear transformation on an input batch.

MovingAverage(shape, buffer_size[, init_value])

Computes moving average of an input batch.

ExponentialMovingAverage(shape[, momentum, …])

computes exponential moving average (also called EMA or EWMA) of an input batch.

Sequential([iterable])

Executes modules in the order they were passed to the constructor.

SyncedBatchNorm(dims, redux[, momentum, eps])

Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).

SyncedBatchNorm0D(nin[, momentum, eps])

Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).

SyncedBatchNorm1D(nin[, momentum, eps])

Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).

SyncedBatchNorm2D(nin[, momentum, eps])

Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).

class objax.nn.BatchNorm(dims, redux, momentum=0.999, eps=1e-06)[source]

Applies a batch normalization on different ranks of an input tensor.

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per specified dimensions and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape dims. The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(dims, redux, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm module instance.

Parameters
  • dims (Iterable[int]) – shape of the batch normalization state variables.

  • redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)[source]

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm0D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 0D batch normalization on a 2D-input batch of shape (N,C).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm0D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm1D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm1D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm2D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm2D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Conv2D(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).

In the simplest case (strides = 1, padding = VALID), the output tensor \((N,C_{out},H_{out},W_{out})\) is computed from an input tensor \((N,C_{in},H,W)\) with kernel weight \((k,k,C_{in},C_{out})\) and bias \((C_{out})\) as follows:

\[\mathrm{out}[n,c,h,w] = \mathrm{b}[c] + \sum_{t=0}^{C_{in}-1}\sum_{i=0}^{k-1}\sum_{j=0}^{k-1} \mathrm{in}[n,c,i+h,j+w] \times \mathrm{w}[i,j,t,c]\]

where \(H_{out}=H-k+1\), \(W_{out}=W-k+1\). Note that the implementation follows the definition of cross-correlation. When padding = SAME, the input tensor is zero-padded by \(\lfloor\frac{k-1}{2}\rfloor\) for left and up sides and \(\lfloor\frac{k}{2}\rfloor\) for right and down sides.

__init__(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Creates a Conv2D module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.

  • strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.

  • dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.

  • groups (int) – number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.

  • use_bias (bool) – if True then convolution will have bias term.

  • w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).

__call__(x)[source]

Returns the results of applying the convolution to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.ConvTranspose2D(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).

This module can be seen as a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution. Note that ConvTranspose2D is consistent with Conv2DTranspose, of Tensorflow but is not consistent with ConvTranspose2D of PyTorch due to kernel transpose and padding.

__init__(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Creates a ConvTranspose2D module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.

  • strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.

  • dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.

  • use_bias (bool) – if True then convolution will have bias term.

  • w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).

__call__(x)[source]

Returns the results of applying the transposed convolution to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Dropout(keep, generator=objax.random.Generator(seed=0))[source]

In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.

During the evaluation, the module does not modify the input tensor. Dropout (Improving neural networks by preventing co-adaptation of feature detectors) is an effective regularization technique which reduces the overfitting and increases the overall utility.

__init__(keep, generator=objax.random.Generator(seed=0))[source]

Creates Dropout module instance.

Parameters
  • keep (float) – probability to keep element of the tensor.

  • generator – optional argument with instance of ObJAX random generator.

__call__(x, training, dropout_keep=None)[source]

Performs dropout of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True then apply dropout to the input, otherwise keep input tensor unchanged.

  • dropout_keep (Optional[float]) – optional argument, when set overrides dropout keep probability.

Returns

Tensor with applied dropout.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Linear(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]

Applies a linear transformation on an input batch.

The output tensor \((N,C_{out})\) is computed from an input tensor \((N,C_{in})\) with kernel weight \((C_{in},C_{out})\) and bias \((C_{out})\) as follows:

\[\mathrm{out}[n,c] = \mathrm{b}[c] + \sum_{t=1}^{C_{in}} \mathrm{in}[n,t] \times \mathrm{w}[t,c]\]
__init__(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]

Creates a Linear module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • use_bias (bool) – if True then linear layer will have bias term.

  • w_init (Callable) – weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).

__call__(x)[source]

Returns the results of applying the linear transformation to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.MovingAverage(shape, buffer_size, init_value=0)[source]

Computes moving average of an input batch.

__init__(shape, buffer_size, init_value=0)[source]

Creates a MovingAverage module instance.

Parameters
  • shape (Tuple[int, ..]) – shape of the input tensor.

  • buffer_size (int) – buffer size for moving average.

  • init_value (float) – initial value for moving average buffer.

__call__(x)[source]

Update the statistics using x and return the moving average.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.ExponentialMovingAverage(shape, momentum=0.999, init_value=0)[source]

computes exponential moving average (also called EMA or EWMA) of an input batch.

\[x_{\mathrm{EMA}} \leftarrow \mathrm{momentum} \times x_{\mathrm{EMA}} + (1-\mathrm{momentum}) \times x\]
__init__(shape, momentum=0.999, init_value=0)[source]

Creates a ExponentialMovingAverage module instance.

Parameters
  • shape (Tuple[int, ..]) – shape of the input tensor.

  • momentum (float) – momentum for exponential decrease of accumulated value.

  • init_value (float) – initial value for exponential moving average.

__call__(x)[source]

Update the statistics using x and return the exponential moving average.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Sequential(iterable=(), /)[source]

Executes modules in the order they were passed to the constructor.

Usage example:

import objax

ml = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu,
                          objax.nn.Linear(3, 4)])
x = objax.random.normal((10, 2))
y = ml(x)  # Runs all the operations (Linear -> ReLU -> Linear).
print(y.shape)  # (10, 4)

# objax.nn.Sequential is really a list.
ml.insert(2, objax.nn.BatchNorm0D(3))  # Add a batch norm layer after ReLU
ml.append(objax.nn.Dropout(keep=0.5))  # Add a dropout layer at the end
y = ml(x, training=False)  # Both batch norm and dropout expect a training argument.
# Sequential automatically pass arguments to the modules using them.

# You can run a subset of operations since it is a list.
y1 = ml[:2](x)  # Run first two layers (Linear -> ReLU)
y2 = ml[2:](y1, training=False)  # Run all layers starting from third (BatchNorm0D -> Dropout)
print(ml(x, training=False) - y2)  # [[0. 0. ...]] - results are the same.

print(ml.vars())
# (Sequential)[0](Linear).b                              3 (3,)
# (Sequential)[0](Linear).w                              6 (2, 3)
# (Sequential)[2](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[2](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[2](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[2](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[3](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[3](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[4](Linear).b                              4 (4,)
# (Sequential)[4](Linear).w                             12 (3, 4)
# (Sequential)[5](Dropout).keygen(Generator)._key        2 (2,)
# +Total(13)                                            51
__call__(*args, **kwargs)[source]

Execute the sequence of operations contained on *args and **kwargs and return result.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

append(object, /)

Append object to the end of the list.

clear()

Remove all items from list.

copy()

Return a shallow copy of the list.

count(value, /)

Return number of occurrences of value.

extend(iterable, /)

Extend list by appending elements from the iterable.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

insert(index, object, /)

Insert object before index.

pop(index=-1, /)

Remove and return item at index (default last).

Raises IndexError if list is empty or index is out of range.

remove(value, /)

Remove first occurrence of value.

Raises ValueError if the value is not present.

reverse()

Reverse IN PLACE.

vars(scope='')

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.nn.SyncedBatchNorm(dims, redux, momentum=0.999, eps=1e-06)[source]

Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).

__call__(x, training, batch_norm_update=True)[source]

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

__init__(dims, redux, momentum=0.999, eps=1e-06)

Creates a BatchNorm module instance.

Parameters
  • dims (Iterable[int]) – shape of the batch normalization state variables.

  • redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

class objax.nn.SyncedBatchNorm0D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm0D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.SyncedBatchNorm1D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm1D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.SyncedBatchNorm2D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm2D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.nn.init

gain_leaky_relu([relu_slope])

The recommended gain value for leaky_relu.

identity(shape[, gain])

Returns the identity matrix.

kaiming_normal_gain(shape)

Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

kaiming_normal(shape[, gain])

Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

kaiming_truncated_normal(shape[, lower, …])

Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

orthogonal(shape[, gain, axis])

Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.

truncated_normal(shape[, lower, upper, stddev])

Returns a tensor with values assigned using truncated normal initialization.

xavier_normal_gain(shape)

Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.

xavier_normal(shape[, gain])

Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.

xavier_truncated_normal(shape[, lower, …])

Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.

class objax.nn.init.gain_leaky_relu(relu_slope=0.1)[source]

The recommended gain value for leaky_relu.

Parameters

relu_slope – negative slope of leaky_relu.

Returns

The recommended gain value for leaky_relu.

The returned gain value is

\[\sqrt{\frac{2}{1 + \text{relu_slope}^2}}.\]
class objax.nn.init.identity(shape, gain=1)[source]

Returns the identity matrix. This initializer was proposed in A Simple Way to Initialize Recurrent Networks of Rectified Linear Units.

Parameters
  • shape – Shape of the tensor. It should have exactly rank 2.

  • gain – optional scaling factor.

Returns

Tensor initialized to the identity matrix.

class objax.nn.init.kaiming_normal_gain(shape)[source]

Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters

shape – shape of the output tensor.

Returns

Scalar, the standard deviation gain.

The returned gain value is

\[\sqrt{\frac{1}{\text{fan_in}}}.\]
class objax.nn.init.kaiming_normal(shape, gain=1)[source]

Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

Returns

Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).

class objax.nn.init.kaiming_truncated_normal(shape, lower=- 2, upper=2, gain=1)[source]

Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • gain – optional scaling factor.

Returns

Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper].

class objax.nn.init.orthogonal(shape, gain=1, axis=- 1)[source]

Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

  • axis – the orthogonalizarion axis

Returns

An orthogonally initialized tensor. These tensors will be row-orthonormal along the access specified by axis. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if the axis dimension is larger, otherwise the tensor will be transposed (equivalently, it will be column orthonormal instead of row orthonormal). If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

class objax.nn.init.truncated_normal(shape, lower=- 2, upper=2, stddev=1)[source]

Returns a tensor with values assigned using truncated normal initialization.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • stddev – expected standard deviation.

Returns

Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper].

class objax.nn.init.xavier_normal_gain(shape)[source]

Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.

Parameters

shape – shape of the output tensor.

Returns

Scalar, the standard deviation gain.

The returned gain value is

\[\sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}.\]
class objax.nn.init.xavier_normal(shape, gain=1)[source]

Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

Returns

Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).

class objax.nn.init.xavier_truncated_normal(shape, lower=- 2, upper=2, gain=1)[source]

Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • gain – optional scaling factor.

Returns

Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper].