Interactive online version: Open In Colab

Objax Basics Tutorial

This tutorial introduces basic Objax concepts.


Objax is a machine learning library written in Python which works on top of JAX. Readers should therefore have some familiarity with following:

  • Python. If you’re new to Python or need a refresher check Python Tutorial or Python Beginner’s Guide.

  • NumPy is a library for mathematical computations. Many of JAX primitives are based on NumPy and have the same syntax. NumPy is also useful for data manipulation outside of JAX/Objax. NumPy quickstart covers most of the needed topics. More information can be found on NumPy documentation site.

  • JAX can be described as NumPy with gradients which runs on accelerators (GPU and TPU). The JAX quickstart covers most of the concepts needed to understand Objax.

Installation and imports

Let’s first install Objax:

%pip --quiet install objax

After Objax is installed, you can import all necessary modules:

import jax.numpy as jn
import numpy as np
import objax


Tensors are essentially multi-dimensional arrays. In JAX and Objax tensors can be placed on GPUs or TPUs to accelerate computations.

Objax relies on the jax.numpy.ndarray primitive from JAX to represent tensors. In turn, this primitive has a very similar API to NumPy ndarray.

Creating tensors

Tensors creation is very similar to NumPy and is done in multiple ways:

  1. Provide explicit values to the tensor:

# Providing explicit values
jn.array([[1.0, 2.0, 3.0],
          [4.0, 5.0, 6.0]])
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float32)
  1. From a NumPy array:

arr = np.array([1.0, 2.0, 3.0])
DeviceArray([1., 2., 3.], dtype=float32)
  1. From another JAX tensor:

another_tensor = jn.array([[1.0, 2.0, 3.0],
                           [4.0, 5.0, 6.0]])
DeviceArray([[1., 2., 3.],
             [4., 5., 6.]], dtype=float32)
  1. Using ones or zeros:

jn.ones((3, 4))
DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)
jn.zeros((4, 5))
DeviceArray([[0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.]], dtype=float32)
  1. As a result of a mathematical operation performed on other tensors:

t1 = jn.array([[1.0, 2.0, 3.0],
               [4.0, 5.0, 6.0]])
t2 = jn.ones(t1.shape) * 3
t1 + t2
DeviceArray([[4., 5., 6.],
             [7., 8., 9.]], dtype=float32)

Tensor Properties

Similar to NumPy, one can explore various properties of tensors like shape, number of dimensions, or data type:

t = jn.array([[1.0, 2.0, 3.0],
              [4.0, 5.0, 6.0]])
print('Number of dimensions: ', t.ndim)
print('Shape: ', t.shape)
print('Data type: ', t.dtype)
Number of dimensions:  2
Shape:  (2, 3)
Data type:  float32

Converting tensors to numpy array

Objax/JAX tensors can be converted to NumPy arrays when needed to perform computations with NumPy:

array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float32)

Tensors are immutable

One big difference between JAX ndarray and NumPy ndarray is that JAX ndarray is immutable:

print('Original tensor t:\n', t)

    t[0, 0] = -5.0  # This line will fail
except Exception as e:
    print(f'Exception {e}')

print('Tensor t after failed attempt to update:\n', t)
Original tensor t:
 [[1. 2. 3.]
 [4. 5. 6.]]
Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
Tensor t after failed attempt to update:
 [[1. 2. 3.]
 [4. 5. 6.]]

Instead of updating an existing tensor, a new tensor should be created with updated elements. Updates of individual tensor elements is done using the at property:

import jax.ops

print('Original tensor t:\n', t)
new_t =[0,0].set(-5.0)
print('Tensor t after update stays the same:\n', t)
print('Tensor new_t has updated value:\n', new_t)
Original tensor t:
 [[1. 2. 3.]
 [4. 5. 6.]]
Tensor t after update stays the same:
 [[1. 2. 3.]
 [4. 5. 6.]]
Tensor new_t has updated value:
 [[-5.  2.  3.]
 [ 4.  5.  6.]]

More details about per-element updates of tensors can be found in JAX documentation.

In practice, most mathematical operations cover tensors as a whole while manual per-element updates are rarely needed.

Random numbers

It’s very easy to generate random tensors in Objax:

x = objax.random.normal((3, 4))
[[-0.1441347   0.89507747 -0.46038115  0.10503326]
 [-0.7460886   0.89681065  0.38794124  0.11750659]
 [ 1.0659382   0.22656879 -2.548792    1.9700414 ]]

There are multiple primitives for doing so:

print('Random integers:', objax.random.randint((4,), low=0, high=10))
print('Random normal:', objax.random.normal((4,), mean=1.0, stddev=2.0))
print('Random truncated normal: ', objax.random.truncated_normal((4,), stddev=2.0))
print('Random uniform: ', objax.random.uniform((4,)))
Random integers: [1 3 4 7]
Random normal: [-3.4102912  -1.8277478  -0.02106905  0.65338284]
Random truncated normal:  [ 0.78602946  2.3004575  -0.22719319  0.22819921]
Random uniform:  [0.19456518 0.4642099  0.94732213 0.57298625]

Objax Variables and Modules

Objax Variables store values of tensors. Unlike tensors variables are mutable, i.e. the value which is stored in the variable can change. Since tensors are immutable, variables change their value by replacing it with new tensors.

Variables are commonly used together with modules. The Module is a basic building block in Objax that stores variables and other modules. Also most modules are typically callable (i.e., implement the __call__ method) and when called perform some computations on their variables and sub-modules.

Here is an example of a simple module with one variable which performs the dot product of that variable with an input tensor:

class SimpleModule(objax.Module):

    def __init__(self, length):
        self.v1 = objax.TrainVar(objax.random.normal((length,)))
        self.v2 = objax.TrainVar(jn.ones((2,)))

    def __call__(self, x):
        return, self.v1)

m = SimpleModule(3)

Modules keep track of all variables they own, including variables in sub-modules. The .vars() method list all the module’s variables. The method returns an instance of VarCollection which is a dictionary with several other useful methods.

module_vars = m.vars()

print('type(module_vars): ', type(module_vars))
print('isinstance(module_vars, dict): ', isinstance(module_vars, dict))

print('Variable names and shapes:')

print('Variable names and values:')
for k, v in module_vars.items():
  print(f'{k}      {v.value}')
type(module_vars):  objax.variable.VarCollection
isinstance(module_vars, dict):  True

Variable names and shapes:
(SimpleModule).v1           3 (3,)
(SimpleModule).v2           2 (2,)
+Total(2)                   5

Variable names and values:
(SimpleModule).v1      [-1.1010289  -0.68184537 -0.95236546]
(SimpleModule).v2      [1. 1.]

If the __call__ method of the module takes tensors as input and outputs tensors then it can act as a mathematical function. In the general case __call__ can be a multivariate vector-values function.

The SimpleModule described above takes a vector of size length as input and outputs a scalar:

x = jn.ones((3,))
y = m(x)
print('Input: ', x)
print('Output: ', y)
Input:  [1. 1. 1.]
Output:  -2.7352397

The way works allows us to run code on 2D tensors as well. In this case SimpleModule will treat the input as a batch of vectors, perform the dot product on each of them and return a vector with the results:

x = jn.array([[1., 1., 1.],
              [1., 0., 0.],
              [0., 1., 0.],
              [0., 0., 1.]])
y = m(x)
print('Input:\n', x)
print('Output:\n', y)
 [[1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
 [-2.7352397  -1.1010289  -0.68184537 -0.95236546]

For comparison, here is the result of calling module m on each row of tensor x:

print('Sequentially calling module on each row of 2D tensor:')
for idx in range(x.shape[0]):
  row_value = x[idx]
  out_value = m(row_value)
  print(f'm( {row_value} ) = {out_value}')
Sequentially calling module on each row of 2D tensor:
m( [1. 1. 1.] ) = -2.7352397441864014
m( [1. 0. 0.] ) = -1.1010289192199707
m( [0. 1. 0.] ) = -0.6818453669548035
m( [0. 0. 1.] ) = -0.9523654580116272

How to compute gradients

As shown above, modules can act as mathematical functions. It’s essential in machine learning to be able to compute gradients of functions and Objax provides a simple way to do this.

It’s important to keep in mind that gradients are usually defined for scalar-value functions, while our modules can be vector-valued. In this case we need to define additional functions which will convert vector-valued output of the module into scalar. Then we can compute gradients of scalar valued function with respect to all input variables.

In the example with SimpleModule above let’s define scalar-values loss function first:

def loss_fn(x):
    return m(x).sum()

print('loss_fn(x) = ', loss_fn(x))
loss_fn(x) =  -5.4704795

Then we create an objax.GradValues module which computes the gradients of loss_fn. We need to pass the function itself to the constructor of objax.GradValues as well as a VarCollection with the variables that loss_fn depends on:

# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)

gv is a module which returns the gradients of loss_fn and the values of loss_fn for the given input:

# gv returns both gradients and values of original function
grads, value = gv(x)

for g, var_name in zip(grads, module_vars.keys()):
    print(g, ' w.r.t. ', var_name)
print('Value: ', value)
[2. 2. 2.]  w.r.t.  (SimpleModule).v1
[0. 0.]  w.r.t.  (SimpleModule).v2

Value:  [DeviceArray(-5.4704795, dtype=float32)]

In the example above, grads is a list of gradients with respect to all variables from module_vars. The order of gradients in the grads list is the same as the order of corresponding variables in module_vars. So grads[0] is the gradient of the function w.r.t. m.v1 and grads[1] is the gradient w.r.t. m.v2.

Just-in-time compilation (JIT)

In the examples shown so far, the Python interpreter executes all operations one by one. This mode of execution becomes slow for larger and more complicated code.

Objax provides an easy and convenient way to compile sequence of operations using objax.Jit:

jit_m = objax.Jit(m)
y = jit_m(x)
print('Input:\n', x)
print('Output:\n', y)
 [[1. 1. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
 [-2.7352397  -1.1010289  -0.68184537 -0.95236546]

Objax.Jit can compile not only modules, but also functions and callables. In this case a variable collection should be passed to objax.Jit:

def loss_fn(x, y):
  return ((m(x) - y) ** 2).sum()

jit_loss_fn = objax.Jit(loss_fn, module_vars)

x = objax.random.normal((2, 3))
y = jn.array((-1.0, 1.0))

print('x:\n ', x)
print('y:\n', y)

print('loss_fn(x, y): ', loss_fn(x, y))
print('jit_loss_fn(x, y): ', jit_loss_fn(x, y))
  [[ 2.2491198   0.18783404  0.65321374]
 [-0.23017201  0.18411613  1.341197  ]]
 [-1.  1.]
loss_fn(x, y):  9.577398
jit_loss_fn(x, y):  9.577398

There is no need to use JIT if you only need to compute a single JAX operation. However JIT can give significant speedups when multiple Objax/JAX operations are chained together. The next tutorial will show examples of how JIT is used in practice.

Nevertherless the difference in execution speed with and without JIT is evident even in this simple example:

x = objax.random.normal((100, 3))
# gv is a module define above which compute gradients
jit_gv = objax.Jit(gv)
print('Timing for jit_gv:')
%timeit jit_gv(x)
print('Timing for gv:')
%timeit gv(x)
Timing for jit_gv:
1000 loops, best of 3: 309 µs per loop
Timing for gv:
100 loops, best of 3: 3.57 ms per loop

When to use JAX and when to use Objax primitives?

Attentive readers will notice that while Objax works on top of JAX, it redefines quite a few concepts from JAX. Some examples are:

  • objax.GradValues vs jax.value_and_grad for computing gradients.

  • objax.Jit vs jax.jit for just-in-time compilation.

  • objax.random vs jax.random to generate random numbers.

All these differences originate from the fact that JAX is a stateless functional framework, while Objax provides a stateful, object-oriented way to use JAX.

Mixing OOP and functional code can be quite confusing, thus we recommend to use JAX primitives only for basic mathematical operations (defined in jax.numpy) and use Objax primitives for everything else.

Next: Logistic Regression Tutorial

This tutorial introduces all concepts necessary to build and train a machine learning classifier. The next tutorial shows how to apply all of them in logistic regression.