Welcome to Objax’s documentation!

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX – a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

Try the 5 minutes tutorial.

Machine learning’s 'Hello world': optimizing the weights of classifier net through gradient descent:

opt = objax.optimizer.Adam(net.vars())

@objax.Function.with_vars(net.vars())
def loss(x, y):
    logits = net(x)  # Output of classifier on x
    xe = cross_entropy_logits(logits, y)
    return xe.mean()

# Perform gradient descent wrt to net weights
gv = objax.GradValues(loss, net.vars())

@objax.Function.with_vars(net.vars() + opt.vars())
def train_op(x, y):
    g, v = gv(x, y)  # returns gradients g and loss v
    opt(lr, g)  # update weights
    return v

train_op = objax.Jit(train_op)

Objax philosophy

Objax pursues the quest for the simplest design and code that’s as easy as possible to extend without sacrificing performance.

– Objax Devs

Motivation

Researchers and students look at machine learning frameworks in their own way. Often they read the code of some technique, say an Adam optimizer, to understand how it works so they can extend it or design a new optimizer. This is how machine learning frameworks differ from standard libraries: a large class of users not only look at the APIs but also at the code behind these APIs.

Coded for simplicity

Source code should be understandable by everyone, including users without background in computer science. So how simple is it really? Judge for yourself with this tutorial: Logistic Regression.

Object-oriented

It is common in machine learning to separate the inputs (\(X\)) from the parameters (\(\theta\)) of a function \(f(X; \theta)\). Math notation captures this difference by using a semi-colon to semantically separate the first group of arguments from the other.

Objax represents this semantic distinction through objax.Module:

  • the module’s parameters \(\theta\) are attributes of the form self.w, ...

  • inputs \(X\) are method arguments such as def __call__(self, x, y, ...):

Designed for flexibility

Objax minimizes the number of abstractions users need to understand. There are two main ones: Modules and Variables. Everything is built out of these two basic classes. You can read more about this in Variables and Modules.

Engineered for performance

In machine learning, performance is essential. Every second counts. Objax makes it count by using the JAX/XLA engine that also powers TensorFlow. Read more about this in Compilation and Parallelism.

Installation and Setup

For developing or contributing to Objax, see Development setup.

User installation

Install using pip with the following command:

pip install --upgrade objax

For GPU support, we assume you have already some version of CUDA installed. Here are the extra steps:

# Specify your installed CUDA version.
CUDA_VERSION=11.0
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`
Useful shell configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Testing your installation

You can run the code below to test your installation:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal((100, 4))
m = objax.nn.Linear(4, 5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal((100, 3, 32, 32))
m = objax.nn.Conv2D(3, 4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

If you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Installing examples

Clone the code repository:

git clone https://github.com/google/objax.git
cd objax/examples

Interactive online version: Open In Colab

Objax Basics Tutorial

This tutorial introduces basic Objax concepts.

Prerequisites

Objax is a machine learning library written in Python which works on top of JAX. Readers should therefore have some familiarity with following:

  • Python. If you’re new to Python or need a refresher check Python Tutorial or Python Beginner’s Guide.

  • NumPy is a library for mathematical computations. Many of JAX primitives are based on NumPy and have the same syntax. NumPy is also useful for data manipulation outside of JAX/Objax. NumPy quickstart covers most of the needed topics. More information can be found on NumPy documentation site.

  • JAX can be described as NumPy with gradients which runs on accelerators (GPU and TPU). The JAX quickstart covers most of the concepts needed to understand Objax.

Installation and imports

Let’s first install Objax:

[1]:
%pip --quiet install objax

After Objax is installed, you can import all necessary modules:

[2]:
import jax.numpy as jn
import numpy as np
import objax
Tensors

Tensors are essentially multi-dimensional arrays. In JAX and Objax tensors can be placed on GPUs or TPUs to accelerate computations.

Objax relies on the jax.numpy.ndarray primitive from JAX to represent tensors. In turn, this primitive has a very similar API to NumPy ndarray.

Creating tensors

Tensors creation is very similar to NumPy and is done in multiple ways:

  1. Provide explicit values to the tensor:

[3]:
# Providing explicit values
jn.array([[1.0, 2.0, 3.0],
          [4.0, 5.0, 6.0]])
[3]:
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float32)
  1. From a NumPy array:

[4]:
arr = np.array([1.0, 2.0, 3.0])
jn.array(arr)
[4]:
DeviceArray([1., 2., 3.], dtype=float32)
  1. From another JAX tensor:

[5]:
another_tensor = jn.array([[1.0, 2.0, 3.0],
                           [4.0, 5.0, 6.0]])
jn.array(another_tensor)
[5]:
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float32)
  1. Using ones or zeros:

[6]:
jn.ones((3, 4))
[6]:
DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)
[7]:
jn.zeros((4, 5))
[7]:
DeviceArray([[0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.]], dtype=float32)
  1. As a result of a mathematical operation performed on other tensors:

[8]:
t1 = jn.array([[1.0, 2.0, 3.0],
               [4.0, 5.0, 6.0]])
t2 = jn.ones(t1.shape) * 3
t1 + t2
[8]:
DeviceArray([[4., 5., 6.],
             [7., 8., 9.]], dtype=float32)
Tensor Properties

Similar to NumPy, one can explore various properties of tensors like shape, number of dimensions, or data type:

[9]:
t = jn.array([[1.0, 2.0, 3.0],
              [4.0, 5.0, 6.0]])
print('Number of dimensions: ', t.ndim)
print('Shape: ', t.shape)
print('Data type: ', t.dtype)
Number of dimensions:  2
Shape:  (2, 3)
Data type:  float32
Converting tensors to numpy array

Objax/JAX tensors can be converted to NumPy arrays when needed to perform computations with NumPy:

[10]:
np.array(t)
[10]:
array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float32)
Tensors are immutable

One big difference between JAX ndarray and NumPy ndarray is that JAX ndarray is immutable:

[11]:
print('Original tensor t:\n', t)

try:
    t[0, 0] = -5.0  # This line will fail
except Exception as e:
    print(f'Exception {e}')

print('Tensor t after failed attempt to update:\n', t)
Original tensor t:
 [[1. 2. 3.]
 [4. 5. 6.]]
Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
Tensor t after failed attempt to update:
 [[1. 2. 3.]
 [4. 5. 6.]]

Instead of updating an existing tensor, a new tensor should be created with updated elements. Updates of individual tensor elements is done using index_update, index_add and some other JAX primitives:

[12]:
import jax.ops

print('Original tensor t:\n', t)
new_t = jax.ops.index_update(t, jax.ops.index[0, 0], -5.0)
print('Tensor t after update stays the same:\n', t)
print('Tensor new_t has updated value:\n', new_t)
Original tensor t:
 [[1. 2. 3.]
 [4. 5. 6.]]
Tensor t after update stays the same:
 [[1. 2. 3.]
 [4. 5. 6.]]
Tensor new_t has updated value:
 [[-5.  2.  3.]
 [ 4.  5.  6.]]

More details about per-element updates of tensors can be found in JAX documentation.

In practice, most mathematical operations cover tensors as a whole while manual per-element updates are rarely needed.

Random numbers

It’s very easy to generate random tensors in Objax:

[13]:
x = objax.random.normal((3, 4))
print(x)
[[-0.1441347   0.89507747 -0.46038115  0.10503326]
 [-0.7460886   0.89681065  0.38794124  0.11750659]
 [ 1.0659382   0.22656879 -2.548792    1.9700414 ]]

There are multiple primitives for doing so:

[14]:
print('Random integers:', objax.random.randint((4,), low=0, high=10))
print('Random normal:', objax.random.normal((4,), mean=1.0, stddev=2.0))
print('Random truncated normal: ', objax.random.truncated_normal((4,), stddev=2.0))
print('Random uniform: ', objax.random.uniform((4,)))
Random integers: [1 3 4 7]
Random normal: [-3.4102912  -1.8277478  -0.02106905  0.65338284]
Random truncated normal:  [ 0.78602946  2.3004575  -0.22719319  0.22819921]
Random uniform:  [0.19456518 0.4642099  0.94732213 0.57298625]
Objax Variables and Modules

Objax Variables store values of tensors. Unlike tensors variables are mutable, i.e. the value which is stored in the variable can change. Since tensors are immutable, variables change their value by replacing it with new tensors.

Variables are commonly used together with modules. The Module is a basic building block in Objax that stores variables and other modules. Also most modules are typically callable (i.e., implement the __call__ method) and when called perform some computations on their variables and sub-modules.

Here is an example of a simple module with one variable which performs the dot product of that variable with an input tensor:

[15]:
class SimpleModule(objax.Module):

    def __init__(self, length):
        self.v1 = objax.TrainVar(objax.random.normal((length,)))
        self.v2 = objax.TrainVar(jn.ones((2,)))

    def __call__(self, x):
        return jn.dot(x, self.v1.value)


m = SimpleModule(3)

Modules keep track of all variables they own, including variables in sub-modules. The .vars() method list all the module’s variables. The method returns an instance of VarCollection which is a dictionary with several other useful methods.

[16]:
module_vars = m.vars()

print('type(module_vars): ', type(module_vars))
print('isinstance(module_vars, dict): ', isinstance(module_vars, dict))
print()

print('Variable names and shapes:')
print(module_vars)
print()

print('Variable names and values:')
for k, v in module_vars.items():
  print(f'{k}      {v.value}')
type(module_vars):  objax.variable.VarCollection
isinstance(module_vars, dict):  True

Variable names and shapes:
(SimpleModule).v1           3 (3,)
(SimpleModule).v2           2 (2,)
+Total(2)                   5

Variable names and values:
(SimpleModule).v1      [-1.1010289  -0.68184537 -0.95236546]
(SimpleModule).v2      [1. 1.]

If the __call__ method of the module takes tensors as input and outputs tensors then it can act as a mathematical function. In the general case __call__ can be a multivariate vector-values function.

The SimpleModule described above takes a vector of size length as input and outputs a scalar:

[17]:
x = jn.ones((3,))
y = m(x)
print('Input: ', x)
print('Output: ', y)
Input:  [1. 1. 1.]
Output:  -2.7352397

The way jn.dot works allows us to run code on 2D tensors as well. In this case SimpleModule will treat the input as a batch of vectors, perform the dot product on each of them and return a vector with the results:

[18]:
x = jn.array([[1., 1., 1.],
              [1., 0., 0.],
              [0., 1., 0.],
              [0., 0., 1.]])
y = m(x)
print('Input:\n', x)
print('Output:\n', y)
Input:
 [[1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Output:
 [-2.7352397  -1.1010289  -0.68184537 -0.95236546]

For comparison, here is the result of calling module m on each row of tensor x:

[19]:
print('Sequentially calling module on each row of 2D tensor:')
for idx in range(x.shape[0]):
  row_value = x[idx]
  out_value = m(row_value)
  print(f'm( {row_value} ) = {out_value}')
Sequentially calling module on each row of 2D tensor:
m( [1. 1. 1.] ) = -2.7352397441864014
m( [1. 0. 0.] ) = -1.1010289192199707
m( [0. 1. 0.] ) = -0.6818453669548035
m( [0. 0. 1.] ) = -0.9523654580116272
How to compute gradients

As shown above, modules can act as mathematical functions. It’s essential in machine learning to be able to compute gradients of functions and Objax provides a simple way to do this.

It’s important to keep in mind that gradients are usually defined for scalar-value functions, while our modules can be vector-valued. In this case we need to define additional functions which will convert vector-valued output of the module into scalar. Then we can compute gradients of scalar valued function with respect to all input variables.

In the example with SimpleModule above let’s define scalar-values loss function first:

[20]:
def loss_fn(x):
    return m(x).sum()

print('loss_fn(x) = ', loss_fn(x))
loss_fn(x) =  -5.4704795

Then we create an objax.GradValues module which computes the gradients of loss_fn. We need to pass the function itself to the constructor of objax.GradValues as well as a VarCollection with the variables that loss_fn depends on:

[21]:
# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)

gv is a module which returns the gradients of loss_fn and the values of loss_fn for the given input:

[22]:
# gv returns both gradients and values of original function
grads, value = gv(x)

print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
    print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)
Gradients:
[2. 2. 2.]  w.r.t.  (SimpleModule).v1
[0. 0.]  w.r.t.  (SimpleModule).v2

Value:  [DeviceArray(-5.4704795, dtype=float32)]

In the example above, grads is a list of gradients with respect to all variables from module_vars. The order of gradients in the grads list is the same as the order of corresponding variables in module_vars. So grads[0] is the gradient of the function w.r.t. m.v1 and grads[1] is the gradient w.r.t. m.v2.

Just-in-time compilation (JIT)

In the examples shown so far, the Python interpreter executes all operations one by one. This mode of execution becomes slow for larger and more complicated code.

Objax provides an easy and convenient way to compile sequence of operations using objax.Jit:

[23]:
jit_m = objax.Jit(m)
y = jit_m(x)
print('Input:\n', x)
print('Output:\n', y)
Input:
 [[1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Output:
 [-2.7352397  -1.1010289  -0.68184537 -0.95236546]

Objax.Jit can compile not only modules, but also functions and callables. In this case a variable collection should be passed to objax.Jit:

[24]:
def loss_fn(x, y):
  return ((m(x) - y) ** 2).sum()

jit_loss_fn = objax.Jit(loss_fn, module_vars)

x = objax.random.normal((2, 3))
y = jn.array((-1.0, 1.0))

print('x:\n ', x)
print('y:\n', y)

print('loss_fn(x, y): ', loss_fn(x, y))
print('jit_loss_fn(x, y): ', jit_loss_fn(x, y))
x:
  [[ 2.2491198   0.18783404  0.65321374]
 [-0.23017201  0.18411613  1.341197  ]]
y:
 [-1.  1.]
loss_fn(x, y):  9.577398
jit_loss_fn(x, y):  9.577398

There is no need to use JIT if you only need to compute a single JAX operation. However JIT can give significant speedups when multiple Objax/JAX operations are chained together. The next tutorial will show examples of how JIT is used in practice.

Nevertherless the difference in execution speed with and without JIT is evident even in this simple example:

[25]:
x = objax.random.normal((100, 3))
# gv is a module define above which compute gradients
jit_gv = objax.Jit(gv)
print('Timing for jit_gv:')
%timeit jit_gv(x)
print('Timing for gv:')
%timeit gv(x)
Timing for jit_gv:
1000 loops, best of 3: 309 µs per loop
Timing for gv:
100 loops, best of 3: 3.57 ms per loop
When to use JAX and when to use Objax primitives?

Attentive readers will notice that while Objax works on top of JAX, it redefines quite a few concepts from JAX. Some examples are:

  • objax.GradValues vs jax.value_and_grad for computing gradients.

  • objax.Jit vs jax.jit for just-in-time compilation.

  • objax.random vs jax.random to generate random numbers.

All these differences originate from the fact that JAX is a stateless functional framework, while Objax provides a stateful, object-oriented way to use JAX.

Mixing OOP and functional code can be quite confusing, thus we recommend to use JAX primitives only for basic mathematical operations (defined in jax.numpy) and use Objax primitives for everything else.

Next: Logistic Regression Tutorial

This tutorial introduces all concepts necessary to build and train a machine learning classifier. The next tutorial shows how to apply all of them in logistic regression.

Interactive online version: Open In Colab

Logistic Regression

In this example we will see how to classify images as horses or people using logistic regression. The tutorial builds upon the concepts introduced in the Objax basics tutorial. Consider reading that tutorial first, or you can always go back.

Imports

First, we import the modules we will use in our code.

[1]:
%pip --quiet install objax

import matplotlib.pyplot as plt
import os

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.util import EasyDict
Loading the Dataset

Next, we will load the “horses_or_humans” dataset from TensorFlow DataSets.

The prepare method downscales the image by 3x to reduce training time, flattens each image to a vector, and rescales each pixel value to [-1,1].

[2]:
# Data: train has 1027 images - test has 256 images
# Each image is 300 x 300 x 3 bytes
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR))

def prepare(x, downscale=3):
  """Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them."""
  s = x.shape
  x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3]))
  return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1

train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data
Visualizing the Data

Let’s see a couple of the images in the dataset and their corresponding labels. Note that label 0 corresponds to a horse, while label 1 corresponds to a human.

[3]:
#sample image of a horse.
horse_image = np.reshape(train.image[0], [100,100,3])
plt.imshow(horse_image)
print("label for horse_image:", train.label[0])
label for horse_image: 0
_images/notebooks_Logistic_Regression_6_1.png
[4]:
#sample image of a human.
human_image = np.reshape(train.image[9], [100, 100, 3])
plt.imshow(human_image)
print("label for human_image:", train.label[9])

label for human_image: 1
_images/notebooks_Logistic_Regression_7_1.png
Model Definition

objax.nn.Linear(ndim, 1) is a linear neural unit with ndim inputs and a single output. Given input \(\mathbf{X}\), the output is equal to \(\mathbf{W}\mathbf{X} + \mathbf{b}\) where \(\mathbf{W}, \mathbf{b}\) are the model’s parameters. These parameters are available through model.vars()

[5]:
# Settings
lr = 0.0001  # learning rate
batch = 256
epochs = 20

model = objax.nn.Linear(ndim, 1)
print(model.vars())
(Linear).b                  1 (1,)
(Linear).w              30000 (30000, 1)
+Total(2)               30001
Model Inference

Now that we have defined the model we can use to classify images. To do so, we call the model with an image from the train dataset we previously prepared. Notice that we use the image of a human we previously visualized.

We get the output of the model by calling model(). We then apply the sigmoid activation function and round the output. Activation outputs lower than or equal to 0.5 are rounded to zero (i.e., horses) whereas outputs larger than 0.5 are rounded to one (i.e., humans).

[6]:
# This is an image of a human.
print(np.round(objax.functional.sigmoid(model(train.image[9]))))
[1.]

Considering that we initialized the model with random weights, it should not come as a surprise that the model may misclassify a human as a horse.

Optimizer and Loss Function

In this example we use the objax.optimizer.SGD optimizer. Next, we define the loss function we will use to optimize the network. In this case we use the cross entropy loss function. Note that we use objax.functional.loss.sigmoid_cross_entropy_logits because we perform binary classification.

[7]:
opt = objax.optimizer.SGD(model.vars())

# Cross Entropy Loss
def loss(x, label):
  return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()
Back Propagation and Gradient Descent

objax.GradValues calculates the gradient of loss wrt model.vars(). If you want to learn more about gradients read the Understanding Gradients in-depth topic.

The train_op function implements the core of backward propagation and gradient descent. First, we calculate the gradient g and then pass it to the optimizer which updates the model’s weights.

[8]:
gv = objax.GradValues(loss, model.vars())

def train_op(x, label):
  g, v = gv(x, label)  # returns gradients, loss
  opt(lr, g)
  return v

# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
Training and Evaluation Loop

For each of the training epochs we process all the training data, contained in the train dictionary, in batches of batch size. At the end of each epoch we compute the classification accuracy by comparing the model’s predictions over the test data to the ground truth labels.

[9]:
for epoch in range(epochs):
  # Train
  avg_loss = 0
  # randomly shuffle training data
  shuffle_idx = np.random.permutation(train.image.shape[0])
  for it in range(0, train.image.shape[0], batch):
    sel = shuffle_idx[it: it + batch]
    avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * len(sel)
  avg_loss /= it + batch

  # Eval
  accuracy = 0
  for it in range(0, test.image.shape[0], batch):
    x, y = test.image[it: it + batch], test.label[it: it + batch]
    accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
  accuracy /= test.image.shape[0]
  print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Epoch 0001  Loss 0.25  Accuracy 82.81
Epoch 0002  Loss 0.25  Accuracy 82.81
Epoch 0003  Loss 0.25  Accuracy 83.59
Epoch 0004  Loss 0.25  Accuracy 82.81
Epoch 0005  Loss 0.25  Accuracy 82.03
Epoch 0006  Loss 0.25  Accuracy 80.86
Epoch 0007  Loss 0.25  Accuracy 80.86
Epoch 0008  Loss 0.24  Accuracy 80.86
Epoch 0009  Loss 0.24  Accuracy 82.03
Epoch 0010  Loss 0.24  Accuracy 83.20
Epoch 0011  Loss 0.24  Accuracy 82.03
Epoch 0012  Loss 0.24  Accuracy 80.86
Epoch 0013  Loss 0.24  Accuracy 81.25
Epoch 0014  Loss 0.24  Accuracy 83.59
Epoch 0015  Loss 0.24  Accuracy 84.38
Epoch 0016  Loss 0.24  Accuracy 85.16
Epoch 0017  Loss 0.24  Accuracy 84.77
Epoch 0018  Loss 0.24  Accuracy 83.98
Epoch 0019  Loss 0.24  Accuracy 83.20
Epoch 0020  Loss 0.24  Accuracy 83.59
Model Inference After Training

Now that the network is trained we can retry classification example above:

[10]:
print(np.round(objax.functional.sigmoid(model(train.image[9]))))
[1.]
Next Steps

We saw how we can define, use, and train a very simple model in Objax to classify images of humans and horses. Next, we will learn how to create custom models to classify handwritten digits.

Interactive online version: Open In Colab

Creating Custom Networks for Multi-Class Classification

This tutorial demonstrates how to define, train, and use different models for multi-class classification. We will reuse most of the code from the Logistic Regression tutorial so if you haven’t gone through that, consider reviewing it first.

Note that this tutorial includes a demonstration on how to build and train a simple convolutional neural network and running this colab on CPU may take some time. Therefore, we recommend to run this colab on GPU (select GPU on the menu Runtime -> Change runtime type -> Hardware accelerator if hardware accelerator is not set to GPU).

Import Modules

We start by importing the modules we will use in our code.

[1]:
%pip --quiet install objax

import os

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.util import EasyDict
from objax.zoo.dnnet import DNNet
Load the data

Next, we will load the “MNIST” dataset from TensorFlow DataSets. This dataset contains handwritten digits (i.e., numbers between 0 and 9) and to correctly identify each handwritten digit.

The prepare method pads 2 pixels to the left, right, top, and bottom of each image to resize into 32 x 32 pixes. While MNIST images are grayscale the prepare method expands each image to three color channels to demonstrate the process of working with color images. The same method also rescales each pixel value to [-1, 1], and converts the image to (N, C, H, W) format.

[2]:
# Data: train has 60000 images - test has 10000 images
# Each image is resized and converted to 32 x 32 x 3
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))

def prepare(x):
  """Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format."""
  s = x.shape
  x_pad = np.zeros((s[0], 32, 32, 1))
  x_pad[:, 2:-2, 2:-2, :] = x
  return objax.util.image.nchw(
      np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))

train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]

del data
Deep Neural Network Model

Objax offers many predefined models that we can use for classification. One example is the objax.zoo.DNNet model comprising multiple fully connected layers with configurable size and activation functions.

[3]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
Custom Model Definition

Alternatively, we can define a new model customized to our machine learning task. We demonstrate this process by defining a convolutional network (ConvNet) from scratch.

We use objax.nn.Sequential to compose multiple layers of convolution (objax.nn.Conv2D), batch normalization (objax.nn.BatchNorm2D), ReLU (objax.functional.relu), Max Pooling (objax.functional.max_pool_2d), Average Pooling (jax.mean), and Linear (objax.nn.Linear) layers.

Since batch normalization layer behaves differently at training and at prediction, we pass the training flag to a __call__ function of ConvNet class. We also use the flag to output logits at training and probability at prediction.

[4]:
class ConvNet(objax.Module):
  """ConvNet implementation."""

  def __init__(self, nin, nclass):
    """Define 3 blocks of conv-bn-relu-conv-bn-relu followed by linear layer."""
    self.conv_block1 = objax.nn.Sequential([objax.nn.Conv2D(nin, 16, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(16),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(16, 16, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(16),
                                            objax.functional.relu])
    self.conv_block2 = objax.nn.Sequential([objax.nn.Conv2D(16, 32, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(32),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(32, 32, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(32),
                                            objax.functional.relu])
    self.conv_block3 = objax.nn.Sequential([objax.nn.Conv2D(32, 64, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(64),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(64, 64, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(64),
                                            objax.functional.relu])
    self.linear = objax.nn.Linear(64, nclass)

  def __call__(self, x, training):
    x = self.conv_block1(x, training=training)
    x = objax.functional.max_pool_2d(x, size=2, strides=2)
    x = self.conv_block2(x, training=training)
    x = objax.functional.max_pool_2d(x, size=2, strides=2)
    x = self.conv_block3(x, training=training)
    x = x.mean((2, 3))
    x = self.linear(x)
    return x
[5]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b                                           10 (10,)
(ConvNet).linear(Linear).w                                          640 (64, 10)
+Total(32)                                                        73402
Model Training and Evaluation

The train_model method combines all the parts of defining the loss function, gradient descent, training loop, and evaluation. It takes the model as a parameter so it can be reused with the two models we defined earlier.

Unlike the Logistic Regression tutorial we use the objax.functional.loss.cross_entropy_logits_sparse because we perform multi-class classification. The optimizer, gradient descent operation, and training loop remain the same.

The DNNet model expects flattened images whereas ConvNet images in (C, H, W) format. The flatten_image method prepares images before passing them to the model.

When using the model for inference we apply the objax.functional.softmax method to compute the probability distribution from the model’s logits.

[6]:
# Settings
lr = 0.03  # learning rate
batch = 128
epochs = 100

# Train loop

def train_model(model):

  def predict(model, x):
    """"""
    return objax.functional.softmax(model(x,  training=False))

  def flatten_image(x):
    """Flatten the image before passing it to the DNN."""
    if isinstance(model, DNNet):
      return objax.functional.flatten(x)
    else:
      return x

  opt = objax.optimizer.Momentum(model.vars())

  # Cross Entropy Loss
  def loss(x, label):
    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()

  gv = objax.GradValues(loss, model.vars())
  def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v

  train_op = objax.Jit(train_op, gv.vars() + opt.vars())

  for epoch in range(epochs):
    avg_loss = 0
    # randomly shuffle training data
    shuffle_idx = np.random.permutation(train.image.shape[0])
    for it in range(0, train.image.shape[0], batch):
      sel = shuffle_idx[it: it + batch]
      avg_loss += float(train_op(flatten_image(train.image[sel]), train.label[sel])[0]) * len(sel)
    avg_loss /= it + len(sel)

    # Eval
    accuracy = 0
    for it in range(0, test.image.shape[0], batch):
      x, y = test.image[it: it + batch], test.label[it: it + batch]
      accuracy += (np.argmax(predict(model, flatten_image(x)), axis=1) == y).sum()
    accuracy /= test.image.shape[0]
    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Training the DNN Model
[7]:
train_model(dnn_model)
Epoch 0001  Loss 2.39  Accuracy 56.31
Epoch 0002  Loss 1.24  Accuracy 74.19
Epoch 0003  Loss 0.75  Accuracy 84.91
Epoch 0004  Loss 0.56  Accuracy 83.10
Epoch 0005  Loss 0.48  Accuracy 86.42
Epoch 0006  Loss 0.43  Accuracy 89.00
Epoch 0007  Loss 0.41  Accuracy 89.62
Epoch 0008  Loss 0.39  Accuracy 89.82
Epoch 0009  Loss 0.37  Accuracy 90.04
Epoch 0010  Loss 0.36  Accuracy 89.55
Epoch 0011  Loss 0.35  Accuracy 90.53
Epoch 0012  Loss 0.35  Accuracy 90.64
Epoch 0013  Loss 0.34  Accuracy 90.85
Epoch 0014  Loss 0.33  Accuracy 90.87
Epoch 0015  Loss 0.33  Accuracy 91.02
Epoch 0016  Loss 0.32  Accuracy 91.35
Epoch 0017  Loss 0.32  Accuracy 91.35
Epoch 0018  Loss 0.31  Accuracy 91.50
Epoch 0019  Loss 0.31  Accuracy 91.57
Epoch 0020  Loss 0.31  Accuracy 91.75
Epoch 0021  Loss 0.30  Accuracy 91.47
Epoch 0022  Loss 0.30  Accuracy 88.14
Epoch 0023  Loss 0.30  Accuracy 91.82
Epoch 0024  Loss 0.30  Accuracy 91.92
Epoch 0025  Loss 0.29  Accuracy 92.03
Epoch 0026  Loss 0.29  Accuracy 92.04
Epoch 0027  Loss 0.29  Accuracy 92.11
Epoch 0028  Loss 0.29  Accuracy 92.11
Epoch 0029  Loss 0.29  Accuracy 92.18
Epoch 0030  Loss 0.28  Accuracy 92.24
Epoch 0031  Loss 0.28  Accuracy 92.36
Epoch 0032  Loss 0.28  Accuracy 92.17
Epoch 0033  Loss 0.28  Accuracy 92.42
Epoch 0034  Loss 0.28  Accuracy 92.42
Epoch 0035  Loss 0.27  Accuracy 92.47
Epoch 0036  Loss 0.27  Accuracy 92.50
Epoch 0037  Loss 0.27  Accuracy 92.49
Epoch 0038  Loss 0.27  Accuracy 92.58
Epoch 0039  Loss 0.26  Accuracy 92.56
Epoch 0040  Loss 0.26  Accuracy 92.56
Epoch 0041  Loss 0.26  Accuracy 92.77
Epoch 0042  Loss 0.26  Accuracy 92.72
Epoch 0043  Loss 0.26  Accuracy 92.80
Epoch 0044  Loss 0.25  Accuracy 92.85
Epoch 0045  Loss 0.25  Accuracy 92.90
Epoch 0046  Loss 0.25  Accuracy 92.96
Epoch 0047  Loss 0.25  Accuracy 93.00
Epoch 0048  Loss 0.25  Accuracy 92.82
Epoch 0049  Loss 0.25  Accuracy 93.18
Epoch 0050  Loss 0.24  Accuracy 93.09
Epoch 0051  Loss 0.24  Accuracy 92.94
Epoch 0052  Loss 0.24  Accuracy 93.20
Epoch 0053  Loss 0.24  Accuracy 93.26
Epoch 0054  Loss 0.23  Accuracy 93.21
Epoch 0055  Loss 0.24  Accuracy 93.42
Epoch 0056  Loss 0.23  Accuracy 93.35
Epoch 0057  Loss 0.23  Accuracy 93.36
Epoch 0058  Loss 0.23  Accuracy 93.56
Epoch 0059  Loss 0.23  Accuracy 93.54
Epoch 0060  Loss 0.22  Accuracy 93.39
Epoch 0061  Loss 0.23  Accuracy 93.56
Epoch 0062  Loss 0.22  Accuracy 93.74
Epoch 0063  Loss 0.22  Accuracy 93.68
Epoch 0064  Loss 0.22  Accuracy 93.72
Epoch 0065  Loss 0.22  Accuracy 93.76
Epoch 0066  Loss 0.22  Accuracy 93.87
Epoch 0067  Loss 0.21  Accuracy 93.89
Epoch 0068  Loss 0.21  Accuracy 93.96
Epoch 0069  Loss 0.21  Accuracy 93.90
Epoch 0070  Loss 0.21  Accuracy 93.99
Epoch 0071  Loss 0.21  Accuracy 94.02
Epoch 0072  Loss 0.21  Accuracy 93.86
Epoch 0073  Loss 0.21  Accuracy 94.06
Epoch 0074  Loss 0.21  Accuracy 94.14
Epoch 0075  Loss 0.20  Accuracy 94.31
Epoch 0076  Loss 0.20  Accuracy 94.14
Epoch 0077  Loss 0.20  Accuracy 94.15
Epoch 0078  Loss 0.20  Accuracy 94.10
Epoch 0079  Loss 0.20  Accuracy 94.16
Epoch 0080  Loss 0.20  Accuracy 94.28
Epoch 0081  Loss 0.20  Accuracy 94.30
Epoch 0082  Loss 0.20  Accuracy 94.28
Epoch 0083  Loss 0.19  Accuracy 94.37
Epoch 0084  Loss 0.19  Accuracy 94.33
Epoch 0085  Loss 0.19  Accuracy 94.31
Epoch 0086  Loss 0.19  Accuracy 94.25
Epoch 0087  Loss 0.19  Accuracy 94.37
Epoch 0088  Loss 0.19  Accuracy 94.38
Epoch 0089  Loss 0.19  Accuracy 94.35
Epoch 0090  Loss 0.19  Accuracy 94.38
Epoch 0091  Loss 0.19  Accuracy 94.41
Epoch 0092  Loss 0.19  Accuracy 94.46
Epoch 0093  Loss 0.19  Accuracy 94.53
Epoch 0094  Loss 0.18  Accuracy 94.47
Epoch 0095  Loss 0.18  Accuracy 94.54
Epoch 0096  Loss 0.18  Accuracy 94.65
Epoch 0097  Loss 0.18  Accuracy 94.56
Epoch 0098  Loss 0.18  Accuracy 94.60
Epoch 0099  Loss 0.18  Accuracy 94.63
Epoch 0100  Loss 0.18  Accuracy 94.46
Training the ConvNet Model
[8]:
train_model(cnn_model)
Epoch 0001  Loss 0.27  Accuracy 27.08
Epoch 0002  Loss 0.05  Accuracy 41.07
Epoch 0003  Loss 0.03  Accuracy 67.77
Epoch 0004  Loss 0.03  Accuracy 73.31
Epoch 0005  Loss 0.02  Accuracy 90.30
Epoch 0006  Loss 0.02  Accuracy 93.10
Epoch 0007  Loss 0.02  Accuracy 95.98
Epoch 0008  Loss 0.01  Accuracy 98.77
Epoch 0009  Loss 0.01  Accuracy 96.58
Epoch 0010  Loss 0.01  Accuracy 99.12
Epoch 0011  Loss 0.01  Accuracy 98.88
Epoch 0012  Loss 0.01  Accuracy 98.64
Epoch 0013  Loss 0.01  Accuracy 98.66
Epoch 0014  Loss 0.00  Accuracy 98.38
Epoch 0015  Loss 0.00  Accuracy 99.15
Epoch 0016  Loss 0.00  Accuracy 97.50
Epoch 0017  Loss 0.00  Accuracy 98.98
Epoch 0018  Loss 0.00  Accuracy 98.94
Epoch 0019  Loss 0.00  Accuracy 98.56
Epoch 0020  Loss 0.00  Accuracy 99.06
Epoch 0021  Loss 0.00  Accuracy 99.26
Epoch 0022  Loss 0.00  Accuracy 99.30
Epoch 0023  Loss 0.00  Accuracy 99.18
Epoch 0024  Loss 0.00  Accuracy 99.49
Epoch 0025  Loss 0.00  Accuracy 99.34
Epoch 0026  Loss 0.00  Accuracy 99.24
Epoch 0027  Loss 0.00  Accuracy 99.38
Epoch 0028  Loss 0.00  Accuracy 99.43
Epoch 0029  Loss 0.00  Accuracy 99.40
Epoch 0030  Loss 0.00  Accuracy 99.50
Epoch 0031  Loss 0.00  Accuracy 99.44
Epoch 0032  Loss 0.00  Accuracy 99.52
Epoch 0033  Loss 0.00  Accuracy 99.46
Epoch 0034  Loss 0.00  Accuracy 99.39
Epoch 0035  Loss 0.00  Accuracy 99.22
Epoch 0036  Loss 0.00  Accuracy 99.26
Epoch 0037  Loss 0.00  Accuracy 99.47
Epoch 0038  Loss 0.00  Accuracy 99.18
Epoch 0039  Loss 0.00  Accuracy 99.39
Epoch 0040  Loss 0.00  Accuracy 99.44
Epoch 0041  Loss 0.00  Accuracy 99.43
Epoch 0042  Loss 0.00  Accuracy 99.50
Epoch 0043  Loss 0.00  Accuracy 99.50
Epoch 0044  Loss 0.00  Accuracy 99.53
Epoch 0045  Loss 0.00  Accuracy 99.51
Epoch 0046  Loss 0.00  Accuracy 99.49
Epoch 0047  Loss 0.00  Accuracy 99.46
Epoch 0048  Loss 0.00  Accuracy 99.46
Epoch 0049  Loss 0.00  Accuracy 99.35
Epoch 0050  Loss 0.00  Accuracy 99.50
Epoch 0051  Loss 0.00  Accuracy 99.48
Epoch 0052  Loss 0.00  Accuracy 99.48
Epoch 0053  Loss 0.00  Accuracy 99.48
Epoch 0054  Loss 0.00  Accuracy 99.46
Epoch 0055  Loss 0.00  Accuracy 99.48
Epoch 0056  Loss 0.00  Accuracy 99.50
Epoch 0057  Loss 0.00  Accuracy 99.41
Epoch 0058  Loss 0.00  Accuracy 99.49
Epoch 0059  Loss 0.00  Accuracy 99.48
Epoch 0060  Loss 0.00  Accuracy 99.47
Epoch 0061  Loss 0.00  Accuracy 99.52
Epoch 0062  Loss 0.00  Accuracy 99.49
Epoch 0063  Loss 0.00  Accuracy 99.48
Epoch 0064  Loss 0.00  Accuracy 99.51
Epoch 0065  Loss 0.00  Accuracy 99.46
Epoch 0066  Loss 0.00  Accuracy 99.51
Epoch 0067  Loss 0.00  Accuracy 99.49
Epoch 0068  Loss 0.00  Accuracy 99.52
Epoch 0069  Loss 0.00  Accuracy 99.49
Epoch 0070  Loss 0.00  Accuracy 99.51
Epoch 0071  Loss 0.00  Accuracy 99.51
Epoch 0072  Loss 0.00  Accuracy 99.52
Epoch 0073  Loss 0.00  Accuracy 99.43
Epoch 0074  Loss 0.00  Accuracy 99.53
Epoch 0075  Loss 0.00  Accuracy 99.47
Epoch 0076  Loss 0.00  Accuracy 99.51
Epoch 0077  Loss 0.00  Accuracy 99.55
Epoch 0078  Loss 0.00  Accuracy 99.52
Epoch 0079  Loss 0.00  Accuracy 99.52
Epoch 0080  Loss 0.00  Accuracy 98.78
Epoch 0081  Loss 0.00  Accuracy 99.16
Epoch 0082  Loss 0.00  Accuracy 99.40
Epoch 0083  Loss 0.00  Accuracy 99.35
Epoch 0084  Loss 0.00  Accuracy 99.32
Epoch 0085  Loss 0.00  Accuracy 99.49
Epoch 0086  Loss 0.00  Accuracy 99.49
Epoch 0087  Loss 0.00  Accuracy 99.56
Epoch 0088  Loss 0.00  Accuracy 99.48
Epoch 0089  Loss 0.00  Accuracy 99.48
Epoch 0090  Loss 0.00  Accuracy 99.51
Epoch 0091  Loss 0.00  Accuracy 99.45
Epoch 0092  Loss 0.00  Accuracy 99.52
Epoch 0093  Loss 0.00  Accuracy 99.52
Epoch 0094  Loss 0.00  Accuracy 99.51
Epoch 0095  Loss 0.00  Accuracy 99.51
Epoch 0096  Loss 0.00  Accuracy 99.48
Epoch 0097  Loss 0.00  Accuracy 99.51
Epoch 0098  Loss 0.00  Accuracy 99.53
Epoch 0099  Loss 0.00  Accuracy 99.50
Epoch 0100  Loss 0.00  Accuracy 99.53
Training with PyTorch data processing API

One of the pain points for ML researchers/practitioners when building a new ML model is the data processing. Here, we demonstrate how to use data processing API of PyTorch to train a model with Objax. Different deep learning library comes with different data processing APIs, and depending on your preference, you can choose an API and easily combine with Objax.

Similarly, we prepare an MNIST dataset and apply the same data preprocessing.

[9]:
import torch
from torchvision import datasets, transforms

transform=transforms.Compose([
                              transforms.Pad((2,2,2,2), 0),
                              transforms.ToTensor(),
                              transforms.Lambda(lambda x: np.concatenate([x] * 3, axis=0)),
                              transforms.Lambda(lambda x: x * 2 - 1)
                              ])
train_dataset = datasets.MNIST(os.environ['HOME'], train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(os.environ['HOME'], train=False, download=True, transform=transform)

# Define data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/train-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/train-labels-idx1-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/MNIST/raw/t10k-images-idx3-ubyte.gz

Extracting /root/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

We replace data processing pipeline of the train and test loop with train_loader and test_loader and that’s it!

[10]:
# Train loop

def train_model_with_torch_data_api(model):

  def predict(model, x):
    """"""
    return objax.functional.softmax(model(x,  training=False))

  def flatten_image(x):
    """Flatten the image before passing it to the DNN."""
    if isinstance(model, DNNet):
      return objax.functional.flatten(x)
    else:
      return x

  opt = objax.optimizer.Momentum(model.vars())

  # Cross Entropy Loss
  def loss(x, label):
    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()

  gv = objax.GradValues(loss, model.vars())
  def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v

  train_op = objax.Jit(train_op, gv.vars() + opt.vars())

  for epoch in range(epochs):
    avg_loss = 0
    tot_data = 0
    for _, (img, label) in enumerate(train_loader):
      avg_loss += float(train_op(flatten_image(img.numpy()), label.numpy())[0]) * len(img)
      tot_data += len(img)
    avg_loss /= tot_data

    # Eval
    accuracy = 0
    tot_data = 0
    for _, (img, label) in enumerate(test_loader):
      accuracy += (np.argmax(predict(model, flatten_image(img.numpy())), axis=1) == label.numpy()).sum()
      tot_data += len(img)
    accuracy /= tot_data
    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))

Training the DNN Model with PyTorch data API
[11]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
train_model_with_torch_data_api(dnn_model)
Epoch 0001  Loss 2.57  Accuracy 34.35
Epoch 0002  Loss 1.93  Accuracy 58.51
Epoch 0003  Loss 1.32  Accuracy 68.46
Epoch 0004  Loss 0.83  Accuracy 80.95
Epoch 0005  Loss 0.62  Accuracy 84.74
Epoch 0006  Loss 0.53  Accuracy 86.53
Epoch 0007  Loss 0.48  Accuracy 84.18
Epoch 0008  Loss 0.45  Accuracy 88.42
Epoch 0009  Loss 0.42  Accuracy 87.34
Epoch 0010  Loss 0.40  Accuracy 89.29
Epoch 0011  Loss 0.39  Accuracy 89.31
Epoch 0012  Loss 0.38  Accuracy 89.86
Epoch 0013  Loss 0.37  Accuracy 89.91
Epoch 0014  Loss 0.36  Accuracy 86.94
Epoch 0015  Loss 0.36  Accuracy 89.89
Epoch 0016  Loss 0.35  Accuracy 90.12
Epoch 0017  Loss 0.34  Accuracy 90.40
Epoch 0018  Loss 0.34  Accuracy 90.31
Epoch 0019  Loss 0.34  Accuracy 90.79
Epoch 0020  Loss 0.33  Accuracy 90.71
Epoch 0021  Loss 0.33  Accuracy 90.70
Epoch 0022  Loss 0.33  Accuracy 90.69
Epoch 0023  Loss 0.33  Accuracy 90.91
Epoch 0024  Loss 0.32  Accuracy 90.92
Epoch 0025  Loss 0.32  Accuracy 91.06
Epoch 0026  Loss 0.32  Accuracy 91.19
Epoch 0027  Loss 0.32  Accuracy 91.31
Epoch 0028  Loss 0.31  Accuracy 91.31
Epoch 0029  Loss 0.31  Accuracy 91.20
Epoch 0030  Loss 0.31  Accuracy 91.31
Epoch 0031  Loss 0.31  Accuracy 91.36
Epoch 0032  Loss 0.31  Accuracy 91.42
Epoch 0033  Loss 0.30  Accuracy 91.27
Epoch 0034  Loss 0.31  Accuracy 91.47
Epoch 0035  Loss 0.30  Accuracy 91.57
Epoch 0036  Loss 0.30  Accuracy 91.44
Epoch 0037  Loss 0.30  Accuracy 91.55
Epoch 0038  Loss 0.30  Accuracy 91.56
Epoch 0039  Loss 0.29  Accuracy 91.75
Epoch 0040  Loss 0.29  Accuracy 91.69
Epoch 0041  Loss 0.29  Accuracy 91.60
Epoch 0042  Loss 0.29  Accuracy 91.77
Epoch 0043  Loss 0.29  Accuracy 91.76
Epoch 0044  Loss 0.29  Accuracy 91.84
Epoch 0045  Loss 0.28  Accuracy 92.05
Epoch 0046  Loss 0.28  Accuracy 91.78
Epoch 0047  Loss 0.28  Accuracy 92.01
Epoch 0048  Loss 0.28  Accuracy 91.95
Epoch 0049  Loss 0.28  Accuracy 90.11
Epoch 0050  Loss 0.28  Accuracy 92.14
Epoch 0051  Loss 0.28  Accuracy 92.03
Epoch 0052  Loss 0.27  Accuracy 92.29
Epoch 0053  Loss 0.27  Accuracy 92.17
Epoch 0054  Loss 0.27  Accuracy 92.12
Epoch 0055  Loss 0.27  Accuracy 92.34
Epoch 0056  Loss 0.27  Accuracy 92.32
Epoch 0057  Loss 0.27  Accuracy 92.47
Epoch 0058  Loss 0.27  Accuracy 92.38
Epoch 0059  Loss 0.27  Accuracy 92.39
Epoch 0060  Loss 0.26  Accuracy 92.51
Epoch 0061  Loss 0.27  Accuracy 92.50
Epoch 0062  Loss 0.26  Accuracy 92.46
Epoch 0063  Loss 0.26  Accuracy 92.65
Epoch 0064  Loss 0.26  Accuracy 92.57
Epoch 0065  Loss 0.26  Accuracy 92.63
Epoch 0066  Loss 0.26  Accuracy 92.75
Epoch 0067  Loss 0.26  Accuracy 92.57
Epoch 0068  Loss 0.26  Accuracy 92.88
Epoch 0069  Loss 0.25  Accuracy 92.53
Epoch 0070  Loss 0.25  Accuracy 92.80
Epoch 0071  Loss 0.25  Accuracy 92.71
Epoch 0072  Loss 0.25  Accuracy 92.75
Epoch 0073  Loss 0.25  Accuracy 92.84
Epoch 0074  Loss 0.25  Accuracy 92.71
Epoch 0075  Loss 0.25  Accuracy 92.95
Epoch 0076  Loss 0.25  Accuracy 92.82
Epoch 0077  Loss 0.25  Accuracy 92.90
Epoch 0078  Loss 0.25  Accuracy 92.87
Epoch 0079  Loss 0.25  Accuracy 89.55
Epoch 0080  Loss 0.25  Accuracy 92.86
Epoch 0081  Loss 0.24  Accuracy 92.99
Epoch 0082  Loss 0.24  Accuracy 93.03
Epoch 0083  Loss 0.24  Accuracy 93.03
Epoch 0084  Loss 0.24  Accuracy 93.01
Epoch 0085  Loss 0.24  Accuracy 93.13
Epoch 0086  Loss 0.24  Accuracy 93.17
Epoch 0087  Loss 0.24  Accuracy 92.87
Epoch 0088  Loss 0.24  Accuracy 92.93
Epoch 0089  Loss 0.24  Accuracy 93.16
Epoch 0090  Loss 0.24  Accuracy 93.38
Epoch 0091  Loss 0.24  Accuracy 92.98
Epoch 0092  Loss 0.24  Accuracy 93.30
Epoch 0093  Loss 0.23  Accuracy 93.09
Epoch 0094  Loss 0.23  Accuracy 93.19
Epoch 0095  Loss 0.23  Accuracy 93.25
Epoch 0096  Loss 0.23  Accuracy 93.22
Epoch 0097  Loss 0.23  Accuracy 93.28
Epoch 0098  Loss 0.23  Accuracy 93.39
Epoch 0099  Loss 0.23  Accuracy 93.25
Epoch 0100  Loss 0.23  Accuracy 93.30
Training the ConvNet Model with PyTorch data API
[12]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
train_model_with_torch_data_api(cnn_model)
(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b                                           10 (10,)
(ConvNet).linear(Linear).w                                          640 (64, 10)
+Total(32)                                                        73402
Epoch 0001  Loss 0.26  Accuracy 24.18
Epoch 0002  Loss 0.05  Accuracy 37.53
Epoch 0003  Loss 0.03  Accuracy 42.17
Epoch 0004  Loss 0.03  Accuracy 73.50
Epoch 0005  Loss 0.02  Accuracy 80.33
Epoch 0006  Loss 0.02  Accuracy 83.28
Epoch 0007  Loss 0.02  Accuracy 90.87
Epoch 0008  Loss 0.01  Accuracy 98.77
Epoch 0009  Loss 0.01  Accuracy 98.42
Epoch 0010  Loss 0.01  Accuracy 98.16
Epoch 0011  Loss 0.01  Accuracy 98.74
Epoch 0012  Loss 0.01  Accuracy 95.05
Epoch 0013  Loss 0.01  Accuracy 98.89
Epoch 0014  Loss 0.00  Accuracy 98.70
Epoch 0015  Loss 0.00  Accuracy 99.01
Epoch 0016  Loss 0.00  Accuracy 98.97
Epoch 0017  Loss 0.00  Accuracy 98.79
Epoch 0018  Loss 0.00  Accuracy 98.37
Epoch 0019  Loss 0.00  Accuracy 99.19
Epoch 0020  Loss 0.00  Accuracy 99.22
Epoch 0021  Loss 0.00  Accuracy 98.43
Epoch 0022  Loss 0.00  Accuracy 99.02
Epoch 0023  Loss 0.00  Accuracy 99.38
Epoch 0024  Loss 0.00  Accuracy 99.42
Epoch 0025  Loss 0.00  Accuracy 99.45
Epoch 0026  Loss 0.00  Accuracy 99.35
Epoch 0027  Loss 0.00  Accuracy 99.42
Epoch 0028  Loss 0.00  Accuracy 99.42
Epoch 0029  Loss 0.00  Accuracy 99.14
Epoch 0030  Loss 0.00  Accuracy 99.33
Epoch 0031  Loss 0.00  Accuracy 99.36
Epoch 0032  Loss 0.00  Accuracy 99.18
Epoch 0033  Loss 0.00  Accuracy 99.43
Epoch 0034  Loss 0.00  Accuracy 99.47
Epoch 0035  Loss 0.00  Accuracy 99.49
Epoch 0036  Loss 0.00  Accuracy 99.53
Epoch 0037  Loss 0.00  Accuracy 99.38
Epoch 0038  Loss 0.00  Accuracy 99.39
Epoch 0039  Loss 0.00  Accuracy 99.49
Epoch 0040  Loss 0.00  Accuracy 99.49
Epoch 0041  Loss 0.00  Accuracy 99.47
Epoch 0042  Loss 0.00  Accuracy 99.54
Epoch 0043  Loss 0.00  Accuracy 99.35
Epoch 0044  Loss 0.00  Accuracy 99.45
Epoch 0045  Loss 0.00  Accuracy 99.47
Epoch 0046  Loss 0.00  Accuracy 99.53
Epoch 0047  Loss 0.00  Accuracy 99.50
Epoch 0048  Loss 0.00  Accuracy 99.52
Epoch 0049  Loss 0.00  Accuracy 99.51
Epoch 0050  Loss 0.00  Accuracy 99.49
Epoch 0051  Loss 0.00  Accuracy 99.45
Epoch 0052  Loss 0.00  Accuracy 99.48
Epoch 0053  Loss 0.00  Accuracy 99.50
Epoch 0054  Loss 0.00  Accuracy 99.46
Epoch 0055  Loss 0.00  Accuracy 99.50
Epoch 0056  Loss 0.00  Accuracy 99.48
Epoch 0057  Loss 0.00  Accuracy 99.46
Epoch 0058  Loss 0.00  Accuracy 99.44
Epoch 0059  Loss 0.00  Accuracy 99.46
Epoch 0060  Loss 0.00  Accuracy 99.26
Epoch 0061  Loss 0.00  Accuracy 93.99
Epoch 0062  Loss 0.00  Accuracy 97.80
Epoch 0063  Loss 0.00  Accuracy 80.26
Epoch 0064  Loss 0.00  Accuracy 99.20
Epoch 0065  Loss 0.00  Accuracy 99.38
Epoch 0066  Loss 0.00  Accuracy 99.44
Epoch 0067  Loss 0.00  Accuracy 99.51
Epoch 0068  Loss 0.00  Accuracy 99.45
Epoch 0069  Loss 0.00  Accuracy 99.42
Epoch 0070  Loss 0.00  Accuracy 99.50
Epoch 0071  Loss 0.00  Accuracy 99.52
Epoch 0072  Loss 0.00  Accuracy 99.44
Epoch 0073  Loss 0.00  Accuracy 99.41
Epoch 0074  Loss 0.00  Accuracy 99.46
Epoch 0075  Loss 0.00  Accuracy 99.42
Epoch 0076  Loss 0.00  Accuracy 99.49
Epoch 0077  Loss 0.00  Accuracy 99.50
Epoch 0078  Loss 0.00  Accuracy 99.56
Epoch 0079  Loss 0.00  Accuracy 99.52
Epoch 0080  Loss 0.00  Accuracy 99.42
Epoch 0081  Loss 0.00  Accuracy 99.49
Epoch 0082  Loss 0.00  Accuracy 99.48
Epoch 0083  Loss 0.00  Accuracy 99.44
Epoch 0084  Loss 0.00  Accuracy 99.49
Epoch 0085  Loss 0.00  Accuracy 99.53
Epoch 0086  Loss 0.00  Accuracy 99.52
Epoch 0087  Loss 0.00  Accuracy 99.52
Epoch 0088  Loss 0.00  Accuracy 99.50
Epoch 0089  Loss 0.00  Accuracy 99.51
Epoch 0090  Loss 0.00  Accuracy 99.50
Epoch 0091  Loss 0.00  Accuracy 99.49
Epoch 0092  Loss 0.00  Accuracy 99.52
Epoch 0093  Loss 0.00  Accuracy 99.50
Epoch 0094  Loss 0.00  Accuracy 99.50
Epoch 0095  Loss 0.00  Accuracy 99.55
Epoch 0096  Loss 0.00  Accuracy 99.48
Epoch 0097  Loss 0.00  Accuracy 99.51
Epoch 0098  Loss 0.00  Accuracy 99.52
Epoch 0099  Loss 0.00  Accuracy 99.49
Epoch 0100  Loss 0.00  Accuracy 99.52
What’s Next

We have learned how to use existing models and define new models to classify MNIST. Next, you can read one or more of the in-depth topics or browse through the Objax’s APIs.

Code Examples

This section describes the code examples found in objax/examples.

Classification
Image

Example code available at examples/image_classification.

Logistic Regression

Train and evaluate a logistic regression model for binary classification on horses or humans dataset.

# Run command
python3 examples/image_classification/horses_or_humans_logistic.py

Code

examples/image_classification/horses_or_humans_logistic.py

Data

horses_or_humans from tensorflow_datasets

Network

Custom single layer

Loss

objax.functional.loss.sigmoid_cross_entropy_logits()

Optimizer

objax.optimizer.SGD

Accuracy

~77%

Hardware

CPU or GPU or TPU

Digit Classification with Deep Neural Network (DNN)

Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_dnn.py

Code

examples/image_classification/mnist_dnn.py

Data

MNIST from tensorflow_datasets

Network

Deep Neural Net objax.zoo.dnnet.DNNet

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Adam

Accuracy

~98%

Hardware

CPU or GPU or TPU

Techniques

Model weight averaging for improved accuracy using objax.optimizer.ExponentialMovingAverage.

Digit Classification with Convolutional Neural Network (CNN)

Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/image_classification/mnist_cnn.py

Code

examples/image_classification/mnist_cnn.py

Data

MNIST from tensorflow_datasets

Network

Custom Convolution Neural Net using objax.nn.Sequential

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Adam

Accuracy

~99.5%

Hardware

CPU or GPU or TPU

Techniques

Digit Classification using Differential Privacy

Train and evaluate a convNet model for MNIST dataset with differential privacy.

# Run command
python3 examples/image_classification/mnist_dp.py
# See available options with
python3 examples/image_classification/mnist_dp.py --help

Code

examples/image_classification/mnist_dp.py

Data

MNIST from tensorflow_datasets

Network

Custom Convolution Neural Net using objax.nn.Sequential

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.SGD

Accuracy

Hardware

GPU

Techniques

Image Classification on CIFAR-10 (Simple)

Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_simple.py

Code

examples/image_classification/cifar10_simple.py

Data

CIFAR10 from tf.keras.datasets

Network

Wide ResNet using objax.zoo.wide_resnet.WideResNet

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

~91%

Hardware

GPU or TPU

Techniques

  • Learning rate schedule.

  • Data augmentation (mirror / pixel shifts) in Numpy.

  • Regularization using extra weight decay term in loss.

Image Classification on CIFAR-10 (Advanced)

Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/image_classification/cifar10_advanced.py
# Run with custom settings
python3 examples/image_classification/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/image_classification/cifar10_advanced.py --help

Code

examples/image_classification/cifar10_advanced.py

Data

CIFAR10 from tensorflow_datasets

Network

Configurable with --arch="network" * wrn28-1, wrn28-2 using objax.zoo.wide_resnet.WideResNet * cnn32-3-max, cnn32-3-mean, cnn64-3-max, cnn64-3-mean using objax.zoo.convnet.ConvNet

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Momentum

Accuracy

~94%

Hardware

GPU, Multi-GPU or TPU

Techniques

  • Model weight averaging for improved accuracy using objax.optimizer.ExponentialMovingAverage.

  • Parallelized on multiple GPUs using objax.Parallel.

  • Data augmentation (mirror / pixel shifts) in TensorFlow.

  • Cosine learning rate decay.

  • Regularization using extra weight decay term in loss.

  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using objax.io.Checkpoint.

  • Saving of tensorboard visualization files using objax.jaxboard.SummaryWriter.

  • Multi-loss reporting (cross-entropy, L2).

  • Reusable training loop example.

Image Classification on ImageNet

Train and evaluate a ResNet50 model on the ImageNet dataset. See README for additional information.

Code

examples/image_classification/imagenet_resnet50_train.py

Data

ImageNet from tensorflow_datasets

Network

ResNet50

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

Hardware

GPU, Multi-GPU or TPU

Techniques

  • Parallelized on multiple GPUs using objax.Parallel.

  • Data augmentation (distorted bounding box crop) in TensorFlow.

  • Linear warmup followed by multi-step learning rate decay.

  • Regularization using extra weight decay term in loss.

  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using objax.io.Checkpoint.

  • Saving of tensorboard visualization files using objax.jaxboard.SummaryWriter.

Image Classification using Pretrained VGG Network

Image classification using an ImageNet-pretrained VGG19 model. See README for additional information.

Code

examples/image_classification/imagenet_pretrained_vgg.py

Techniques

Load VGG-19 model with pretrained weights and run 1000-way image classification.

Semi-Supervised Learning

Example code available at examples/fixmatch.

Semi-Supervised Learning with FixMatch

Semi-supervised learning of image classification models with FixMatch.

# Run command
python3 examples/fixmatch/fixmatch.py
# Run with custom settings
python3 examples/fixmatch/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/fixmatch/fixmatch.py --help

Code

examples/fixmatch/fixmatch.py

Data

CIFAR10, CIFAR100, SVHN, STL10

Network

Custom implementation of Wide ResNet.

Loss

objax.functional.loss.cross_entropy_logits() and objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

See paper

Hardware

GPU, Multi-GPU, TPU

Techniques

GPT-2

Example code is available at examples/gpt-2.

Generating a Text Sequence using GPT-2

Load pretrained GPT-2 model (124M parameter) and demonstrate how to use the model to generate a text sequence. See README for additional information.

Code

examples/gpt-2/gpt2.py

Hardware

GPU or TPU

Techniques

  • Define Transformer model.

  • Load GPT-2 model with pretrained weights and generate a sequence.

RNN

Example code is available at examples/text_generation.

Train a Vanilla RNN to Predict Characters

Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset. See README for additional information.

# Run command
python3 examples/text_generation/shakespeare_rnn.py

Code

examples/text_generation/shakespeare_rnn.py

Data

Shakespeare corpus from tensorflow_datasets

Network

Custom implementation of vanilla RNN.

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Adam

Hardware

GPU or TPU

Techniques

Optimization

Example codes available at examples/maml.

Model Agnostic Meta-Learning (MAML)

Meta-learning method MAML implementation to demonstrate computing the gradient of a gradient.

# Run command
python3 examples/maml/maml.py

Code

examples/maml/maml.py

Data

Synthetic data

Network

3-layer DNNet

Hardware

CPU or GPU or TPU

Techniques

Gradient of gradient.

Jaxboard

Example code available at examples/jaxboard.

How to Use Jaxboard

Sample usage of jaxboard. See README for additional information.

# Run command
python3 examples/jaxboard/summary.py

Code

examples/jaxboard/summary.py

Hardware

CPU

Usages

  • summary scalar

  • summary text

  • summary image

Additional tutorials

This section includes various tutorials for Objax.

Frequently Asked Questions

What is the difference between Objax and other JAX frameworks?

JAX itself as well as most of JAX frameworks (other than Objax) follows a functional style programming paradigm. This means that all computations are expected to be performed by stateless pure functions. And state (i.e. model weights) has to be manually passed to these functions.

On the other hand, Objax follows an object-oriented programming paradigm (similar to PyTorch and Tensorflow). Objax provides objects (called Objax modules) which store and manage the state of a neural network.

To better illustrate this distinction, below are two examples of a similar code written in pure JAX and Objax.

Every time when a user calls neural network components in JAX (and many JAX frameworks), they have to pass both neural network parameters params as well as training examples batch['x'], batch['y']:

params = (jn.zeros(ndim), jn.zeros(1))

def loss(params, x, y):
    w, b = params
    pred = jn.dot(x, w) + b
    return 0.5 * ((y - pred) ** 2).mean()

g_fn = jax.grad(loss)              # g_fn is a function

# Need to pass both parameters and batch to g_fn
g_value = g_fn(params, batch['x'], batch['y'])

On the other, modules in Objax store parameters and state internally. Thus a user only has to pass around training examples batch['x'], batch['y']:

w = objax.TrainVar(jn.zeros(ndim))
b = objax.TrainVar(jn.zeros(1))

def loss(x, y):
    pred = jn.dot(x, w.value) + b.value
    return 0.5 * ((y - pred) ** 2).mean()

g_fn = objax.Grad(loss,           # g_fn is Objax module
                  objax.VarCollection({'w': w, 'b': b}))

# Need to pass only batch to g_fn
g_value = g_fn(batch['x'], batch['y'])
What is the difference between Objax and PyTorch/Tensorflow?
Execution runtime

Objax is implemented on top of JAX, while PyTorch and Tensorflow have their own underlying runtime environments. In practice it mainly means that to interoperate between these frameworks some conversion needs to be done. For example convert PyTorch/Tensorflow tensor to NumPy array and then feed this NumPy array to code in Objax.

Design of API

Objax was inspired by the best of other machine learning frameworks (including PyTorch and Tensorflow). Thus readers may observe similarities between Objax API and API of PyTorch (or some other frameworks).

Nevertheless, Objax is not intended to be a re-implementation of the API of any other framework and each Objax design decision is weighted on its own merit. So there will always be differences between Objax API and APIs of other frameworks.

Objax API

objax package
Modules

Module()

A module is a container to associate variables and functions.

ModuleList([iterable])

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

ForceArgs(module, **kwargs)

Forces override of arguments of given module.

Function(f, vc)

Turn a function into a Module by keeping the vars it uses.

Grad(f, variables[, input_argnums])

The Grad module is used to compute the gradients of a function.

GradValues(f, variables[, input_argnums])

The GradValues module is used to compute the gradients of a function.

Jit(f[, vc, static_argnums])

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Parallel(f[, vc, reduce, axis_name, …])

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Vectorize(f[, vc, batch_axis])

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

class objax.Module[source]

A module is a container to associate variables and functions.

vars(scope='')[source]

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

class objax.ModuleList(iterable=(), /)[source]

Bases: objax.module.Module, list

This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.

Usage example:

import objax

ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))])
print(ml.vars())
# (ModuleList)[1]            20 (10, 2)
# +Total(1)                  20

ml.pop()
ml.append(objax.nn.Linear(2, 3))
print(ml.vars())
# (ModuleList)[1](Linear).b        3 (3,)
# (ModuleList)[1](Linear).w        6 (2, 3)
# +Total(2)                        9
vars(scope='')[source]

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Function(f, vc)[source]

Turn a function into a Module by keeping the vars it uses.

Usage example:

import objax
import jax.numpy as jn

m = objax.nn.Linear(2, 3)

def f1(x, y):
    return ((m(x) - y) ** 2).mean()

# Method 1: Create module by calling objax.Function to tell which variables are used.
m1 = objax.Function(f1, m.vars())

# Method 2: Use the decorator.
@objax.Function.with_vars(m.vars())
def f2(x, y):
    return ((m(x) - y) ** 2).mean()

# All behave like functions
x = jn.arange(10).reshape((5, 2))
y = jn.arange(15).reshape((5, 3))
print(type(f1), f1(x, y))  # <class 'function'> 237.01947
print(type(m1), m1(x, y))  # <class 'objax.module.Function'> 237.01947
print(type(f2), f2(x, y))  # <class 'objax.module.Function'> 237.01947

Usage of Function is not necessary: it is made available for aesthetic reasons (to accomodate for users personal taste). It is also used internally to keep the code simple for Grad, Jit, Parallel, Vectorize and future primitives.

__init__(f, vc)[source]

Function constructor.

Parameters
  • f (Callable) – the function or the module to represent.

  • vc (objax.variable.VarCollection) – the VarCollection of variables used by the function.

__call__(*args, **kwargs)[source]

Call the the function.

vars(scope='')[source]

Return the VarCollection of the variables used by the function.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.ForceArgs(module, **kwargs)[source]

Forces override of arguments of given module.

One example of ForceArgs usage is to override training argument for batch normalization:

import objax
from objax.zoo.resnet_v2 import ResNet50

model = ResNet50(in_channels=3, num_classes=1000)

# Modify model to force training=False on first resnet block.
# First two ops in the resnet are convolution and padding,
# resnet blocks are starting at index 2.
model[2] = objax.ForceArgs(model[2], training=False)

# model(x, training=True) will be using `training=False` on model[2] due to ForceArgs
# ...

# Undo specific value of forced arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=True)

# Undo all values of specific argument in `model` and all submodules of `model`
objax.ForceArgs.undo(model, training=objax.ForceArgs.ANY)

# Undo all values of all arguments in `model` and all submodules of `model`
objax.ForceArgs.undo(model)
class ANY

Token used in ForceArgs.undo to indicate undo of all values of specific argument.

static undo(module, **kwargs)[source]

Undo ForceArgs on each submodule of the module. Modifications are done in-place.

Parameters
  • module (objax.module.Module) – module for which to undo ForceArgs.

  • **kwargs – dictionary of argument overrides to undo. name=val remove override for value val of argument name. name=ForceArgs.ANY remove all overrides of argument name. If **kwargs is empty then all overrides will be undone.

__init__(module, **kwargs)[source]

Initializes ForceArgs by wrapping another module.

Parameters
  • module (objax.module.Module) – module which argument will be overridden.

  • kwargs – values of keyword arguments which will be forced to use.

vars(scope='')[source]

Returns the VarCollection of the wrapped module.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables of wrapped module.

Return type

objax.variable.VarCollection

__call__(*args, **kwargs)[source]

Calls wrapped module using forced args to override wrapped module arguments.

class objax.Grad(f, variables, input_argnums=None)[source]

The Grad module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])

def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_f = objax.Grad(f, m.vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_fx = objax.Grad(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f.

Returns

A list of input gradients, if any, followed by the variable gradients.

class objax.GradValues(f, variables, input_argnums=None)[source]

The GradValues module is used to compute the gradients of a function.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 2)])

@objax.Function.with_vars(m.vars())
def f(x, y):
    return ((m(x) - y) ** 2).mean()

# Create module to compute gradients of f for m.vars()
grad_val_f = objax.GradValues(f, m.vars())

# Create module to compute gradients of f for only some variables
grad_val_f_head = objax.GradValues(f, m[:1].vars())

# Create module to compute gradients of f for input 0 (x) and m.vars()
grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))

For more information and examples, see Understanding Gradients.

__init__(f, variables, input_argnums=None)[source]

Constructs an instance to compute the gradient of f w.r.t. variables.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function for which to compute gradients.

  • variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.

  • input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.

__call__(*args, **kwargs)[source]

Returns the computed gradients for the first value returned by f and the values returned by f.

Returns

A tuple (gradients , values of f]), where gradients is a list containing

the input gradients, if any, followed by the variable gradients.

vars(scope='')[source]

Return the VarCollection of the variables used.

Parameters

scope (str) –

Return type

objax.variable.VarCollection

class objax.Jit(f, vc=None, static_argnums=None)[source]

JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
jit_m = objax.Jit(m)                          # Jit a module

# For jitting functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

jit_f = objax.Jit(f)

For more information, refer to JIT Compilation. Also note that one can pass variables to be used by Jit for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, static_argnums=None)[source]

Jit constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args, **kwargs)[source]

Call the compiled version of the function or module.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Parallel(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
para_m = objax.Parallel(m)                         # Parallelize a module

# For parallelizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

para_f = objax.Parallel(f)

When calling a parallelized module, one must replicate the variables it uses on all devices:

x = objax.random.normal((16, 2))
with m.vars().replicate():
    y = para_m(x)

For more information, refer to Parallelism. Also note that one can pass variables to be used by Parallel for a module m: the rest will be optimized away as constants, for more information refer to Constant optimization.

__init__(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]

Parallel constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • reduce (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.

  • axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.

  • static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.

__call__(*args)[source]

Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.Vectorize(f, vc=None, batch_axis=(0))[source]

Vectorize module takes a function or a module and compiles it for running in parallel on a single device.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vec_m = objax.Vectorize(m)                         # Vectorize a module

# For vectorizing functions, use objax.Function.with_vars
@objax.Function.with_vars(m.vars())
def f(x):
    return m(x) + 1

vec_f = objax.Parallel(f)

For more information and examples, refer to Vectorization.

__init__(f, vc=None, batch_axis=(0))[source]

Vectorize constructor.

Parameters
  • f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.

  • vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.

  • batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.

vars(scope='')

Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

__call__(*args)[source]

Call the vectorized version of the function or module.

Variables

BaseVar(reduce)

The abstract base class to represent objax variables.

TrainVar(tensor[, reduce])

A trainable variable.

BaseState(reduce)

The abstract base class used to represent objax state variables.

StateVar(tensor[, reduce])

StateVar are variables that get updated manually, and are not automatically updated by optimizers.

TrainRef(ref)

A state variable that references a trainable variable for assignment.

RandomState(seed)

RandomState are variables that track the random generator state.

VarCollection

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy.

class objax.BaseVar(reduce)[source]

The abstract base class to represent objax variables.

__init__(reduce)[source]

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainVar(tensor, reduce=<function reduce_mean>)[source]

A trainable variable.

__init__(tensor, reduce=<function reduce_mean>)[source]

TrainVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.

assign(tensor, check=True)[source]

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.BaseState(reduce)[source]

The abstract base class used to represent objax state variables. State variables are not trainable.

__init__(reduce)

Constructor for BaseVar class.

Parameters

reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

class objax.TrainRef(ref)[source]

A state variable that references a trainable variable for assignment.

TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.

__init__(ref)[source]

TrainRef constructor.

Parameters

ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.

property value

The value stored in the referenced TrainVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.StateVar(tensor, reduce=<function reduce_mean>)[source]

StateVar are variables that get updated manually, and are not automatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.

__init__(tensor, reduce=<function reduce_mean>)[source]

StateVar constructor.

Parameters
  • tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.

  • reduce (Optional[Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape (n, *dims) and returns one of shape (*dims). Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.

property value

The value stored in the StateVar, it can be read or written.

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

class objax.RandomState(seed)[source]

RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.

__init__(seed)[source]

RandomState constructor.

Parameters

seed (int) – the initial seed of the random number generator.

seed(seed)[source]

Sets a new random seed.

Parameters

seed (int) – the new initial seed of the random number generator.

split(n)[source]

Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.

Parameters

n (int) – the number of seeds to generate.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

assign(tensor, check=True)

Sets the value of the variable.

Parameters

tensor (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

reduce(tensors)

Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.

Parameters

tensors (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

property value

The value stored in the StateVar, it can be read or written.

class objax.VarCollection[source]

A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
vc = m.vars()  # This is a VarCollection

# It is a dictionary
print(repr(vc))
# {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>,
#  '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>}
print(vc.keys())  # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w'])
assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all()

# Convenience print
print(vc)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# Extra methods for manipulation of variables:
# For example, increment all variables by 1
vc.assign([x+1 for x in vc.tensors()])

# It's used by other modules.
# For example it's used to tell what variables are used by a function.

@objax.Function.with_vars(vc)
def my_function(x):
    return objax.functional.softmax(m(x))

For more information and examples, refer to VarCollection.

assign(tensors)[source]

Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.

Parameters

tensors (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.

rename(renamer)[source]

Rename the entries in the VarCollection.

Renaming entries in a VarCollection is a powerful tool that can be used for

  • mapping weights between models that differ slightly.

  • loading data checkpoints from foreign ML frameworks.

Usage example:

import re
import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu])
print(m.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# For example remove modules from the name
renamer = objax.util.Renamer([(re.compile('\([^)]+\)'), '')])
print(m.vars().rename(renamer))
# [0].b                       3 (3,)
# [0].w                       6 (2, 3)
# +Total(2)                   9

# One can chain renamers, their syntax is flexible and it can use a string mapping:
renamer_all = objax.util.Renamer({'[': '.', ']': ''}, renamer)
print(m.vars().rename(renamer_all))
# .0.b                        3 (3,)
# .0.w                        6 (2, 3)
# +Total(2)                   9
Parameters

renamer (objax.util.util.Renamer) –

replicate()[source]

A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.

subset(is_a=None, is_not=None)[source]

Return a new VarCollection that is a filtered subset of the current collection.

Parameters
  • is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

  • is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.

Returns

A new VarCollection containing the subset of variables.

Return type

objax.variable.VarCollection

tensors(is_a=None)[source]

Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.

Parameters

is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.

Returns

A new VarCollection containing the subset of variables.

Return type

List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]

update(other)[source]

Overload dict.update method to catch potential conflicts during assignment.

Parameters

other (Union[objax.variable.VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

Constants
class objax.ConvPadding(value)[source]

An Enum holding the possible padding values for convolution modules.

SAME = 'SAME'
VALID = 'VALID'
objax.functional package
objax.functional

Due to the large number of APIs in this section, we organized it into the following sub-sections:

Activation

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis])

Log-Softmax function.

logsumexp(a[, axis, b, keepdims, return_sign])

Compute the log of the sum of exponentials of input elements.

relu(x)

Rectified linear unit activation function.

selu(x)

Scaled exponential linear unit activation.

sigmoid(x)

Sigmoid activation function.

softmax(x[, axis])

Softmax function.

softplus(x)

Softplus activation function.

tanh(x)

Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\).

objax.functional.celu(x, alpha=1.0)[source]

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

Parameters
  • x (Any) –

  • alpha (Any) –

Return type

Any

objax.functional.elu(x, alpha=1.0)[source]

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
Parameters
  • x (Any) –

  • alpha (Any) –

Return type

Any

objax.functional.leaky_relu(x, negative_slope=0.01)[source]

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

Parameters
  • x (Any) –

  • negative_slope (Any) –

Return type

Any

objax.functional.log_sigmoid(x)[source]

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
Parameters

x (Any) –

Return type

Any

objax.functional.log_softmax(x, axis=- 1)[source]

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters
  • axis (Optional[Union[int, Tuple[int, ..]]]) – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.

  • x (Any) –

Return type

Any

objax.functional.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]

Compute the log of the sum of exponentials of input elements.

LAX-backend implementation of logsumexp(). Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.

  • b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.

  • return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).

Returns

  • res (ndarray) – The result, np.log(np.sum(np.exp(a))) calculated in a numerically more stable way. If b is given then np.log(np.sum(b*np.exp(a))) is returned.

  • sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.

See also

numpy.logaddexp, numpy.logaddexp2

Notes

NumPy has a logaddexp function which is very similar to logsumexp, but only handles two arguments. logaddexp.reduce is similar to this function, but may be less stable.

Examples

>>> from scipy.special import logsumexp
>>> a = np.arange(10)
>>> np.log(np.sum(np.exp(a)))
9.4586297444267107
>>> logsumexp(a)
9.4586297444267107

With weights

>>> a = np.arange(10)
>>> b = np.arange(10, 0, -1)
>>> logsumexp(a, b=b)
9.9170178533034665
>>> np.log(np.sum(b*np.exp(a)))
9.9170178533034647

Returning a sign flag

>>> logsumexp([1,2],b=[1,-1],return_sign=True)
(1.5413248546129181, -1.0)

Notice that logsumexp does not directly support masked arrays. To use it on a masked array, convert the mask into zero weights:

>>> a = np.ma.array([np.log(2), 2, np.log(3)],
...                  mask=[False, True, False])
>>> b = (~a.mask).astype(int)
>>> logsumexp(a.data, b=b), np.log(5)
1.6094379124341005, 1.6094379124341005
objax.functional.relu(x)[source]

Rectified linear unit activation function.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

Returns

tensor with the element-wise output relu(x) = max(x, 0).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.selu(x)[source]

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters

x (Any) –

Return type

Any

objax.functional.sigmoid(x)[source]

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
Parameters

x (Any) –

Return type

Any

objax.functional.softmax(x, axis=- 1)[source]

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters
  • axis (Optional[Union[int, Tuple[int, ..]]]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.

  • x (Any) –

Return type

Any

objax.functional.softplus(x)[source]

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
Parameters

x (Any) –

Return type

Any

objax.functional.tanh(x)[source]

Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\).

Parameters

x (Any) –

Return type

Any

Pooling

average_pool_2d(x[, size, strides, padding])

Applies average pooling using a square 2D filter.

batch_to_space2d(x[, size])

Transfer batch dimension N into spatial dimensions (H, W).

channel_to_space2d(x[, size])

Transfer channel dimension C into spatial dimensions (H, W).

max_pool_2d(x[, size, strides, padding])

Applies max pooling using a square 2D filter.

space_to_batch2d(x[, size])

Transfer spatial dimensions (H, W) into batch dimension N.

space_to_channel2d(x[, size])

Transfer spatial dimensions (H, W) into channel dimension C.

objax.functional.average_pool_2d(x, size=2, strides=None, padding=<ConvPadding.VALID: 'VALID'>)[source]

Applies average pooling using a square 2D filter.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of pooling filter.

  • strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

Returns

output tensor of shape (N, C, H, W).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

For a definition of pooling, including examples see Pooling Layer.

objax.functional.batch_to_space2d(x, size=2)[source]

Transfer batch dimension N into spatial dimensions (H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.channel_to_space2d(x, size=2)[source]

Transfer channel dimension C into spatial dimensions (H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.max_pool_2d(x, size=2, strides=None, padding=<ConvPadding.VALID: 'VALID'>)[source]

Applies max pooling using a square 2D filter.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of pooling filter.

  • strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

Returns

output tensor of shape (N, C, H, W).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

For a definition of pooling, including examples see Pooling Layer.

objax.functional.space_to_batch2d(x, size=2)[source]

Transfer spatial dimensions (H, W) into batch dimension N.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.space_to_channel2d(x, size=2)[source]

Transfer spatial dimensions (H, W) into channel dimension C.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

Misc

dynamic_slice(operand, start_indices, …)

Wraps XLA’s DynamicSlice operator.

flatten(x)

Flattens input tensor to a 2D tensor.

interpolate(input[, size, scale_factor, mode])

Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g.

one_hot(x, num_classes, *[, dtype])

One-hot encodes the given indicies.

pad(array, pad_width[, mode, stat_length, …])

Pad an array.

scan(f, init, xs[, length, reverse, unroll])

Scan a function over leading array axes while carrying along state.

stop_gradient(x)

Stops gradient computation.

top_k(operand, k)

Returns top k values and their indices along the last axis of operand.

rsqrt(x)

Elementwise reciprocal square root: :math:`1 over sqrt{x}.

upsample_2d(x, scale[, method])

Function to upscale 2D images.

upscale_nn(x[, scale])

Nearest neighbor upscale for image batches of shape (N, C, H, W).

objax.functional.dynamic_slice(operand, start_indices, slice_sizes)[source]

Wraps XLA’s DynamicSlice operator.

Parameters
  • operand (Any) – an array to slice.

  • start_indices (Sequence[Any]) – a list of scalar indices, one per dimension. These values may be dynamic.

  • slice_sizes (Sequence[int]) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).

Returns

An array containing the slice.

Return type

Any

objax.functional.flatten(x)[source]

Flattens input tensor to a 2D tensor.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor with dimensions (n_1, n_2, …, n_k)

Returns

The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.interpolate(input, size=None, scale_factor=None, mode=<Interpolate.BILINEAR: 'bilinear'>)[source]

Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. [‘bilinear’, ‘nearest’]

Returns

output JaxArray after interpolation

Return type

output

Parameters
  • input (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

  • size (Optional[Union[int, Tuple[int, ..]]]) –

  • scale_factor (Optional[Union[int, Tuple[int, ..]]]) –

  • mode (Union[objax.constants.Interpolate, str]) –

objax.functional.one_hot(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)[source]

One-hot encodes the given indicies.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],

[0., 1., 0.], [0., 0., 1.]], dtype=float32)

Indicies outside the range [0, num_classes) will be encoded as zeros:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],

[0., 0., 0.]], dtype=float32)

Parameters
  • x (Any) – A tensor of indices.

  • num_classes (int) – Number of classes in the one-hot dimension.

  • dtype (Any) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Return type

Any

objax.functional.pad(array, pad_width, mode='constant', *, stat_length=None, constant_values=0, end_values=None, reflect_type=None)[source]

Pad an array.

LAX-backend implementation of pad(). Unlike numpy, JAX “function” mode’s argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, “function” mode’s argument should modify a rank 1 array in-place.)

Original docstring below.

Parameters
  • array (array_like of rank N) – The array to pad.

  • pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.

  • mode (str or function, optional) – One of the following string values or a user supplied function.

  • stat_length (sequence or int, optional) – Used in ‘maximum’, ‘mean’, ‘median’, and ‘minimum’. Number of values at edge of each axis used to calculate the statistic value.

  • constant_values (sequence or scalar, optional) – Used in ‘constant’. The values to set the padded values for each axis.

  • end_values (sequence or scalar, optional) – Used in ‘linear_ramp’. The values used for the ending value of the linear_ramp and that will form the edge of the padded array.

  • reflect_type ({'even', 'odd'}, optional) – Used in ‘reflect’, and ‘symmetric’. The ‘even’ style is the default with an unaltered reflection around the edge value. For the ‘odd’ style, the extended part of the array is created by subtracting the reflected values from two times the edge value.

Returns

pad – Padded array of rank equal to array with shape increased according to pad_width.

Return type

ndarray

Notes

New in version 1.7.0.

For an array with rank greater than 1, some of the padding of later axes is calculated from padding of previous axes. This is easiest to think about with a rank 2 array where the corners of the padded array are calculated by using padded values from the first axis.

The padding function, if used, should modify a rank 1 array in-place. It has the following signature:

padding_func(vector, iaxis_pad_width, iaxis, kwargs)

where

vectorndarray

A rank 1 array already padded with zeros. Padded values are vector[:iaxis_pad_width[0]] and vector[-iaxis_pad_width[1]:].

iaxis_pad_widthtuple

A 2-tuple of ints, iaxis_pad_width[0] represents the number of values padded at the beginning of vector where iaxis_pad_width[1] represents the number of values padded at the end of vector.

iaxisint

The axis currently being calculated.

kwargsdict

Any keyword arguments the function requires.

Examples

>>> a = [1, 2, 3, 4, 5]
>>> np.pad(a, (2, 3), 'constant', constant_values=(4, 6))
array([4, 4, 1, ..., 6, 6, 6])
>>> np.pad(a, (2, 3), 'edge')
array([1, 1, 1, ..., 5, 5, 5])
>>> np.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4))
array([ 5,  3,  1,  2,  3,  4,  5,  2, -1, -4])
>>> np.pad(a, (2,), 'maximum')
array([5, 5, 1, 2, 3, 4, 5, 5, 5])
>>> np.pad(a, (2,), 'mean')
array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> np.pad(a, (2,), 'median')
array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> a = [[1, 2], [3, 4]]
>>> np.pad(a, ((3, 2), (2, 3)), 'minimum')
array([[1, 1, 1, 2, 1, 1, 1],
       [1, 1, 1, 2, 1, 1, 1],
       [1, 1, 1, 2, 1, 1, 1],
       [1, 1, 1, 2, 1, 1, 1],
       [3, 3, 3, 4, 3, 3, 3],
       [1, 1, 1, 2, 1, 1, 1],
       [1, 1, 1, 2, 1, 1, 1]])
>>> a = [1, 2, 3, 4, 5]
>>> np.pad(a, (2, 3), 'reflect')
array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2])
>>> np.pad(a, (2, 3), 'reflect', reflect_type='odd')
array([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8])
>>> np.pad(a, (2, 3), 'symmetric')
array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3])
>>> np.pad(a, (2, 3), 'symmetric', reflect_type='odd')
array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7])
>>> np.pad(a, (2, 3), 'wrap')
array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3])
>>> def pad_with(vector, pad_width, iaxis, kwargs):
...     pad_value = kwargs.get('padder', 10)
...     vector[:pad_width[0]] = pad_value
...     vector[-pad_width[1]:] = pad_value
>>> a = np.arange(6)
>>> a = a.reshape((2, 3))
>>> np.pad(a, 2, pad_with)
array([[10, 10, 10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10, 10, 10],
       [10, 10,  0,  1,  2, 10, 10],
       [10, 10,  3,  4,  5, 10, 10],
       [10, 10, 10, 10, 10, 10, 10],
       [10, 10, 10, 10, 10, 10, 10]])
>>> np.pad(a, 2, pad_with, padder=100)
array([[100, 100, 100, 100, 100, 100, 100],
       [100, 100, 100, 100, 100, 100, 100],
       [100, 100,   0,   1,   2, 100, 100],
       [100, 100,   3,   4,   5, 100, 100],
       [100, 100, 100, 100, 100, 100, 100],
       [100, 100, 100, 100, 100, 100, 100]])
objax.functional.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]

Scan a function over leading array axes while carrying along state.

The type signature in brief is

scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

When a is an array type or None, and b is an array type, the semantics of scan are given roughly by this Python implementation:

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

Unlike that Python version, both a and b may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. (None is actually an empty pytree.)

Also unlike that Python version, scan is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an @jit function are unrolled, leading to large XLA computations.

Finally, the loop-carried value carry must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters
  • f (Callable[[Carry, X], Tuple[Carry, Y]]) – a Python function to be scanned of type c -> a -> (c, b), meaning that f accepts two arguments where the first is a value of the loop carry and the second is a slice of xs along its leading axis, and that f returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.

  • init (Carry) – an initial loop carry value of type c, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by f.

  • xs (X) – the value of type [a] over which to scan along the leading axis, where [a] can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.

  • length (Optional[int]) – optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in xs (but can be used to perform scans where no input xs are needed).

  • reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

  • unroll (int) – optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.

Returns

A pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs.

Return type

Tuple[Carry, Y]

objax.functional.stop_gradient(x)[source]

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
array(6., dtype=float32)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
array(0., dtype=float32)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
array(2., dtype=float32)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
array(0., dtype=float32)
objax.functional.top_k(operand, k)[source]

Returns top k values and their indices along the last axis of operand.

Parameters
  • operand (Any) –

  • k (int) –

Return type

Tuple[Any, Any]

objax.functional.rsqrt(x)[source]

Elementwise reciprocal square root: :math:`1 over sqrt{x}.

Parameters

x (Any) –

Return type

Any

objax.functional.upsample_2d(x, scale, method=<Interpolate.BILINEAR: 'bilinear'>)[source]

Function to upscale 2D images.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • scale (Union[Tuple[int, int], int]) – int or tuple scaling factor

  • method (Union[objax.constants.Interpolate, str]) – str or UpSample interpolation methods e.g. [‘bilinear’, ‘nearest’].

Returns

upscaled 2d image tensor

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.upscale_nn(x, scale=2)[source]

Nearest neighbor upscale for image batches of shape (N, C, H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).

  • scale (int) – integer scaling factor.

Returns

Output tensor of shape (N, C, H * scale, W * scale).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.divergence

kl(p, q[, eps])

Calculates the Kullback-Leibler divergence between arrays p and q.

objax.functional.divergence.kl(p, q, eps=7.62939453125e-06)[source]

Calculates the Kullback-Leibler divergence between arrays p and q.

\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]

The \(\epsilon\) term is added to ensure that neither p nor q are zero.

Parameters
  • p (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

  • q (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

  • eps (float) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.loss

cross_entropy_logits(logits, labels)

Computes the softmax cross-entropy loss on n-dimensional data.

cross_entropy_logits_sparse(logits, labels)

Computes the softmax cross-entropy loss.

l2(x)

Computes the L2 loss.

mean_absolute_error(x, y[, keep_axis])

Computes the mean absolute error between x and y.

mean_squared_error(x, y[, keep_axis])

Computes the mean squared error between x and y.

mean_squared_log_error(y_true, y_pred[, …])

Computes the mean squared logarithmic error between y_true and y_pred.

sigmoid_cross_entropy_logits(logits, labels)

Computes the sigmoid cross-entropy loss.

objax.functional.loss.cross_entropy_logits(logits, labels)[source]

Computes the softmax cross-entropy loss on n-dimensional data.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

Calculates the cross entropy loss, defined as follows:

\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]

where \(o_k\) are the logits and \(y_k\) are the labels.

objax.functional.loss.cross_entropy_logits_sparse(logits, labels)[source]

Computes the softmax cross-entropy loss.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.loss.l2(x)[source]

Computes the L2 loss.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – n-dimensional tensor of floats.

Returns

scalar tensor containing the l2 loss of x.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

Calculates the l2 loss, as:

\[l_2 = \frac{\sum_{i} x_{i}^2}{2}\]
objax.functional.loss.mean_absolute_error(x, y, keep_axis=(0))[source]

Computes the mean absolute error between x and y.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).

  • y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).

  • keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.

Returns

tensor of shape (d_i, …, for i in keep_axis) containing the mean absolute error.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.loss.mean_squared_error(x, y, keep_axis=(0))[source]

Computes the mean squared error between x and y.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).

  • y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).

  • keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.

Returns

tensor of shape (d_i, …, for i in keep_axis) containing the mean squared error.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.loss.sigmoid_cross_entropy_logits(logits, labels)[source]

Computes the sigmoid cross-entropy loss.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.functional.parallel

pmax(x[, axis_name])

Compute a multi-device reduce max on x over the device axis axis_name.

pmean(x[, axis_name])

Compute a multi-device reduce mean on x over the device axis axis_name.

pmin(x[, axis_name])

Compute a multi-device reduce min on x over the device axis axis_name.

psum(x[, axis_name])

Compute a multi-device reduce sum on x over the device axis axis_name.

objax.functional.parallel.pmax(x, axis_name='device')[source]

Compute a multi-device reduce max on x over the device axis axis_name.

Parameters
  • x (jax.interpreters.pxla.ShardedDeviceArray) –

  • axis_name (str) –

objax.functional.parallel.pmean(x, axis_name='device')[source]

Compute a multi-device reduce mean on x over the device axis axis_name.

Parameters
  • x (jax.interpreters.pxla.ShardedDeviceArray) –

  • axis_name (str) –

objax.functional.parallel.pmin(x, axis_name='device')[source]

Compute a multi-device reduce min on x over the device axis axis_name.

Parameters
  • x (jax.interpreters.pxla.ShardedDeviceArray) –

  • axis_name (str) –

objax.functional.parallel.psum(x, axis_name='device')[source]

Compute a multi-device reduce sum on x over the device axis axis_name.

Parameters
  • x (jax.interpreters.pxla.ShardedDeviceArray) –

  • axis_name (str) –

objax.io package

Checkpoint(logdir, keep_ckpts[, makedir, …])

Helper class which performs saving and restoring of the variables.

load_var_collection(file, vc[, renamer])

Loads values of all variables in the given variables collection from file.

save_var_collection(file, vc)

Saves variables collection into file.

class objax.io.Checkpoint(logdir, keep_ckpts, makedir=True, verbose=True)[source]

Helper class which performs saving and restoring of the variables.

Variables are stored in the checkpoint files. One checkpoint file stores a single snapshot of the variables. Different checkpoint files store different snapshots of the variables (for example at different training step). Each checkpoint has associated index, which is used to identify time when snapshot of the variables was made. Typically training step or training epoch are used as an index.

DIR_NAME: str = 'ckpt'

Name of the subdirectory of model directory where checkpoints will be saved.

FILE_MATCH: str = '*.npz'

File pattern which is used to search for checkpoint files.

FILE_FORMAT: str = '%010d.npz'

Format of the filename of one checkpoint file.

static LOAD_FN(file, vc, renamer=None)

Load function, which loads variables collection from given file.

Parameters
  • file (Union[str, IO[BinaryIO]]) –

  • vc (objax.variable.VarCollection) –

  • renamer (Optional[objax.util.util.Renamer]) –

static SAVE_FN(file, vc)

Save function, which saves variables collection into given file.

Parameters
  • file (Union[str, IO[BinaryIO]]) –

  • vc (objax.variable.VarCollection) –

__init__(logdir, keep_ckpts, makedir=True, verbose=True)[source]

Creates instance of the Checkpoint class.

Parameters
  • logdir (str) – model directory. Checkpoints will be saved in the subdirectory of model directory.

  • keep_ckpts (int) – maximum number of checkpoints to keep.

  • makedir (bool) – if True then directory for checkpoints will be created, otherwise it’s expected that directory already exists.

  • verbose (bool) – if True then print when data is restored from checkpoint.

static checkpoint_idx(filename)[source]

Returns index of checkpoint from given checkpoint filename.

Parameters

filename (str) – checkpoint filename.

Returns

checkpoint index.

restore(vc, idx=None)[source]

Restores values of all variables of given variables collection from the checkpoint.

Old values from the variables collection will be replaced with the new values read from checkpoint. If variable does not exist in the variables collection, it won’t be restored from checkpoint.

Parameters
  • vc (objax.variable.VarCollection) – variables collection to restore.

  • idx (Optional[int]) – if provided then checkpoint index to use, if None then latest checkpoint will be restored.

Returns

index of the restored checkpoint. ckpt: full path to the restored checkpoint.

Return type

idx

save(vc, idx)[source]

Saves variables collection to checkpoint with given index.

Parameters
  • vc (objax.variable.VarCollection) – variables collection to save.

  • idx (int) – index of the new checkpoint where variables should be saved.

objax.io.load_var_collection(file, vc, renamer=None)[source]

Loads values of all variables in the given variables collection from file.

Values loaded from file will replace old values in the variables collection. If variable exists in the file, but does not exist in the variables collection it will be ignored. If variable exists in the variables collection, but not found in the file then exception will be raised.

Parameters
  • file (Union[str, IO[BinaryIO]]) – filename or python file handle of the input file.

  • vc (objax.variable.VarCollection) – variables collection which will be loaded from file.

  • renamer (Optional[objax.util.util.Renamer]) – optional renamer to pre-process variables names from the file being read.

Raises

ValueError – if variable from variables collection is not found in the input file.

objax.io.save_var_collection(file, vc)[source]

Saves variables collection into file.

Parameters
  • file (Union[str, IO[BinaryIO]]) – filename or python file handle of the file where variables will be saved.

  • vc (objax.variable.VarCollection) – variables collection which will be saved into file.

objax.jaxboard package

Reducer(value)

Reduces tensor batch into a single tensor.

Summary

Writes entries to Summary protocol buffer.

SummaryWriter(logdir[, queue_size, …])

Writes entries to event files in the logdir to be consumed by Tensorboard.

class objax.jaxboard.Summary[source]

Writes entries to Summary protocol buffer.

image(tag, image)[source]

Adds image to the summary. Float image in [-1, 1] in CHW format expected.

Parameters
  • tag (str) –

  • image (numpy.ndarray) –

scalar(tag, value, reduce=<function Reducer.<lambda>>)[source]

Adds scalar to the summary.

Parameters
  • tag (str) –

  • value (float) –

  • reduce (Union[Callable, objax.jaxboard.Reducer]) –

text(tag, text)[source]

Adds text to the summary.

Parameters
  • tag (str) –

  • text (str) –

__call__()[source]

Call self as a function.

class objax.jaxboard.SummaryWriter(logdir, queue_size=5, write_interval=5)[source]

Writes entries to event files in the logdir to be consumed by Tensorboard.

__init__(logdir, queue_size=5, write_interval=5)[source]

Creates SummaryWriter instance.

Parameters
  • logdir (str) – directory where event file will be written.

  • queue_size (int) – size of the queue for pending events and summaries before one of the ‘add’ calls forces a flush to disk.

  • write_interval (int) – how often, in seconds, to write the pending events and summaries to disk.

write(summary, step)[source]

Adds on event to the event file.

Parameters
close()[source]

Flushes the event file to disk and close the file.

objax.nn package
objax.nn

BatchNorm(dims, redux[, momentum, eps])

Applies a batch normalization on different ranks of an input tensor.

BatchNorm0D(nin[, momentum, eps])

Applies a 0D batch normalization on a 2D-input batch of shape (N,C).

BatchNorm1D(nin[, momentum, eps])

Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).

BatchNorm2D(nin[, momentum, eps])

Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).

Conv2D(nin, nout, k[, strides, dilations, …])

Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).

ConvTranspose2D(nin, nout, k[, strides, …])

Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).

Dropout(keep[, generator])

In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.

Linear(nin, nout[, use_bias, w_init])

Applies a linear transformation on an input batch.

MovingAverage(shape, buffer_size[, init_value])

Computes moving average of an input batch.

ExponentialMovingAverage(shape[, momentum, …])

computes exponential moving average (also called EMA or EWMA) of an input batch.

Sequential([iterable])

Executes modules in the order they were passed to the constructor.

SyncedBatchNorm(dims, redux[, momentum, eps])

Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).

SyncedBatchNorm0D(nin[, momentum, eps])

Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).

SyncedBatchNorm1D(nin[, momentum, eps])

Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).

SyncedBatchNorm2D(nin[, momentum, eps])

Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).

class objax.nn.BatchNorm(dims, redux, momentum=0.999, eps=1e-06)[source]

Applies a batch normalization on different ranks of an input tensor.

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per specified dimensions and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape dims. The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(dims, redux, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm module instance.

Parameters
  • dims (Iterable[int]) – shape of the batch normalization state variables.

  • redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)[source]

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm0D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 0D batch normalization on a 2D-input batch of shape (N,C).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm0D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm1D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm1D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.BatchNorm2D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).

The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.

\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]

The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a BatchNorm2D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Conv2D(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).

In the simplest case (strides = 1, padding = VALID), the output tensor \((N,C_{out},H_{out},W_{out})\) is computed from an input tensor \((N,C_{in},H,W)\) with kernel weight \((k,k,C_{in},C_{out})\) and bias \((C_{out})\) as follows:

\[\mathrm{out}[n,c,h,w] = \mathrm{b}[c] + \sum_{t=0}^{C_{in}-1}\sum_{i=0}^{k-1}\sum_{j=0}^{k-1} \mathrm{in}[n,c,i+h,j+w] \times \mathrm{w}[i,j,t,c]\]

where \(H_{out}=H-k+1\), \(W_{out}=W-k+1\). Note that the implementation follows the definition of cross-correlation. When padding = SAME, the input tensor is zero-padded by \(\lfloor\frac{k-1}{2}\rfloor\) for left and up sides and \(\lfloor\frac{k}{2}\rfloor\) for right and down sides.

__init__(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Creates a Conv2D module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.

  • strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.

  • dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.

  • groups (int) – number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.

  • use_bias (bool) – if True then convolution will have bias term.

  • w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).

__call__(x)[source]

Returns the results of applying the convolution to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.ConvTranspose2D(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).

This module can be seen as a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution. Note that ConvTranspose2D is consistent with Conv2DTranspose, of Tensorflow but is not consistent with ConvTranspose2D of PyTorch due to kernel transpose and padding.

__init__(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]

Creates a ConvTranspose2D module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.

  • strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.

  • dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values.

  • use_bias (bool) – if True then convolution will have bias term.

  • w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).

__call__(x)[source]

Returns the results of applying the transposed convolution to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Dropout(keep, generator=objax.random.Generator(seed=0))[source]

In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.

During the evaluation, the module does not modify the input tensor. Dropout (Improving neural networks by preventing co-adaptation of feature detectors) is an effective regularization technique which reduces the overfitting and increases the overall utility.

__init__(keep, generator=objax.random.Generator(seed=0))[source]

Creates Dropout module instance.

Parameters
  • keep (float) – probability to keep element of the tensor.

  • generator – optional argument with instance of ObJAX random generator.

__call__(x, training, dropout_keep=None)[source]

Performs dropout of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True then apply dropout to the input, otherwise keep input tensor unchanged.

  • dropout_keep (Optional[float]) – optional argument, when set overrides dropout keep probability.

Returns

Tensor with applied dropout.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Linear(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]

Applies a linear transformation on an input batch.

The output tensor \((N,C_{out})\) is computed from an input tensor \((N,C_{in})\) with kernel weight \((C_{in},C_{out})\) and bias \((C_{out})\) as follows:

\[\mathrm{out}[n,c] = \mathrm{b}[c] + \sum_{t=1}^{C_{in}} \mathrm{in}[n,t] \times \mathrm{w}[t,c]\]
__init__(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]

Creates a Linear module instance.

Parameters
  • nin (int) – number of channels of the input tensor.

  • nout (int) – number of channels of the output tensor.

  • use_bias (bool) – if True then linear layer will have bias term.

  • w_init (Callable) – weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).

__call__(x)[source]

Returns the results of applying the linear transformation to input x.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.MovingAverage(shape, buffer_size, init_value=0)[source]

Computes moving average of an input batch.

__init__(shape, buffer_size, init_value=0)[source]

Creates a MovingAverage module instance.

Parameters
  • shape (Tuple[int, ..]) – shape of the input tensor.

  • buffer_size (int) – buffer size for moving average.

  • init_value (float) – initial value for moving average buffer.

__call__(x)[source]

Update the statistics using x and return the moving average.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.ExponentialMovingAverage(shape, momentum=0.999, init_value=0)[source]

computes exponential moving average (also called EMA or EWMA) of an input batch.

\[x_{\mathrm{EMA}} \leftarrow \mathrm{momentum} \times x_{\mathrm{EMA}} + (1-\mathrm{momentum}) \times x\]
__init__(shape, momentum=0.999, init_value=0)[source]

Creates a ExponentialMovingAverage module instance.

Parameters
  • shape (Tuple[int, ..]) – shape of the input tensor.

  • momentum (float) – momentum for exponential decrease of accumulated value.

  • init_value (float) – initial value for exponential moving average.

__call__(x)[source]

Update the statistics using x and return the exponential moving average.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.Sequential(iterable=(), /)[source]

Executes modules in the order they were passed to the constructor.

Usage example:

import objax

ml = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu,
                          objax.nn.Linear(3, 4)])
x = objax.random.normal((10, 2))
y = ml(x)  # Runs all the operations (Linear -> ReLU -> Linear).
print(y.shape)  # (10, 4)

# objax.nn.Sequential is really a list.
ml.insert(2, objax.nn.BatchNorm0D(3))  # Add a batch norm layer after ReLU
ml.append(objax.nn.Dropout(keep=0.5))  # Add a dropout layer at the end
y = ml(x, training=False)  # Both batch norm and dropout expect a training argument.
# Sequential automatically pass arguments to the modules using them.

# You can run a subset of operations since it is a list.
y1 = ml[:2](x)  # Run first two layers (Linear -> ReLU)
y2 = ml[2:](y1, training=False)  # Run all layers starting from third (BatchNorm0D -> Dropout)
print(ml(x, training=False) - y2)  # [[0. 0. ...]] - results are the same.

print(ml.vars())
# (Sequential)[0](Linear).b                              3 (3,)
# (Sequential)[0](Linear).w                              6 (2, 3)
# (Sequential)[2](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[2](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[2](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[2](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_mean              3 (1, 3)
# (Sequential)[3](BatchNorm0D).running_var               3 (1, 3)
# (Sequential)[3](BatchNorm0D).beta                      3 (1, 3)
# (Sequential)[3](BatchNorm0D).gamma                     3 (1, 3)
# (Sequential)[4](Linear).b                              4 (4,)
# (Sequential)[4](Linear).w                             12 (3, 4)
# (Sequential)[5](Dropout).keygen(Generator)._key        2 (2,)
# +Total(13)                                            51
__call__(*args, **kwargs)[source]

Execute the sequence of operations contained on *args and **kwargs and return result.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]

__init__(*args, **kwargs)

Initialize self. See help(type(self)) for accurate signature.

append(object, /)

Append object to the end of the list.

clear()

Remove all items from list.

copy()

Return a shallow copy of the list.

count(value, /)

Return number of occurrences of value.

extend(iterable, /)

Extend list by appending elements from the iterable.

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

insert(index, object, /)

Insert object before index.

pop(index=-1, /)

Remove and return item at index (default last).

Raises IndexError if list is empty or index is out of range.

remove(value, /)

Remove first occurrence of value.

Raises ValueError if the value is not present.

reverse()

Reverse IN PLACE.

vars(scope='')

Collect all the variables (and their names) contained in the list and its submodules.

Parameters

scope (str) – string to prefix to the variable names.

Returns

A VarCollection of all the variables.

Return type

objax.variable.VarCollection

class objax.nn.SyncedBatchNorm(dims, redux, momentum=0.999, eps=1e-06)[source]

Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).

__call__(x, training, batch_norm_update=True)[source]

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

__init__(dims, redux, momentum=0.999, eps=1e-06)

Creates a BatchNorm module instance.

Parameters
  • dims (Iterable[int]) – shape of the batch normalization state variables.

  • redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

class objax.nn.SyncedBatchNorm0D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm0D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.SyncedBatchNorm1D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm1D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.nn.SyncedBatchNorm2D(nin, momentum=0.999, eps=1e-06)[source]

Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).

Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.

__init__(nin, momentum=0.999, eps=1e-06)[source]

Creates a SyncedBatchNorm2D module instance.

Parameters
  • nin (int) – number of channels in the input example.

  • momentum (float) – value used to compute exponential moving average of batch statistics.

  • eps (float) – small value which is used for numerical stability.

__call__(x, training, batch_norm_update=True)

Performs batch normalization of input tensor.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.

  • training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).

  • batch_norm_update (bool) –

Returns

Batch normalized tensor.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.nn.init

gain_leaky_relu([relu_slope])

The recommended gain value for leaky_relu.

identity(shape[, gain])

Returns the identity matrix.

kaiming_normal_gain(shape)

Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

kaiming_normal(shape[, gain])

Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

kaiming_truncated_normal(shape[, lower, …])

Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

orthogonal(shape[, gain, axis])

Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.

truncated_normal(shape[, lower, upper, stddev])

Returns a tensor with values assigned using truncated normal initialization.

xavier_normal_gain(shape)

Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.

xavier_normal(shape[, gain])

Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.

xavier_truncated_normal(shape[, lower, …])

Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.

class objax.nn.init.gain_leaky_relu(relu_slope=0.1)[source]

The recommended gain value for leaky_relu.

Parameters

relu_slope – negative slope of leaky_relu.

Returns

The recommended gain value for leaky_relu.

The returned gain value is

\[\sqrt{\frac{2}{1 + \text{relu_slope}^2}}.\]
class objax.nn.init.identity(shape, gain=1)[source]

Returns the identity matrix. This initializer was proposed in A Simple Way to Initialize Recurrent Networks of Rectified Linear Units.

Parameters
  • shape – Shape of the tensor. It should have exactly rank 2.

  • gain – optional scaling factor.

Returns

Tensor initialized to the identity matrix.

class objax.nn.init.kaiming_normal_gain(shape)[source]

Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters

shape – shape of the output tensor.

Returns

Scalar, the standard deviation gain.

The returned gain value is

\[\sqrt{\frac{1}{\text{fan_in}}}.\]
class objax.nn.init.kaiming_normal(shape, gain=1)[source]

Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

Returns

Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).

class objax.nn.init.kaiming_truncated_normal(shape, lower=- 2, upper=2, gain=1)[source]

Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • gain – optional scaling factor.

Returns

Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper].

class objax.nn.init.orthogonal(shape, gain=1, axis=- 1)[source]

Returns a uniformly distributed orthogonal tensor from Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

  • axis – the orthogonalizarion axis

Returns

An orthogonally initialized tensor. These tensors will be row-orthonormal along the access specified by axis. If the rank of the weight is greater than 2, the shape will be flattened in all other dimensions and then will be row-orthonormal along the final dimension. Note that this only works if the axis dimension is larger, otherwise the tensor will be transposed (equivalently, it will be column orthonormal instead of row orthonormal). If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

class objax.nn.init.truncated_normal(shape, lower=- 2, upper=2, stddev=1)[source]

Returns a tensor with values assigned using truncated normal initialization.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • stddev – expected standard deviation.

Returns

Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper].

class objax.nn.init.xavier_normal_gain(shape)[source]

Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.

Parameters

shape – shape of the output tensor.

Returns

Scalar, the standard deviation gain.

The returned gain value is

\[\sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}.\]
class objax.nn.init.xavier_normal(shape, gain=1)[source]

Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.

Parameters
  • shape – shape of the output tensor.

  • gain – optional scaling factor.

Returns

Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).

class objax.nn.init.xavier_truncated_normal(shape, lower=- 2, upper=2, gain=1)[source]

Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.

Parameters
  • shape – shape of the output tensor.

  • lower – lower truncation of the normal.

  • upper – upper truncation of the normal.

  • gain – optional scaling factor.

Returns

Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper].

objax.optimizer package

Adam(vc[, beta1, beta2, eps])

Adam optimizer.

ExponentialMovingAverageModule(module[, …])

Creates a module that uses the moving average weights of another module.

ExponentialMovingAverage(vc[, momentum, …])

Maintains exponential moving averages for each variable from provided VarCollection.

LARS(vc[, momentum, weight_decay, tc, eps])

Layerwise adaptive rate scaling (LARS) optimizer.

Momentum(vc[, momentum, nesterov])

Momentum optimizer.

SGD(vc)

Stochastic Gradient Descent (SGD) optimizer.

class objax.optimizer.Adam(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Adam optimizer.

Adam is an adaptive learning rate optimization algorithm originally presented in Adam: A Method for Stochastic Optimization. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\), the update rule is as follows:

\[\begin{split}\begin{eqnarray} v_{k} &=& \beta_1 v_{k-1} + (1 - \beta_1) \nabla f (.; w_{k-1}) \nonumber \\ s_{k} &=& \beta_2 s_{k-1} - (1 - \beta_2) (\nabla f (.; w_{k-1}))^2 \nonumber \\ \hat{v_{k}} &=& \frac{v_{k}}{(1 - \beta_{1}^{k})} \nonumber \\ \hat{s_{k}} &=& \frac{s_{k}}{(1 - \beta_{2}^{k})} \nonumber \\ w_{k} &=& w_{k-1} - \eta \frac{\hat{v_{k}}}{\sqrt{\hat{s_{k}}} + \epsilon} \nonumber \end{eqnarray}\end{split}\]

Adam updates exponential moving averages of the gradient \((v_{k})\) and the squared gradient \((s_{k})\) where the hyper-parameters \(\beta_1\) and \(\beta_2 \in [0, 1)\) control the exponential decay rates of these moving averages. The \(\eta\) constant in the weight update rule is the learning rate and is passed as a parameter in the __call__ method. Note that the implementation uses the approximation \(\sqrt{(\hat{s_{k}} + \epsilon)} \approx \sqrt{\hat{s_{k}}} + \epsilon\).

__init__(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Constructor for Adam optimizer class.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • beta1 (float) – value of Adam’s beta1 hyperparameter. Defaults to 0.9.

  • beta2 (float) – value of Adam’s beta2 hyperparameter. Defaults to 0.999.

  • eps (float) – value of Adam’s epsilon hyperparameter. Defaults to 1e-8.

__call__(lr, grads, beta1=None, beta2=None)[source]

Updates variables and other state based on Adam algorithm.

Parameters
  • lr (float) – the learning rate.

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.

  • beta1 (Optional[float]) – optional, override the default beta1.

  • beta2 (Optional[float]) – optional, override the default beta2.

class objax.optimizer.ExponentialMovingAverageModule(module, momentum=0.999, debias=False, eps=1e-06)[source]

Creates a module that uses the moving average weights of another module.

Convenience interface to apply objax.optimizer.ExponentialMovingAverage to a module.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)])
m_ema = objax.optimizer.ExponentialMovingAverageModule(m, momentum=0.999, debias=True)

x = objax.random.uniform((16, 2))

# When the weights of m change, simply call update_ema() to update moving averages
v1 = m(x, training=True)
m_ema.update_ema()

# You call m_ema just like you would call m
v2 = m_ema(x, training=False)
__init__(module, momentum=0.999, debias=False, eps=1e-06)[source]

Creates ExponentialMovingAverageModule instance with given hyperparameters.

Parameters
  • module (objax.module.Module) – a module for which to compute the moving average.

  • momentum (float) – the decay factor for the moving average.

  • debias (bool) – bool indicating whether to use initialization bias correction.

  • eps (float) – small adjustment to prevent division by zero.

__call__(*args, **kwargs)[source]

Calls the original module with moving average weights.

update_ema()[source]

Updates the moving average.

class objax.optimizer.ExponentialMovingAverage(vc, momentum=0.999, debias=False, eps=1e-06)[source]

Maintains exponential moving averages for each variable from provided VarCollection.

When training a model, it is often beneficial to maintain exponential moving averages (EMA) of the trained parameters. Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values (see Acceleration of Stochastic Approximation by Averaging).

This maintains an EMA of the parameters passed in the VarCollection vc. The EMA update rule for weights \(w\), the EMA \(m\) at step \(t\) when using a momentum \(\mu\) is:

\[m_t = \mu m_{t-1} + (1 - \mu) w_t\]

The EMA weights \(\hat{w_t}\) are simply \(m_t\) when debias=False. When debias=True, the EMA weights are defined as:

\[\hat{w_t} = \frac{m_t}{1 - (1 - \epsilon)\mu^t}\]

Where \(\epsilon\) is a small constant to avoid a divide-by-0.

__init__(vc, momentum=0.999, debias=False, eps=1e-06)[source]

Creates ExponentialMovingAverage instance with given hyperparameters.

Parameters
  • momentum (float) – the decay factor for the moving average.

  • debias (bool) – bool indicating whether to use initialization bias correction.

  • eps (float) – small adjustment to prevent division by zero.

  • vc (objax.variable.VarCollection) –

__call__()[source]

Updates the moving average.

refs_and_values()[source]

Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and their corresponding EMA values.

Return type

Tuple[objax.variable.VarCollection, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]

replace_vars(f)[source]

Returns a function that acts as f called when variables are replaced by their averages.

Parameters

f (Callable) – function to be called on the stored averages.

Returns

A function that returns the output of calling f with stored variables replaced by their moving averages.

class objax.optimizer.LARS(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]

Layerwise adaptive rate scaling (LARS) optimizer.

See https://arxiv.org/abs/1708.03888

The Layer-Wise Rate Scaling (LARS) optimizer implements the scheme originally proposed in Large Batch Training of Convolutional Networks. The optimizer takes as input the base learning rate \(\gamma_0\), momentum \(m\), weight decay \(\beta\), and trust coefficient \(\eta\) and updates the model weights \(w\) as follows:

\[\begin{split}\begin{eqnarray} g_{t}^{l} &\leftarrow& \nabla L(w_{t}^{l}) \nonumber \\ \gamma_t &\leftarrow& \gamma_0 \ast (1 - \frac{t}{T})^{2} \nonumber \\ \lambda^{l} &\leftarrow& \frac{\| w_{t}^{l} \| }{ \| g_t^{l} \| + \beta \| w_{t}^{l} \|} \nonumber \\ v_{t+1}^{l} &\leftarrow& m v_{t}^{l} + \gamma_{t+1} \ast \lambda^{l} \ast (g_{t}^{l} + \beta w_{t}^{l}) \nonumber \\ w_{t+1}^{l} &\leftarrow& w_{t}^{l} - v_{t+1}^{l} \nonumber \\ \end{eqnarray}\end{split}\]

where \(T\) is the total number of steps (epochs) that the optimizer will take, \(t\) is the current step number, and \(w_{t}^{l}\) are the weights for during step \(t\) for layer \(l\).

__init__(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]

Constructor for LARS optimizer.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • momentum (float) – coefficient used for the moving average of the gradient.

  • weight_decay (float) – weight decay coefficient.

  • tc (float) – trust coefficient eta ( < 1) for trust ratio computation.

  • eps (float) – epsilon used for trust ratio computation.

__call__(lr, grads)[source]

Updates variables based on LARS algorithm.

Parameters
  • lr (float) – learning rate. The LARS paper suggests using lr = lr_0 * (1 -t/T)**2,

  • t is the current epoch number and T the maximum number of epochs. (where) –

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.

class objax.optimizer.Momentum(vc, momentum=0.9, nesterov=False)[source]

Momentum optimizer.

The momentum optimizer (expository article) introduces a tweak to the standard gradient descent. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\) the update rule is as follows:

\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f (.; w_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]

The term \(v\) is the velocity: It accumulates past gradients through a weighted moving average calculation. The parameters \(\mu, \eta\) are the momentum and the learning rate.

The momentum class also implements Nesterov’s Accelerated Gradient (NAG) (see Sutskever et. al.). Like momentum, NAG is a first-order optimization method with better convergence rate than gradient descent in certain situations. The NAG update can be written as:

\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f(.; w_{k-1} + \mu v_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]

The implementation uses the simplification presented by Bengio et. al.

__init__(vc, momentum=0.9, nesterov=False)[source]

Constructor for momentum optimizer class.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • momentum (float) – the momentum hyperparameter.

  • nesterov (bool) – bool indicating whether to use the Nesterov method.

__call__(lr, grads, momentum=None)[source]

Updates variables and other state based on momentum (or Nesterov) SGD.

Parameters
  • lr (float) – the learning rate.

  • grads (List[jax._src.numpy.lax_numpy.ndarray]) – the gradients to apply.

  • momentum (Optional[float]) – optional, override the default momentum.

class objax.optimizer.SGD(vc)[source]

Stochastic Gradient Descent (SGD) optimizer.

The stochastic gradient optimizer performs Stochastic Gradient Descent (SGD). It uses the following update rule for a loss \(f\) parameterized with model weights \(w\) and a user provided learning rate \(\eta\):

\[w_k = w_{k-1} - \eta\nabla f(.; w_{k-1})\]
__init__(vc)[source]

Constructor for SGD optimizer.

Parameters

vc (objax.variable.VarCollection) – collection of variables to optimize.

__call__(lr, grads)[source]

Updates variables based on SGD algorithm.

Parameters
  • lr (float) – the learning rate.

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.

objax.privacy.dpsgd package

PrivateGradValues(f, vc, noise_multiplier, …)

Computes differentially private gradients as required by DP-SGD.

analyze_dp(q, noise_multiplier, steps[, …])

Compute and print results of DP-SGD analysis.

analyze_renyi(q, noise_multiplier, steps, orders)

Compute RDP of the Sampled Gaussian Mechanism.

convert_renyidp_to_dp(orders, rdp[, …])

Compute delta (or eps) for given eps (or delta) from RDP values.

class objax.privacy.dpsgd.PrivateGradValues(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0), keygen=objax.random.Generator(seed=0))[source]

Computes differentially private gradients as required by DP-SGD. This module can be used in place of GradVals, and automatically makes the optimizer differentially private.

__init__(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0), keygen=objax.random.Generator(seed=0))[source]

Constructs a PrivateGradValues instance.

Parameters
  • f (Callable) – the function for which to compute gradients.

  • vc (objax.variable.VarCollection) – the variables for which to compute gradients.

  • noise_multiplier (float) – scale of standard deviation for added noise in DP-SGD.

  • l2_norm_clip (float) – value of clipping norm for DP-SGD.

  • microbatch (int) – the size of each microbatch.

  • batch_axis (Tuple[Optional[int], ..]) – the axis to use as batch during vectorization. Should be a tuple of 0s.

  • keygen (objax.random.random.Generator) – a Generator for random numbers. Defaults to objax.random.DEFAULT_GENERATOR.

reshape_microbatch(x)[source]

Reshapes examples into microbatches. DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient. To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy. If speed is not an issue, the microbatch size should be set to 1.

If x has shape [D0, D1, …, Dn], the reshaped output will have shape [number_of_microbatches, microbatch_size, D1, …, DN].

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – items to be reshaped.

Returns

The reshaped items.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

__call__(*args)[source]

Returns the computed DP-SGD gradients.

Returns

A tuple (gradients, value of f).

objax.privacy.dpsgd.analyze_dp(q, noise_multiplier, steps, orders=(1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 3.0, 3.5, 4.0, 4.5, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20), delta=1e-05)[source]

Compute and print results of DP-SGD analysis.

Parameters
  • q (float) – The sampling rate.

  • noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.

  • steps (int) – The number of steps.

  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

  • delta (float) – The target delta.

Returns

eps

Raises

ValueError – If target_delta are messed up.

Return type

float

objax.privacy.dpsgd.analyze_renyi(q, noise_multiplier, steps, orders)[source]

Compute RDP of the Sampled Gaussian Mechanism.

Parameters
  • q (float) – The sampling rate.

  • noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.

  • steps (int) – The number of steps.

  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

Returns

The RDPs at all orders, can be np.inf.

objax.privacy.dpsgd.convert_renyidp_to_dp(orders, rdp, target_eps=None, target_delta=None)[source]

Compute delta (or eps) for given eps (or delta) from RDP values.

Parameters
  • orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.

  • rdp (Tuple[float, ..]) – An array of RDP values. Must be of the same length as the orders list.

  • target_eps (Optional[float]) – If not None, the epsilon for which we compute the corresponding delta.

  • target_delta (Optional[float]) – If not None, the delta for which we compute the corresponding epsilon. Exactly one of target_eps and target_delta must be None.

Returns

eps, delta, opt_order.

Raises

ValueError – If target_eps and target_delta are messed up.

Return type

Tuple[float, float, float]

objax.random package

Generator([seed])

Random number generator module.

normal(shape, *[, mean, stddev, generator])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

randint(shape, low, high[, generator])

Returns a JaxAarray of shape shape with random integers in {low, …, high-1}.

truncated_normal(shape, *[, stddev, lower, …])

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

uniform(shape[, generator])

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

class objax.random.Generator(seed=0)[source]

Random number generator module.

The default generator can be accessed through objax.random.DEFAULT_GENERATOR. Its seed is 0 by default, and can be set through objax.random.DEFAULT_GENERATOR.seed(s) where integer s is the desired seed.

__init__(seed=0)[source]

Create a random key generator, seed is the random generator initial seed.

Parameters

seed (int) –

property key

The random generator state (a tensor of 2 int32).

seed(seed=0)[source]

Sets a new random generator seed.

Parameters

seed (int) –

__call__()[source]

Generate a new generator state.

objax.random.normal(shape, *, mean=0, stddev=1, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean mean and standard deviation stddev.

Parameters
  • shape (Tuple[int, ..]) –

  • mean (float) –

  • stddev (float) –

  • generator (objax.random.random.Generator) –

objax.random.randint(shape, low, high, generator=objax.random.Generator(seed=0))[source]

Returns a JaxAarray of shape shape with random integers in {low, …, high-1}.

Parameters
  • shape (Tuple[int, ..]) –

  • low (int) –

  • high (int) –

  • generator (objax.random.random.Generator) –

objax.random.truncated_normal(shape, *, stddev=1, lower=- 2, upper=2, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a normal distribution with mean 0 and standard deviation stddev truncated by (lower, upper).

Parameters
  • shape (Tuple[int, ..]) –

  • stddev (float) –

  • lower (float) –

  • upper (float) –

  • generator (objax.random.random.Generator) –

objax.random.uniform(shape, generator=objax.random.Generator(seed=0))[source]

Returns a JaxArray of shape shape with random numbers from a uniform distribution [0, 1].

Parameters
  • shape (Tuple[int, ..]) –

  • generator (objax.random.random.Generator) –

objax.util package
objax.util

EasyDict(*args, **kwargs)

Custom dictionary that allows to access dict values as attributes.

Objax2Tf(module)

Objax to Tensorflow converter, which converts Objax module to tf.Module.

Renamer(rules[, chain])

Helper class for renaming string contents.

args_indexes(f, args)

Returns the indexes of variable names of a function.

dummy_context_mgr()

Empty Context Manager.

ilog2(x)

Integer log2.

positional_args_names(f)

Returns the ordered names of the positional arguments of a function.

to_tuple(v, n)

Converts input to tuple.

class objax.util.EasyDict(*args, **kwargs)[source]

Custom dictionary that allows to access dict values as attributes.

__init__(*args, **kwargs)[source]

Initialize self. See help(type(self)) for accurate signature.

clear() → None. Remove all items from D.
copy() → a shallow copy of D
fromkeys(value=None, /)

Create a new dictionary with keys from iterable and values set to value.

get(key, default=None, /)

Return the value for key if key is in the dictionary, else default.

items() → a set-like object providing a view on D’s items
keys() → a set-like object providing a view on D’s keys
pop(k[, d]) → v, remove specified key and return the corresponding value.

If key is not found, d is returned if given, otherwise KeyError is raised

popitem() → (k, v), remove and return some (key, value) pair as a

2-tuple; but raise KeyError if D is empty.

setdefault(key, default=None, /)

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

update([E, ]**F) → None. Update D from dict/iterable E and F.

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

values() → an object providing a view on D’s values
class objax.util.Objax2Tf(module)[source]

Objax to Tensorflow converter, which converts Objax module to tf.Module.

__init__(module)[source]

Create a Tensorflow module from Objax module.

Parameters

module (objax.module.Module) – Objax module to be converted to Tensorflow tf.Module.

__call__(*args, **kwargs)[source]

Calls Tensorflow function which was generated from Objax module.

class objax.util.Renamer(rules, chain=None)[source]

Helper class for renaming string contents.

__init__(rules, chain=None)[source]

Create a renamer object.

Parameters
  • rules (Union[Dict[str, str], Sequence[Tuple[Pattern[str], str]], Callable[[str], str]]) – the replacement mapping.

  • chain (Optional[objax.util.util.Renamer]) – optionally, another renamer to call after this one completes.

__call__(s)[source]

Rename input string s using the rules provided to the constructor.

Parameters

s (str) –

Return type

str

objax.util.args_indexes(f, args)[source]

Returns the indexes of variable names of a function.

Parameters
  • f (Callable) –

  • args (Iterable[str]) –

Return type

Iterable[int]

objax.util.dummy_context_mgr()[source]

Empty Context Manager.

objax.util.ilog2(x)[source]

Integer log2.

Parameters

x (float) –

objax.util.positional_args_names(f)[source]

Returns the ordered names of the positional arguments of a function.

Parameters

f (Callable) –

Return type

List[str]

objax.util.to_tuple(v, n)[source]

Converts input to tuple.

Parameters
  • v (Union[Tuple[numbers.Number, ..], numbers.Number, Iterable]) –

  • n (int) –

objax.util.image

nchw(x)

Converts an array in (N,H,W,C) format to (N,C,H,W) format.

nhwc(x)

Converts an array in (N,C,H,W) format to (N,H,W,C) format.

normalize_to_uint8(x)

Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, …, 255}.

normalize_to_unit_float(x)

Map an uint8 image in {0, 1, …, 255} to float interval [1/256-1, 1-1/256].

to_png(x)

Converts numpy array in (C,H,W) format into PNG format.

objax.util.image.from_file(file)[source]

Read an image from a file, convert it RGB and return it as an array.

Parameters

file (Union[str, IO[BinaryIO]]) – filename or python file handle of the input file.

Returns

3D numpy array (C, H, W) normalized with normalize_to_unit_float.

Return type

numpy.ndarray

objax.util.image.image_grid(image)[source]

Rearrange array of images (nh, hw, c, h, w) into image grid in a single image (c, nh * h, nh * w).

Parameters

image (numpy.ndarray) –

Return type

numpy.ndarray

objax.util.image.nchw(x)[source]

Converts an array in (N,H,W,C) format to (N,C,H,W) format.

Parameters

x (Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.util.image.nhwc(x)[source]

Converts an array in (N,C,H,W) format to (N,H,W,C) format.

Parameters

x (Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.util.image.normalize_to_uint8(x)[source]

Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, …, 255}.

Parameters

x (Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.util.image.normalize_to_unit_float(x)[source]

Map an uint8 image in {0, 1, …, 255} to float interval [1/256-1, 1-1/256].

Parameters

x (Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.util.image.to_png(x)[source]

Converts numpy array in (C,H,W) format into PNG format.

Parameters

x (Union[numpy.ndarray, jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

bytes

objax.zoo package
objax.zoo.convnet
class objax.zoo.convnet.ConvNet(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]

ConvNet implementation.

__init__(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]

Creates ConvNet instance.

Parameters
  • nin – number of channels in the input image.

  • nclass – number of output classes.

  • scales – number of pooling layers, each of which reduces spatial dimension by 2.

  • filters – base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max.

  • filters_max – maximum number of filters.

  • pooling – type of pooling layer.

objax.zoo.dnnet
class objax.zoo.dnnet.DNNet(layer_sizes, activation)[source]

Deep neural network (MLP) implementation.

__init__(layer_sizes, activation)[source]

Creates DNNet instance.

Parameters
  • layer_sizes (Iterable[int]) – number of neurons for each layer.

  • activation (Callable) – layer activation.

objax.zoo.resnet_v2
class objax.zoo.resnet_v2.ResNetV2(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Base implementation of ResNet v2 from https://arxiv.org/abs/1603.05027.

__init__(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNetV2 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • blocks_per_group (Sequence[int]) – number of blocks in each block group.

  • bottleneck (bool) – if True then use bottleneck blocks.

  • channels_per_group (Sequence[int]) – number of output channels for each block group.

  • group_strides (Sequence[int]) – strides for each block group.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

  • group_use_projection (Sequence[bool]) –

class objax.zoo.resnet_v2.ResNet18(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 18 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet18 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet34(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 34 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet34 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet50(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 50 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet50 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet101(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 101 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet101 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet152(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 152 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet152 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

class objax.zoo.resnet_v2.ResNet200(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Implementation of ResNet v2 with 200 layers.

__init__(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]

Creates ResNet200 instance.

Parameters
  • in_channels (int) – number of channels in the input image.

  • num_classes (int) – number of output classes.

  • normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.

  • activation_fn (Callable[[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.

objax.zoo.wide_resnet
class objax.zoo.wide_resnet.WRNBlock(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

WideResNet block.

__init__(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WRNBlock instance.

Parameters
  • nin (int) – number of input filters.

  • nout (int) – number of output filters.

  • stride (int) – stride for convolution and projection convolution in this block.

  • bn (Callable) – module which used as batch norm function.

__call__(x, training)[source]

Optional module __call__ method, typically a forward pass computation for standard primitives.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

  • training (bool) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

class objax.zoo.wide_resnet.WideResNetGeneral(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Base WideResNet implementation.

static mean_reduce(x)[source]
Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

__init__(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WideResNetGeneral instance.

Parameters
  • nin (int) – number of channels in the input image.

  • nclass (int) – number of output classes.

  • blocks_per_group (List[int]) – number of blocks in each block group.

  • width (int) – multiplier to the number of convolution filters.

  • bn (Callable) – module which used as batch norm function.

class objax.zoo.wide_resnet.WideResNet(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

WideResNet implementation with 3 groups.

Reference:

http://arxiv.org/abs/1605.07146 https://github.com/szagoruyko/wide-residual-networks

__init__(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]

Creates WideResNet instance.

Parameters
  • nin (int) – number of channels in the input image.

  • nclass (int) – number of output classes.

  • depth (int) – number of convolution layers. (depth-4) should be divisible by 6

  • width (int) – multiplier to the number of convolution filters.

  • bn (Callable) – module which used as batch norm function.

objax.zoo.rnn
class objax.zoo.rnn.RNN(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]

Recurrent Neural Network (RNN) block.

__init__(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.<lambda>>, w_init=<function kaiming_normal>)[source]

Creates an RNN instance.

Parameters
  • nstate (int) – number of hidden units.

  • nin (int) – number of input units.

  • nout (int) – number of output units.

  • activation (Callable) – actication function for hidden layer.

  • w_init (Callable) – weight initializer for RNN model weights.

init_state(batch_size)[source]

Initialize hidden state for input batch of size batch_size.

__call__(inputs, only_return_final=False)[source]

Forward pass through RNN.

Parameters
  • inputs (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]) – JaxArray with dimensions num_steps, batch_size, vocabulary_size.

  • only_return_final – return only the last output if True, or all output otherwise.`

Returns

Output tensor with dimensions num_steps * batch_size, vocabulary_size.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]

objax.zoo.vgg
class objax.zoo.vgg.VGG19(pretrained=False)[source]

VGG19 implementation.

__init__(pretrained=False)[source]

Creates VGG19 instance.

Parameters

pretrained – if True load weights from ImageNet pretrained model.

build()[source]

Variables and Modules

Objax relies on only two concepts: Variable and Module. While not a concept in itself, another commonly used object is the VarCollection.

  • A variable is an object that has a value, typically a JaxArray.

  • A module is an object that has variables and methods attached to it.

  • A VarCollection is a dictionary {name: variable} with some additional methods to facilitate variable passing and manipulation.

The following topics are covered in this guide:

Variable

Variables store a value. Contrary to most frameworks variables do not require users to give them names. Instead the names are inferred from Python. See Variable names section for more details.

Objax has four types of variables to handle the various situations encountered in machine learning:

  1. TrainVar is a trainable variable. Its value cannot be directly modified, so as to maintain its differentiability.

  2. StateVar is a state variable. It is not trainable and its value can be directly modified.

  3. TrainRef is a reference to a TrainVar. It is a state variable used to change the value of a TrainVar, typically in the context of optimizers such as SGD.

  4. RandomState is a random state variable. JAX made the design choice to require explicit random state manipulation, and this variable does that for you. This type of variable is used in objax.random.Generator.

Note

Some of variable types above take a reduce argument, this is discussed in the Parallelism topic.

TrainVar

An objax.TrainVar is a trainable variable. TrainVar variables are meant to keep the trainable weights of neural networks. As such, when calling a gradient module such as objax.GradValues, their gradients are computed. This contrasts with the other type of variable (state variables), which do not have gradients. A TrainVar is created by passing a JaxArray containing its initial value:

import objax
import jax.numpy as jn

v = objax.TrainVar(jn.arange(6, dtype=jn.float32))
print(v.value)  # [0. 1. 2. 3. 4. 5.]

# It is not directly writable to avoid accidentally breaking differentiability.
v.value += 1  # Raises a ValueError

# You can force assign it if you know what you are doing.
v.assign(v.value + 1)  # Still differentiable, v depends on its previous values.
print(v.value)  # [1. 2. 3. 4. 5. 6.]
v.assign(jn.arange(1, 7, dtype=jn.float32))  # Old values lost, v is not differentiable anymore.
print(v.value)  # [1. 2. 3. 4. 5. 6.]
StateVar

An objax.StateVar is a state variable. Unlike TrainVar variables, state variables are non-trainable. StateVar variables are used for parameters that are manually/programmatically updated. For example, when computing a running mean of the input to a module, a StateVar is used. StateVars are created just like TrainVars, by passing a JaxArray containing their initial value:

import objax
import jax.numpy as jn

v = objax.StateVar(jn.arange(6, dtype=jn.float32))
print(v.value)  # [0. 1. 2. 3. 4. 5.]

# It is directly writable.
v.value += 1
print(v.value)  # [1. 2. 3. 4. 5. 6.]

# You can also assign to it, it's the same as doing v.value = ...
v.assign(v.value + 1)
print(v.value)  # [2. 3. 4. 5. 6. 7.]

StateVar variables are ignored by gradient methods. Unlike TrainVar variables, their gradients are not computed.

Why not use Python variables instead of StateVars?

You may be tempted to simply use Python values or numpy arrays directly since StateVars are programmatically updated.

StateVars are necessary. They are needed to run on GPU since standard Python values and numpy arrays would not run on the GPU. Another reason is objax.Jit or objax.Parallel only recognize Objax variables.

TrainRef

An objax.TrainRef is a state variable which is used to keep a reference to a TrainVar. TrainRef variables are used in optimizers since optimizers need to keep a reference to the TrainVar they are meant to optimize. TrainRef creation differs from the previously seen variables as it takes a TrainVar as its input:

import objax
import jax.numpy as jn

t = objax.TrainVar(jn.arange(6, dtype=jn.float32))
v = objax.TrainRef(t)
print(t.value)  # [0. 1. 2. 3. 4. 5.]

# It is directly writable.
v.value += 1
print(v.value)  # [1. 2. 3. 4. 5. 6.]

# It writes the TrainVar it references.
print(t.value)  # [1. 2. 3. 4. 5. 6.]

# You can also assign to it, it's the same as doing v.value = ...
v.assign(v.value + 1)
print(v.value)  # [2. 3. 4. 5. 6. 7.]
print(t.value)  # [2. 3. 4. 5. 6. 7.]

TrainRef variables are ignored by gradient methods. Unlike TrainVar variables, their gradients are not computed.

Philosophically, one may ask why a TrainRef is needed to keep a reference to a TrainVar in an optimizer. Indeed, why not simply keep the TrainVar itself in the optimizer? The answer is that the optimizer is a module like any other (make sure to read the Module section first). As such, one could compute the gradient of the optimizer itself. It is for this situation that we need a TrainRef to distinguish between the optimizer’s own trainable variables (needed for its functionality) and the trainable variables of the neural network it is meant to optimize. It should be noted that most current optimizers do not have their own trainable variables, but we wanted to provide the flexibility needed for future research.

RandomState

A objax.RandomState is a state variable which is used to handle the tracking of random number generator states. It is only used in objax.random.Generator. It is responsible for automatically creating different states when the code is run in parallel in multiple GPUs (see objax.Parallel) or in a vectorized way (see objax.Vectorize). This is necessary in order for random numbers to be truly random. In the rare event that you want to use the same random seed in a multi-GPU or vectorized module, you can use a StateVar to store the seed.

Here’s a simple example using the objax.random.Generator API:

import objax

# Use default objax.random.DEFAULT_GENERATOR that transparently handles RandomState
print(objax.random.normal((2,)))  # [ 0.19307713 -0.52678305]
# A subsequent call gives, as expected new random numbers.
print(objax.random.normal((2,)))  # [ 0.00870693 -0.04888531]

# Make two random generators with same seeds
g1 = objax.random.Generator(seed=1337)
g2 = objax.random.Generator(seed=1337)

# Random numbers using g1
print(objax.random.normal((2,), generator=g1))  # [-0.3361883 -0.9903351]
print(objax.random.normal((2,), generator=g1))  # [ 0.5825488 -1.4342074]

# Random numbers using g1
print(objax.random.normal((2,), generator=g2))  # [-0.3361883 -0.9903351]
print(objax.random.normal((2,), generator=g2))  # [ 0.5825488 -1.4342074]
# The result are reproducible: we obtained the same random numbers with 2 generators
# using the same random seed.

You can also manually manipulate RandomState directly for the purpose of designing custom random numbers rules, for example with forced correlation. A RandomState has an extra method called objax.RandomState.split() which lets it create n new random states. Here’s a basic example of RandomState manipulation:

import objax

v = objax.RandomState(1)  # 1 is the seed
print(v.value)     # [0 1]

# We call v.split(1) to generate 1 new state, note that split also updates v.value
print(v.split(1))  # [[3819641963 2025898573]]
print(v.value)     # [2441914641 1384938218]

# We call v.split(2) to generate 2 new states, again v.value is updated
print(v.split(2))  # [[ 622232657  209145368] [2741198523 2127103341]]
print(v.value)     # [3514448473 2078537737]
Module

An objax.Module is a simple container in which to store variables or other modules and on which to attach methods that use these variables. ObJax uses the term module instead of class to avoid confusion with the Python term class. The Module class only offers one method objax.Module.vars() which returns all variables contained by the module and its submodules in a VarCollection.

Warning

To avoid surprising unintended behaviors, vars() won’t look for variables or modules in lists, dicts or any structure that is not a Module. See [Scope of the Module.vars method] for how to handle lists in Objax.

Let’s start with a simple example: a module called Linear, which does a simple matrix product and adds a bias y = x.w + b, where \(w\in\mathbb{R}^{m\times n}\) and \(b\in\mathbb{R}^n\):

import objax
import jax.numpy as jn

class Linear(objax.Module):
    def __init__(self, m, n):
        self.w = objax.TrainVar(objax.random.normal((m, n)))
        self.b = objax.TrainVar(jn.zeros(n))

    def __call__(self, x):
        return x.dot(self.w.value) + self.b.value

This simple module can be used on a batch \(x\in\mathbb{R}^{d\times m}\) to compute the resulting value \(y\in\mathbb{R}^{d\times n}\) for batch size \(d\). Let’s continue our example by creating an actual of our module and running a random batch x through it:

f = Linear(4, 5)
x = objax.random.normal((100, 4))  # A (100 x 4) matrix of random numbers
y = f(x)  # y.shape == (100, 5)

We can easily make a more complicated module that uses the previously defined module Linear:

class MiniNet(objax.Module):
    def __init__(self, m, n, p):
        self.f1 = objax.nn.Linear(m, n)
        self.f2 = objax.nn.Linear(n, p)

    def __call__(self, x):
        y = self.f1(x)
        y = objax.functional.relu(y)  # Apply a non-linearity.
        return self.f2(y)

    # You can create as many functions as you want.
    def another_function(self, x1, x2):
        return self.f2(self.f1(x1) + x2)

f = MiniNet(4, 5, 6)
y = f(x)  # y.shape == (100, 6)
x2 = objax.random.normal((100, 5))  # A (100 x 5) matrix of random numbers
another_y = f.another_function(x1, x2)

# You can also call internal parts for example to see intermediate values.
y1 = f.f1(x)
y2 = objax.functional.relu(y1)
y3 = f.f2(y2)  # y3 == y
Variable names

Continuing on the previous example, let’s find what the name of the variables are. We mentioned earlier that variable names are inferred from Python and not specified by the programmer. The way their names are inferred is from the field names, such as self.w. This has the benefit of ensuring consistency: a variable has a single name, and it’s the name it is given in the Python code.

Let’s inspect the names:

f = Linear(4, 5)
print(f.vars())  # print name, size, dimensions
# (Linear).w                 20 (4, 5)
# (Linear).b                  5 (5,)
# +Total(2)                  25

f = MiniNet(4, 5, 6)
print(f.vars())
# (MiniNet).f1(Linear).w       20 (4, 5)
# (MiniNet).f1(Linear).b        5 (5,)
# (MiniNet).f2(Linear).w       30 (5, 6)
# (MiniNet).f2(Linear).b        6 (6,)
# +Total(4)                    61

As you can see, the names correspond to the names of the fields in which the variables are kept.

Scope of the Module.vars method

The objax.Module.vars() is meant to be simple and to remain simple. With that in mind, we limited its scope: vars() won’t look for variables or modules in lists, dicts or any structure that is not a Module. This is to avoid surprising unintended behavior.

Instead we made the decision to create an explicit class objax.ModuleList to store a list of variables and modules.

ModuleList

The class objax.ModuleList inherits from list and behaves exactly like a list with the difference that vars() looks for variables and modules in it. This class is very simple, and we invite you to look at it and use it for inspiration if you want to extend other Python containers or design your own.

Here’s a simple example of its usage:

import objax
import jax.numpy as jn

class MyModule(objax.Module):
    def __init__(self):
        self.bad = [objax.TrainVar(jn.zeros(1)),
                    objax.TrainVar(jn.zeros(2))]
        self.good = objax.ModuleList([objax.TrainVar(jn.zeros(3)),
                                      objax.TrainVar(jn.zeros(4))])

print(MyModule().vars())
# (MyModule).good(ModuleList)[0]        3 (3,)
# (MyModule).good(ModuleList)[1]        4 (4,)
# +Total(2)                             7
VarCollection

The Module.vars method returns an objax.VarCollection. This class is a dictionary that maps names to variables. It has some additional methods and some modified behaviors specifically for variable manipulation. In most cases, you won’t need to use the more advanced methods such as __iter__, tensors and assign, but this is an in-depth topic.

Let’s take a look at some of them through an example:

import objax
import jax.numpy as jn

class Linear(objax.Module):
    def __init__(self, m, n):
        self.w = objax.TrainVar(objax.random.normal((m, n)))
        self.b = objax.TrainVar(jn.zeros(n))

m1 = Linear(3, 4)
m2 = Linear(4, 5)

# First, as seen before, we can print the contents with print() method
print(m1.vars())
# (Linear).w                 12 (3, 4)
# (Linear).b                  4 (4,)
# +Total(2)                  16

# A VarCollection is really a dictionary
print(repr(m1.vars()))
# {'(Linear).w': <objax.variable.TrainVar object at 0x7fb5e47c0ad0>,
#  '(Linear).b': <objax.variable.TrainVar object at 0x7fb5ec017890>}

Combining multiple VarCollections is done by using addition:

all_vars = m1.vars('m1') + m2.vars('m2')
print(all_vars)
# m1(Linear).w               12 (3, 4)
# m1(Linear).b                4 (4,)
# m2(Linear).w               20 (4, 5)
# m2(Linear).b                5 (5,)
# +Total(4)                  41

# We had to specify starting names for each of the var collections since
# they have variables with the same name. Had we not, a name collision would
# have occurred since VarCollection is a dictionary that maps names to variables.
m1.vars() + m2.vars()  # raises ValueError('Name conflicts...')
Weight sharing

It’s a common technique in machine learning to share some weights. However, it is important not to apply gradients twice or more to shared weights. This is handled automatically by VarCollection and its __iter__ method described in the next section. Here’s a simple weight sharing example where we simply refer to the same module twice under different names:

# Weight sharing
shared_vars = m1.vars('m1') + m1.vars('m1_shared')
print(shared_vars)
# m1(Linear).w               12 (3, 4)
# m1(Linear).b                4 (4,)
# m1_shared(Linear).w        12 (3, 4)
# m1_shared(Linear).b         4 (4,)
# +Total(4)                  32
VarCollection.__iter__

Deduplication is handled automatically by the VarCollection default iterator objax.VarCollection.__iter__(). Following up on the weight sharing example above, the iterator only returns each variable once:

list(shared_vars)  # [<objax.variable.TrainVar>, <objax.variable.TrainVar>]
VarCollection.tensors

You can collect all the values (JaxArray) for all the variables with objax.VarCollection.tensors(), again in a deduplicated manner:

shared_vars.tensors()  # DeviceArray([[-0.1441347...]), DeviceArray([0...], dtype=float32)]
VarCollection.assign

The last important method objax.VarCollection.assign() lets you assign a tensor list to all the VarCollection’s (deduplicated) variables at once:

shared_vars.tensors()  # DeviceArray([[-0.1441347...]), DeviceArray([0...], dtype=float32)]
# The following increments all the variables.
shared_vars.assign([x + 1 for x in shared_vars.tensors()])
shared_vars.tensors()  # DeviceArray([[0.8558653...]), DeviceArray([1...], dtype=float32)]

Understanding Gradients

Pre-requisites: Variables and Modules.

In this guide we discuss how to compute and use gradients in various situations and will cover the following topics:

Examples will illustrate the following cases:

  • Describe the objax.GradValues Module.

  • Show how to write your own optimizer from scratch.

  • How to write a basic training iteration.

  • How to handle complex gradients such as in Generative Adversarial Networks (GANs) or meta-learning.

  • Explain potential optimizations in the presence of constants.

Computing gradients

JAX, and therefore Objax, differ from most frameworks in how gradients are represented. Gradients in JAX are represented as functions since everything in JAX is a function. In Objax, however, they are represented as module objects.

Gradient as a module

In machine learning, for a function \(f(X; \theta)\), it is common practice to separate the inputs \(X\) from the parameters \(\theta\). Mathematically, this is captured by using a semi-colon to semantically separate one group of arguments from another.

In Objax, we represent this semantic distinction through an object objax.Module:

  • the module parameters \(\theta\) are object attributes of the form self.w, ...

  • the inputs \(X\) are arguments to the methods such as def __call__(self, x1, x2, ...):

The gradient of a function \(f(X; \theta)\) w.r.t to \(Y\subseteq X, \phi\subseteq\theta\) is a function

\[g_{\scriptscriptstyle Y, \phi}(X; \theta) = (\nabla_Y f(X; \theta), \nabla_\phi f(X; \theta))\]

The gradient function is also a module since the same semantic distinction can be made as in f between inputs \(X\) and parameters \(\theta\). Meanwhile \(Y, \phi\) are constants of g (which inputs and which variables to compute the gradient of). In practice, \(Y, \phi\) are also implemented as object attributes.

The direct benefit of such a decision is that gradient manipulation is very easy and explicit: in fact it follows the standard mathematical formulation of gradients. While this demonstration may seem abstract, we are going to see in examples how simple it turns out to be.

A simple example

Let’s look at what gradient as a module looks like through a simple example:

import objax

m = objax.nn.Linear(2, 3)

@objax.Function.with_vars(m.vars())
def loss(x, y):
    return ((m(x) - y) ** 2).mean()

# Create Module that returns a tuple (g, v):
#    g is the gradient of the loss
#    v is the value of the loss
gradient_loss = objax.GradValues(loss, m.vars())

# Make up some fake data
x = objax.random.normal((100, 2))
y = objax.random.normal((100, 3))

# Calling the module gradient_loss returns the actual g, v for (x, y)
g, v = gradient_loss(x, y)
print(v, '==', loss(x, y))  #  [DeviceArray(2.7729921, dtype=float32)] == 2.7729921
print(g)  # A list of tensors (gradients of variables in module m)

As stated, gradient_loss is a module instance and has variables. Its variables are simply the ones passed to objax.GradValues, we can verify it:

print(gradient_loss.vars())
# (Linear).b                  3 (3,)
# (Linear).w                  6 (2, 3)
# +Total(2)                   9

# These variables are from
print(m.vars())
# (Linear).b                  3 (3,)
# (Linear).w                  6 (2, 3)
# +Total(2)                   9

Let’s be clear: These are the exact same variables, not copies. This is an instance of weight sharing, m and gradient_loss share the same weights.

Loss optimization

Gradients are useful to minimize or maximize losses. This can be done using Stochastic Gradient Descent (SGD) with the following steps, for a network with weights \(\theta\) and a learning rate \(\mu\):

  1. At iteration \(t\), take a batch of data \(x_t\)

  2. Compute the gradient \(g_t=\nabla loss(x_t)\)

  3. Update the weights \(\theta_t = \theta_{t-1} - \mu\dot g_t\)

  4. Goto 1

Objax already has a library of optimizers: the objax.optimizer package. However we are going to create our own to demonstrate how it works with gradients. First let’s recall that everything is a Module (or a function) in Objax. In this case, SGD will be a module, we will want to store the list of variables on which to do gradient descent. And the function of the module will take the gradients as inputs and apply them to the variables.

Read first the part about Variables and Modules if you haven’t done so yet. Let’s get started:

import objax

class SGD(objax.Module):
    def __init__(self, variables: objax.VarCollection):
        self.refs = objax.ModuleList(objax.TrainRef(x)
                                     for x in variables.subset(objax.TrainVar))

    def __call__(self, lr: float, gradients: list):
        for v, g in zip(self.refs, gradients):
            v.value -= lr * g

In short, self.refs keeps a list of references to the network trainable variables TrainVar. When calling the __call__ method, the values of the variables get updated by the SGD method.

From this we can demonstrate the training of a classifier:

import objax

# SGD definition code from above.

my_classifier = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu,
                                     objax.nn.Linear(3, 4)])
opt = SGD(my_classifier.vars())

@objax.Function.with_vars(my_classifier.vars())
def loss(x, labels):
    logits = my_classifier(x)
    return objax.functional.loss.cross_entropy_logits(logits, labels).mean()

gradient_loss = objax.GradValues(loss, my_classifier.vars())

@objax.Function.with_vars(my_classifier.vars() + opt.vars())
def train(x, labels, lr):
    g, v = gradient_loss(x, labels)  # Compute gradients and loss
    opt(lr, g)                       # Apply SGD
    return v                         # Return loss value

# Observe that the gradient contains the variables of the model (weight sharing)
print(gradient_loss.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# (Sequential)[2](Linear).b        4 (4,)
# (Sequential)[2](Linear).w       12 (3, 4)
# +Total(4)                                   25

# At this point you can simply call train on your training data and pass the learning rate.
# The call will do a single step minimization the loss following the SGD method on your data.
# Repeated calls (through various batches of data) will minimize the loss on your data.
x = objax.random.normal((100, 2))
labels = objax.random.randint((100,), low=0, high=4)
labels = objax.functional.one_hot(labels, 4)
print(train(x, labels, lr=0.01))
# and so on...

# See examples section for real examples.
Returning multiple values for the loss

Let’s say we want to add weight decay and returning the individual components of the loss (cross-entropy, weight decay). The loss function can return any number of values or even structures such as dicts or list. Only the first returned value is used to compute the gradient, the others are returned as the loss value.

Continuing on our example, lets create a new loss that returns its multiple components:

@objax.Function.with_vars(my_classifier.vars())
def losses(x, labels):
    logits = my_classifier(x)
    loss_xe = objax.functional.loss.cross_entropy_logits(logits, labels).mean()
    loss_wd = sum((v.value ** 2).sum() for k, v in my_classifier.vars().items() if k.endswith('.w'))
    return loss_xe + 0.0002 * loss_wd, loss_xe, loss_wd

gradient_losses = objax.GradValues(losses, my_classifier.vars())
print(gradient_losses(x, labels)[1])
# (DeviceArray(1.7454103, dtype=float32),
#  DeviceArray(1.7434813, dtype=float32),
#  DeviceArray(9.645493, dtype=float32))

Or one might prefer to return a dict to keep things organized:

@objax.Function.with_vars(my_classifier.vars())
def loss_dict(x, labels):
    logits = my_classifier(x)
    loss_xe = objax.functional.loss.cross_entropy_logits(logits, labels).mean()
    loss_wd = sum((v.value ** 2).sum() for k, v in my_classifier.vars().items() if k.endswith('.w'))
    return loss_xe + 0.0002 * loss_wd, {'loss/xe': loss_xe, 'loss/wd': loss_wd}

gradient_loss_dict = objax.GradValues(loss_dict, my_classifier.vars())
print(gradient_loss_dict(x, labels)[1])
# (DeviceArray(1.7454103, dtype=float32),
#  {'loss/wd': DeviceArray(9.645493, dtype=float32),
#   'loss/xe': DeviceArray(1.7434813, dtype=float32)})
Input gradients

When computing gradients it’s sometimes useful to compute the gradients for some or all the inputs of the network. For example, such functionality is needed for adversarial training or gradient penalties in GANs. This can be easily achieved using the input_argnums argument of objax.GradValues, here’s an example:

# Compute the gradient for my_classifier variables and for the first input of the loss:
gradient_loss_v_x = objax.GradValues(loss, my_classifier.vars(), input_argnums=(0,))
print(gradient_loss_v_x(x, labels)[0])
# g = [gradient(x)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]

# Compute the gradient for my_classifier variables and for the second input of the loss:
gradient_loss_v_y = objax.GradValues(loss, my_classifier.vars(), input_argnums=(1,))
print(gradient_loss_v_y(x, labels)[0])
# g = [gradient(labels)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]

# Compute the gradient for my_classifier variables and for all the inputs of the loss:
gradient_loss_v_xy = objax.GradValues(loss, my_classifier.vars(), input_argnums=(0, 1))
print(gradient_loss_v_xy(x, labels)[0])
# g = [gradient(x), gradient(labels)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]

# You can also compute the gradients from the inputs alone
gradient_loss_xy = objax.GradValues(loss, objax.VarCollection(), input_argnums=(0, 1))
print(gradient_loss_xy(x, labels)[0])
# g = [gradient(x), gradient(labels)]

# The order of the inputs matters, using input_argnums=(1, 0) instead of (0, 1)
gradient_loss_yx = objax.GradValues(loss, objax.VarCollection(), input_argnums=(1, 0))
print(gradient_loss_yx(x, labels)[0])
# g = [gradient(labels), gradient(x)]
Gradients of a subset of variables

When doing more complex optimizations, one might want to temporarily treat a part of a network as constant. This is achieved by simply passing only the variables you want the gradient of to objax.GradValues. This is useful for example in GANs where one has to optimize the discriminator and the generator networks separately.

Continuing our example:

all_vars = my_classifier.vars()
print(all_vars)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# (Sequential)[2](Linear).b        4 (4,)
# (Sequential)[2](Linear).w       12 (3, 4)
# +Total(4)                       25

Let’s say we want to freeze the second Linear layer by treating it as constant:

# We create two VarCollection
vars_train = objax.VarCollection((k, v) for k, v in all_vars.items() if '[2](Linear)' not in k)
print(vars_train)
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# +Total(2)                        9

# We define a gradient function that ignores variables not in vars_train
gradient_loss_freeze = objax.GradValues(loss, vars_train)
print(gradient_loss_freeze(x, labels)[0])
# As expected, we now have two gradient arrays, corresponding to vars_train.
# [DeviceArray([0.19490579, 0.12267624, 0.05770121], dtype=float32),
#  DeviceArray([[-0.21900907, -0.10813318, -0.05385721],
#               [ 0.12701261, -0.03145855, -0.04397186]], dtype=float32)]
Higher-order gradients

Finally one might want to optimize a loss that has a gradient in a gradient, for example let’s consider the following nested loss that relies on another loss \(\mathcal{L}=\texttt{loss}\):

\[\texttt{nested_loss}(x_1, y_1, x_2, y_2, \mu) = \mathcal{L}(x_1, y_1; \theta - \mu\nabla\mathcal{L}(x_2, y_2; \theta))\]

Implementing this in Objax remains simple, one just applies the formula verbatim. In the following example, for the loss \(\mathcal{L}\) we picked a cross-entropy loss but we could have picked any other loss since nested_loss is independent of the choice of loss:

train_vars = my_classifier.vars().subset(objax.TrainVar)

@objax.Function.with_vars(my_classifier.vars())
def loss(x, labels):
    logits = my_classifier(x)
    return objax.functional.loss.cross_entropy_logits(logits, labels).mean()

gradient_loss = objax.GradValues(loss, train_vars)

@objax.Function.with_vars(my_classifier.vars())
def nested_loss(x1, y1, x2, y2, mu):
    # Save original network variable values
    original_values = train_vars.tensors()
    # Apply gradient from loss(x2, y2)
    for v, g in zip(train_vars, gradient_loss(x2, y2)[0]):
         v.assign(v.value - mu * g)
    # Compute loss(x1, y1)
    loss_x1y1 = loss(x1, y1)
    # Undo the gradient from loss(x2, y2)
    for v, val in zip(train_vars, original_values):
         v.assign(val)
    # Return the loss
    return loss_x1y1

gradient_nested_loss = objax.GradValues(nested_loss, train_vars)

# Run with mock up data, note it's only example because the loss is not for batch data.
x1 = objax.random.normal((1, 2))
y1 = objax.functional.one_hot(objax.random.randint((1,), low=0, high=4), 4)
x2 = objax.random.normal((1, 2))
y2 = objax.functional.one_hot(objax.random.randint((1,), low=0, high=4), 4)
print(gradient_nested_loss(x1, y1, x2, y2, 0.1))
# (gradients, loss), where the gradients are 4 tensors of the same shape as the layer variables.
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# (Sequential)[2](Linear).b        4 (4,)
# (Sequential)[2](Linear).w       12 (3, 4)

Generally speaking, it is discouraged to use objax.TrainVar.assign() unless you know what you are doing. This is precisely a situation of one knowing what they are doing and it’s perfectly fine to use assign here. The reason assign is generally discouraged is to avoid accidental bugs by overwriting a trainable variable.

On a final note, by observing that the weight update is invertible in the code above, the nested loss can be simplified to:

@objax.Function.with_vars(my_classifier.vars())
def nested_loss(x1, y1, x2, y2, mu):
    # Compute the gradient for loss(x2, y2)
    g_x2y2 = gradient_loss(x2, y2)[0]
    # Apply gradient from loss(x2, y2)
    for v, g in zip(train_vars, g_x2y2):
         v.assign(v.value - mu * g)
    # Compute loss(x1, y1)
    loss_x1y1 = loss(x1, y1)
    # Undo the gradient from loss(x2, y2)
    for v, g in zip(train_vars, g_x2y2):
         v.assign(v.value + mu * g)
    # Return the loss
    return loss_x1y1
Local gradients

In even more advanced situations, such as meta-learning, it can be desirable to have even more control over gradients. In the above example, nested_loss can accept vectors or matrices for its inputs x1, y1, x2, y2. In case of matrices, the nested_loss is computed as:

\[\texttt{nested_loss}(X_1, Y_1, X_2, Y_2, \mu) = \mathbb{E}_{i}\mathcal{L}(X_1^{(i)}, Y_1^{(i)}; \theta - \mu\mathbb{E}_{j}\nabla\mathcal{L}(X_2^{(j)}, Y_2^{(j)}; \theta))\]

As a more advanced example, let’s reproduce the loss from Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks in a batch form. It is expressed as follows:

\[\begin{split}\texttt{nested_pairwise_loss}(X_1, Y_1, X_2, Y_2, \mu) &= \mathbb{E}_{i}\mathcal{L}(X_1^{(i)}, Y_1^{(i)}; \theta - \mu\nabla\mathcal{L}(X_2^{(i)}, Y_2^{(i)}; \theta)) \\ &= \mathbb{E}_{i}\texttt{nested_loss}(X_1^{(i)}, Y_1^{(i)}, X_2^{(i)}, Y_2^{(i)}, \mu)\end{split}\]

Using the previously defined nested_loss, we can apply vectorization (see Vectorization for details) on it. In doing so we will create a module vec_nested_loss that computes nested_loss for all the entries in the batches in X1, Y1, X2, Y2:

# Make vec_nested_loss a Module that calls nested_loss on one batch entry at a time
vec_nested_loss = objax.Vectorize(nested_loss, batch_axis=(0, 0, 0, 0, None))

# The final loss just calls vec_nested_loss and returns the mean of the losses
@objax.Function.with_vars(my_classifier.vars())
def nested_pairwise_loss(X1, Y1, X2, Y2, mu):
    return vec_nested_loss(X1, Y1, X2, Y2, mu).mean()

# Just like any simpler loss, we can compute its gradient.
gradient_nested_pairwise_loss = objax.GradValues(nested_pairwise_loss, vec_nested_loss.vars())

# Run with mock up data, note it's only example because the loss is not for batch data.
X1 = objax.random.normal((100, 2))
Y1 = objax.functional.one_hot(objax.random.randint((100,), low=0, high=4), 4)
X2 = objax.random.normal((100, 2))
Y2 = objax.functional.one_hot(objax.random.randint((100,), low=0, high=4), 4)
print(gradient_nested_pairwise_loss(X1, Y1, X2, Y2, 0.1))

Have fun!

Compilation and Parallelism

In this section we discuss the concepts of code compilation and parallelism typically for the purpose of accelerated performance. We’ll cover the following subtopics:

  • Just-In-Time (JIT) Compilation is a compilation of the code on the first time it’s executed with the goal of speeding up subsequent runs.

  • Parallelism runs operations on multiple devices (for example multiple GPUs).

  • Vectorization can be seen as batch-level parallelization, running an operation on a batch in parallel.

JIT Compilation

objax.Jit is a Module that takes a module or a function and compiles it for faster performance.

As a simple starting example, let’s jit a module:

import objax

net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
jit_net = objax.Jit(net)

x = objax.random.normal((100, 2))
print(((net(x) - jit_net(x)) ** 2).sum())  # 0.0

You can also jit a function, in this case you must decorate the function with the variables it uses:

@objax.Function.with_vars(net.vars())
def my_function(x):
    return objax.functional.softmax(net(x))

jit_func = objax.Jit(my_function)
print(((objax.functional.softmax(net(x)) - jit_func(x)) ** 2).sum())

In terms of performance, on this small example there’s a significant gain in speed, numbers vary depending on hardware present in your computer and what code is being jitted:

from time import time

t0 = time(); y = net(x); print(time() - t0)       # 0.005...
t0 = time(); y = jit_net(x); print(time() - t0)   # 0.001...

As mentioned earlier, jit_net is a module instance, it’s sharing the variables with the module net, we can verify it:

print(net.vars())
# (Sequential)[0](Linear).b        3 (3,)
# (Sequential)[0](Linear).w        6 (2, 3)
# (Sequential)[2](Linear).b        4 (4,)
# (Sequential)[2](Linear).w       12 (3, 4)
# +Total(4)                       25

print(jit_net.vars())
# (Jit)(Sequential)[0](Linear).b        3 (3,)
# (Jit)(Sequential)[0](Linear).w        6 (2, 3)
# (Jit)(Sequential)[2](Linear).b        4 (4,)
# (Jit)(Sequential)[2](Linear).w       12 (3, 4)
# +Total(4)                            25

# We can verify that jit_func also shares the same variables
print(jit_func.vars())
# (Jit){my_function}(Sequential)[0](Linear).b        3 (3,)
# (Jit){my_function}(Sequential)[0](Linear).w        6 (2, 3)
# (Jit){my_function}(Sequential)[2](Linear).b        4 (4,)
# (Jit){my_function}(Sequential)[2](Linear).w       12 (3, 4)
# +Total(4)                                         25

Note that we only verified that the variables names and sizes were the same (or almost the same since the variables in Jit are prefixed with (Jit)). Let’s now verify that the weights are indeed shared by modifying the weights:

net[-1].b.assign(net[-1].b.value + 1)
print(((net(x) - jit_net(x)) ** 2).sum())  # 0.0
# Both net(x) and jit_net(x) were affected in the same way by the change
# since the weights are shared.
# You can also inspect the values print(net(x)) for more insight.
A realistic case: Fully jitted training step

Let’s write a classifier training op, this is very similar to example shown in Loss optimization. We are going to define a model, an optimizer, a loss and a gradient:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())

@objax.Function.with_vars(m.vars())
def loss(x, labels):
    logits = m(x)
    return objax.functional.loss.cross_entropy_logits(logits, labels).mean()

gradient_loss = objax.GradValues(loss, m.vars())

@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
    g, v = gradient_loss(x, labels)  # Compute gradients and loss
    opt(lr, g)                       # Apply SGD
    return v                         # Return loss value

# It's better to jit the top level call to allow internal optimizations.
train_jit = objax.Jit(train)

Note that we passed to Jit all the vars that were used in train. We passed gradient_loss.vars() + opt.vars(). Why didn’t we pass m.vars() + gradient_loss.vars() + opt.vars()? We could and it’s perfectly fine to do so, but keep in mind that gradient_loss is itself a module which shares the weights of m and consequently m.vars() is already included in gradient_loss.vars().

Static arguments

Static arguments are arguments that are treated as static (compile-time constant) in the jitted function. Boolean arguments, numerical arguments used in comparisons (resulting in a bool), strings must be marked as static.

Calling the jitted function with different values for these constants will trigger recompilation. As a rule of thumb:

  • Good static arguments: training (boolean), my_mode (int that can take only a few values), …

  • Bad static arguments: training_step (int that can take a lot of values)

Let’s look at an example with BatchNorm which takes a training argument:

import objax

net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)])

@objax.Function.with_vars(net.vars())
def f(x, training):
    return net(x, training=training)

jit_f_static = objax.Jit(f, static_argnums=(1,))
# Note the static_argnums=(1,) which indicates that argument 1 (training) is static.

x = objax.random.normal((100, 2))
print(((net(x, training=False) - jit_f_static(x, False)) ** 2).sum())  # 0.0

What happens if you don’t use static_argnums?

jit_f = objax.Jit(f)
y = jit_f(x, False)
# Traceback (most recent call last):
#   File <...>
#   <long stack trace>
# jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
# Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
# See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
# Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/1)>

To cut a long story short: when compiling boolean inputs must be made static.

For more information, please refer to jax.jit which is the API Objax uses under the hood.

Constant optimization

As seen previously, objax.Jit takes a variables argument to specify the variables of a function or of a module that Jit is compiling.

If a variable is not passed to Jit it will be treated as a constant and will be optimized away.

Warning

A jitted module will not see any change made to a constant. A constant is not expected to change since it is supposed to be… constant!

A simple constant optimization example:

import objax

m = objax.nn.Linear(3, 4)
# Pass an empty VarCollection to signify to Jit that m has no variable.
jit_constant = objax.Jit(m, objax.VarCollection())

x = objax.random.normal((10, 3))
print(((m(x) - jit_constant(x)) ** 2).sum())  # 0.0

# Modify m (which was supposed to be constant!)
m.b.assign(m.b.value + 1)
print(((m(x) - jit_constant(x)) ** 2).sum())  # 40.0
# As expected jit_constant didn't see the change.

Warning

The XLA backend (the interface to the hardware) will do the constant optimization and may take a long time and a lot of memory due to compilation, often with very little gain in final performance, if any.

Parallelism

Note

If you don’t have multiple devices, you can simulate them on CPU by starting python with the following command:

CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=8 python

Alternatively you can do it in Python directly by inserting this snippet before importing Objax:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

objax.Parallel provides a way to distribute computations across multi-GPU (or TPU). It also performs JIT under the hood and its API shares a lot with objax.Jit: It takes a function to be compiled, a VarCollection as well as a static_argnums parameters which all behave the same as in Jit. However it also takes specific arguments for the task of handling parallelism which we are going to introduce.

When running a parallelized a function f on a batch \(x\) of shape \((n, ...)\) on \(d\) devices, the following steps happen:

  1. The batch \(x\) is divided into \(d\) sub-batches \(x_i\) of shape \((n/d, ...)\) for \(i\in\{0, ..., d-1\}\)

  2. Each sub-batch \(x_i\) is passed to f and ran on device \(i\) in parallel.

  3. The results are collected as output sub-subatches \(y_i=f(x_i)\)

  4. The outputs \(y_i\) are represented as a single tensor \(y\) of shape \((d, ...)\)

  5. The final output is obtained by calling the reduce function on \(y\): out = reduce(y).

With this in mind, we can now detail the additional arguments of objax.Parallel:

  • reduce: a function that aggregates the output results from each GPU/TPU.

  • axis_name: is the name of the device dimension which we referred to as \(d\) earlier. By default, it is called 'device'.

Let’s illustrate this with a simple example with the parallelization of a module (para_net) and of a function (para_func):

# This code was run on 8 simulated devices
import objax

net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_net = objax.Parallel(net)
para_func = objax.Parallel(objax.Function(lambda x: net(x) + 1, net.vars()))

# A batch of mockup data
x = objax.random.normal((96, 3))

# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
    y = para_net(x)
    z = para_func(x)

print(((net(x) - y) ** 2).sum())        # 8.90954e-14
print(((net(x) - (z - 1)) ** 2).sum())  # 4.6487814e-13

We can also show the parallel version of A realistic case: Fully jitted training step, highlighted are the changes from the jitted version:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())

@objax.Function.with_vars(m.vars())
def loss(x, labels):
    logits = m(x)
    return objax.functional.loss.cross_entropy_logits(logits, labels).mean()

gradient_loss = objax.GradValues(loss, m.vars())

@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
    g, v = gradient_loss(x, labels)                     # Compute gradients and loss
    opt(lr, objax.functional.parallel.pmean(g))        # Apply averaged gradients
    return objax.functional.parallel.pmean(v)          # Return averaged loss value

# It's better to parallelize the top level call to allow internal optimizations.
train_para = objax.Parallel(train, reduce=lambda x:x[0])

Let’s study the concepts introduced in this example in more details.

Variable replication

Variable replication copies the variables to multiple devices’ own memory. It is necessary to do variable replication before calling a parallelized module or function. Variable replication is done through objax.VarCollection.replicate() which is a context manager. One could go further and creating their own replication, this is not covered here but the source of replicate is rather simple and a good starting point.

Here is a detailed example:

# This code was run on 8 simulated devices
import objax
import jax.numpy as jn

m = objax.ModuleList([objax.TrainVar(jn.arange(5))])
# We use "repr" to see the whole type information.
print(repr(m[0].value))  # DeviceArray([0, 1, 2, 3, 4], dtype=int32)

with m.vars().replicate():
    # In the scope of the with-statement, the variables are replicated to all devices.
    print(repr(m[0].value))
    # ShardedDeviceArray([[0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4],
    #                     [0, 1, 2, 3, 4]], dtype=int32)
    # SharedDeviceArray is a DeviceArray across multiple devices.

# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value))  # DeviceArray([0., 1., 2., 3., 4.], dtype=float32)

Something interesting happened: the value of m[0] was initially of type integer but it became a float by the end. This is due to the reduction that follows a replication. By default, the reduction method takes the average of the copies on each device. And the average of multiple integer values is casted automatically to a float.

You can customize the variable reduction, this is not something one typically would need to do but it’s available for advanced users nonetheless:

# This code was run on 8 simulated devices
import objax
import jax.numpy as jn

m = objax.ModuleList([objax.TrainVar(jn.arange(5), reduce=lambda x: x[0]),
                      objax.TrainVar(jn.arange(5), reduce=lambda x: x.sum(0)),
                      objax.TrainVar(jn.arange(5), reduce=lambda x: jn.stack(x))])
print(repr(m[0].value))  # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value))  # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[2].value))  # DeviceArray([0, 1, 2, 3, 4], dtype=int32)

with m.vars().replicate():
    pass

# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value))  # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value))  # DeviceArray([ 0,  8, 16, 24, 32], dtype=int32)
print(repr(m[2].value))  # DeviceArray([[0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4],
                         #              [0, 1, 2, 3, 4]], dtype=int32)
Output aggregation

Similarly the output \(y\) of parallel call is reduced using the reduce argument. The first dimension \(d\) of \(y\) is the device dimension and its name comes from the axis_name argument while by default is simply "device".

Again, let’s look at a simple example:

# This code was run on 8 simulated devices
import objax
import jax.numpy as jn

net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_none = objax.Parallel(net, reduce=lambda x: x)
para_first = objax.Parallel(net, reduce=lambda x: x[0])
para_concat = objax.Parallel(net, reduce=lambda x: jn.concatenate(x))
para_average = objax.Parallel(net, reduce=lambda x: x.mean(0))

# A batch of mockup data
x = objax.random.normal((96, 3))

# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
    print(para_none(x).shape)     # (8, 12, 4)
    print(para_first(x).shape)    # (12, 4)
    print(para_concat(x).shape)   # (96, 4)  - This is the default setting
    print(para_average(x).shape)  # (12, 4)

In the example above, the batch x (of size 96) was divided into 8 batches of size 12 by the parallel call. Each of these batches was processed on its own device. The final value was then reduced using the provided reduce method.

  • para_none didn’t do any reduction, it just returned the value it was given, as expected is shape is (devices, batch // devices, ...).

  • para_first and para_mean took either the first entry or the average over dimension 0, resulting in a shape (batch // devices, ...).

  • para_concat concatenated all the values resulting in a shape of (batch, ...).

Synchronized computations

So far, we only considered the case where all the devices were acting on their own, unaware of others’ existence. It’s commonly desirable for devices to communicate with each other.

For example, when training a model, for efficiency one would want the optimizer to update the weights on all the devices at the same time. To achieve this, we would like the gradients to be computed for each sub-batch on the device, and then averaged across all devices.

The good news is it is very easy to do, there are a set of predefined primitives that can be found in objax.functional.parallel which are the direct equivalent of single device primitives:

Recalling the code for the parallelized train operation:

@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
    g, v = gradient_loss(x, labels)                     # Compute gradients and loss
    opt(lr, objax.funcational.parallel.pmean(g))        # Apply averaged gradients
    return objax.funcational.parallel.pmean(v)          # Return averaged loss value

The train function is called on each device in parallel. The objax.funcational.parallel.pmean(g) averages the gradients g on all devices. Then on each device, the optimizer applies the averaged gradient to the local weight copy. Finally the average loss is returned objax.funcational.parallel.pmean(v).

Vectorization

objax.Vectorize is the module responsible for code vectorization. Vectorization can be seen as a parallelization without knowledge of the devices available. On a single GPU, vectorization parallelizes the execution in concurrent threads. It can be combined with objax.Parallel resulting in multi-GPU multi-threading! Vectorization can also be done on a single CPU. A typical example of CPU vectorization could data pre-processing or augmentation.

In its simplest form vectorization applies a function to the elements of a batch concurrently. objax.Vectorize takes a module or a function f and vectorizes it. Similarly to Jit and Parallel you must specify the variables used by the function. Finally batch_axis is used to say which axis should be considered as the batch axis for each input argument of f. For values with no batch axis, for example when passing a value to be shared by all the calls to the function f, set its batch axis to None to broadcast it.

Let’s clarify this with a simple example:

# Randomly reverse rows in a batch.
import objax
import jax.numpy as jn

class RandomReverse(objax.Module):
    """Randomly reverse a single vector x and add a value y to it."""

    def __init__(self, keygen=objax.random.DEFAULT_GENERATOR):
        self.keygen = keygen

    def __call__(self, x, y):
        r = objax.random.randint([], 0, 2, generator=self.keygen)
        return x + y + r * (x[::-1] - x), r, y

random_reverse = RandomReverse()
vector_reverse = objax.Vectorize(random_reverse, batch_axis=(0, None))
# vector_reverse takes two arguments (just like random_reverse), we're going to pass:
# - a matrix x for the first argument, interpreted as a batch of vectors (batch_axis=0).
# - a value y for the second argument, interpreted as a broadcasted value (batch_axis=None).

# Test it on some mock up data
x = jn.arange(20).reshape((5, 4))
print(x)  # [[ 0  1  2  3]
          #  [ 4  5  6  7]
          #  [ 8  9 10 11]
          #  [12 13 14 15]
          #  [16 17 18 19]]

objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, 1)
print(r)  # [0 1 0 1 1] - whether a row was reversed
print(y)  # [1 1 1 1 1] - the brodacasted input y
print(z)  # [[ 1  2  3  4]
          #  [ 8  7  6  5]
          #  [ 9 10 11 12]
          #  [16 15 14 13]
          #  [20 19 18 17]]

# Above we added a single constant (y=1)
# We can also add a vector y=(-2, -1, 0, 1)
objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, jn.array((-2, -1, 0, 1)))
print(r)  # [0 1 0 1 1] - whether a row was reversed
print(y)  # [[-2 -1  0  1] - the brodacasted input y
          #  [-2 -1  0  1]
          #  [-2 -1  0  1]
          #  [-2 -1  0  1]
          #  [-2 -1  0  1]]
print(z)  # [[-2  0  2  4]
          #  [ 5  5  5  5]
          #  [ 6  8 10 12]
          #  [13 13 13 13]
          #  [17 17 17 17]]
Computing weights gradients per batch entry

This is a more advanced example, conceptually it is similar to what’s powering differential privacy gradients:

import objax

m = objax.nn.Linear(3, 4)

@objax.Function.with_vars(m.vars())
def loss(x, y):
    return ((m(x) - y) ** 2).mean()

g = objax.Grad(loss, m.vars())
single_gradients = objax.Vectorize(g, batch_axis=(0, 0))  # Batch is dimension of x and y

# Mock some data
x = objax.random.normal((10, 3))
y = objax.random.normal((10, 4))

# Compute standard gradients
print([v.shape for v in g(x, y)])              # [(4,), (3, 4)]

# Compute per batch entry gradients
print([v.shape for v in single_gradients(x, y)])   # [(10, 4), (10, 3, 4)]

As expected, we obtained as many gradients for each of the network’s weights as there are entries in the batch.

Loading and Saving

Being able to load and save the weights of a model, or a model itself (e.g. the weights and the function itself) is essential for machine learning purposes. In this section we describe how to load/save the weights and also how to save an entire model. Furthermore we discuss how to keep multiple saves, a concept known as checkpointing, which is typically used for resuming interrupted training sessions.

Saving and loading model weights

Loading and saving is done on objax.VarCollection objects. Such objects are returned by the objax.Module.vars() method or can be constructed manually if one wishes to. The saving method uses numpy .npz format which in essence stores tensors in a zip file.

Here’s a simple example:

import objax

# Let's pretend we have a neural network net and we want to save it.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])

# Saving only takes one line.
objax.io.save_var_collection('net.npz', net.vars())

# Let's modify the bias of the Linear layer
net[0].b.assign(net[0].b.value + 1)
print(net[0].b.value.sum())         # 1.0

# Loading
objax.io.load_var_collection('net.npz', net.vars())
print(net[0].b.value.sum())         # 0.0

Note that in the example above we used a filename to specify where to save the weights. These APIs also accept a file descriptor, so another way to save would be:

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())

Note

The advantage of using a filename instead of file handle is that data will be written to a temporary file first and the temporary file will be renamed to provided filename only after all data has been written. In the event of the program being killed, this prevents from having truncated files. When using a file descriptor the code does not have this protection. File descriptors are typically used for unit testing.

Custom saving and loading

You can make your own saving and loading functions easily. In essence saving has to store pairs of (name, numpy array), loading must provide a numpy array for the variables of the objax.VarCollection. The only gotcha to pay attention to is to avoid saving duplicated information such as shared weights under different names or variable references TrainRef. Since the code for loading and saving is very concise, simply looking at it is the best example.

Checkpointing

Checkpointing can be defined as saving neural network weights during training. Often checkpointing keeps multiple saves, each from different training steps. For space reasons, it’s common to keep only the latest-k saves. Checkpointing can be used for a variety of purposes:

  • Resuming training after the program was interrupted.

  • Keeping multiple copies of the network for weight averaging strategies.

Objax provides a simple checkpointing interface called objax.io.Checkpoint, here’s an example:

import objax

# Let's pretend we have a neural network net and we want to save it.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])

# This time we use the Checkpoint class
ckpt = objax.io.Checkpoint(logdir='save_folder', keep_ckpts=5)

# Saving
ckpt.save(net.vars(), idx=1)
net[0].b.assign(net[0].b.value + 1)
ckpt.save(net.vars(), idx=2)

# Restoring
ckpt.restore(net.vars(), idx=1)   # net[0].b.value = (0,)
ckpt.restore(net.vars(), idx=2)   # net[0].b.value = (1,)

# When no epoch is specified use latest checkpoint (e.g. 2 here)
idx, file = ckpt.restore(net.vars())
print(idx, file)  # 2 save_folder/ckpt/0000000002.npz
Customized checkpointing

The objax.io.Checkpoint class has some constants that allow it to customize its behavior. You can redefine them for example creating a child class that inherits from Checkpoint. The fields are the following:

class Checkpoint:
    DIR_NAME: str = 'ckpt'
    FILE_MATCH: str = '*.npz'
    FILE_FORMAT: str = '%010d.npz'
    LOAD_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(load_var_collection)
    SAVE_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(save_var_collection)

This lets you change the folder name where the checkpoints are saved, the file extension and the numbering format. If you have your own saving and loading functions, you can also replace them. Remember to wrap them in staticmethod since they don’t depend on the Checkpoint class itself.

Saving a module

Warning

Python pickle is not *security* safe. Only use it for your own saves and loads. Any pickle coming from an external source is a potential risk.

Now that we warned you, let’s mention that Objax modules can be pickled with Python’s pickle module like many other Python objects. This can be quite convenient since you can save not only the module’s weight, but the module itself.

Let’s look at a simple example:

import pickle
import objax

# Let's pretend we have a neural network net and we want to save it as whole.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])

# Pickling
pickle.dump(net, open('net.pickle', 'wb'))

# Unpickling and storing into a new network
net2 = pickle.load(open('net.pickle', 'rb'))

# Confirm the network net2 has the same function as net
x = objax.random.normal((100, 768))
print(((net(x) - net2(x)) ** 2).mean())  # 0.0

# Confirm the network net2 does not share net's weights
net[0].b.assign(net[0].b.value + 1)
print(((net(x) - net2(x)) ** 2).mean())  # 0.038710583

As the example shows, pickling is really easy to use. Be aware that Python pickling has some limitations, namely lambda functions cannot always be saved (they have to be named). Objax is not limited to pickle, since its design is pythonic it should be compatible with other python pickling systems.

Development setup

This section describes some basic setup to start developing and extending Objax.

Environment setup

First of all you need to install all necessary dependencies. We recommend to setup a separate virtualenv to work on Objax, it could be done with following commands on Ubuntu or similar Linux distribution:

# Install virtualenv if you haven't done so already
sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv pandoc
# Create a virtual environment (for example ~/.venv/objax, you can use your name here)
virtualenv -p python3 --system-site-packages ~/.venv/objax
# Start the virtual environment
. ~/.venv/objax/bin/activate

# Clone objax git repository, if you haven't.
git clone https://github.com/google/objax.git
cd objax

# Install python dependencies.
pip install --upgrade -r requirements.txt
pip install --upgrade -r tests/requirements.txt
pip install --upgrade -r docs/requirements.txt
pip install --upgrade -r examples/requirements.txt
pip install flake8

# If you have CUDA installed, specify your installed CUDA version.
CUDA_VERSION=11.0
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`
Running tests and linter

Run linter:

./tests/run_linter.sh

Run tests:

./tests/run_tests.sh

Running a single test:

CUDA_VISIBLE_DEVICES= python3 -m unittest tests/jit.py

Adding or changing Objax modules

This guide explains how to add a new module to Objax or change existing one.

In addition to this guide, consider looking at an example pull request which adds new module with documentation.

Writing code

When adding new module or function, you have to decide where to put it. Typical locations of new modules and functions are the following:

  • objax/functional contains various stateless functions (non-modules) which are used in machine learning. For example: loss functions, activations, stateless ops.

  • objax/io contains routines for model saving and loading.

  • objax/nn contains layers, which serve as building blocks for neural network. It also contains initializes for layer parameters.

  • objax/optimizer contains optimizers.

  • objax/privacy contains code for privacy-preserving training of neural networks.

  • objax/random contains routines for random number generation.

  • objax/zoo is a “model zoo” of various well-known neural network architectures.

When writing code we follow PEP8 style guide with the following two exceptions:

  • We maximum line length to 120 characters.

  • We allow to assign lambda, e.g. f = lambda x: x

  • In addition we recommend trying to keep APIs ordered alphabetically when feasible within source code files.

Note: Remember to add your new APIs in __all__ variable (also ordered alphabetically) at the top of the file to make them visible.

Script ./tests/run_linter.sh automatically checks majority of code style violations. PyCharm code formatter could be used to automatically reformat code.

Writing unit tests

Unit tests are required for most code changes and new code of Objax library. However tests are not required for examples.

All unit tests are placed into tests directory. They are grouped into different files based on what they are testing. For example tests/conv.py contains unit tests for convolution modules.

We use Python unittest module for tests.

Writing documentation

Documentation for specific APIs is written inside docstrings within the code (example for Conv2D).

Other documentation is stored in docs subdirectory of Objax repository. It uses reST as a markup language, and Sphinx automatically generates documentation for objax.readthedocs.io.

Docstrings

All public facing classes, functions and class methods should have a short docstring describing what they are doing. Functions and methods should also have a description of their arguments and return value.

To keep code easy to read, we recommend to write short and concise docstrings:

  • Try to keep description of classes and functions with 1-5 lines.

  • Try to fit description of each argument of each function within 1-2 lines.

  • Avoid putting examples or long descriptions into docstrings, those should go into reST docs.

Here is an example of how to write docstring for a function:

def cross_entropy_logits(logits: JaxArray, labels: JaxArray) -> JaxArray:
    """Computes the cross-entropy loss.

    Args:
        logits: (batch, #class) tensor of logits.
        labels: (batch, #class) tensor of label probabilities (e.g. labels.sum(axis=1) must be 1)

    Returns:
        (batch,) tensor of the cross-entropies for each entry.
    """
    return logsumexp(logits, axis=1) - (logits * labels).sum(1)

If you are only updating existing docstrings these changes will be automatically reflected in objax.readthedocs.io after pull request is merged into repository. When adding docstrings for new classes and functions, you also may need to update reST files as described below.

reST documentation

Updates of reST files are required either when new APIs are added (new function, new module) or when other (non API) documentation is needed.

Most of the API documentation is located in docs/source/objax They are grouped into different .rst files by the name of python package, for example docs/source/objax/nn.rst contains documentation for objax.nn package.

To add new class or function, you typically need to add name of the class or function into autosummary section and add autoclass or autofunction section for new class/function. Here is an example of changes which are needed to add documentation for Conv2D module:

.. autosummary::

  ...
  Conv2D
  ...

...

.. autoclass:: Conv2D
    :members:

    Additional documentation (non-docstrings) for Conv2D goes here.

For reference about reST syntax, refer to reST documentation or cheat sheet.

Indices and tables