objax.zoo package

objax.zoo.convnet

class objax.zoo.convnet.ConvNet(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]

ConvNet implementation.

__init__(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]

Creates ConvNet instance.

Parameters
  • nin – number of channels in the input image.

  • nclass – number of output classes.

  • scales – number of pooling layers, each of which reduces spatial dimension by 2.

  • filters – base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max.

  • filters_max – maximum number of filters.

  • pooling – type of pooling layer.

objax.zoo.dnnet

class objax.zoo.dnnet.DNNet(layer_sizes, activation)[source]

Deep neural network (MLP) implementation.

__init__(layer_sizes, activation)[source]

Creates DNNet instance.

Parameters
  • layer_sizes (Iterable[int]) – number of neurons for each layer.

  • activation (Callable) – layer activation.

objax.zoo.resnet_v2

class objax.zoo.resnet_v2.ResNetV2(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Base implementation of ResNet v2 from https://arxiv.org/abs/1603.05027.

__init__(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNetV2 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • blocks_per_group (Sequence[int]) – number of blocks in each block group.

  • bottleneck (bool) – if True then use bottleneck blocks.

  • channels_per_group (Sequence[int]) – number of output channels for each block group.

  • group_strides (Sequence[int]) – strides for each block group.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

  • group_use_projection (Sequence[bool]) –

class objax.zoo.resnet_v2.ResNet18(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 18 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet18 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet34(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 34 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet34 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet50(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 50 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet50 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet101(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 101 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet101 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet152(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 152 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet152 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet200(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 200 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet200 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

objax.zoo.wide_resnet

class objax.zoo.wide_resnet.WRNBlock(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

WideResNet block.

__init__(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WRNBlock instance.

Parameters
  • nin (int) – number of input filters.

  • nout (int) – number of output filters.

  • stride (int) – stride for convolution and projection convolution in this block.

  • bn (Callable) – module which used as batch norm function.

__call__(x, training)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

  • training (bool) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.zoo.wide_resnet.WideResNetGeneral(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Base WideResNet implementation.

static mean_reduce(x)[source]
Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

__init__(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WideResNetGeneral instance.

Parameters
  • nin (int) – number of channels in the input image.

  • nclass (int) – number of output classes.

  • blocks_per_group (List[int]) – number of blocks in each block group.

  • width (int) – multiplier to the number of convolution filters.

  • bn (Callable) – module which used as batch norm function.

class objax.zoo.wide_resnet.WideResNet(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

WideResNet implementation with 3 groups.

Reference:

http://arxiv.org/abs/1605.07146 https://github.com/szagoruyko/wide-residual-networks

__init__(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WideResNet instance.

Parameters
  • nin (int) – number of channels in the input image.

  • nclass (int) – number of output classes.

  • depth (int) – number of convolution layers. (depth-4) should be divisible by 6

  • width (int) – multiplier to the number of convolution filters.

  • bn (Callable) – module which used as batch norm function.

objax.zoo.rnn

class objax.zoo.rnn.RNN(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]

Recurrent Neural Network (RNN) block.

__init__(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]

Creates an RNN instance.

Parameters
  • nstate (int) – number of hidden units.

  • nin (int) – number of input units.

  • nout (int) – number of output units.

  • activation (Callable) – actication function for hidden layer.

  • w_init (Callable) – weight initializer for RNN model weights.

init_state(batch_size)[source]

Initialize hidden state for input batch of size batch_size.

__call__(inputs, only_return_final=False)[source]

Forward pass through RNN.

Parameters
  • inputs (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – JaxArray with dimensions num_steps, batch_size, vocabulary_size.

  • only_return_final – return only the last output if True, or all output otherwise.`

Returns

Output tensor with dimensions num_steps * batch_size, vocabulary_size.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.zoo.vgg

class objax.zoo.vgg.VGG19(pretrained=False)[source]

VGG19 implementation.

__init__(pretrained=False)[source]

Creates VGG19 instance.

Parameters

pretrained – if True load weights from ImageNet pretrained model.

build()[source]