objax.zoo package¶
objax.zoo.convnet¶
-
class
objax.zoo.convnet.
ConvNet
(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶ ConvNet implementation.
-
__init__
(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶ Creates ConvNet instance.
- Parameters
nin – number of channels in the input image.
nclass – number of output classes.
scales – number of pooling layers, each of which reduces spatial dimension by 2.
filters – base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max.
filters_max – maximum number of filters.
pooling – type of pooling layer.
-
objax.zoo.dnnet¶
objax.zoo.resnet_v2¶
-
class
objax.zoo.resnet_v2.
ResNetV2
(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Base implementation of ResNet v2 from https://arxiv.org/abs/1603.05027.
-
__init__
(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNetV2 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
blocks_per_group (Sequence[int]) – number of blocks in each block group.
bottleneck (bool) – if True then use bottleneck blocks.
channels_per_group (Sequence[int]) – number of output channels for each block group.
group_strides (Sequence[int]) – strides for each block group.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
group_use_projection (Sequence[bool]) –
-
-
class
objax.zoo.resnet_v2.
ResNet18
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 18 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet18 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet34
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 34 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet34 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet50
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 50 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet50 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet101
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 101 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet101 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet152
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 152 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet152 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet200
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 200 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet200 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
objax.zoo.wide_resnet¶
-
class
objax.zoo.wide_resnet.
WRNBlock
(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ WideResNet block.
-
__init__
(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WRNBlock instance.
- Parameters
nin (int) – number of input filters.
nout (int) – number of output filters.
stride (int) – stride for convolution and projection convolution in this block.
bn (Callable) – module which used as batch norm function.
-
__call__
(x, training)[source]¶ Optional module __call__ method, typically a forward pass computation for standard primitives.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
training (bool) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.zoo.wide_resnet.
WideResNetGeneral
(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Base WideResNet implementation.
-
static
mean_reduce
(x)[source]¶ - Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-
__init__
(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WideResNetGeneral instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
blocks_per_group (List[int]) – number of blocks in each block group.
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
-
static
-
class
objax.zoo.wide_resnet.
WideResNet
(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ WideResNet implementation with 3 groups.
-
__init__
(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WideResNet instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
depth (int) – number of convolution layers. (depth-4) should be divisible by 6
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
-
objax.zoo.rnn¶
-
class
objax.zoo.rnn.
RNN
(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]¶ Recurrent Neural Network (RNN) block.
-
__init__
(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]¶ Creates an RNN instance.
- Parameters
nstate (int) – number of hidden units.
nin (int) – number of input units.
nout (int) – number of output units.
activation (Callable) – actication function for hidden layer.
w_init (Callable) – weight initializer for RNN model weights.
-
__call__
(inputs, only_return_final=False)[source]¶ Forward pass through RNN.
- Parameters
inputs (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –
JaxArray
with dimensionsnum_steps, batch_size, vocabulary_size
.only_return_final – return only the last output if
True
, or all output otherwise.`
- Returns
Output tensor with dimensions
num_steps * batch_size, vocabulary_size
.- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]
-