objax.zoo package¶
objax.zoo.convnet¶
- class objax.zoo.convnet.ConvNet(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶
ConvNet implementation.
- __init__(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶
Creates ConvNet instance.
- Parameters
nin – number of channels in the input image.
nclass – number of output classes.
scales – number of pooling layers, each of which reduces spatial dimension by 2.
filters – base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max.
filters_max – maximum number of filters.
pooling – type of pooling layer.
objax.zoo.dnnet¶
objax.zoo.resnet_v2¶
- class objax.zoo.resnet_v2.ResNetV2(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Base implementation of ResNet v2 from https://arxiv.org/abs/1603.05027.
- Parameters
in_channels (int) –
num_classes (int) –
blocks_per_group (Sequence[int]) –
bottleneck (bool) –
channels_per_group (Sequence[int]) –
group_strides (Sequence[int]) –
group_use_projection (Sequence[bool]) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNetV2 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
blocks_per_group (Sequence[int]) – number of blocks in each block group.
bottleneck (bool) – if True then use bottleneck blocks.
channels_per_group (Sequence[int]) – number of output channels for each block group.
group_strides (Sequence[int]) – strides for each block group.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
group_use_projection (Sequence[bool]) –
- class objax.zoo.resnet_v2.ResNet18(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 18 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet18 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
- class objax.zoo.resnet_v2.ResNet34(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 34 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet34 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
- class objax.zoo.resnet_v2.ResNet50(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 50 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet50 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
- class objax.zoo.resnet_v2.ResNet101(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 101 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet101 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
- class objax.zoo.resnet_v2.ResNet152(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 152 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet152 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
- class objax.zoo.resnet_v2.ResNet200(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Implementation of ResNet v2 with 200 layers.
- Parameters
in_channels (int) –
num_classes (int) –
normalization_fn (Callable[[...], Module]) –
activation_fn (Callable[[Array], Array]) –
- __init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶
Creates ResNet200 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[...], Module]) – module which used as normalization function.
activation_fn (Callable[[Array], Array]) – activation function.
objax.zoo.wide_resnet¶
- class objax.zoo.wide_resnet.WRNBlock(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
WideResNet block.
- Parameters
nin (int) –
nout (int) –
stride (int) –
bn (Callable) –
- __init__(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
Creates WRNBlock instance.
- Parameters
nin (int) – number of input filters.
nout (int) – number of output filters.
stride (int) – stride for convolution and projection convolution in this block.
bn (Callable) – module which used as batch norm function.
- class objax.zoo.wide_resnet.WideResNetGeneral(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
Base WideResNet implementation.
- Parameters
nin (int) –
nclass (int) –
blocks_per_group (List[int]) –
width (int) –
bn (Callable) –
- __init__(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
Creates WideResNetGeneral instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
blocks_per_group (List[int]) – number of blocks in each block group.
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
- class objax.zoo.wide_resnet.WideResNet(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
WideResNet implementation with 3 groups.
- Parameters
nin (int) –
nclass (int) –
depth (int) –
width (int) –
bn (Callable) –
- __init__(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶
Creates WideResNet instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
depth (int) – number of convolution layers. (depth-4) should be divisible by 6
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
objax.zoo.rnn¶
- class objax.zoo.rnn.RNN(nstate, nin, nout, activation=<CompiledFunction of <function jax.numpy.tanh>>, w_init=<function kaiming_normal>)[source]¶
Recurrent Neural Network (RNN) block.
- Parameters
nstate (int) –
nin (int) –
nout (int) –
activation (Callable) –
w_init (Callable) –
- __init__(nstate, nin, nout, activation=<CompiledFunction of <function jax.numpy.tanh>>, w_init=<function kaiming_normal>)[source]¶
Creates an RNN instance.
- Parameters
nstate (int) – number of hidden units.
nin (int) – number of input units.
nout (int) – number of output units.
activation (Callable) – actication function for hidden layer.
w_init (Callable) – weight initializer for RNN model weights.
- __call__(inputs, only_return_final=False)[source]¶
Forward pass through RNN.
- Parameters
inputs (Array) –
JaxArray
with dimensionsnum_steps, batch_size, vocabulary_size
.only_return_final – return only the last output if
True
, or all output otherwise.`
- Returns
Output tensor with dimensions
num_steps * batch_size, vocabulary_size
.- Return type
Array