Welcome to Objax’s documentation!¶
Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX – a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
Machine learning’s 'Hello world'
: optimizing the weights of classifier net
through gradient descent:
opt = objax.optimizer.Adam(net.vars())
def loss(x, y):
logits = net(x) # Output of classifier on x
xe = cross_entropy_logits(logits, y)
return xe.mean()
# Perform gradient descent wrt to net weights
gv = objax.GradValues(loss, net.vars())
def train_op(x, y):
g, v = gv(x, y) # returns gradients g and loss v
opt(lr, g) # update weights
return v
train_op = objax.Jit(train_op, net.vars() + opt.vars())
Objax philosophy¶
Objax pursues the quest for the simplest design and code that’s as easy as possible to extend without sacrificing performance.
– Objax Devs
Motivation¶
Researchers and students look at machine learning frameworks in their own way. Often they read the code of some technique, say an Adam optimizer, to understand how it works so they can extend it or design a new optimizer. This is how machine learning frameworks differ from standard libraries: a large class of users not only look at the APIs but also at the code behind these APIs.
Coded for simplicity¶
Source code should be understandable by everyone, including users without background in computer science. So how simple is it really? Judge for yourself with this tutorial: Logistic Regression.
Object-oriented¶
It is common in machine learning to separate the inputs (\(X\)) from the parameters (\(\theta\)) of a function \(f(X; \theta)\). Math notation captures this difference by using a semi-colon to semantically separate the first group of arguments from the other.
Objax represents this semantic distinction through objax.Module
:
the module’s parameters \(\theta\) are attributes of the form
self.w, ...
inputs \(X\) are method arguments such as
def __call__(self, x, y, ...):
Designed for flexibility¶
Objax minimizes the number of abstractions users need to understand. There are two main ones: Modules and Variables. Everything is built out of these two basic classes. You can read more about this in Variables and Modules.
Engineered for performance¶
In machine learning, performance is essential. Every second counts. Objax makes it count by using the JAX/XLA engine that also powers TensorFlow. Read more about this in Compilation and Parallelism.
Installation and Setup¶
User installation¶
Install using pip
with the following command:
pip install --upgrade objax
For GPU support, we assume you have already some version of CUDA installed. Here are the extra steps:
# Specify your installed CUDA version.
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl
Useful shell configurations¶
Here are a few useful options:
# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false
Testing your installation¶
You can run the code below to test your installation:
import jax
import objax
print(f'Number of GPUs {jax.device_count()}')
x = objax.random.normal((100, 4))
m = objax.nn.Linear(4, 5)
print('Matrix product shape', m(x).shape) # (100, 5)
x = objax.random.normal((100, 3, 32, 32))
m = objax.nn.Conv2D(3, 4, k=3)
print('Conv2D return shape', m(x).shape) # (100, 4, 32, 32)
If you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.
Installing examples¶
Clone the code repository:
git clone https://github.com/google/objax.git
cd objax/examples
Developer installation¶
For developing purpose we recommend using virtualenv
.
The setup in Ubuntu or similar Linux distributions is as follows:
# Install virtualenv if you haven't done so already
sudo apt install python3-dev python3-virtualenv python3-tk imagemagick virtualenv
# Create a virtual environment (for example ~/jax3, you can use your name here)
virtualenv -p python3 --system-site-packages ~/jax3
# Start the virtual environment
. ~/jax3/bin/activate
# Clone objax git repository.
git clone https://github.com/google/objax.git
cd objax
# Install python dependencies.
pip install --upgrade -r requirements.txt
pip install --upgrade -r docs/requirements.txt
pip install --upgrade -r examples/requirements.txt
# If you have CUDA installed, specify your installed CUDA version.
CUDA_VERSION=11.0
pip install --upgrade https://storage.googleapis.com/jax-releases/cuda`echo $CUDA_VERSION | sed s:\\\.::g`/jaxlib-`python3 -c 'import jaxlib; print(jaxlib.__version__)'`-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl
The current folder must be in PYTHONPATH
.
This can be done with the following command:
export PYTHONPATH=$PYTHONPATH:.
See also
Objax Basics Tutorial¶
This tutorial introduces basic Objax concepts.
Prerequisites¶
Objax is a machine learning library written in Python which works on top of JAX. Readers should therefore have some familiarity with following:
Python. If you’re new to Python or need a refresher check Python Tutorial or Python Beginner’s Guide.
NumPy is a library for mathematical computations. Many of JAX primitives are based on NumPy and have the same syntax. NumPy is also useful for data manipulation outside of JAX/Objax. NumPy quickstart covers most of the needed topics. More information can be found on NumPy documentation site.
JAX can be described as NumPy with gradients which runs on accelerators (GPU and TPU). The JAX quickstart covers most of the concepts needed to understand Objax.
Installation and imports¶
Let’s first install Objax:
[1]:
%pip --quiet install objax
After Objax is installed, you can import all necessary modules:
[2]:
import jax.numpy as jn
import numpy as np
import objax
Tensors¶
Tensors are essentially multi-dimentional arrays. In JAX and Objax tensors can be placed on GPUs or TPUs to accelerate computations.
Objax relies on the jax.numpy.ndarray
primitive from JAX to represent tensors. In turn, this primitive has a very similar API to NumPy ndarray
.
Creating tensors¶
Tensors creation is very similar to NumPy and is done in multiple ways:
Provide explicit values to the tensor:
[3]:
# Providing explicit values
jn.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
[3]:
DeviceArray([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)
From a NumPy array:
[4]:
arr = np.array([1.0, 2.0, 3.0])
jn.array(arr)
[4]:
DeviceArray([1., 2., 3.], dtype=float32)
From another JAX tensor:
[5]:
another_tensor = jn.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
jn.array(another_tensor)
[5]:
DeviceArray([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)
Using
ones
orzeros
:
[6]:
jn.ones((3, 4))
[6]:
DeviceArray([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)
[7]:
jn.zeros((4, 5))
[7]:
DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32)
As a result of a mathematical operation performed on other tensors:
[8]:
t1 = jn.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
t2 = jn.ones(t1.shape) * 3
t1 + t2
[8]:
DeviceArray([[4., 5., 6.],
[7., 8., 9.]], dtype=float32)
Tensor Properties¶
Similar to NumPy, one can explore various properties of tensors like shape, number of dimensions, or data type:
[9]:
t = jn.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
print('Number of dimensions: ', t.ndim)
print('Shape: ', t.shape)
print('Data type: ', t.dtype)
Number of dimensions: 2
Shape: (2, 3)
Data type: float32
Converting tensors to numpy array¶
Objax/JAX tensors can be converted to NumPy arrays when needed to perform computations with NumPy:
[10]:
np.array(t)
[10]:
array([[1., 2., 3.],
[4., 5., 6.]], dtype=float32)
Tensors are immutable¶
One big difference between JAX ndarray
and NumPy ndarray
is that JAX ndarray
is immutable:
[11]:
print('Original tensor t:\n', t)
try:
t[0, 0] = -5.0 # This line will fail
except Exception as e:
print(f'Exception {e}')
print('Tensor t after failed attempt to update:\n', t)
Original tensor t:
[[1. 2. 3.]
[4. 5. 6.]]
Exception '<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?
Tensor t after failed attempt to update:
[[1. 2. 3.]
[4. 5. 6.]]
Instead of updating an existing tensor, a new tensor should be created with updated elements. Updates of individual tensor elements is done using index_update, index_add and some other JAX primitives:
[12]:
import jax.ops
print('Original tensor t:\n', t)
new_t = jax.ops.index_update(t, jax.ops.index[0, 0], -5.0)
print('Tensor t after update stays the same:\n', t)
print('Tensor new_t has updated value:\n', new_t)
Original tensor t:
[[1. 2. 3.]
[4. 5. 6.]]
Tensor t after update stays the same:
[[1. 2. 3.]
[4. 5. 6.]]
Tensor new_t has updated value:
[[-5. 2. 3.]
[ 4. 5. 6.]]
More details about per-element updates of tensors can be found in JAX documentation.
In practice, most mathematical operations cover tensors as a whole while manual per-element updates are rarely needed.
Random numbers¶
It’s very easy to generate random tensors in Objax:
[13]:
x = objax.random.normal((3, 4))
print(x)
[[-0.1441347 0.89507747 -0.46038115 0.10503326]
[-0.7460886 0.89681065 0.38794124 0.11750659]
[ 1.0659382 0.22656879 -2.548792 1.9700414 ]]
There are multiple primitives for doing so:
[14]:
print('Random integers:', objax.random.randint((4,), low=0, high=10))
print('Random normal:', objax.random.normal((4,), mean=1.0, stddev=2.0))
print('Random truncated normal: ', objax.random.truncated_normal((4,), stddev=2.0))
print('Random uniform: ', objax.random.uniform((4,)))
Random integers: [1 3 4 7]
Random normal: [-3.4102912 -1.8277478 -0.02106905 0.65338284]
Random truncated normal: [ 0.78602946 2.3004575 -0.22719319 0.22819921]
Random uniform: [0.19456518 0.4642099 0.94732213 0.57298625]
Objax Variables and Modules¶
Objax Variables store values of tensors. Unlike tensors variables are mutable, i.e. the value which is stored in the variable can change. Since tensors are immutable, variables change their value by replacing it with new tensors.
Variables are commonly used together with modules. The Module is a basic building block in Objax that stores variables and other modules. Also most modules are typically callable (i.e., implement the __call__
method) and when called perform some computations on their variables and sub-modules.
Here is an example of a simple module with one variable which performs the dot product of that variable with an input tensor:
[15]:
class SimpleModule(objax.Module):
def __init__(self, length):
self.v1 = objax.TrainVar(objax.random.normal((length,)))
self.v2 = objax.TrainVar(jn.ones((2,)))
def __call__(self, x):
return jn.dot(x, self.v1.value)
m = SimpleModule(3)
Modules keep track of all variables they own, including variables in sub-modules. The .vars()
method list all the module’s variables. The method returns an instance of VarCollection
which is a dictionary with several other useful methods.
[16]:
module_vars = m.vars()
print('type(module_vars): ', type(module_vars))
print('isinstance(module_vars, dict): ', isinstance(module_vars, dict))
print()
print('Variable names and shapes:')
print(module_vars)
print()
print('Variable names and values:')
for k, v in module_vars.items():
print(f'{k} {v.value}')
type(module_vars): objax.variable.VarCollection
isinstance(module_vars, dict): True
Variable names and shapes:
(SimpleModule).v1 3 (3,)
(SimpleModule).v2 2 (2,)
+Total(2) 5
Variable names and values:
(SimpleModule).v1 [-1.1010289 -0.68184537 -0.95236546]
(SimpleModule).v2 [1. 1.]
If the __call__
method of the module takes tensors as input and outputs tensors then it can act as a mathematical function. In the general case __call__
can be a multivariate vector-values function.
The SimpleModule
described above takes a vector of size length
as input and outputs a scalar:
[17]:
x = jn.ones((3,))
y = m(x)
print('Input: ', x)
print('Output: ', y)
Input: [1. 1. 1.]
Output: -2.7352397
The way jn.dot works allows us to run code on 2D tensors as well. In this case SimpleModule
will treat the input as a batch of vectors, perform the dot product on each of them and return a vector with the results:
[18]:
x = jn.array([[1., 1., 1.],
[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
y = m(x)
print('Input:\n', x)
print('Output:\n', y)
Input:
[[1. 1. 1.]
[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
Output:
[-2.7352397 -1.1010289 -0.68184537 -0.95236546]
For comparison, here is the result of calling module m
on each row of tensor x
:
[19]:
print('Sequentially calling module on each row of 2D tensor:')
for idx in range(x.shape[0]):
row_value = x[idx]
out_value = m(row_value)
print(f'm( {row_value} ) = {out_value}')
Sequentially calling module on each row of 2D tensor:
m( [1. 1. 1.] ) = -2.7352397441864014
m( [1. 0. 0.] ) = -1.1010289192199707
m( [0. 1. 0.] ) = -0.6818453669548035
m( [0. 0. 1.] ) = -0.9523654580116272
How to compute gradients¶
As shown above, modules can act as mathematical functions. It’s essential in machine learning to be able to compute gradients of functions and Objax provides a simple way to do this.
It’s important to keep in mind that gradients are usually defined for scalar-value functions, while our modules can be vector-valued. In this case we need to define additional functions which will convert vector-valued output of the module into scalar. Then we can compute gradients of scalar valued function with respect to all input variables.
In the example with SimpleModule
above let’s define scalar-values loss function first:
[20]:
def loss_fn(x):
return m(x).sum()
print('loss_fn(x) = ', loss_fn(x))
loss_fn(x) = -5.4704795
Then we create an objax.GradValues
module which computes the gradients of loss_fn
. We need to pass the function itself to the constructor of objax.GradValues
as well as a VarCollection
with the variables that loss_fn
depends on:
[21]:
# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)
gv
is a module which returns the gradients of loss_fn
and the values of loss_fn
for the given input:
[22]:
# gv returns both gradients and values of original function
grads, value = gv(x)
print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)
Gradients:
[2. 2. 2.] w.r.t. (SimpleModule).v1
[0. 0.] w.r.t. (SimpleModule).v2
Value: [DeviceArray(-5.4704795, dtype=float32)]
In the example above, grads
is a list of gradients with respect to all variables from module_vars
. The order of gradients in the grads
list is the same as the order of corresponding variables in module_vars
. So grads[0]
is the gradient of the function w.r.t. m.v1
and grads[1]
is the gradient w.r.t. m.v2
.
Just-in-time compilation (JIT)¶
In the examples shown so far, the Python interpreter executes all operations one by one. This mode of execution becomes slow for larger and more complicated code.
Objax provides an easy and convenient way to compile sequence of operations using objax.Jit
:
[23]:
jit_m = objax.Jit(m)
y = jit_m(x)
print('Input:\n', x)
print('Output:\n', y)
Input:
[[1. 1. 1.]
[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
Output:
[-2.7352397 -1.1010289 -0.68184537 -0.95236546]
Objax.Jit
can compile not only modules, but also functions and callables. In this case a variable collection should be passed to objax.Jit
:
[24]:
def loss_fn(x, y):
return ((m(x) - y) ** 2).sum()
jit_loss_fn = objax.Jit(loss_fn, module_vars)
x = objax.random.normal((2, 3))
y = jn.array((-1.0, 1.0))
print('x:\n ', x)
print('y:\n', y)
print('loss_fn(x, y): ', loss_fn(x, y))
print('jit_loss_fn(x, y): ', jit_loss_fn(x, y))
x:
[[ 2.2491198 0.18783404 0.65321374]
[-0.23017201 0.18411613 1.341197 ]]
y:
[-1. 1.]
loss_fn(x, y): 9.577398
jit_loss_fn(x, y): 9.577398
There is no need to use JIT if you only need to compute a single JAX operation. However JIT can give significant speedups when multiple Objax/JAX operations are chained together. The next tutorial will show examples of how JIT is used in practice.
Nevertherless the difference in execution speed with and without JIT is evident even in this simple example:
[25]:
x = objax.random.normal((100, 3))
# gv is a module define above which compute gradients
jit_gv = objax.Jit(gv)
print('Timing for jit_gv:')
%timeit jit_gv(x)
print('Timing for gv:')
%timeit gv(x)
Timing for jit_gv:
1000 loops, best of 3: 309 µs per loop
Timing for gv:
100 loops, best of 3: 3.57 ms per loop
When to use JAX and when to use Objax primitives?¶
Attentive readers will notice that while Objax works on top of JAX, it redefines quite a few concepts from JAX. Some examples are:
objax.GradValues
vsjax.value_and_grad
for computing gradients.objax.Jit
vsjax.jit
for just-in-time compilation.objax.random
vsjax.random
to generate random numbers.
All these differences originate from the fact that JAX is a stateless functional framework, while Objax provides a stateful, object-oriented way to use JAX.
Mixing OOP and functional code can be quite confusing, thus we recommend to use JAX primitives only for basic mathematical operations (defined in jax.numpy
) and use Objax primitives for everything else.
Next: Logistic Regression Tutorial¶
This tutorial introduces all concepts necessary to build and train a machine learning classifier. The next tutorial shows how to apply all of them in logistic regression.
Logistic Regression¶
In this example we will see how to classify images as horses or people using logistic regression. The tutorial builds upon the concepts introduced in the Objax basics tutorial. Consider reading that tutorial first, or you can always go back.
Imports¶
First, we import the modules we will use in our code.
[1]:
%pip --quiet install objax
import matplotlib.pyplot as plt
import os
import numpy as np
import tensorflow_datasets as tfds
import objax
from objax.util import EasyDict
Loading the Dataset¶
Next, we will load the “horses_or_humans” dataset from TensorFlow DataSets.
The prepare
method downscales the image by 3x to reduce training time, flattens each image to a vector, and rescales each pixel value to [-1,1].
[2]:
# Data: train has 1027 images - test has 256 images
# Each image is 300 x 300 x 3 bytes
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR))
def prepare(x, downscale=3):
"""Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them."""
s = x.shape
x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3]))
return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1
train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data
Visualizing the Data¶
Let’s see a couple of the images in the dataset and their corresponding labels. Note that label 0
corresponds to a horse, while label 1
corresponds to a human.
[3]:
#sample image of a horse.
horse_image = np.reshape(train.image[0], [100,100,3])
plt.imshow(horse_image)
print("label for horse_image:", train.label[0])
label for horse_image: 0

[4]:
#sample image of a human.
human_image = np.reshape(train.image[9], [100, 100, 3])
plt.imshow(human_image)
print("label for human_image:", train.label[9])
label for human_image: 1

Model Definition¶
objax.nn.Linear(ndim, 1)
is a linear neural unit with ndim
inputs and a single output. Given input \(\mathbf{X}\), the output is equal to \(\mathbf{W}\mathbf{X} + \mathbf{b}\) where \(\mathbf{W}, \mathbf{b}\) are the model’s parameters. These parameters are available through model.vars()
[5]:
# Settings
lr = 0.0001 # learning rate
batch = 256
epochs = 20
model = objax.nn.Linear(ndim, 1)
print(model.vars())
(Linear).b 1 (1,)
(Linear).w 30000 (30000, 1)
+Total(2) 30001
Model Inference¶
Now that we have defined the model we can use to classify images. To do so, we call the model with an image from the train
dataset we previously prepared. Notice that we use the image of a human we previously visualized.
We get the output of the model by calling model()
. We then apply the sigmoid activation function and round the output. Activation outputs lower than or equal to 0.5 are rounded to zero (i.e., horses) whereas outputs larger than 0.5 are rounded to one (i.e., humans).
[6]:
# This is an image of a human.
print(np.round(objax.functional.sigmoid(model(train.image[9]))))
[1.]
Considering that we initialized the model with random weights, it should not come as a surprise that the model may misclassy a human as a horse.
Optimizer and Loss Function¶
In this example we use the objax.optimizer.SGD
optimizer. Next, we define the loss function we will use to optimize the network. In this case we use the cross entropy loss function. Note that we use objax.functional.loss.sigmoid_cross_entropy_logits
because we perform binary classification.
[7]:
opt = objax.optimizer.SGD(model.vars())
# Cross Entropy Loss
def loss(x, label):
return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()
Back Propagation and Gradient Descent¶
objax.GradValues
calculates the gradient of loss
wrt model.vars()
. If you want to learn more about gradients read the Understanding Gradients in-depth topic.
The train_op
function implements the core of backward propagation and gradient descent. First, we calculate the gradient g
and then pass it to the optimizer which updates the model’s weights.
[8]:
gv = objax.GradValues(loss, model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
Training and Evaluation Loop¶
For each of the training epochs we process all the training data, contained in the train dictionary, in batches of batch
size. At the end of each epoch we compute the classification accuracy by comparing the model’s predictions over the test data to the ground truth labels.
[9]:
for epoch in range(epochs):
# Train
avg_loss = 0
# randomly shuffle training data
shuffle_idx = np.random.permutation(train.image.shape[0])
for it in range(0, train.image.shape[0], batch):
sel = shuffle_idx[it: it + batch]
avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * len(sel)
avg_loss /= it + batch
# Eval
accuracy = 0
for it in range(0, test.image.shape[0], batch):
x, y = test.image[it: it + batch], test.label[it: it + batch]
accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
accuracy /= test.image.shape[0]
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Epoch 0001 Loss 0.25 Accuracy 82.81
Epoch 0002 Loss 0.25 Accuracy 82.81
Epoch 0003 Loss 0.25 Accuracy 83.59
Epoch 0004 Loss 0.25 Accuracy 82.81
Epoch 0005 Loss 0.25 Accuracy 82.03
Epoch 0006 Loss 0.25 Accuracy 80.86
Epoch 0007 Loss 0.25 Accuracy 80.86
Epoch 0008 Loss 0.24 Accuracy 80.86
Epoch 0009 Loss 0.24 Accuracy 82.03
Epoch 0010 Loss 0.24 Accuracy 83.20
Epoch 0011 Loss 0.24 Accuracy 82.03
Epoch 0012 Loss 0.24 Accuracy 80.86
Epoch 0013 Loss 0.24 Accuracy 81.25
Epoch 0014 Loss 0.24 Accuracy 83.59
Epoch 0015 Loss 0.24 Accuracy 84.38
Epoch 0016 Loss 0.24 Accuracy 85.16
Epoch 0017 Loss 0.24 Accuracy 84.77
Epoch 0018 Loss 0.24 Accuracy 83.98
Epoch 0019 Loss 0.24 Accuracy 83.20
Epoch 0020 Loss 0.24 Accuracy 83.59
Model Inference After Training¶
Now that the network is trained we can retry classification example above:
[10]:
print(np.round(objax.functional.sigmoid(model(train.image[9]))))
[1.]
Next Steps¶
We saw how we can define, use, and train a very simple model in Objax to classify images of humans and horses. Next, we will learn how to create custom models to classify handwritten digits.
Creating Custom Networks for Multi-Class Classification¶
This tutorial demonstrates how to define, train, and use different models for multi-class classification. We will reuse most of the code from the Logistic Regression tutorial so if you haven’t gone through that, consider reviewing it first.
Note that this tutorial includes a demonstration on how to build and train a simple convolutional neural network and running this colab on CPU may take some time. Therefore, we recommend to run this colab on GPU (select GPU
on the menu Runtime
-> Change runtime type
-> Hardware accelerator
if hardware accelerator is not set to GPU).
Import Modules¶
We start by importing the modules we will use in our code.
[1]:
%pip --quiet install objax
import os
import numpy as np
import tensorflow_datasets as tfds
import objax
from objax.util import EasyDict
from objax.zoo.dnnet import DNNet
Load the data¶
Next, we will load the “MNIST” dataset from TensorFlow DataSets. This dataset contains handwritten digits (i.e., numbers between 0 and 9) and to correctly identify each handwritten digit.
The prepare
method pads 2 pixels to the left, right, top, and bottom of each image to resize into 32 x 32 pixes. While MNIST images are grayscale the prepare
method expands each image to three color channels to demonstrate the process of working with color images. The same method also rescales each pixel value to [-1, 1], and converts the image to (N, C, H, W) format.
[2]:
# Data: train has 60000 images - test has 10000 images
# Each image is resized and converted to 32 x 32 x 3
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
def prepare(x):
"""Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format."""
s = x.shape
x_pad = np.zeros((s[0], 32, 32, 1))
x_pad[:, 2:-2, 2:-2, :] = x
return objax.util.image.nchw(
np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))
train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data
Deep Neural Network Model¶
Objax offers many predefined models that we can use for classification. One example is the objax.zoo.DNNet
model comprising multiple fully connected layers with configurable size and activation functions.
[3]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
Custom Model Definition¶
Alternatively, we can define a new model customized to our machine learning task. We demonstrate this process by defining a convolutional network (ConvNet) from scratch.
We use objax.nn.Sequential
to compose multiple layers of convolution (objax.nn.Conv2D
), batch normalization (objax.nn.BatchNorm2D
), ReLU (objax.functional.relu
), Max Pooling (objax.functional.max_pool_2d
), Average Pooling (jax.mean
), and Linear (objax.nn.Linear
) layers.
Since batch normalization layer behaves differently at training and at prediction, we pass the training
flag to a __call__
function of ConvNet
class. We also use the flag to output logits at training and probability at prediction.
[4]:
class ConvNet(objax.Module):
"""ConvNet implementation."""
def __init__(self, nin, nclass):
"""Define 3 blocks of conv-bn-relu-conv-bn-relu followed by linear layer."""
self.conv_block1 = objax.nn.Sequential([objax.nn.Conv2D(nin, 16, 3, use_bias=False),
objax.nn.BatchNorm2D(16),
objax.functional.relu,
objax.nn.Conv2D(16, 16, 3, use_bias=False),
objax.nn.BatchNorm2D(16),
objax.functional.relu])
self.conv_block2 = objax.nn.Sequential([objax.nn.Conv2D(16, 32, 3, use_bias=False),
objax.nn.BatchNorm2D(32),
objax.functional.relu,
objax.nn.Conv2D(32, 32, 3, use_bias=False),
objax.nn.BatchNorm2D(32),
objax.functional.relu])
self.conv_block3 = objax.nn.Sequential([objax.nn.Conv2D(32, 64, 3, use_bias=False),
objax.nn.BatchNorm2D(64),
objax.functional.relu,
objax.nn.Conv2D(64, 64, 3, use_bias=False),
objax.nn.BatchNorm2D(64),
objax.functional.relu])
self.linear = objax.nn.Linear(64, nclass)
def __call__(self, x, training):
x = self.conv_block1(x, training=training)
x = objax.functional.max_pool_2d(x, size=2, strides=2)
x = self.conv_block2(x, training=training)
x = objax.functional.max_pool_2d(x, size=2, strides=2)
x = self.conv_block3(x, training=training)
x = x.mean((2, 3))
x = self.linear(x)
return x
[5]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
(ConvNet).conv_block1(Sequential)[0](Conv2D).w 432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w 2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w 4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w 9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w 18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w 36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b 10 (10,)
(ConvNet).linear(Linear).w 640 (64, 10)
+Total(32) 73402
Model Training and Evaluation¶
The train_model
method combines all the parts of defining the loss function, gradient descent, training loop, and evaluation. It takes the model
as a parameter so it can be reused with the two models we defined earlier.
Unlike the Logistic Regression tutorial we use the objax.functional.loss.cross_entropy_logits_sparse
because we perform multi-class classification. The optimizer, gradient descent operation, and training loop remain the same.
The DNNet
model expects flattened images whereas ConvNet
images in (C, H, W) format. The flatten_image
method prepares images before passing them to the model.
When using the model for inference we apply the objax.functional.softmax
method to compute the probability distribution from the model’s logits.
[6]:
# Settings
lr = 0.03 # learning rate
batch = 128
epochs = 100
# Train loop
def train_model(model):
def predict(model, x):
""""""
return objax.functional.softmax(model(x, training=False))
def flatten_image(x):
"""Flatten the image before passing it to the DNN."""
if isinstance(model, DNNet):
return objax.functional.flatten(x)
else:
return x
opt = objax.optimizer.Momentum(model.vars())
# Cross Entropy Loss
def loss(x, label):
return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()
gv = objax.GradValues(loss, model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
for epoch in range(epochs):
avg_loss = 0
# randomly shuffle training data
shuffle_idx = np.random.permutation(train.image.shape[0])
for it in range(0, train.image.shape[0], batch):
sel = shuffle_idx[it: it + batch]
avg_loss += float(train_op(flatten_image(train.image[sel]), train.label[sel])[0]) * len(sel)
avg_loss /= it + len(sel)
# Eval
accuracy = 0
for it in range(0, test.image.shape[0], batch):
x, y = test.image[it: it + batch], test.label[it: it + batch]
accuracy += (np.argmax(predict(model, flatten_image(x)), axis=1) == y).sum()
accuracy /= test.image.shape[0]
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Training the DNN Model¶
[7]:
train_model(dnn_model)
Epoch 0001 Loss 2.39 Accuracy 56.31
Epoch 0002 Loss 1.24 Accuracy 74.19
Epoch 0003 Loss 0.75 Accuracy 84.91
Epoch 0004 Loss 0.56 Accuracy 83.10
Epoch 0005 Loss 0.48 Accuracy 86.42
Epoch 0006 Loss 0.43 Accuracy 89.00
Epoch 0007 Loss 0.41 Accuracy 89.62
Epoch 0008 Loss 0.39 Accuracy 89.82
Epoch 0009 Loss 0.37 Accuracy 90.04
Epoch 0010 Loss 0.36 Accuracy 89.55
Epoch 0011 Loss 0.35 Accuracy 90.53
Epoch 0012 Loss 0.35 Accuracy 90.64
Epoch 0013 Loss 0.34 Accuracy 90.85
Epoch 0014 Loss 0.33 Accuracy 90.87
Epoch 0015 Loss 0.33 Accuracy 91.02
Epoch 0016 Loss 0.32 Accuracy 91.35
Epoch 0017 Loss 0.32 Accuracy 91.35
Epoch 0018 Loss 0.31 Accuracy 91.50
Epoch 0019 Loss 0.31 Accuracy 91.57
Epoch 0020 Loss 0.31 Accuracy 91.75
Epoch 0021 Loss 0.30 Accuracy 91.47
Epoch 0022 Loss 0.30 Accuracy 88.14
Epoch 0023 Loss 0.30 Accuracy 91.82
Epoch 0024 Loss 0.30 Accuracy 91.92
Epoch 0025 Loss 0.29 Accuracy 92.03
Epoch 0026 Loss 0.29 Accuracy 92.04
Epoch 0027 Loss 0.29 Accuracy 92.11
Epoch 0028 Loss 0.29 Accuracy 92.11
Epoch 0029 Loss 0.29 Accuracy 92.18
Epoch 0030 Loss 0.28 Accuracy 92.24
Epoch 0031 Loss 0.28 Accuracy 92.36
Epoch 0032 Loss 0.28 Accuracy 92.17
Epoch 0033 Loss 0.28 Accuracy 92.42
Epoch 0034 Loss 0.28 Accuracy 92.42
Epoch 0035 Loss 0.27 Accuracy 92.47
Epoch 0036 Loss 0.27 Accuracy 92.50
Epoch 0037 Loss 0.27 Accuracy 92.49
Epoch 0038 Loss 0.27 Accuracy 92.58
Epoch 0039 Loss 0.26 Accuracy 92.56
Epoch 0040 Loss 0.26 Accuracy 92.56
Epoch 0041 Loss 0.26 Accuracy 92.77
Epoch 0042 Loss 0.26 Accuracy 92.72
Epoch 0043 Loss 0.26 Accuracy 92.80
Epoch 0044 Loss 0.25 Accuracy 92.85
Epoch 0045 Loss 0.25 Accuracy 92.90
Epoch 0046 Loss 0.25 Accuracy 92.96
Epoch 0047 Loss 0.25 Accuracy 93.00
Epoch 0048 Loss 0.25 Accuracy 92.82
Epoch 0049 Loss 0.25 Accuracy 93.18
Epoch 0050 Loss 0.24 Accuracy 93.09
Epoch 0051 Loss 0.24 Accuracy 92.94
Epoch 0052 Loss 0.24 Accuracy 93.20
Epoch 0053 Loss 0.24 Accuracy 93.26
Epoch 0054 Loss 0.23 Accuracy 93.21
Epoch 0055 Loss 0.24 Accuracy 93.42
Epoch 0056 Loss 0.23 Accuracy 93.35
Epoch 0057 Loss 0.23 Accuracy 93.36
Epoch 0058 Loss 0.23 Accuracy 93.56
Epoch 0059 Loss 0.23 Accuracy 93.54
Epoch 0060 Loss 0.22 Accuracy 93.39
Epoch 0061 Loss 0.23 Accuracy 93.56
Epoch 0062 Loss 0.22 Accuracy 93.74
Epoch 0063 Loss 0.22 Accuracy 93.68
Epoch 0064 Loss 0.22 Accuracy 93.72
Epoch 0065 Loss 0.22 Accuracy 93.76
Epoch 0066 Loss 0.22 Accuracy 93.87
Epoch 0067 Loss 0.21 Accuracy 93.89
Epoch 0068 Loss 0.21 Accuracy 93.96
Epoch 0069 Loss 0.21 Accuracy 93.90
Epoch 0070 Loss 0.21 Accuracy 93.99
Epoch 0071 Loss 0.21 Accuracy 94.02
Epoch 0072 Loss 0.21 Accuracy 93.86
Epoch 0073 Loss 0.21 Accuracy 94.06
Epoch 0074 Loss 0.21 Accuracy 94.14
Epoch 0075 Loss 0.20 Accuracy 94.31
Epoch 0076 Loss 0.20 Accuracy 94.14
Epoch 0077 Loss 0.20 Accuracy 94.15
Epoch 0078 Loss 0.20 Accuracy 94.10
Epoch 0079 Loss 0.20 Accuracy 94.16
Epoch 0080 Loss 0.20 Accuracy 94.28
Epoch 0081 Loss 0.20 Accuracy 94.30
Epoch 0082 Loss 0.20 Accuracy 94.28
Epoch 0083 Loss 0.19 Accuracy 94.37
Epoch 0084 Loss 0.19 Accuracy 94.33
Epoch 0085 Loss 0.19 Accuracy 94.31
Epoch 0086 Loss 0.19 Accuracy 94.25
Epoch 0087 Loss 0.19 Accuracy 94.37
Epoch 0088 Loss 0.19 Accuracy 94.38
Epoch 0089 Loss 0.19 Accuracy 94.35
Epoch 0090 Loss 0.19 Accuracy 94.38
Epoch 0091 Loss 0.19 Accuracy 94.41
Epoch 0092 Loss 0.19 Accuracy 94.46
Epoch 0093 Loss 0.19 Accuracy 94.53
Epoch 0094 Loss 0.18 Accuracy 94.47
Epoch 0095 Loss 0.18 Accuracy 94.54
Epoch 0096 Loss 0.18 Accuracy 94.65
Epoch 0097 Loss 0.18 Accuracy 94.56
Epoch 0098 Loss 0.18 Accuracy 94.60
Epoch 0099 Loss 0.18 Accuracy 94.63
Epoch 0100 Loss 0.18 Accuracy 94.46
Training the ConvNet Model¶
[8]:
train_model(cnn_model)
Epoch 0001 Loss 0.27 Accuracy 27.08
Epoch 0002 Loss 0.05 Accuracy 41.07
Epoch 0003 Loss 0.03 Accuracy 67.77
Epoch 0004 Loss 0.03 Accuracy 73.31
Epoch 0005 Loss 0.02 Accuracy 90.30
Epoch 0006 Loss 0.02 Accuracy 93.10
Epoch 0007 Loss 0.02 Accuracy 95.98
Epoch 0008 Loss 0.01 Accuracy 98.77
Epoch 0009 Loss 0.01 Accuracy 96.58
Epoch 0010 Loss 0.01 Accuracy 99.12
Epoch 0011 Loss 0.01 Accuracy 98.88
Epoch 0012 Loss 0.01 Accuracy 98.64
Epoch 0013 Loss 0.01 Accuracy 98.66
Epoch 0014 Loss 0.00 Accuracy 98.38
Epoch 0015 Loss 0.00 Accuracy 99.15
Epoch 0016 Loss 0.00 Accuracy 97.50
Epoch 0017 Loss 0.00 Accuracy 98.98
Epoch 0018 Loss 0.00 Accuracy 98.94
Epoch 0019 Loss 0.00 Accuracy 98.56
Epoch 0020 Loss 0.00 Accuracy 99.06
Epoch 0021 Loss 0.00 Accuracy 99.26
Epoch 0022 Loss 0.00 Accuracy 99.30
Epoch 0023 Loss 0.00 Accuracy 99.18
Epoch 0024 Loss 0.00 Accuracy 99.49
Epoch 0025 Loss 0.00 Accuracy 99.34
Epoch 0026 Loss 0.00 Accuracy 99.24
Epoch 0027 Loss 0.00 Accuracy 99.38
Epoch 0028 Loss 0.00 Accuracy 99.43
Epoch 0029 Loss 0.00 Accuracy 99.40
Epoch 0030 Loss 0.00 Accuracy 99.50
Epoch 0031 Loss 0.00 Accuracy 99.44
Epoch 0032 Loss 0.00 Accuracy 99.52
Epoch 0033 Loss 0.00 Accuracy 99.46
Epoch 0034 Loss 0.00 Accuracy 99.39
Epoch 0035 Loss 0.00 Accuracy 99.22
Epoch 0036 Loss 0.00 Accuracy 99.26
Epoch 0037 Loss 0.00 Accuracy 99.47
Epoch 0038 Loss 0.00 Accuracy 99.18
Epoch 0039 Loss 0.00 Accuracy 99.39
Epoch 0040 Loss 0.00 Accuracy 99.44
Epoch 0041 Loss 0.00 Accuracy 99.43
Epoch 0042 Loss 0.00 Accuracy 99.50
Epoch 0043 Loss 0.00 Accuracy 99.50
Epoch 0044 Loss 0.00 Accuracy 99.53
Epoch 0045 Loss 0.00 Accuracy 99.51
Epoch 0046 Loss 0.00 Accuracy 99.49
Epoch 0047 Loss 0.00 Accuracy 99.46
Epoch 0048 Loss 0.00 Accuracy 99.46
Epoch 0049 Loss 0.00 Accuracy 99.35
Epoch 0050 Loss 0.00 Accuracy 99.50
Epoch 0051 Loss 0.00 Accuracy 99.48
Epoch 0052 Loss 0.00 Accuracy 99.48
Epoch 0053 Loss 0.00 Accuracy 99.48
Epoch 0054 Loss 0.00 Accuracy 99.46
Epoch 0055 Loss 0.00 Accuracy 99.48
Epoch 0056 Loss 0.00 Accuracy 99.50
Epoch 0057 Loss 0.00 Accuracy 99.41
Epoch 0058 Loss 0.00 Accuracy 99.49
Epoch 0059 Loss 0.00 Accuracy 99.48
Epoch 0060 Loss 0.00 Accuracy 99.47
Epoch 0061 Loss 0.00 Accuracy 99.52
Epoch 0062 Loss 0.00 Accuracy 99.49
Epoch 0063 Loss 0.00 Accuracy 99.48
Epoch 0064 Loss 0.00 Accuracy 99.51
Epoch 0065 Loss 0.00 Accuracy 99.46
Epoch 0066 Loss 0.00 Accuracy 99.51
Epoch 0067 Loss 0.00 Accuracy 99.49
Epoch 0068 Loss 0.00 Accuracy 99.52
Epoch 0069 Loss 0.00 Accuracy 99.49
Epoch 0070 Loss 0.00 Accuracy 99.51
Epoch 0071 Loss 0.00 Accuracy 99.51
Epoch 0072 Loss 0.00 Accuracy 99.52
Epoch 0073 Loss 0.00 Accuracy 99.43
Epoch 0074 Loss 0.00 Accuracy 99.53
Epoch 0075 Loss 0.00 Accuracy 99.47
Epoch 0076 Loss 0.00 Accuracy 99.51
Epoch 0077 Loss 0.00 Accuracy 99.55
Epoch 0078 Loss 0.00 Accuracy 99.52
Epoch 0079 Loss 0.00 Accuracy 99.52
Epoch 0080 Loss 0.00 Accuracy 98.78
Epoch 0081 Loss 0.00 Accuracy 99.16
Epoch 0082 Loss 0.00 Accuracy 99.40
Epoch 0083 Loss 0.00 Accuracy 99.35
Epoch 0084 Loss 0.00 Accuracy 99.32
Epoch 0085 Loss 0.00 Accuracy 99.49
Epoch 0086 Loss 0.00 Accuracy 99.49
Epoch 0087 Loss 0.00 Accuracy 99.56
Epoch 0088 Loss 0.00 Accuracy 99.48
Epoch 0089 Loss 0.00 Accuracy 99.48
Epoch 0090 Loss 0.00 Accuracy 99.51
Epoch 0091 Loss 0.00 Accuracy 99.45
Epoch 0092 Loss 0.00 Accuracy 99.52
Epoch 0093 Loss 0.00 Accuracy 99.52
Epoch 0094 Loss 0.00 Accuracy 99.51
Epoch 0095 Loss 0.00 Accuracy 99.51
Epoch 0096 Loss 0.00 Accuracy 99.48
Epoch 0097 Loss 0.00 Accuracy 99.51
Epoch 0098 Loss 0.00 Accuracy 99.53
Epoch 0099 Loss 0.00 Accuracy 99.50
Epoch 0100 Loss 0.00 Accuracy 99.53
Training with PyTorch data processing API¶
One of the pain points for ML researchers/practioners when building a new ML model is the data processing. Here, we demonstrate how to use data processing API of PyTorch to train a model with Objax. Different deep learning library comes with different data processing APIs, and depending on your preference, you can choose an API and easily combine with Objax.
Similarly, we prepare an MNIST
dataset and apply the same data preprocessing.
[9]:
import torch
from torchvision import datasets, transforms
transform=transforms.Compose([
transforms.Pad((2,2,2,2), 0),
transforms.ToTensor(),
transforms.Lambda(lambda x: np.concatenate([x] * 3, axis=0)),
transforms.Lambda(lambda x: x * 2 - 1)
])
train_dataset = datasets.MNIST(os.environ['HOME'], train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(os.environ['HOME'], train=False, download=True, transform=transform)
# Define data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/train-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/train-labels-idx1-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
We replace data processing pipeline of the train and test loop with train_loader
and test_loader
and that’s it!
[10]:
# Train loop
def train_model_with_torch_data_api(model):
def predict(model, x):
""""""
return objax.functional.softmax(model(x, training=False))
def flatten_image(x):
"""Flatten the image before passing it to the DNN."""
if isinstance(model, DNNet):
return objax.functional.flatten(x)
else:
return x
opt = objax.optimizer.Momentum(model.vars())
# Cross Entropy Loss
def loss(x, label):
return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()
gv = objax.GradValues(loss, model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
for epoch in range(epochs):
avg_loss = 0
tot_data = 0
for _, (img, label) in enumerate(train_loader):
avg_loss += float(train_op(flatten_image(img.numpy()), label.numpy())[0]) * len(img)
tot_data += len(img)
avg_loss /= tot_data
# Eval
accuracy = 0
tot_data = 0
for _, (img, label) in enumerate(test_loader):
accuracy += (np.argmax(predict(model, flatten_image(img.numpy())), axis=1) == label.numpy()).sum()
tot_data += len(img)
accuracy /= tot_data
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Training the DNN Model with PyTorch data API¶
[11]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
train_model_with_torch_data_api(dnn_model)
Epoch 0001 Loss 2.57 Accuracy 34.35
Epoch 0002 Loss 1.93 Accuracy 58.51
Epoch 0003 Loss 1.32 Accuracy 68.46
Epoch 0004 Loss 0.83 Accuracy 80.95
Epoch 0005 Loss 0.62 Accuracy 84.74
Epoch 0006 Loss 0.53 Accuracy 86.53
Epoch 0007 Loss 0.48 Accuracy 84.18
Epoch 0008 Loss 0.45 Accuracy 88.42
Epoch 0009 Loss 0.42 Accuracy 87.34
Epoch 0010 Loss 0.40 Accuracy 89.29
Epoch 0011 Loss 0.39 Accuracy 89.31
Epoch 0012 Loss 0.38 Accuracy 89.86
Epoch 0013 Loss 0.37 Accuracy 89.91
Epoch 0014 Loss 0.36 Accuracy 86.94
Epoch 0015 Loss 0.36 Accuracy 89.89
Epoch 0016 Loss 0.35 Accuracy 90.12
Epoch 0017 Loss 0.34 Accuracy 90.40
Epoch 0018 Loss 0.34 Accuracy 90.31
Epoch 0019 Loss 0.34 Accuracy 90.79
Epoch 0020 Loss 0.33 Accuracy 90.71
Epoch 0021 Loss 0.33 Accuracy 90.70
Epoch 0022 Loss 0.33 Accuracy 90.69
Epoch 0023 Loss 0.33 Accuracy 90.91
Epoch 0024 Loss 0.32 Accuracy 90.92
Epoch 0025 Loss 0.32 Accuracy 91.06
Epoch 0026 Loss 0.32 Accuracy 91.19
Epoch 0027 Loss 0.32 Accuracy 91.31
Epoch 0028 Loss 0.31 Accuracy 91.31
Epoch 0029 Loss 0.31 Accuracy 91.20
Epoch 0030 Loss 0.31 Accuracy 91.31
Epoch 0031 Loss 0.31 Accuracy 91.36
Epoch 0032 Loss 0.31 Accuracy 91.42
Epoch 0033 Loss 0.30 Accuracy 91.27
Epoch 0034 Loss 0.31 Accuracy 91.47
Epoch 0035 Loss 0.30 Accuracy 91.57
Epoch 0036 Loss 0.30 Accuracy 91.44
Epoch 0037 Loss 0.30 Accuracy 91.55
Epoch 0038 Loss 0.30 Accuracy 91.56
Epoch 0039 Loss 0.29 Accuracy 91.75
Epoch 0040 Loss 0.29 Accuracy 91.69
Epoch 0041 Loss 0.29 Accuracy 91.60
Epoch 0042 Loss 0.29 Accuracy 91.77
Epoch 0043 Loss 0.29 Accuracy 91.76
Epoch 0044 Loss 0.29 Accuracy 91.84
Epoch 0045 Loss 0.28 Accuracy 92.05
Epoch 0046 Loss 0.28 Accuracy 91.78
Epoch 0047 Loss 0.28 Accuracy 92.01
Epoch 0048 Loss 0.28 Accuracy 91.95
Epoch 0049 Loss 0.28 Accuracy 90.11
Epoch 0050 Loss 0.28 Accuracy 92.14
Epoch 0051 Loss 0.28 Accuracy 92.03
Epoch 0052 Loss 0.27 Accuracy 92.29
Epoch 0053 Loss 0.27 Accuracy 92.17
Epoch 0054 Loss 0.27 Accuracy 92.12
Epoch 0055 Loss 0.27 Accuracy 92.34
Epoch 0056 Loss 0.27 Accuracy 92.32
Epoch 0057 Loss 0.27 Accuracy 92.47
Epoch 0058 Loss 0.27 Accuracy 92.38
Epoch 0059 Loss 0.27 Accuracy 92.39
Epoch 0060 Loss 0.26 Accuracy 92.51
Epoch 0061 Loss 0.27 Accuracy 92.50
Epoch 0062 Loss 0.26 Accuracy 92.46
Epoch 0063 Loss 0.26 Accuracy 92.65
Epoch 0064 Loss 0.26 Accuracy 92.57
Epoch 0065 Loss 0.26 Accuracy 92.63
Epoch 0066 Loss 0.26 Accuracy 92.75
Epoch 0067 Loss 0.26 Accuracy 92.57
Epoch 0068 Loss 0.26 Accuracy 92.88
Epoch 0069 Loss 0.25 Accuracy 92.53
Epoch 0070 Loss 0.25 Accuracy 92.80
Epoch 0071 Loss 0.25 Accuracy 92.71
Epoch 0072 Loss 0.25 Accuracy 92.75
Epoch 0073 Loss 0.25 Accuracy 92.84
Epoch 0074 Loss 0.25 Accuracy 92.71
Epoch 0075 Loss 0.25 Accuracy 92.95
Epoch 0076 Loss 0.25 Accuracy 92.82
Epoch 0077 Loss 0.25 Accuracy 92.90
Epoch 0078 Loss 0.25 Accuracy 92.87
Epoch 0079 Loss 0.25 Accuracy 89.55
Epoch 0080 Loss 0.25 Accuracy 92.86
Epoch 0081 Loss 0.24 Accuracy 92.99
Epoch 0082 Loss 0.24 Accuracy 93.03
Epoch 0083 Loss 0.24 Accuracy 93.03
Epoch 0084 Loss 0.24 Accuracy 93.01
Epoch 0085 Loss 0.24 Accuracy 93.13
Epoch 0086 Loss 0.24 Accuracy 93.17
Epoch 0087 Loss 0.24 Accuracy 92.87
Epoch 0088 Loss 0.24 Accuracy 92.93
Epoch 0089 Loss 0.24 Accuracy 93.16
Epoch 0090 Loss 0.24 Accuracy 93.38
Epoch 0091 Loss 0.24 Accuracy 92.98
Epoch 0092 Loss 0.24 Accuracy 93.30
Epoch 0093 Loss 0.23 Accuracy 93.09
Epoch 0094 Loss 0.23 Accuracy 93.19
Epoch 0095 Loss 0.23 Accuracy 93.25
Epoch 0096 Loss 0.23 Accuracy 93.22
Epoch 0097 Loss 0.23 Accuracy 93.28
Epoch 0098 Loss 0.23 Accuracy 93.39
Epoch 0099 Loss 0.23 Accuracy 93.25
Epoch 0100 Loss 0.23 Accuracy 93.30
Training the ConvNet Model with PyTorch data API¶
[12]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
train_model_with_torch_data_api(cnn_model)
(ConvNet).conv_block1(Sequential)[0](Conv2D).w 432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w 2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w 4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w 9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w 18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w 36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b 10 (10,)
(ConvNet).linear(Linear).w 640 (64, 10)
+Total(32) 73402
Epoch 0001 Loss 0.26 Accuracy 24.18
Epoch 0002 Loss 0.05 Accuracy 37.53
Epoch 0003 Loss 0.03 Accuracy 42.17
Epoch 0004 Loss 0.03 Accuracy 73.50
Epoch 0005 Loss 0.02 Accuracy 80.33
Epoch 0006 Loss 0.02 Accuracy 83.28
Epoch 0007 Loss 0.02 Accuracy 90.87
Epoch 0008 Loss 0.01 Accuracy 98.77
Epoch 0009 Loss 0.01 Accuracy 98.42
Epoch 0010 Loss 0.01 Accuracy 98.16
Epoch 0011 Loss 0.01 Accuracy 98.74
Epoch 0012 Loss 0.01 Accuracy 95.05
Epoch 0013 Loss 0.01 Accuracy 98.89
Epoch 0014 Loss 0.00 Accuracy 98.70
Epoch 0015 Loss 0.00 Accuracy 99.01
Epoch 0016 Loss 0.00 Accuracy 98.97
Epoch 0017 Loss 0.00 Accuracy 98.79
Epoch 0018 Loss 0.00 Accuracy 98.37
Epoch 0019 Loss 0.00 Accuracy 99.19
Epoch 0020 Loss 0.00 Accuracy 99.22
Epoch 0021 Loss 0.00 Accuracy 98.43
Epoch 0022 Loss 0.00 Accuracy 99.02
Epoch 0023 Loss 0.00 Accuracy 99.38
Epoch 0024 Loss 0.00 Accuracy 99.42
Epoch 0025 Loss 0.00 Accuracy 99.45
Epoch 0026 Loss 0.00 Accuracy 99.35
Epoch 0027 Loss 0.00 Accuracy 99.42
Epoch 0028 Loss 0.00 Accuracy 99.42
Epoch 0029 Loss 0.00 Accuracy 99.14
Epoch 0030 Loss 0.00 Accuracy 99.33
Epoch 0031 Loss 0.00 Accuracy 99.36
Epoch 0032 Loss 0.00 Accuracy 99.18
Epoch 0033 Loss 0.00 Accuracy 99.43
Epoch 0034 Loss 0.00 Accuracy 99.47
Epoch 0035 Loss 0.00 Accuracy 99.49
Epoch 0036 Loss 0.00 Accuracy 99.53
Epoch 0037 Loss 0.00 Accuracy 99.38
Epoch 0038 Loss 0.00 Accuracy 99.39
Epoch 0039 Loss 0.00 Accuracy 99.49
Epoch 0040 Loss 0.00 Accuracy 99.49
Epoch 0041 Loss 0.00 Accuracy 99.47
Epoch 0042 Loss 0.00 Accuracy 99.54
Epoch 0043 Loss 0.00 Accuracy 99.35
Epoch 0044 Loss 0.00 Accuracy 99.45
Epoch 0045 Loss 0.00 Accuracy 99.47
Epoch 0046 Loss 0.00 Accuracy 99.53
Epoch 0047 Loss 0.00 Accuracy 99.50
Epoch 0048 Loss 0.00 Accuracy 99.52
Epoch 0049 Loss 0.00 Accuracy 99.51
Epoch 0050 Loss 0.00 Accuracy 99.49
Epoch 0051 Loss 0.00 Accuracy 99.45
Epoch 0052 Loss 0.00 Accuracy 99.48
Epoch 0053 Loss 0.00 Accuracy 99.50
Epoch 0054 Loss 0.00 Accuracy 99.46
Epoch 0055 Loss 0.00 Accuracy 99.50
Epoch 0056 Loss 0.00 Accuracy 99.48
Epoch 0057 Loss 0.00 Accuracy 99.46
Epoch 0058 Loss 0.00 Accuracy 99.44
Epoch 0059 Loss 0.00 Accuracy 99.46
Epoch 0060 Loss 0.00 Accuracy 99.26
Epoch 0061 Loss 0.00 Accuracy 93.99
Epoch 0062 Loss 0.00 Accuracy 97.80
Epoch 0063 Loss 0.00 Accuracy 80.26
Epoch 0064 Loss 0.00 Accuracy 99.20
Epoch 0065 Loss 0.00 Accuracy 99.38
Epoch 0066 Loss 0.00 Accuracy 99.44
Epoch 0067 Loss 0.00 Accuracy 99.51
Epoch 0068 Loss 0.00 Accuracy 99.45
Epoch 0069 Loss 0.00 Accuracy 99.42
Epoch 0070 Loss 0.00 Accuracy 99.50
Epoch 0071 Loss 0.00 Accuracy 99.52
Epoch 0072 Loss 0.00 Accuracy 99.44
Epoch 0073 Loss 0.00 Accuracy 99.41
Epoch 0074 Loss 0.00 Accuracy 99.46
Epoch 0075 Loss 0.00 Accuracy 99.42
Epoch 0076 Loss 0.00 Accuracy 99.49
Epoch 0077 Loss 0.00 Accuracy 99.50
Epoch 0078 Loss 0.00 Accuracy 99.56
Epoch 0079 Loss 0.00 Accuracy 99.52
Epoch 0080 Loss 0.00 Accuracy 99.42
Epoch 0081 Loss 0.00 Accuracy 99.49
Epoch 0082 Loss 0.00 Accuracy 99.48
Epoch 0083 Loss 0.00 Accuracy 99.44
Epoch 0084 Loss 0.00 Accuracy 99.49
Epoch 0085 Loss 0.00 Accuracy 99.53
Epoch 0086 Loss 0.00 Accuracy 99.52
Epoch 0087 Loss 0.00 Accuracy 99.52
Epoch 0088 Loss 0.00 Accuracy 99.50
Epoch 0089 Loss 0.00 Accuracy 99.51
Epoch 0090 Loss 0.00 Accuracy 99.50
Epoch 0091 Loss 0.00 Accuracy 99.49
Epoch 0092 Loss 0.00 Accuracy 99.52
Epoch 0093 Loss 0.00 Accuracy 99.50
Epoch 0094 Loss 0.00 Accuracy 99.50
Epoch 0095 Loss 0.00 Accuracy 99.55
Epoch 0096 Loss 0.00 Accuracy 99.48
Epoch 0097 Loss 0.00 Accuracy 99.51
Epoch 0098 Loss 0.00 Accuracy 99.52
Epoch 0099 Loss 0.00 Accuracy 99.49
Epoch 0100 Loss 0.00 Accuracy 99.52
What’s Next¶
We have learned how to use existing models and define new models to classify MNIST. Next, you can read one or more of the in-depth topics or browse through the Objax’s APIs.
Code Examples¶
This section describes the code examples found in objax/examples
Classification¶
Image¶
Example code available at examples/classify
.
examples/classify/img/logistic.py
¶Train and evaluate a logistic regression model for binary classification on horses or humans dataset.
# Run command
python3 examples/classify/img/logistic.py
Data |
horses_or_humans from tensorflow_datasets |
Network |
Custom single layer |
Loss |
|
Optimizer |
|
Accuracy |
~77% |
Hardware |
CPU or GPU or TPU |
examples/classify/img/mnist_dnn.py
¶Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.
# Run command
python3 examples/classify/img/mnist_dnn.py
Data |
MNIST from tensorflow_datasets |
Network |
Deep Neural Net |
Loss |
|
Optimizer |
|
Accuracy |
~98% |
Hardware |
CPU or GPU or TPU |
Techniques |
Model weight averaging for improved accuracy using
|
examples/classify/img/mnist_cnn.py
¶Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.
# Run command
python3 examples/classify/img/mnist_cnn.py
Data |
MNIST from tensorflow_datasets |
Network |
Custom Convolution Neural Net using |
Loss |
|
Optimizer |
|
Accuracy |
~99.5% |
Hardware |
CPU or GPU or TPU |
Techniques |
|
examples/classify/img/mnist_dp.py
¶Train and evaluate a convNet model for MNIST dataset with differential privacy.
# Run command
python3 examples/classify/img/mnist_dp.py
# See available options with
python3 examples/classify/img/mnist_dp.py --help
Data |
MNIST from tensorflow_datasets |
Network |
Custom Convolution Neural Net using |
Loss |
|
Optimizer |
|
Accuracy |
|
Hardware |
GPU |
Techniques |
|
examples/classify/img/cifar10_simple.py
¶Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.
# Run command
python3 examples/classify/img/cifar10_simple.py
Data |
CIFAR10 from tf.keras.datasets |
Network |
Wide ResNet using |
Loss |
|
Optimizer |
|
Accuracy |
~91% |
Hardware |
GPU or TPU |
Techniques |
|
examples/classify/img/cifar10_advanced.py
¶Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.
# Run command
python3 examples/classify/img/cifar10_advanced.py
# Run with custom settings
python3 examples/classify/img/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/classify/img/cifar10_advanced.py --help
Data |
|
Network |
Configurable with |
Loss |
|
Optimizer |
|
Accuracy |
~94% |
Hardware |
GPU, Multi-GPU or TPU |
Techniques |
|
examples/classify/img/imagenet/imagenet_train.py
¶Train and evaluate a ResNet50 model on the ImageNet dataset.
See examples/classify/img/imagenet/README.md
for additional information.
Data |
|
Network |
|
Loss |
|
Optimizer |
|
Accuracy |
|
Hardware |
GPU, Multi-GPU or TPU |
Techniques |
|
Semi-Supervised Learning¶
Example code available at examples/semi_supervised
.
examples/semi_supervised/img/fixmatch.py
¶Semi-supervised learning of image classification models with FixMatch.
# Run command
python3 examples/classify/semi_supervised/img/fixmatch.py
# Run with custom settings
python3 examples/classify/semi_supervised/img/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/classify/semi_supervised/img/fixmatch.py --help
Data |
|
Network |
Custom implementation of Wide ResNet. |
Loss |
|
Optimizer |
|
Accuracy |
See paper |
Hardware |
GPU, Multi-GPU, TPU |
Techniques |
|
GPT-2¶
Example code available at examples/gpt-2
.
examples/gpt-2/gpt2.py
¶
Load pretrained GPT2
model (124M parameter) and demonstrate how to use the model to generate a text sequence.
See examples/gpt-2/README.md
for additional information.
Hardware |
GPU or TPU |
Techniques |
|
RNN¶
Example code is available at examples/rnn
.
examples/rnn/shakespeare.py
¶
Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset.
See examples/rnn/README.md
for additional information.
# Run command
python3 examples/rnn/shakespeare.py
Data |
|
Network |
Custom implementation of vanilla RNN. |
Loss |
|
Optimizer |
|
Hardware |
GPU or TPU |
Techniques |
|
Optimization¶
Example codes available at examples/optimization
.
Objax API¶
objax package¶
Modules¶
|
A module is a container to associate variables and functions. |
|
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it. |
|
The GradValues module is used to compute the gradients of a function. |
|
JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution. |
|
Parallel module takes a function or a module and compiles it for running on multiple devices in parallel. |
|
Vectorize module takes a function or a module and compiles it for running in parallel on a single device. |
-
class
objax.
Module
[source]¶ A module is a container to associate variables and functions.
-
vars
(scope='')[source]¶ Collect all the variables (and their names) contained in the module and its submodules. Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList if you need such a feature.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
ModuleList
(iterable=(), /)[source]¶ Bases:
objax.module.Module
,list
This is a replacement for Python’s list that provides a vars() method to return all the variables that it contains, including the ones contained in the modules and sub-modules in it.
Usage example:
import objax ml = objax.ModuleList(['hello', objax.TrainVar(objax.random.normal((10,2)))]) print(ml.vars()) # (ModuleList)[1] 20 (10, 2) # +Total(1) 20 ml.pop() ml.append(objax.nn.Linear(2, 3)) print(ml.vars()) # (ModuleList)[1](Linear).b 3 (3,) # (ModuleList)[1](Linear).w 6 (2, 3) # +Total(2) 9
-
class
objax.
GradValues
(f, variables, input_argnums=None)[source]¶ The GradValues module is used to compute the gradients of a function.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) def f(x, y): return ((m(x) - y) ** 2).mean() # Create module to compute gradients of f for m.vars() grad_val_f = objax.GradValues(f, m.vars()) # Create module to compute gradients of f for input 0 (x) and m.vars() grad_val_fx = objax.GradValues(f, m.vars(), input_argnums=(0,))
For more information and examples, see Understanding Gradients.
-
__init__
(f, variables, input_argnums=None)[source]¶ Constructs an instance to compute the gradient of f w.r.t. variables.
- Parameters
f (Callable) – the function for which to compute gradients.
variables (Optional[objax.variable.VarCollection]) – the variables for which to compute gradients.
input_argnums (Optional[Tuple[int, ..]]) – input indexes, if any, on which to compute gradients.
-
-
class
objax.
Jit
(f, vc=None, static_argnums=None)[source]¶ JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) jit_m = objax.Jit(m) # Jit a module jit_f = objax.Jit(lambda x: m(x), m.vars()) # Jit a function: provide vars it uses
For more information, refer to JIT Compilation.
-
__init__
(f, vc=None, static_argnums=None)[source]¶ Jit constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is equired for functions.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the VarCollection.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Parallel
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel module takes a function or a module and compiles it for running on multiple devices in parallel.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) para_m = objax.Parallel(m) # Parallelize a module para_f = objax.Parallel(lambda x: m(x), m.vars()) # Parallelize a function: provide vars it uses
When calling a parallelized module, one must replicate the variables it uses on all devices:
x = objax.random.normal((16, 2)) with m.vars().replicate(): y = para_m(x)
For more information, refer to Parallelism.
-
__init__
(f, vc=None, reduce=<function concatenate>, axis_name='device', static_argnums=None)[source]¶ Parallel constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for parallelism.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
reduce (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the function used reduce the outputs from many devices to a single device value.
axis_name (str) – what name to give to the device dimension, used in conjunction with objax.functional.parallel.
static_argnums (Optional[Tuple[int, ..]]) – tuple of indexes of f’s input arguments to treat as static (constants)). A new graph is compiled for each different combination of values for such inputs.
-
__call__
(*args)[source]¶ Call the compiled function or module on multiple devices in parallel. Important: Make sure you call this function within the scope of VarCollection.replicate() statement.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the VarCollection.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.
Vectorize
(f, vc=None, batch_axis=0)[source]¶ Vectorize module takes a function or a module and compiles it for running in parallel on a single device.
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vec_m = objax.Vectorize(m) # Vectorize a module vec_f = objax.Vectorize(lambda x: m(x), m.vars()) # Vectorize a function: provide vars it uses
For more information and examples, refer to Vectorization.
-
__init__
(f, vc=None, batch_axis=0)[source]¶ Vectorize constructor.
- Parameters
f (Union[objax.module.Module, Callable]) – the function or the module to compile for vectorization.
vc (Optional[objax.variable.VarCollection]) – the VarCollection of variables used by the function or module. This argument is required for functions.
batch_axis (Tuple[Optional[int], ..]) – tuple of int or None for each of f’s input arguments: the axis to use as batch during vectorization. Use None to automatically broadcast.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the VarCollection.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
Variables¶
|
The abstract base class to represent objax variables. |
|
A trainable variable. |
|
The abstract base class used to represent objax state variables. |
|
StateVar are variables that get updated manually, and are not autmatically updated by optimizers. |
|
A state variable that references a trainable variable for assignment. |
|
RandomState are variables that track the random generator state. |
A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. |
-
class
objax.
BaseVar
(reduce)[source]¶ The abstract base class to represent objax variables.
-
__init__
(reduce)[source]¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainVar
(tensor, reduce=<function reduce_mean>)[source]¶ A trainable variable.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ TrainVar constructor.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the TrainVar.
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable. You can write a value to a TrainVar by using assign.
-
assign
(tensor)[source]¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
BaseState
(reduce)[source]¶ The abstract base class used to represent objax state variables. State variables are not trainable.
-
__init__
(reduce)¶ Constructor for BaseVar class.
- Parameters
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
-
class
objax.
TrainRef
(ref)[source]¶ A state variable that references a trainable variable for assignment.
TrainRef are used by optimizers to keep references to trainable variables. This is necessary to differentiate them from the optimizer own training variables if any.
-
__init__
(ref)[source]¶ TrainRef constructor.
- Parameters
ref (objax.variable.TrainVar) – the TrainVar to keep the reference of.
-
property
value
¶ The value stored in the referenced TrainVar, it can be read or written.
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
StateVar
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar are variables that get updated manually, and are not autmatically updated by optimizers. For example, the mean and variance statistics in BatchNorm are StateVar.
-
__init__
(tensor, reduce=<function reduce_mean>)[source]¶ StateVar constructor.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – the initial value of the StateVar.
reduce (Optional[Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]) – a function that takes an array of shape
(n, *dims)
and returns one of shape(*dims)
. Used to combine the multiple states produced in an objax.Vectorize or an objax.Parallel call.
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
-
class
objax.
RandomState
(seed)[source]¶ RandomState are variables that track the random generator state. They are meant to be used internally. Currently only the random.Generator module uses them.
-
__init__
(seed)[source]¶ RandomState constructor.
- Parameters
seed (int) – the initial seed of the random number generator.
-
seed
(seed)[source]¶ Sets a new random seed.
- Parameters
seed (int) – the new initial seed of the random number generator.
-
split
(n)[source]¶ Create multiple seeds from the current seed. This is used internally by Parallel and Vectorize to ensure that random numbers are different in parallel threads.
- Parameters
n (int) – the number of seeds to generate.
- Return type
List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]
-
assign
(tensor)¶ Sets the value of the variable.
- Parameters
tensor (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
reduce
(tensors)¶ Method called by Parallel and Vectorize to reduce a multiple-device (or batched in case of vectoriaation) value to a single device.
- Parameters
tensors (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
-
property
value
¶ The value stored in the StateVar, it can be read or written.
-
-
class
objax.
VarCollection
[source]¶ A VarCollection is a dictionary (name, var) with some additional methods to make manipulation of collections of variables easy. A VarCollection is ordered by insertion order. It is the object returned by Module.vars() and used as input by many modules: optimizers, Jit, etc…
Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu]) vc = m.vars() # This is a VarCollection # It is a dictionary print(repr(vc)) # {'(Sequential)[0](Linear).b': <objax.variable.TrainVar object at 0x7faecb506390>, # '(Sequential)[0](Linear).w': <objax.variable.TrainVar object at 0x7faec81ee350>} print(vc.keys()) # dict_keys(['(Sequential)[0](Linear).b', '(Sequential)[0](Linear).w']) assert (vc['(Sequential)[0](Linear).w'].value == m[0].w.value).all() # Convenience print print(vc) # (Sequential)[0](Linear).b 3 (3,) # (Sequential)[0](Linear).w 6 (2, 3) # +Total(2) 9 # Extra methods for manipulation of variables: # For example, increment all variables by 1 vc.assign([x+1 for x in vc.tensors()]) # It's used by other modules. # For example it's used to tell Jit what variables are used by a function. jit_f = objax.Jit(lambda x: m(x), vc)
For more information and examples, refer to VarCollection.
-
update
(other)[source]¶ Overload dict.update method to catch potential conflicts during assignment.
- Parameters
other (Union[VarCollection, Iterable[Tuple[str, objax.variable.BaseVar]]]) –
-
assign
(tensors)[source]¶ Assign tensors to the variables in the VarCollection. Each variable is assigned only once and in the order following the iter(self) iterator.
- Parameters
tensors (List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – the list of tensors used to update variables values.
-
replicate
()[source]¶ A context manager to use in a with statement that replicates the variables in this collection to multiple devices. This is used typically prior to call to objax.Parallel, so that all variables have a copy on each device. Important: replicating also updates the random state in order to have a new one per device.
-
subset
(is_a=None, is_not=None)[source]¶ Return a new VarCollection that is a filtered subset of the current collection.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
is_not (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to exclude.
- Returns
A new VarCollection containing the subset of variables.
- Return type
objax.variable.VarCollection
-
tensors
(is_a=None)[source]¶ Return the list of values for this collection. Similarly to the assign method, each variable value is reported only once and in the order following the iter(self) iterator.
- Parameters
is_a (Optional[Union[type, Tuple[type, ..]]]) – either a variable type or a list of variables types to include.
- Returns
A new VarCollection containing the subset of variables.
- Return type
List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]
-
objax.functional package¶
objax.functional¶
Due to the large number of APIs in this section, we organized it into the following sub-sections:
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
Compute the log of the sum of exponentials of input elements. |
|
Rectified linear unit activation function. |
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
-
objax.functional.
celu
(x, alpha=1.0)[source]¶ Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
-
objax.functional.
elu
(x, alpha=1.0)[source]¶ Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
-
objax.functional.
leaky_relu
(x, negative_slope=0.01)[source]¶ Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.
-
objax.functional.
log_sigmoid
(x)[source]¶ Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
-
objax.functional.
log_softmax
(x, axis=- 1)[source]¶ Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters
axis – the axis or axes along which the
log_softmax
should be computed. Either an integer or a tuple of integers.
-
objax.functional.
logsumexp
(a, axis=None, b=None, keepdims=False, return_sign=False)[source]¶ Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
logsumexp()
. Original docstring below.- Parameters
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Returns
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
See also
numpy.logaddexp()
,numpy.logaddexp2()
Notes
NumPy has a logaddexp function which is very similar to logsumexp, but only handles two arguments. logaddexp.reduce is similar to this function, but may be less stable.
Examples
>>> from scipy.special import logsumexp >>> a = np.arange(10) >>> np.log(np.sum(np.exp(a))) 9.4586297444267107 >>> logsumexp(a) 9.4586297444267107
With weights
>>> a = np.arange(10) >>> b = np.arange(10, 0, -1) >>> logsumexp(a, b=b) 9.9170178533034665 >>> np.log(np.sum(b*np.exp(a))) 9.9170178533034647
Returning a sign flag
>>> logsumexp([1,2],b=[1,-1],return_sign=True) (1.5413248546129181, -1.0)
Notice that logsumexp does not directly support masked arrays. To use it on a masked array, convert the mask into zero weights:
>>> a = np.ma.array([np.log(2), 2, np.log(3)], ... mask=[False, True, False]) >>> b = (~a.mask).astype(int) >>> logsumexp(a.data, b=b), np.log(5) 1.6094379124341005, 1.6094379124341005
-
objax.functional.
relu
(x)[source]¶ Rectified linear unit activation function.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
- Returns
tensor with the element-wise output relu(x) = max(x, 0).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
selu
(x)[source]¶ Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
-
objax.functional.
sigmoid
(x)[source]¶ Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
-
objax.functional.
softmax
(x, axis=- 1)[source]¶ Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters
axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
|
Applies average pooling using a square 2D filter. |
|
Transfer batch dimension N into spatial dimensions (H, W). |
|
Transfer channel dimension C into spatial dimensions (H, W). |
|
Applies max pooling using a square 2D filter. |
|
Transfer spatial dimensions (H, W) into batch dimension N. |
|
Transfer spatial dimensions (H, W) into channel dimension C. |
-
objax.functional.
average_pool_2d
(x, size=2, strides=2, padding=<ConvPadding.VALID: 'VALID'>)[source]¶ Applies average pooling using a square 2D filter.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Union[Tuple[int, int], int]) – stride step.
padding (objax.constants.ConvPadding) – type of padding used in pooling operation.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
For a definition of pooling, including examples see Pooling Layer.
-
objax.functional.
batch_to_space2d
(x, size=2)[source]¶ Transfer batch dimension N into spatial dimensions (H, W).
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
channel_to_space2d
(x, size=2)[source]¶ Transfer channel dimension C into spatial dimensions (H, W).
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
max_pool_2d
(x, size=2, strides=2, padding=<ConvPadding.VALID: 'VALID'>)[source]¶ Applies max pooling using a square 2D filter.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Union[Tuple[int, int], int]) – stride step.
padding (objax.constants.ConvPadding) – type of padding used in pooling operation.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
For a definition of pooling, including examples see Pooling Layer.
-
objax.functional.
space_to_batch2d
(x, size=2)[source]¶ Transfer spatial dimensions (H, W) into batch dimension N.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
space_to_channel2d
(x, size=2)[source]¶ Transfer spatial dimensions (H, W) into channel dimension C.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
|
Wraps XLA’s DynamicSlice operator. |
|
Flattens input tensor to a 2D tensor. |
|
One-hot encodes the given indicies. |
|
Pad an array. |
Stops gradient computation. |
|
|
Returns top |
|
Elementwise reciprocal square root: :math:`1 over sqrt{x}. |
|
Nearest neighbor upscale for image batches of shape (N, C, H, W). |
-
objax.functional.
dynamic_slice
(operand, start_indices, slice_sizes)[source]¶ Wraps XLA’s DynamicSlice operator.
- Parameters
operand (Any) – an array to slice.
start_indices (Sequence[Any]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Sequence[int]) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Returns
An array containing the slice.
- Return type
Any
-
objax.functional.
flatten
(x)[source]¶ Flattens input tensor to a 2D tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor with dimensions (n_1, n_2, …, n_k)
- Returns
The input tensor reshaped to two dimensions (n_1, n_sum), where n_sum is equal to the sum of n_2 to n_k.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
one_hot
(x, num_classes, *, dtype=<class 'jax.numpy.lax_numpy.float64'>)[source]¶ One-hot encodes the given indicies.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
- DeviceArray([[1., 0., 0.],
[0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
- DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
- Parameters
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
-
objax.functional.
pad
(array, pad_width, mode='constant', constant_values=0)[source]¶ Pad an array.
LAX-backend implementation of
pad()
. Original docstring below.- Parameters
array (array_like of rank N) – The array to pad.
pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.
mode (str or function, optional) – One of the following string values or a user supplied function.
constant_values (sequence or scalar, optional) – Used in ‘constant’. The values to set the padded values for each axis.
- Returns
pad – Padded array of rank equal to array with shape increased according to pad_width.
- Return type
ndarray
Notes
New in version 1.7.0.
For an array with rank greater than 1, some of the padding of later axes is calculated from padding of previous axes. This is easiest to think about with a rank 2 array where the corners of the padded array are calculated by using padded values from the first axis.
The padding function, if used, should modify a rank 1 array in-place. It has the following signature:
padding_func(vector, iaxis_pad_width, iaxis, kwargs)
where
- vectorndarray
A rank 1 array already padded with zeros. Padded values are vector[:iaxis_pad_width[0]] and vector[-iaxis_pad_width[1]:].
- iaxis_pad_widthtuple
A 2-tuple of ints, iaxis_pad_width[0] represents the number of values padded at the beginning of vector where iaxis_pad_width[1] represents the number of values padded at the end of vector.
- iaxisint
The axis currently being calculated.
- kwargsdict
Any keyword arguments the function requires.
Examples
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'constant', constant_values=(4, 6)) array([4, 4, 1, ..., 6, 6, 6])
>>> np.pad(a, (2, 3), 'edge') array([1, 1, 1, ..., 5, 5, 5])
>>> np.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)) array([ 5, 3, 1, 2, 3, 4, 5, 2, -1, -4])
>>> np.pad(a, (2,), 'maximum') array([5, 5, 1, 2, 3, 4, 5, 5, 5])
>>> np.pad(a, (2,), 'mean') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> np.pad(a, (2,), 'median') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> a = [[1, 2], [3, 4]] >>> np.pad(a, ((3, 2), (2, 3)), 'minimum') array([[1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [3, 3, 3, 4, 3, 3, 3], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1]])
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'reflect') array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2])
>>> np.pad(a, (2, 3), 'reflect', reflect_type='odd') array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> np.pad(a, (2, 3), 'symmetric') array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3])
>>> np.pad(a, (2, 3), 'symmetric', reflect_type='odd') array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7])
>>> np.pad(a, (2, 3), 'wrap') array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3])
>>> def pad_with(vector, pad_width, iaxis, kwargs): ... pad_value = kwargs.get('padder', 10) ... vector[:pad_width[0]] = pad_value ... vector[-pad_width[1]:] = pad_value >>> a = np.arange(6) >>> a = a.reshape((2, 3)) >>> np.pad(a, 2, pad_with) array([[10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 0, 1, 2, 10, 10], [10, 10, 3, 4, 5, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10]]) >>> np.pad(a, 2, pad_with, padder=100) array([[100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 0, 1, 2, 100, 100], [100, 100, 3, 4, 5, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100]])
-
objax.functional.
stop_gradient
(x)[source]¶ Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) array(6., dtype=float32) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) array(0., dtype=float32) >>> jax.grad(jax.grad(lambda x: x**2))(3.) array(2., dtype=float32) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32)
-
objax.functional.
top_k
(operand, k)[source]¶ Returns top
k
values and their indices along the last axis ofoperand
.- Parameters
operand (Any) –
k (int) –
- Return type
Tuple[Any, Any]
-
objax.functional.
rsqrt
(x)[source]¶ Elementwise reciprocal square root: :math:`1 over sqrt{x}.
- Parameters
x (Any) –
- Return type
Any
-
objax.functional.
upscale_nn
(x, scale=2)[source]¶ Nearest neighbor upscale for image batches of shape (N, C, H, W).
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
scale (int) – integer scaling factor.
- Returns
Output tensor of shape (N, C, H * scale, W * scale).
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.divergence¶
|
Calculates the Kullback-Leibler divergence between arrays p and q. |
-
objax.functional.divergence.
kl
(p, q, eps=7.62939453125e-06)[source]¶ Calculates the Kullback-Leibler divergence between arrays p and q.
\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]The \(\epsilon\) term is added to ensure that neither
p
norq
are zero.- Parameters
p (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
q (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
eps (float) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.loss¶
|
Computes the softmax cross-entropy loss on n-dimensional data. |
|
Computes the softmax cross-entropy loss. |
|
Computes the L2 loss. |
|
Computes the sigmoid cross-entropy loss. |
-
objax.functional.loss.
cross_entropy_logits
(logits, labels)[source]¶ Computes the softmax cross-entropy loss on n-dimensional data.
- Parameters
logits (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
Calculates the cross entropy loss, defined as follows:
\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]where \(o_k\) are the logits and \(y_k\) are the labels.
-
objax.functional.loss.
cross_entropy_logits_sparse
(logits, labels)[source]¶ Computes the softmax cross-entropy loss.
- Parameters
logits (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.
- Returns
(batch,) tensor of the cross-entropies for each entry.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.loss.
l2
(x)[source]¶ Computes the L2 loss.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – n-dimensional tensor of floats.
- Returns
scalar tensor containing the l2 loss of x.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
Calculates the l2 loss, as:
-
objax.functional.loss.
sigmoid_cross_entropy_logits
(logits, labels)[source]¶ Computes the sigmoid cross-entropy loss.
- Parameters
logits (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.parallel¶
|
Compute a multi-device reduce max on x over the device axis axis_name. |
|
Compute a multi-device reduce mean on x over the device axis axis_name. |
|
Compute a multi-device reduce min on x over the device axis axis_name. |
|
Compute a multi-device reduce sum on x over the device axis axis_name. |
-
objax.functional.parallel.
pmax
(x, axis_name='device')[source]¶ Compute a multi-device reduce max on x over the device axis axis_name.
- Parameters
x (jax.interpreters.pxla.ShardedDeviceArray) –
axis_name (str) –
-
objax.functional.parallel.
pmean
(x, axis_name='device')[source]¶ Compute a multi-device reduce mean on x over the device axis axis_name.
- Parameters
x (jax.interpreters.pxla.ShardedDeviceArray) –
axis_name (str) –
objax.io package¶
|
Helper class which performs saving and restoring of the variables. |
|
Loads values of all variables in the given variables collection from file. |
|
Saves variables collection into file. |
-
class
objax.io.
Checkpoint
(logdir, keep_ckpts, makedir=True, verbose=True)[source]¶ Helper class which performs saving and restoring of the variables.
Variables are stored in the checkpoint files. One checkpoint file stores a single snapshot of the variables. Different checkpoint files store different snapshots of the variables (for example at different training step). Each checkpoint has associated index, which is used to identify time when snapshot of the variables was made. Typically training step or training epoch are used as an index.
-
DIR_NAME
: str = 'ckpt'¶ Name of the subdirectory of model directory where checkpoints will be saved.
-
FILE_MATCH
: str = '*.npz'¶ File pattern which is used to search for checkpoint files.
-
FILE_FORMAT
: str = '%010d.npz'¶ Format of the filename of one checkpoint file.
-
static
LOAD_FN
(file, vc)¶ Load function, which loads variables collection from given file.
- Parameters
file (Union[str, IO[BinaryIO]]) –
vc (objax.variable.VarCollection) –
-
static
SAVE_FN
(file, vc)¶ Save function, which saves variables collection into given file.
- Parameters
file (Union[str, IO[BinaryIO]]) –
vc (objax.variable.VarCollection) –
-
__init__
(logdir, keep_ckpts, makedir=True, verbose=True)[source]¶ Creates instance of the Checkpoint class.
- Parameters
logdir (str) – model directory. Checkpoints will be saved in the subdirectory of model directory.
keep_ckpts (int) – maximum number of checkpoints to keep.
makedir (bool) – if True then directory for checkpoints will be created, otherwise it’s expected that directory already exists.
verbose (bool) – if True then print when data is restored from checkpoint.
-
static
checkpoint_idx
(filename)[source]¶ Returns index of checkpoint from given checkpoint filename.
- Parameters
filename (str) – checkpoint filename.
- Returns
checkpoint index.
-
restore
(vc, idx=None)[source]¶ Restores values of all variables of given variables collection from the checkpoint.
Old values from the variables collection will be replaced with the new values read from checkpoint. If variable does not exist in the variables collection, it won’t be restored from checkpoint.
- Parameters
vc (objax.variable.VarCollection) – variables collection to restore.
idx (Optional[int]) – if provided then checkpoint index to use, if None then latest checkpoint will be restored.
- Returns
index of the restored checkpoint. ckpt: full path to the restored checkpoint.
- Return type
idx
-
-
objax.io.
load_var_collection
(file, vc)[source]¶ Loads values of all variables in the given variables collection from file.
Values loaded from file will replace old values in the variables collection. If variable exists in the file, but does not exist in the variables collection it will be ignored. If variable exists in the variables collection, but not found in the file then exception will be raised.
- Parameters
file (Union[str, IO[BinaryIO]]) – filename or python file handle of the input file.
vc (objax.variable.VarCollection) – variables collection which will be loaded from file.
- Raises
ValueError – if variable from variables collection is not found in the input file.
objax.jaxboard package¶
|
Reduces tensor batch into a single tensor. |
Writes entries to Summary protocol buffer. |
|
|
Writes entries to event files in the logdir to be consumed by Tensorboard. |
-
class
objax.jaxboard.
Summary
[source]¶ Writes entries to Summary protocol buffer.
-
image
(tag, image)[source]¶ Adds image to the summary. Float image in [-1, 1] in CHW format expected.
- Parameters
tag (str) –
image (numpy.ndarray) –
-
-
class
objax.jaxboard.
SummaryWriter
(logdir, queue_size=5, write_interval=5)[source]¶ Writes entries to event files in the logdir to be consumed by Tensorboard.
-
__init__
(logdir, queue_size=5, write_interval=5)[source]¶ Creates SummaryWriter instance.
- Parameters
logdir (str) – directory where event file will be written.
queue_size (int) – size of the queue for pending events and summaries before one of the ‘add’ calls forces a flush to disk.
write_interval (int) – how often, in seconds, to write the pending events and summaries to disk.
-
write
(summary, step)[source]¶ Adds on event to the event file.
- Parameters
summary (objax.jaxboard.Summary) –
step (int) –
-
objax.nn package¶
objax.nn¶
|
Applies a batch normalization on different ranks of an input tensor. |
|
Applies a 0D batch normalization on a 2D-input batch of shape (N,C). |
|
Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L). |
|
Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W). |
|
Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W). |
|
Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W). |
|
In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep. |
|
Applies a linear transformation on an input batch. |
|
Computes moving average of an input batch. |
|
computes exponential moving average (also called EMA or EWMA) of an input batch. |
|
Executes modules in the order they were passed to the constructor. |
|
Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs). |
|
Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C). |
|
Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L). |
|
Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W). |
-
class
objax.nn.
BatchNorm
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Applies a batch normalization on different ranks of an input tensor.
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per specified dimensions and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape dims. The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm module instance.
- Parameters
dims (Iterable[int]) – shape of the batch normalization state variables.
redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)[source]¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm0D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 0D batch normalization on a 2D-input batch of shape (N,C).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm0D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm1D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm1D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
BatchNorm2D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W).
The module follows the operation described in Algorithm 1 of Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
\[y = \frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}} \times \gamma + \beta\]The mean (\(\mathrm{E}[x]\)) and variance (\(\mathrm{Var}[x]\)) are calculated per channel and over the mini-batches. \(\beta\) and \(\gamma\) are trainable parameter tensors of shape (1, nin, 1, 1). The elements of \(\beta\) are initialized with zeros and those of \(\gamma\) are initialized with ones.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a BatchNorm2D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Conv2D
(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).
In the simplest case (strides = 1, padding = VALID), the output tensor \((N,C_{out},H_{out},W_{out})\) is computed from an input tensor \((N,C_{in},H,W)\) with kernel weight \((k,k,C_{in},C_{out})\) and bias \((C_{out})\) as follows:
\[\mathrm{out}[n,c,h,w] = \mathrm{b}[c] + \sum_{t=0}^{C_{in}-1}\sum_{i=0}^{k-1}\sum_{j=0}^{k-1} \mathrm{in}[n,c,i+h,j+w] \times \mathrm{w}[i,j,t,c]\]where \(H_{out}=H-k+1\), \(W_{out}=W-k+1\). Note that the implementation follows the definition of cross-correlation. When padding = SAME, the input tensor is zero-padded by \(\lfloor\frac{k-1}{2}\rfloor\) for left and up sides and \(\lfloor\frac{k}{2}\rfloor\) for right and down sides.
-
__init__
(nin, nout, k, strides=1, dilations=1, groups=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Creates a Conv2D module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.
strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.
dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.
groups (int) – number of input and output channels group. When groups > 1 convolution operation is applied individually for each group. nin and nout must both be divisible by groups.
padding (objax.constants.ConvPadding) – padding of the input tensor, either Padding.SAME or Padding.VALID.
use_bias (bool) – if True then convolution will have bias term.
w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the convolution to input x.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
ConvTranspose2D
(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W).
This module can be seen as a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution. Note that ConvTranspose2D is consistent with Conv2DTranspose, of Tensorflow but is not consistent with ConvTranspose2D of PyTorch due to kernel transpose and padding.
-
__init__
(nin, nout, k, strides=1, dilations=1, padding=<ConvPadding.SAME: 'SAME'>, use_bias=True, w_init=<function kaiming_normal>)[source]¶ Creates a ConvTranspose2D module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
k (Union[Tuple[int, int], int]) – size of the convolution kernel, either tuple (height, width) or single number if they’re the same.
strides (Union[Tuple[int, int], int]) – convolution strides, either tuple (stride_y, stride_x) or single number if they’re the same.
dilations (Union[Tuple[int, int], int]) – spacing between kernel points (also known as astrous convolution), either tuple (dilation_y, dilation_x) or single number if they’re the same.
padding (objax.constants.ConvPadding) – padding of the input tensor, either Padding.SAME or Padding.VALID.
use_bias (bool) – if True then convolution will have bias term.
w_init (Callable) – initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the transposed convolution to input x.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Dropout
(keep, generator=<objax.random.random.Generator object>)[source]¶ In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and scale the other elements by a factor of 1/keep.
During the evaluation, the module does not modify the input tensor. Dropout (Improving neural networks by preventing co-adaptation of feature detectors) is an effective regularization technique which reduces the overfitting and increases the overall utility.
-
__init__
(keep, generator=<objax.random.random.Generator object>)[source]¶ Creates Dropout module instance.
- Parameters
keep (float) – probability to keep element of the tensor.
generator – optional argument with instance of ObJAX random generator.
-
__call__
(x, training, dropout_keep=None)[source]¶ Performs dropout of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True then apply dropout to the input, otherwise keep input tensor unchanged.
dropout_keep (Optional[float]) – optional argument, when set overrides dropout keep probability.
- Returns
Tensor with applied dropout.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Linear
(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]¶ Applies a linear transformation on an input batch.
The output tensor \((N,C_{out})\) is computed from an input tensor \((N,C_{in})\) with kernel weight \((C_{in},C_{out})\) and bias \((C_{out})\) as follows:
\[\mathrm{out}[n,c] = \mathrm{b}[c] + \sum_{t=1}^{C_{in}} \mathrm{in}[n,t] \times \mathrm{w}[t,c]\]-
__init__
(nin, nout, use_bias=True, w_init=<function xavier_normal>)[source]¶ Creates a Linear module instance.
- Parameters
nin (int) – number of channels of the input tensor.
nout (int) – number of channels of the output tensor.
use_bias (bool) – if True then linear layer will have bias term.
w_init (Callable) – weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix).
-
__call__
(x)[source]¶ Returns the results of applying the linear transformation to input x.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
MovingAverage
(shape, buffer_size, init_value=0)[source]¶ Computes moving average of an input batch.
-
__init__
(shape, buffer_size, init_value=0)[source]¶ Creates a MovingAverage module instance.
- Parameters
shape (Tuple[int, ..]) – shape of the input tensor.
buffer_size (int) – buffer size for moving average.
init_value (float) – initial value for moving average buffer.
-
__call__
(x)[source]¶ Update the statistics using x and return the moving average.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
ExponentialMovingAverage
(shape, momentum=0.999, init_value=0)[source]¶ computes exponential moving average (also called EMA or EWMA) of an input batch.
\[x_{\mathrm{EMA}} \leftarrow \mathrm{momentum} \times x_{\mathrm{EMA}} + (1-\mathrm{momentum}) \times x\]-
__init__
(shape, momentum=0.999, init_value=0)[source]¶ Creates a ExponentialMovingAverage module instance.
- Parameters
shape (Tuple[int, ..]) – shape of the input tensor.
momentum (float) – momentum for exponential decrease of accumulated value.
init_value (float) – initial value for exponential moving average.
-
__call__
(x)[source]¶ Update the statistics using x and return the exponential moving average.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
Sequential
(iterable=(), /)[source]¶ Executes modules in the order they were passed to the constructor.
-
__call__
(x, **kwargs)[source]¶ Execute the sequence of operation contained on
x
and**kwargs
and return result.- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
__init__
(*args, **kwargs)¶ Initialize self. See help(type(self)) for accurate signature.
-
append
(object, /)¶ Append object to the end of the list.
-
clear
()¶ Remove all items from list.
-
copy
()¶ Return a shallow copy of the list.
-
count
(value, /)¶ Return number of occurrences of value.
-
extend
(iterable, /)¶ Extend list by appending elements from the iterable.
-
index
(value, start=0, stop=9223372036854775807, /)¶ Return first index of value.
Raises ValueError if the value is not present.
-
insert
(index, object, /)¶ Insert object before index.
-
pop
(index=-1, /)¶ Remove and return item at index (default last).
Raises IndexError if list is empty or index is out of range.
-
remove
(value, /)¶ Remove first occurrence of value.
Raises ValueError if the value is not present.
-
reverse
()¶ Reverse IN PLACE.
-
vars
(scope='')¶ Collect all the variables (and their names) contained in the list and its submodules.
- Parameters
scope (str) – string to prefix to the variable names.
- Returns
A VarCollection of all the variables.
- Return type
objax.variable.VarCollection
-
-
class
objax.nn.
SyncedBatchNorm
(dims, redux, momentum=0.999, eps=1e-06)[source]¶ Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).
-
__call__
(x, training, batch_norm_update=True)[source]¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
__init__
(dims, redux, momentum=0.999, eps=1e-06)¶ Creates a BatchNorm module instance.
- Parameters
dims (Iterable[int]) – shape of the batch normalization state variables.
redux (Iterable[int]) – list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
-
class
objax.nn.
SyncedBatchNorm0D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm0D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
SyncedBatchNorm1D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm1D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.nn.
SyncedBatchNorm2D
(nin, momentum=0.999, eps=1e-06)[source]¶ Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W).
Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. Compared to regular batch norm this usually leads to better accuracy at a slight performance cost.
-
__init__
(nin, momentum=0.999, eps=1e-06)[source]¶ Creates a SyncedBatchNorm2D module instance.
- Parameters
nin (int) – number of channels in the input example.
momentum (float) – value used to compute exponential moving average of batch statistics.
eps (float) – small value which is used for numerical stability.
-
__call__
(x, training, batch_norm_update=True)¶ Performs batch normalization of input tensor.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
training (bool) – if True compute batch normalization in training mode (accumulating batch statistics), otherwise compute in evaluation mode (using already accumulated batch statistics).
batch_norm_update (bool) –
- Returns
Batch normalized tensor.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.nn.init¶
|
The recommended gain value for leaky_relu. |
|
Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. |
|
Returns a tensor with values assigned using truncated normal initialization. |
|
Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks. |
|
Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks. |
|
Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks. |
-
class
objax.nn.init.
gain_leaky_relu
(relu_slope=0.1)[source]¶ The recommended gain value for leaky_relu.
- Parameters
relu_slope – negative slope of leaky_relu.
- Returns
The recommended gain value for leaky_relu.
The returned gain value is
\[\sqrt{\frac{2}{1 + \text{relu_slope}^2}}.\]
-
class
objax.nn.init.
kaiming_normal_gain
(shape)[source]¶ Returns Kaiming He gain from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
- Returns
Scalar, the standard deviation gain.
The returned gain value is
\[\sqrt{\frac{1}{\text{fan_in}}}.\]
-
class
objax.nn.init.
kaiming_normal
(shape, gain=1)[source]¶ Returns a tensor with values assigned using Kaiming He normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
gain – optional scaling factor.
- Returns
Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).
-
class
objax.nn.init.
kaiming_truncated_normal
(shape, lower=- 2, upper=2, gain=1)[source]¶ Returns a tensor with values assigned using Kaiming He truncated normal initializer from Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
gain – optional scaling factor.
- Returns
Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper].
-
class
objax.nn.init.
truncated_normal
(shape, lower=- 2, upper=2, stddev=1)[source]¶ Returns a tensor with values assigned using truncated normal initialization.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
stddev – expected standard deviation.
- Returns
Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper].
-
class
objax.nn.init.
xavier_normal_gain
(shape)[source]¶ Returns Xavier Glorot gain from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
- Returns
Scalar, the standard deviation gain.
The returned gain value is
\[\sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}.\]
-
class
objax.nn.init.
xavier_normal
(shape, gain=1)[source]¶ Returns a tensor with values assigned using Xavier Glorot normal initializer from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
gain – optional scaling factor.
- Returns
Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).
-
class
objax.nn.init.
xavier_truncated_normal
(shape, lower=- 2, upper=2, gain=1)[source]¶ Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from Understanding the difficulty of training deep feedforward neural networks.
- Parameters
shape – shape of the output tensor.
lower – lower truncation of the normal.
upper – upper truncation of the normal.
gain – optional scaling factor.
- Returns
Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper].
objax.optimizer package¶
|
Adam optimizer. |
|
Maintains exponential moving averages for each variable from provided VarCollection. |
|
Momentum optimizer. |
|
Stochastic Gradient Descent (SGD) optimizer. |
-
class
objax.optimizer.
Adam
(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]¶ Adam optimizer.
Adam is an adaptive learning rate optimization algorithm originally presented in Adam: A Method for Stochastic Optimization. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\), the update rule is as follows:
\[\begin{split}\begin{eqnarray} v_{k} &=& \beta_1 v_{k-1} + (1 - \beta_1) \nabla f (.; w_{k-1}) \nonumber \\ s_{k} &=& \beta_2 s_{k-1} - (1 - \beta_2) (\nabla f (.; w_{k-1}))^2 \nonumber \\ \hat{v_{k}} &=& \frac{v_{k}}{(1 - \beta_{1}^{k})} \nonumber \\ \hat{s_{k}} &=& \frac{s_{k}}{(1 - \beta_{2}^{k})} \nonumber \\ w_{k} &=& w_{k-1} - \eta \frac{\hat{v_{k}}}{\sqrt{\hat{s_{k}}} + \epsilon} \nonumber \end{eqnarray}\end{split}\]Adam updates exponential moving averages of the gradient \((v_{k})\) and the squared gradient \((s_{k})\) where the hyper-parameters \(\beta_1\) and \(\beta_2 \in [0, 1)\) control the exponential decay rates of these moving averages. The \(\eta\) constant in the weight update rule is the learning rate and is passed as a parameter in the
__call__
method. Note that the implementation uses the approximation \(\sqrt{(\hat{s_{k}} + \epsilon)} \approx \sqrt{\hat{s_{k}}} + \epsilon\).-
__init__
(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]¶ Constructor for Adam optimizer class.
- Parameters
vc (objax.variable.VarCollection) – collection of variables to optimize.
beta1 (float) – value of Adam’s beta1 hyperparameter. Defaults to 0.9.
beta2 (float) – value of Adam’s beta2 hyperparameter. Defaults to 0.999.
eps (float) – value of Adam’s epsilon hyperparameter. Defaults to 1e-8.
-
-
class
objax.optimizer.
ExponentialMovingAverage
(vc, momentum=0.999, debias=False, eps=1e-06)[source]¶ Maintains exponential moving averages for each variable from provided VarCollection.
When training a model, it is often beneficial to maintain exponential moving averages (EMA) of the trained parameters. Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values (see Acceleration of Stochastic Approximation by Averaging).
This maintains an EMA of the parameters passed in the VarCollection
vc
. The EMA update rule for weights \(w\), the EMA \(m\) at step \(t\) when using a momentum \(\mu\) is:\[m_t = \mu m_{t-1} + (1 - \mu) w_t\]The EMA weights \(\hat{w_t}\) are simply \(m_t\) when
debias=False
. Whendebias=True
, the EMA weights are defined as:\[\hat{w_t} = \frac{m_t}{1 - (1 - \epsilon)\mu^t}\]Where \(\epsilon\) is a small constant to avoid a divide-by-0.
-
__init__
(vc, momentum=0.999, debias=False, eps=1e-06)[source]¶ Creates ExponentialMovingAverage instance with given hyperparameters.
- Parameters
momentum (float) – the decay factor for the moving average.
debias (bool) – bool indicating whether to use initialization bias correction.
eps (float) – small adjustment to prevent division by zero.
vc (objax.variable.VarCollection) –
-
refs_and_values
()[source]¶ Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and their corresponding EMA values.
- Return type
Tuple[objax.variable.VarCollection, List[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]]
-
-
class
objax.optimizer.
Momentum
(vc, momentum=0.9, nesterov=False)[source]¶ Momentum optimizer.
The momentum optimizer (expository article) introduces a tweak to the standard gradient descent. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\) the update rule is as follows:
\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f (.; w_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]The term \(v\) is the velocity: It accumulates past gradients through a weighted moving average calculation. The parameters \(\mu, \eta\) are the momentum and the learning rate.
The momentum class also implements Nesterov’s Accelerated Gradient (NAG) (see Sutskever et. al.). Like momentum, NAG is a first-order optimization method with better convergence rate than gradient descent in certain situations. The NAG update can be written as:
\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f(.; w_{k-1} + \mu v_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]The implementation uses the simplification presented by Bengio et. al.
-
class
objax.optimizer.
SGD
(vc)[source]¶ Stochastic Gradient Descent (SGD) optimizer.
The stochastic gradient optimizer performs Stochastic Gradient Descent (SGD). It uses the following update rule for a loss \(f\) parameterized with model weights \(w\) and a user provided learning rate \(\eta\):
\[w_k = w_{k-1} - \eta\nabla f(.; w_{k-1})\]
objax.privacy package¶
|
Computes differentially private gradients as required by DP-SGD. |
|
Compute and print results of DP-SGD analysis. |
|
Compute RDP of the Sampled Gaussian Mechanism. |
|
Compute delta (or eps) for given eps (or delta) from RDP values. |
-
class
objax.privacy.
PrivateGradValues
(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, ), keygen=<objax.random.random.Generator object>)[source]¶ Computes differentially private gradients as required by DP-SGD. This module can be used in place of GradVals, and automatically makes the optimizer differentially private.
-
__init__
(f, vc, noise_multiplier, l2_norm_clip, microbatch, batch_axis=(0, ), keygen=<objax.random.random.Generator object>)[source]¶ Constructs a PrivateGradValues instance.
- Parameters
f (Callable) – the function for which to compute gradients.
vc (objax.variable.VarCollection) – the variables for which to compute gradients.
noise_multiplier (float) – scale of standard deviation for added noise in DP-SGD.
l2_norm_clip (float) – value of clipping norm for DP-SGD.
microbatch (int) – the size of each microbatch.
batch_axis (Tuple[Optional[int], ..]) – the axis to use as batch during vectorization. Should be a tuple of 0s.
keygen (objax.random.random.Generator) – a Generator for random numbers. Defaults to objax.random.DEFAULT_GENERATOR.
-
reshape_microbatch
(x)[source]¶ Reshapes examples into microbatches. DP-SGD requires that per-example gradients are clipped and noised, however this can be inefficient. To speed this up, it is possible to clip and noise a microbatch of examples, at a sight cost to privacy. If speed is not an issue, the microbatch size should be set to 1.
If x has shape [D0, D1, …, Dn], the reshaped output will have shape [number_of_microbatches, microbatch_size, D1, …, DN].
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – items to be reshaped.
- Returns
The reshaped items.
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
objax.privacy.
apply_dp_sgd_analysis
(q, noise_multiplier, steps, orders=1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 3.0, 3.5, 4.0, 4.5, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, delta=1e-05)[source]¶ Compute and print results of DP-SGD analysis.
- Parameters
q (float) – The sampling rate.
noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.
steps (int) – The number of steps.
orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.
delta (float) – The target delta.
- Returns
eps
- Raises
ValueError – If target_delta are messed up.
- Return type
float
-
objax.privacy.
compute_rdp
(q, noise_multiplier, steps, orders)[source]¶ Compute RDP of the Sampled Gaussian Mechanism.
- Parameters
q (float) – The sampling rate.
noise_multiplier (float) – The ratio of the standard deviation of the Gaussian noise to the l2-sensitivity of the function to which it is added.
steps (int) – The number of steps.
orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.
- Returns
The RDPs at all orders, can be np.inf.
-
objax.privacy.
get_privacy_spent
(orders, rdp, target_eps=None, target_delta=None)[source]¶ Compute delta (or eps) for given eps (or delta) from RDP values.
- Parameters
orders (Tuple[float, ..]) – An array (or a scalar) of RDP orders.
rdp (Tuple[float, ..]) – An array of RDP values. Must be of the same length as the orders list.
target_eps (float) – If not None, the epsilon for which we compute the corresponding delta.
target_delta (float) – If not None, the delta for which we compute the corresponding epsilon. Exactly one of target_eps and target_delta must be None.
- Returns
eps, delta, opt_order.
- Raises
ValueError – If target_eps and target_delta are messed up.
- Return type
Tuple[float, float, float]
objax.random package¶
|
Random number generator module. |
|
Returns a |
|
Returns a |
|
Returns a |
|
Returns a |
-
class
objax.random.
Generator
(seed=0)[source]¶ Random number generator module.
-
__init__
(seed=0)[source]¶ Create a random key generator, seed is the random generator initial seed.
- Parameters
seed (int) –
-
property
key
¶ The random generator state (a tensor of 2 int32).
-
-
objax.random.
normal
(shape, *, mean=0, stddev=1, generator=<objax.random.random.Generator object>)[source]¶ Returns a
JaxArray
of shapeshape
with random numbers from a normal distribution with meanmean
and standard deviationstddev
.- Parameters
shape (Tuple[int, ..]) –
mean (float) –
stddev (float) –
generator (objax.random.random.Generator) –
-
objax.random.
randint
(shape, low, high, generator=<objax.random.random.Generator object>)[source]¶ Returns a
JaxAarray
of shapeshape
with random integers in {low, …, high-1}.- Parameters
shape (Tuple[int, ..]) –
low (int) –
high (int) –
generator (objax.random.random.Generator) –
-
objax.random.
truncated_normal
(shape, *, stddev=1, lower=-2, upper=2, generator=<objax.random.random.Generator object>)[source]¶ Returns a
JaxArray
of shapeshape
with random numbers from a normal distribution with mean 0 and standard deviationstddev
truncated by (lower
,upper
).- Parameters
shape (Tuple[int, ..]) –
stddev (float) –
lower (float) –
upper (float) –
generator (objax.random.random.Generator) –
objax.util package¶
objax.util¶
|
Custom dictionary that allows to access dict values as attributes. |
|
Returns the indexes of variable names of a function. |
Empty Context Manager. |
|
|
Integer log2. |
Returns the ordered names of the positional arguments of a function. |
|
|
Converts input to tuple. |
-
class
objax.util.
EasyDict
(*args, **kwargs)[source]¶ Custom dictionary that allows to access dict values as attributes.
-
clear
() → None. Remove all items from D.¶
-
copy
() → a shallow copy of D¶
-
fromkeys
(value=None, /)¶ Create a new dictionary with keys from iterable and values set to value.
-
get
(key, default=None, /)¶ Return the value for key if key is in the dictionary, else default.
-
items
() → a set-like object providing a view on D’s items¶
-
keys
() → a set-like object providing a view on D’s keys¶
-
pop
(k[, d]) → v, remove specified key and return the corresponding value.¶ If key is not found, d is returned if given, otherwise KeyError is raised
-
popitem
() → (k, v), remove and return some (key, value) pair as a¶ 2-tuple; but raise KeyError if D is empty.
-
setdefault
(key, default=None, /)¶ Insert key with a value of default if key is not in the dictionary.
Return the value for key if key is in the dictionary, else default.
-
update
([E, ]**F) → None. Update D from dict/iterable E and F.¶ If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
-
values
() → an object providing a view on D’s values¶
-
-
objax.util.
args_indexes
(f, args)[source]¶ Returns the indexes of variable names of a function.
- Parameters
f (Callable) –
args (Iterable[str]) –
- Return type
Iterable[int]
objax.util.image¶
|
Converts an array in (N,H,W,C) format to (N,C,H,W) format. |
|
Converts an array in (N,C,H,W) format to (N,H,W,C) format. |
Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, …, 255}. |
|
Map an uint8 image in {0, 1, …, 255} to float interval [1/256-1, 1-1/256]. |
|
|
Converts numpy array in (C,H,W) format into PNG format. |
-
objax.util.image.
nchw
(x)[source]¶ Converts an array in (N,H,W,C) format to (N,C,H,W) format.
- Parameters
x (Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.util.image.
nhwc
(x)[source]¶ Converts an array in (N,C,H,W) format to (N,H,W,C) format.
- Parameters
x (Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.util.image.
normalize_to_uint8
(x)[source]¶ Map a float image in [1/256-1, 1-1/256] to uint8 {0, 1, …, 255}.
- Parameters
x (Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.util.image.
normalize_to_unit_float
(x)[source]¶ Map an uint8 image in {0, 1, …, 255} to float interval [1/256-1, 1-1/256].
- Parameters
x (Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[numpy.ndarray, jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.zoo package¶
objax.zoo.convnet¶
-
class
objax.zoo.convnet.
ConvNet
(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶ ConvNet implementation.
-
__init__
(nin, nclass, scales, filters, filters_max, pooling=<function max_pool_2d>, **kwargs)[source]¶ Creates ConvNet instance.
- Parameters
nin – number of channels in the input image.
nclass – number of output classes.
scales – number of pooling layers, each of which reduces spatial dimension by 2.
filters – base number of convolution filters. Number of convolution filters is increased by 2 every scale until it reaches filters_max.
filters_max – maximum number of filters.
pooling – type of pooling layer.
-
objax.zoo.dnnet¶
objax.zoo.resnet_v2¶
-
class
objax.zoo.resnet_v2.
ResNetV2
(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Base implementation of ResNet v2 from https://arxiv.org/abs/1603.05027.
-
__init__
(in_channels, num_classes, blocks_per_group, bottleneck=True, channels_per_group=(256, 512, 1024, 2048), group_strides=(1, 2, 2, 2), group_use_projection=(True, True, True, True), normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNetV2 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
blocks_per_group (Sequence[int]) – number of blocks in each block group.
bottleneck (bool) – if True then use bottleneck blocks.
channels_per_group (Sequence[int]) – number of output channels for each block group.
group_strides (Sequence[int]) – strides for each block group.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
group_use_projection (Sequence[bool]) –
-
-
class
objax.zoo.resnet_v2.
ResNet18
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 18 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet18 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet34
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 34 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet34 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet50
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 50 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet50 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet101
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 101 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet101 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet152
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 152 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet152 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
-
class
objax.zoo.resnet_v2.
ResNet200
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Implementation of ResNet v2 with 200 layers.
-
__init__
(in_channels, num_classes, normalization_fn=<class 'objax.nn.layers.BatchNorm2D'>, activation_fn=<function relu>)[source]¶ Creates ResNet200 instance.
- Parameters
in_channels (int) – number of channels in the input image.
num_classes (int) – number of output classes.
normalization_fn (Callable[[..], objax.module.Module]) – module which used as normalization function.
activation_fn (Callable[[Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]], Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]]) – activation function.
-
objax.zoo.wide_resnet¶
-
class
objax.zoo.wide_resnet.
WRNBlock
(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ WideResNet block.
-
__init__
(nin, nout, stride=1, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WRNBlock instance.
- Parameters
nin (int) – number of input filters.
nout (int) – number of output filters.
stride (int) – stride for convolution and projection convolution in this block.
bn (Callable) – module which used as batch norm function.
-
__call__
(x, training)[source]¶ Optional module __call__ method, typically a forward pass computation for standard primitives.
- Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
training (bool) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
-
class
objax.zoo.wide_resnet.
WideResNetGeneral
(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Base WideResNet implementation.
-
static
mean_reduce
(x)[source]¶ - Parameters
x (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
__init__
(nin, nclass, blocks_per_group, width, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WideResNetGeneral instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
blocks_per_group (List[int]) – number of blocks in each block group.
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
-
static
-
class
objax.zoo.wide_resnet.
WideResNet
(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ WideResNet implementation with 3 groups.
-
__init__
(nin, nclass, depth=28, width=2, bn=functools.partial(<class 'objax.nn.layers.BatchNorm2D'>, momentum=0.9, eps=1e-05))[source]¶ Creates WideResNet instance.
- Parameters
nin (int) – number of channels in the input image.
nclass (int) – number of output classes.
depth (int) – number of convolution layers. (depth-4) should be divisible by 6
width (int) – multiplier to the number of convolution filters.
bn (Callable) – module which used as batch norm function.
-
objax.zoo.rnn¶
-
class
objax.zoo.rnn.
RNN
(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.fn>, w_init=<function kaiming_normal>)[source]¶ Recurrent Neural Network (RNN) block.
-
__init__
(nstate, nin, nout, activation=<function _one_to_one_unop.<locals>.fn>, w_init=<function kaiming_normal>)[source]¶ Creates an RNN instance.
- Parameters
nstate (int) – number of hidden units.
nin (int) – number of input units.
nout (int) – number of output units.
activation (Callable) – actication function for hidden layer.
w_init (Callable) – weight initializer for RNN model weights.
-
__call__
(inputs, only_return_final=False)[source]¶ Forward pass through RNN.
- Parameters
inputs (Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
JaxArray
with dimensionsnum_steps, batch_size, vocabulary_size
.only_return_final – return only the last output if
True
, or all output otherwise.`
- Returns
Output tensor with dimensions
num_steps * batch_size, vocabulary_size
.- Return type
Union[jax.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
Variables and Modules¶
Objax relies on only two concepts: Variable and Module. While not a concept in itself, another commonly used object is the VarCollection.
A variable is an object that has a value, typically a JaxArray.
A module is an object that has variables and methods attached to it.
A VarCollection is a dictionary {name: variable} with some additional methods to facilitate variable passing and manipulation.
The following topics are covered in this guide:
Variable¶
Variables store a value. Contrary to most frameworks variables do not require users to give them names. Instead the names are inferred from Python. See Variable names section for more details.
Objax has four types of variables to handle the various situations encountered in machine learning:
TrainVar is a trainable variable. Its value cannot be directly modified, so as to maintain its differentiability.
StateVar is a state variable. It is not trainable and its value can be directly modified.
TrainRef is a reference to a TrainVar. It is a state variable used to change the value of a TrainVar, typically in the context of optimizers such as SGD.
RandomState is a random state variable. JAX made the design choice to require explicit random state manipulation, and this variable does that for you. This type of variable is used in objax.random.Generator.
Note
Some of variable types above take a reduce argument, this is discussed in the Parallelism topic.
TrainVar¶
An objax.TrainVar
is a trainable variable.
TrainVar variables are meant to keep the trainable weights of neural networks.
As such, when calling a gradient module such as objax.GradValues, their gradients are computed.
This constrasts with the other type of variable (state variables),
which do not have gradients.
A TrainVar is created by passing a JaxArray containing its initial value:
import objax
import jax.numpy as jn
v = objax.TrainVar(jn.arange(6, dtype=jn.float32))
print(v.value) # [0. 1. 2. 3. 4. 5.]
# It is not directly writable.
v.value += 1 # Raises a ValueError
# You can force assign it, however -as expected- all its uses before the assignment are not
# differentiable.
v.assign(v.value + 1)
print(v.value) # [1. 2. 3. 4. 5. 6.]
StateVar¶
An objax.StateVar
is a state variable. Unlike TrainVar variables, state variables are non-trainable.
StateVar variables are used for parameters that are manually/programmatically updated.
For example, when computing a running mean of the input to a module, a StateVar is used.
StateVars are created just like TrainVars, by passing a JaxArray containing their initial value:
import objax
import jax.numpy as jn
v = objax.StateVar(jn.arange(6, dtype=jn.float32))
print(v.value) # [0. 1. 2. 3. 4. 5.]
# It is directly writable.
v.value += 1
print(v.value) # [1. 2. 3. 4. 5. 6.]
# You can also assign to it, it's the same as doing v.value = ...
v.assign(v.value + 1)
print(v.value) # [2. 3. 4. 5. 6., 7.]
StateVar variables are ignored by gradient methods. Unlike TrainVar variables, their gradients are not computed.
You may be tempted to simply use Python values or numpy arrays directly since StateVars are programmatically updated.
StateVars are necessary.
They are needed to run on GPU since standard Python values and numpy arrays would not run on the GPU.
Another reason is objax.Jit
or objax.Parallel
only recognize Objax variables.
TrainRef¶
An objax.TrainRef
is a state variable which is used to keep a reference to a TrainVar.
TrainRef variables are used in optimizers since optimizers need to keep a reference to the
TrainVar they are meant to optimize.
TrainRef creation differs from the previously seen variables as it takes a TrainVar as its input:
import objax
import jax.numpy as jn
t = objax.TrainVar(jn.arange(6, dtype=jn.float32))
v = objax.TrainRef(t)
print(t.value) # [0. 1. 2. 3. 4. 5.]
# It is directly writable.
v.value += 1
print(v.value) # [1. 2. 3. 4. 5. 6.]
# It writes the TrainVar it references.
print(t.value) # [1. 2. 3. 4. 5. 6.]
# You can also assign to it, it's the same as doing v.value = ...
v.assign(v.value + 1)
print(v.value) # [2. 3. 4. 5. 6. 7.]
print(t.value) # [2. 3. 4. 5. 6. 7.]
TrainRef variables are ignored by gradient methods. Unlike TrainVar variables, their gradients are not computed.
Philosophically, one may ask why a TrainRef is needed to keep a reference to a TrainVar in an optimizer. Indeed, why not simply keep the TrainVar itself in the optimizer? The answer is that the optimizer is a module like any other (make sure to read the Module section first). As such, one could compute the gradient of the optimizer itself. It is for this situation that we need a TrainRef to distinguish between the optimizer’s own trainable variables (needed for its functionality) and the trainable variables of the neural network it is meant to optimize. It should be noted that most current optimizers do not have their own trainable variables, but we wanted to provide the flexibility needed for future research.
RandomState¶
A objax.RandomState
is a state variable which is used to handle the tracking of random number generator
states.
It is only used in objax.random.Generator
.
It is responsible for automatically creating different states when the code is run in parallel in multiple GPUs
(see objax.Parallel
) or in a vectorized way (see objax.Vectorize
).
This is necessary in order for random numbers to be truly random.
In the rare event that you want to use the same random seed in a multi-GPU or vectorized module, you can use a StateVar
to store the seed.
Here’s a simple example using the objax.random.Generator
API:
import objax
# Use default objax.random.DEFAULT_GENERATOR that transparently handles RandomState
print(objax.random.normal((2,))) # [ 0.19307713 -0.52678305]
# A subsequent call gives, as expected new random numbers.
print(objax.random.normal((2,))) # [ 0.00870693 -0.04888531]
# Make two random generators with same seeds
g1 = objax.random.Generator(seed=1337)
g2 = objax.random.Generator(seed=1337)
# Random numbers using g1
print(objax.random.normal((2,), generator=g1)) # [-0.3361883 -0.9903351]
print(objax.random.normal((2,), generator=g1)) # [ 0.5825488 -1.4342074]
# Random numbers using g1
print(objax.random.normal((2,), generator=g2)) # [-0.3361883 -0.9903351]
print(objax.random.normal((2,), generator=g2)) # [ 0.5825488 -1.4342074]
# The result are reproducible: we obtained the same random numbers with 2 generators
# using the same random seed.
You can also manually manipulate RandomState directly for the purpose of designing custom random numbers rules,
for example with forced correlation.
A RandomState has an extra method called objax.RandomState.split()
which lets it create n
new random
states.
Here’s a basic example of RandomState manipulation:
import objax
v = objax.RandomState(1) # 1 is the seed
print(v.value) # [0 1]
# We call v.split(1) to generate 1 new state, note that split also updates v.value
print(v.split(1)) # [[3819641963 2025898573]]
print(v.value) # [2441914641 1384938218]
# We call v.split(2) to generate 2 new states, again v.value is updated
print(v.split(2)) # [[ 622232657 209145368] [2741198523 2127103341]]
print(v.value) # [3514448473 2078537737]
Module¶
An objax.Module
is a simple container in which to store variables or other modules and on which to attach
methods that use these variables. ObJax uses the term module instead of class to avoid confusion with the Python term class.
The Module class only offers one method objax.Module.vars()
which returns all variables contained by the
module and its submodules in a VarCollection.
Warning
To avoid surprising unintended behaviors, vars()
won’t look for variables or modules in lists, dicts
or any structure that is not a Module
.
See [Scope of the Module.vars method] for how to handle lists in Objax.
Let’s start with a simple example: a module called Linear
, which does a simple matrix product and adds a bias
y = x.w + b
, where \(w\in\mathbb{R}^{m\times n}\) and \(b\in\mathbb{R}^n\):
import objax
import jax.numpy as jn
class Linear(objax.Module):
def __init__(self, m, n):
self.w = objax.TrainVar(objax.random.normal((m, n)))
self.b = objax.TrainVar(jn.zeros(n))
def __call__(self, x):
return x.dot(self.w.value) + self.b.value
This simple module can be used on a batch \(x\in\mathbb{R}^{d\times m}\) to compute the resulting value \(y\in\mathbb{R}^{d\times n}\) for batch size \(d\). Let’s continue our example by creating an actual of our module and running a random batch x through it:
f = Linear(4, 5)
x = objax.random.normal((100, 4)) # A (100 x 4) matrix of random numbers
y = f(x) # y.shape == (100, 5)
We can easily make a more complicated module that uses the previously defined module Linear:
class MiniNet(objax.Module):
def __init__(self, m, n, p):
self.f1 = Linear(m, n)
self.f2 = Linear(n, p)
def __call__(self, x):
y = self.f1(x)
y = objax.functional.relu(y) # Apply a non-linearity.
return self.f2(y)
# You can create as many functions as you want.
def another_function(self, x1, x2):
return self.f2(self.f1(x1) + x2)
f = MiniNet(4, 5, 6)
y = f(x) # y.shape == (100, 6)
x2 = objax.random.normal((100, 5)) # A (100 x 5) matrix of random numbers
another_y = f.another_function(x1, x2)
# You can also call internal parts for example to see intermediate values.
y1 = f.f1(x)
y2 = objax.functional.relu(y1)
y3 = f.f2(y2) # y3 == y
Variable names¶
Continuing on the previous example, let’s find what the name of the variables are.
We mentioned earlier that variable names are inferred from Python and not specified by the programmer.
The way their names are inferred is from the field names, such as self.w
.
This has the benefit of ensuring consistency: a variable has a single name, and it’s the name it is given in the Python
code.
Let’s inspect the names:
f = Linear(4, 5)
print(f.vars()) # print name, size, dimensions
# (Linear).w 20 (4, 5)
# (Linear).b 5 (5,)
# +Total(2) 25
f = MiniNet(4, 5, 6)
print(f.vars())
# (MiniNet).f1(Linear).w 20 (4, 5)
# (MiniNet).f1(Linear).b 5 (5,)
# (MiniNet).f2(Linear).w 30 (5, 6)
# (MiniNet).f2(Linear).b 6 (6,)
# +Total(4) 61
As you can see, the names correspond to the names of the fields in which the variables are kept.
Scope of the Module.vars method¶
The objax.Module.vars()
is meant to be simple and to remain simple.
With that in mind, we limited its scope: vars()
won’t look for variables or modules in lists, dicts or any
structure that is not a Module
.
This is to avoid surprising unintended behavior.
Instead we made the decision to create an explicit class objax.ModuleList
to store a list of variables and
modules.
ModuleList¶
The class objax.ModuleList
inherits from list
and behaves exactly like a list with the
difference that vars()
looks for variables and modules in it.
This class is very simple, and we invite you to look at it and use it for inspiration if you want to extend other
Python containers or design your own.
Here’s a simple example of its usage:
import objax
import jax.numpy as jn
class MyModule(objax.Module):
def __init__(self):
self.bad = [objax.TrainVar(jn.zeros(1)),
objax.TrainVar(jn.zeros(2))]
self.good = objax.ModuleList([objax.TrainVar(jn.zeros(3)),
objax.TrainVar(jn.zeros(4))])
print(MyModule().vars())
# (MyModule).good(ModuleList)[0] 3 (3,)
# (MyModule).good(ModuleList)[1] 4 (4,)
# +Total(2) 7
VarCollection¶
The Module.vars
method returns an objax.VarCollection
.
This class is a dictionary that maps names to variables.
It has some additional methods and some modified behaviors specifically for variable manipulation.
In most cases, you won’t need to use the more advanced methods such as __iter__
, tensors
and
assign
, but this is an in-depth topic.
Let’s take a look at some of them through an example:
import objax
import jax.numpy as jn
class Linear(objax.Module):
def __init__(self, m, n):
self.w = objax.TrainVar(objax.random.normal((m, n)))
self.b = objax.TrainVar(jn.zeros(n))
m1 = Linear(3, 4)
m2 = Linear(4, 5)
# First, as seen before, we can print the contents with print() method
print(m1.vars())
# (Linear).w 12 (3, 4)
# (Linear).b 4 (4,)
# +Total(2) 16
# A VarCollection is really a dictionary
print(repr(m1.vars()))
# {'(Linear).w': <objax.variable.TrainVar object at 0x7fb5e47c0ad0>,
# '(Linear).b': <objax.variable.TrainVar object at 0x7fb5ec017890>}
Combining multiple VarCollections is done by using addition:
all_vars = m1.vars('m1') + m2.vars('m2')
print(all_vars)
# m1(Linear).w 12 (3, 4)
# m1(Linear).b 4 (4,)
# m2(Linear).w 20 (4, 5)
# m2(Linear).b 5 (5,)
# +Total(4) 41
# We had to specify starting names for each of the var collections since
# they have variables with the same name. Had we not, a name collision would
# have occured since VarCollection is a dictionary that maps names to variables.
m1.vars() + m2.vars() # raises ValueError('Name conflicts...')
Weight sharing¶
It’s a common technique in machine learning to share some weights.
However, it is important not to apply gradients twice or more to shared weights.
This is handled automatically by VarCollection and its __iter__
method described in the next section.
Here’s a simple weight sharing example where we simply refer to the same module twice under different names:
# Weight sharing
shared_vars = m1.vars('m1') + m1.vars('m1_shared')
print(shared_vars)
# m1(Linear).w 12 (3, 4)
# m1(Linear).b 4 (4,)
# m1_shared(Linear).w 12 (3, 4)
# m1_shared(Linear).b 4 (4,)
# +Total(4) 32
VarCollection.__iter__¶
Deduplication is handled automatically by the VarCollection default iterator objax.VarCollection.__iter__()
.
Following up on the weight sharing example above, the iterator only returns each variable once:
list(shared_vars) # [<objax.variable.TrainVar>, <objax.variable.TrainVar>]
VarCollection.tensors¶
You can collect all the values (JaxArray) for all the variables with objax.VarCollection.tensors()
, again in a
deduplicated manner:
shared_vars.tensors() # DeviceArray([[-0.1441347...]), DeviceArray([0...], dtype=float32)]
VarCollection.assign¶
The last important method objax.VarCollection.assign()
lets you assign a tensor list to all the
VarCollection’s (deduplicated) variables at once:
shared_vars.tensors() # DeviceArray([[-0.1441347...]), DeviceArray([0...], dtype=float32)]
# The following increments all the variables.
shared_vars.assign([x + 1 for x in shared_vars.tensors()])
shared_vars.tensors() # DeviceArray([[0.8558653...]), DeviceArray([1...], dtype=float32)]
Understanding Gradients¶
Pre-requisites: Variables and Modules.
In this guide we discuss how to compute and use gradients in various situations and will cover the following topics:
Examples will illustrate the following cases:
Describe the
objax.GradValues
Module.Show how to write your own optimizer from scratch.
How to write a basic training iteration.
How to handle complex gradients such as in Generative Adversarial Networks (GANs) or meta-learning.
Explain potential optimizations in the presence of constants.
Computing gradients¶
JAX, and therefore Objax, differ from most frameworks in how gradients are represented. Gradients in JAX are represented as functions since everything in JAX is a function. In Objax, however, they are represented as module objects.
Gradient as a module¶
In machine learning, for a function \(f(X; \theta)\), it is common practice to separate the inputs \(X\) from the parameters \(\theta\). Mathematically, this is captured by using a semi-colon to semantically separate one group of arguments from another.
In Objax, we represents this semantic distinction through an object objax.Module
:
the module parameters \(\theta\) are object attributes of the form
self.w, ...
the inputs \(X\) are arguments to the methods such as
def __call__(self, x1, x2, ...):
The gradient of a function \(f(X; \theta)\) w.r.t to \(Y\subseteq X, \phi\subseteq\theta\) is a function
The gradient function is also a module since the same semantic distinction can be made as in f between inputs \(X\) and parameters \(\theta\). Meanwhile \(Y, \phi\) are constants of g (which inputs and which variables to compute the gradient of). In practice, \(Y, \phi\) are also implemented as object attributes.
The direct benefit of such a decision is that gradient manipulation is very easy and explicit: in fact it follows the standard mathematical formulation of gradients. While this demonstration may seem abstract, we are going to see in examples how simple it turns out to be.
A simple example¶
Let’s look at what gradient as a module looks like through a simple example:
import objax
m = objax.nn.Linear(2, 3)
def loss(x, y):
return ((m(x) - y) ** 2).mean()
# Create Module that returns a tuple (g, v):
# g is the gradient of the loss
# v is the value of the loss
gradient_loss = objax.GradValues(loss, m.vars())
# Make up some fake data
x = objax.random.normal((100, 2))
y = objax.random.normal((100, 3))
# Calling the module gradient_loss returns the actual g, v for (x, y)
g, v = gradient_loss(x, y)
print(v, '==', loss(x, y)) # [DeviceArray(2.7729921, dtype=float32)] == 2.7729921
print(g) # A list of tensors (gradients of variables in module m)
As stated, gradient_loss
is a module instance and has variables.
Its variables are simply the ones passed to objax.GradValues
, we can verify it:
print(gradient_loss.vars())
# (GradValues)(Linear).b 3 (3,)
# (GradValues)(Linear).w 6 (2, 3)
# +Total(2) 9
# These variables are from
print(m.vars())
# (Linear).b 3 (3,)
# (Linear).w 6 (2, 3)
# +Total(2) 9
Let’s be clear: These are the exact same variables, not copies.
This is an instance of weight sharing, m
and gradient_loss
share the same weights.
Loss optimization¶
Gradients are useful to minimize or maximize losses. This can be done using Stochastic Gradient Descent (SGD) with the following steps, for a network with weights \(\theta\) and a learning rate \(\mu\):
At iteration \(t\), take a batch of data \(x_t\)
Compute the gradient \(g_t=\nabla loss(x_t)\)
Update the weights \(\theta_t = \theta_{t-1} - \mu\dot g_t\)
Goto 1
Objax already has a library of optimizers: the objax.optimizer package. However we are going to create our own to demonstrate how it works with gradients. First let’s recall that everything is a Module (or a function) in Objax. In this case, SGD will be a module, we will want to store the list of variables on which to do gradient descent. And the function of the module will take the gradients as inputs and apply them to the variables.
Read first the part about Variables and Modules if you haven’t done so yet. Let’s get started:
import objax
class SGD(objax.Module):
def __init__(self, variables: objax.VarCollection):
self.refs = objax.ModuleList(objax.TrainRef(x)
for x in variables.subset(objax.TrainVar))
def __call__(self, lr: float, gradients: list):
for v, g in zip(self.refs, gradients):
v.value -= lr * g
In short, self.refs
keeps a list of references to the network trainable variables TrainVar
.
When calling the __call__
method, the values of the variables gets updated by the SGD method.
From this we can demonstrate the training of a classifier:
import objax
# SGD definition code from above.
my_classifier = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu,
objax.nn.Linear(3, 4)])
opt = SGD(my_classifier.vars())
def loss(x, labels):
logits = my_classifier(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, my_classifier.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, g) # Apply SGD
return v # Return loss value
# Observe that the gradient contains the variables of the model (weight sharing)
print(gradient_loss.vars())
# (GradValues)(Sequential)[0](Linear).b 3 (3,)
# (GradValues)(Sequential)[0](Linear).w 6 (2, 3)
# (GradValues)(Sequential)[2](Linear).b 4 (4,)
# (GradValues)(Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
# At this point you can simply call train on your training data and pass the learning rate.
# The call will do a single step minimization the loss following the SGD method on your data.
# Repeated calls (through various batches of data) will minimize the loss on your data.
x = objax.random.normal((100, 2))
labels = objax.random.randint((100,), low=0, high=4)
labels = objax.functional.one_hot(labels, 4)
print(train(x, labels, lr=0.01))
# and so on...
# See examples section for real examples.
Returning multiple values for the loss¶
Let’s say we want to add weight decay and returning the individual components of the loss (cross-entropy, weight decay). The loss function can return any number of values or even structures such as dicts or list. Only the first returned value is used to compute the gradient, the others are returned as the loss value.
Continuing on our example, less create a new loss that returns its multiple components:
def losses(x, labels):
logits = my_classifier(x)
loss_xe = objax.functional.loss.cross_entropy_logits(logits, labels).mean()
loss_wd = sum((v.value ** 2).sum() for k, v in my_classifier.vars().items() if k.endswith('.w'))
return loss_xe + 0.0002 * loss_wd, loss_xe, loss_wd
gradient_losses = objax.GradValues(losses, my_classifier.vars())
print(gradient_losses(x, labels)[1])
# (DeviceArray(1.7454103, dtype=float32),
# DeviceArray(1.7434813, dtype=float32),
# DeviceArray(9.645493, dtype=float32))
Or one might prefer to return a dict to keep things organized:
def loss_dict(x, labels):
logits = my_classifier(x)
loss_xe = objax.functional.loss.cross_entropy_logits(logits, labels).mean()
loss_wd = sum((v.value ** 2).sum() for k, v in my_classifier.vars().items() if k.endswith('.w'))
return loss_xe + 0.0002 * loss_wd, {'loss/xe': loss_xe, 'loss/wd': loss_wd}
gradient_loss_dict = objax.GradValues(loss_dict, my_classifier.vars())
print(gradient_loss_dict(x, labels)[1])
# (DeviceArray(1.7454103, dtype=float32),
# {'loss/wd': DeviceArray(9.645493, dtype=float32),
# 'loss/xe': DeviceArray(1.7434813, dtype=float32)})
Input gradients¶
When computing gradients it’s sometimes useful to compute the gradients for some or all the inputs of the network.
For example, such functionality is needed for adversarial training or gradient penalties in GANs.
This can be easily achieved using the input_argnums
argument of objax.GradValues
,
here’s an example:
# Compute the gradient for my_classifier variables and for the first input of the loss:
gradient_loss_v_x = objax.GradValues(loss, my_classifier.vars(), input_argnums=(0,))
print(gradient_loss_v_x(x, labels)[0])
# g = [gradient(x)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]
# Compute the gradient for my_classifier variables and for the second input of the loss:
gradient_loss_v_y = objax.GradValues(loss, my_classifier.vars(), input_argnums=(1,))
print(gradient_loss_v_y(x, labels)[0])
# g = [gradient(labels)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]
# Compute the gradient for my_classifier variables and for all the inputs of the loss:
gradient_loss_v_xy = objax.GradValues(loss, my_classifier.vars(), input_argnums=(0, 1))
print(gradient_loss_v_xy(x, labels)[0])
# g = [gradient(x), gradient(labels)] + [gradient(v) for v in classifier.vars().subset(TrainVar)]
# You can also compute the gradients from the inputs alone
gradient_loss_xy = objax.GradValues(loss, None, constants=my_classifier.vars(), input_argnums=(0, 1))
print(gradient_loss_xy(x, labels)[0])
# g = [gradient(x), gradient(labels)]
# The order of the inputs matters, using input_argnums=(1, 0) instead of (0, 1)
gradient_loss_yx = objax.GradValues(loss, None, constants=my_classifier.vars(), input_argnums=(1, 0))
print(gradient_loss_yx(x, labels)[0])
# g = [gradient(labels), gradient(x)]
Gradients of a subset of variables¶
When doing more complex optimizations, one might want to temporarily treat a part of a network as constant.
This is achieved by simply passing only the variables you want the gradient of to objax.GradValues
.
This is useful for example in GANs where one has to opimize the discriminator and the generator
networks separately.
Continuing our example:
all_vars = my_classifier.vars()
print(all_vars)
# (Sequential)[0](Linear).b 3 (3,)
# (Sequential)[0](Linear).w 6 (2, 3)
# (Sequential)[2](Linear).b 4 (4,)
# (Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
Let’s say we want to freeze the second Linear layer by treating it as constant:
# We create two VarCollection
vars_train = objax.VarCollection((k, v) for k, v in all_vars.items() if '[2](Linear)' not in k)
print(vars_train)
# (Sequential)[0](Linear).b 3 (3,)
# (Sequential)[0](Linear).w 6 (2, 3)
# +Total(2) 9
# We define a gradient function that ignores variables not in vars_train
gradient_loss_freeze = objax.GradValues(loss, vars_train)
print(gradient_loss_freeze(x, labels)[0])
# As expected, we now have two gradient arrays, corresponding to vars_train.
# [DeviceArray([0.19490579, 0.12267624, 0.05770121], dtype=float32),
# DeviceArray([[-0.21900907, -0.10813318, -0.05385721],
# [ 0.12701261, -0.03145855, -0.04397186]], dtype=float32)]
Higher-order gradients¶
Finally one might want to optimize a loss that has a gradient in a gradient, for example let’s consider the following nested loss that relies on another loss \(\mathcal{L}=\texttt{loss}\):
Implementing this in Objax remains simple, one just applies the formula verbatim.
In the following example, for the loss \(\mathcal{L}\) we picked a cross-entropy loss but we could have picked
any other loss since nested_loss
is independent of the choice of loss
:
train_vars = my_classifier.vars().subset(objax.TrainVar)
def loss(x, labels):
logits = my_classifier(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, train_vars)
def nested_loss(x1, y1, x2, y2, mu):
# Save original network variable values
original_values = train_vars.tensors()
# Apply gradient from loss(x2, y2)
for v, g in zip(train_vars, gradient_loss(x2, y2)[0]):
v.assign(v.value - mu * g)
# Compute loss(x1, y1)
loss_x1y1 = loss(x1, y1)
# Undo the gradient from loss(x2, y2)
for v, val in zip(train_vars, original_values):
v.assign(val)
# Return the loss
return loss_x1y1
gradient_nested_loss = objax.GradValues(nested_loss, train_vars)
# Run with mock up data, note it's only example because the loss is not for batch data.
x1 = objax.random.normal((1, 2))
y1 = objax.functional.one_hot(objax.random.randint((1,), low=0, high=4), 4)
x2 = objax.random.normal((1, 2))
y2 = objax.functional.one_hot(objax.random.randint((1,), low=0, high=4), 4)
print(gradient_nested_loss(x1, y1, x2, y2, 0.1))
# (gradients, loss), where the gradients are 4 tensors of the same shape as the layer variables.
# (Sequential)[0](Linear).b 3 (3,)
# (Sequential)[0](Linear).w 6 (2, 3)
# (Sequential)[2](Linear).b 4 (4,)
# (Sequential)[2](Linear).w 12 (3, 4)
Generally speaking, it is discouraged to use objax.TrainVar.assign()
unless you know what you are doing.
This is precisely a situation of one knowing what they are doing and it’s perfectly fine to use assign here.
The reason assign
is generally discouraged is to avoid accidental bugs by overwriting a trainable variable.
On a final note, by observing that the weight update is invertible in the code above, the nested loss can be simplified to:
def nested_loss(x1, y1, x2, y2, mu):
# Compute the gradient for loss(x2, y2)
g_x2y2 = gradient_loss(x2, y2)[0]
# Apply gradient from loss(x2, y2)
for v, g in zip(train_vars, g_x2y2):
v.assign(v.value - mu * g)
# Compute loss(x1, y1)
loss_x1y1 = loss(x1, y1)
# Undo the gradient from loss(x2, y2)
for v, g in zip(train_vars, g_x2y2):
v.assign(v.value + mu * g)
# Return the loss
return loss_x1y1
Local gradients¶
In even more advanced situations, such as meta-learning, it can be desirable to have even more control over gradients.
In the above example, nested_loss
can accept vectors or matrices for its inputs x1, y1, x2, y2
.
In case of matrices, the nested_loss
is computed as:
As a more advanced example, let’s reproduce the loss from Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks in a batch form. It is expressed as follows:
Using the previously defined nested_loss
, we can apply vectorization
(see Vectorization for details) on it.
In doing so we will create a module vec_nested_loss
that computes nested_loss
for all the entries
in the batches in X1, Y1, X2, Y2
:
# Make vec_nested_loss a Module that calls nested_loss on one batch entry at a time
vec_nested_loss = objax.Vectorize(nested_loss, gradient_loss.vars(),
batch_axis=(0, 0, 0, 0, None))
# The final loss just calls vec_nested_loss and returns the mean of the losses
def nested_pairwise_loss(X1, Y1, X2, Y2, mu):
return vec_nested_loss(X1, Y1, X2, Y2, mu).mean()
# Just like any simpler loss, we can compute its gradient.
gradient_nested_pairwise_loss = objax.GradValues(nested_pairwise_loss, vec_nested_loss.vars())
# Run with mock up data, note it's only example because the loss is not for batch data.
X1 = objax.random.normal((100, 2))
Y1 = objax.functional.one_hot(objax.random.randint((100,), low=0, high=4), 4)
X2 = objax.random.normal((100, 2))
Y2 = objax.functional.one_hot(objax.random.randint((100,), low=0, high=4), 4)
print(gradient_nested_pairwise_loss(X1, Y1, X2, Y2, 0.1))
Have fun!
Compilation and Parallelism¶
In this section we discuss the concepts of code compilation and parallelism typically for the purpose of accelerated performance. We’ll cover the following subtopics:
Just-In-Time (JIT) Compilation is a compilation of the code on the first time it’s executed with the goal of speeding up subsequent runs.
Parallelism runs operations on multiple devices (for example multiple GPUs).
Vectorization can be seen as batch-level parallelization, running an operation on a batch in parallel.
JIT Compilation¶
objax.Jit
is a Module that takes a module or a function and compiles it for faster performance.
As a simple starting example, let’s jit a module:
import objax
net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
jit_net = objax.Jit(net)
x = objax.random.normal((100, 2))
print(((net(x) - jit_net(x)) ** 2).sum()) # 0.0
You can also jit a function, in this case you must pass the variables it uses:
jit_func = objax.Jit(lambda x: objax.functional.softmax(net(x)), net.vars())
print(((objax.functional.softmax(net(x)) - jit_func(x)) ** 2).sum())
In terms of performance, on this small example there’s a significant gain in speed, numbers vary depending on hardware present in your computer and what code is being jitted:
from time import time
t0 = time(); y = net(x); print(time() - t0) # 0.005...
t0 = time(); y = jit_net(x); print(time() - t0) # 0.001...
As mentioned earlier, jit_net
is a module instance, it’s sharing the variables with the module net
,
we can verify it:
print(net.vars())
# (Sequential)[0](Linear).b 3 (3,)
# (Sequential)[0](Linear).w 6 (2, 3)
# (Sequential)[2](Linear).b 4 (4,)
# (Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
print(jit_net.vars())
# (Jit)(Sequential)[0](Linear).b 3 (3,)
# (Jit)(Sequential)[0](Linear).w 6 (2, 3)
# (Jit)(Sequential)[2](Linear).b 4 (4,)
# (Jit)(Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
# We can verify that jit_func also shares the same variables
print(jit_func.vars())
# (Jit)(Sequential)[0](Linear).b 3 (3,)
# (Jit)(Sequential)[0](Linear).w 6 (2, 3)
# (Jit)(Sequential)[2](Linear).b 4 (4,)
# (Jit)(Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
Note that we only verified that the variables names and sizes were the same (or almost the same since the variables
in Jit are prefixed with (Jit)
).
Let’s now verify that the weights are indeed shared by modifying the weights:
net[-1].b.assign(net[-1].b.value + 1)
print(((net(x) - jit_net(x)) ** 2).sum()) # 0.0
# Both net(x) and jit_net(x) were affected in the same way by the change
# since the weights are shared.
# You can also inspect the values print(net(x)) for more insight.
A realistic case: Fully jitted training step¶
Let’s write a classifier training op, this is very similar to example shown in Loss optimization. We are going to define a model, an optimizer, a loss and a gradient:
import objax
m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())
def loss(x, labels):
logits = m(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, m.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, g) # Apply SGD
return v # Return loss value
# It's better to jit the top level call to allow internal optimizations.
train_jit = objax.Jit(train, gradient_loss.vars() + opt.vars())
Note that we passed to Jit all the vars that were used in train
.
We passed gradient_loss.vars() + opt.vars()
.
Why didn’t we pass m.vars() + gradient_loss.vars() + opt.vars()
?
We could and it’s perfectly fine to do so, but keep in mind that gradient_loss
is itself a module which shares
the weights of m
and consequently m.vars()
is already included in gradient_loss.vars()
.
Static arguments¶
Static arguments are arguments that are treated as static (compile-time constant) in the jitted function. Boolean arguments, numerical arguments used in comparisons (resulting in a bool), strings must be marked as static.
Calling the jitted function with different values for these constants will trigger recompilation. As a rule of thumb:
Good static arguments: training (boolean), my_mode (int that can take only a few values), …
Bad static arguments: training_step (int that can take a lot of values)
Let’s look at an example with BatchNorm which takes a training argument:
import objax
net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)])
jit_net_static = objax.Jit(lambda x, training: net(x, training=training), net.vars(),
static_argnums=(1,))
# Note the static_argnums=(1,) which indicates that argument 1 (training) is static.
x = objax.random.normal((100, 2))
print(((net(x, training=False) - jit_net_static(x, False)) ** 2).sum()) # 0.0
What happens if you don’t use static_argnums
?
jit_net = objax.Jit(lambda x, training: net(x, training=training), net.vars())
y = jit_net(x, False)
# Traceback (most recent call last):
# File <...>
# <long stack trace>
# jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
# Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
# See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
# Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/1)>
To cut a long story short: when compiling boolean inputs must be made static.
For more information, please refer to jax.jit which is the API Objax uses under the hood.
Constant optimization¶
As seen previously, objax.Jit
takes a variables
argument to specify the variables of a function
or of a module that Jit is compiling.
If a variable is not passed to Jit it will be treated as a constant and will be optimized away.
Warning
A jitted module will not see any change made to a constant. A constant is not expected to change since it is supposed to be… constant!
A simple constant optimization example:
import objax
m = objax.nn.Linear(3, 4)
# Pass an empty VarCollection to signify to Jit that m has no variable.
jit_constant = objax.Jit(m, objax.VarCollection())
x = objax.random.normal((10, 3))
print(((m(x) - jit_constant(x)) ** 2).sum()) # 0.0
# Modify m (which was supposed to be constant!)
m.b.assign(m.b.value + 1)
print(((m(x) - jit_constant(x)) ** 2).sum()) # 40.0
# As expected jit_constant didn't see the change.
Warning
The XLA backend (the interface to the hardware) will do the constant optimization and may take a long time and a lot of memory due to compilation, often with very little gain in final performance, if any.
Parallelism¶
Note
If you don’t have multiple devices, you can simulate them on CPU by starting python with the following command:
CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=8 python
Alternatively you can do it in Python directly by inserting this snippet before importing Objax:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
objax.Parallel
provides a way to distribute computations across multi-GPU (or TPU).
It also performs JIT under the hood and its API shares a lot with objax.Jit
:
It takes a function to be compiled, a VarCollection
as well as a
static_argnums parameters which all behave the same as in Jit.
However it also takes specific arguments for the task of handling parallelism which we are going to introduce.
When running a parallelized a function f
on a batch \(x\) of shape \((n, ...)\) on \(d\) devices,
the following steps happen:
The batch \(x\) is divided into \(d\) sub-batches \(x_i\) of shape \((n/d, ...)\) for \(i\in\{0, ..., d-1\}\)
Each sub-batch \(x_i\) is passed to
f
and ran on device \(i\) in parallel.The results are collected as output sub-subatches \(y_i=f(x_i)\)
The outputs \(y_i\) are represented as a single tensor \(y\) of shape \((d, ...)\)
The final output is obtained by calling the
reduce
function on \(y\):out = reduce(y)
.
With this in mind, we can now detail the additional arguments of objax.Parallel
:
reduce
: a function that aggregates the output results from each GPU/TPU.axis_name
: is the name of the device dimension which we referred to as \(d\) earlier. By default, it is called'device'
.
Let’s illustrate this with a simple example with the parallelization of a module (para_net
) and of a function
(para_func
):
# This code was run on 8 simulated devices
import objax
net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_net = objax.Parallel(net)
para_func = objax.Parallel(lambda x: net(x) + 1, net.vars())
# A batch of mockup data
x = objax.random.normal((96, 3))
# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
y = para_net(x)
z = para_func(x)
print(((net(x) - y) ** 2).sum()) # 8.90954e-14
print(((net(x) - (z - 1)) ** 2).sum()) # 4.6487814e-13
We can also show the parallel version of A realistic case: Fully jitted training step, highlighted are the changes from the jitted version:
import objax
m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())
def loss(x, labels):
logits = m(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, m.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, objax.funcational.parallel.pmean(g)) # Apply averaged gradients
return objax.funcational.parallel.pmean(v) # Return averaged loss value
# It's better to parallelize the top level call to allow internal optimizations.
train_para = objax.Parallel(train, gradient_loss.vars() + opt.vars(), reduce=lambda x:x[0])
Let’s study the concepts introduced in this example in more details.
Variable replication¶
Variable replication copies the variables to multiple devices’ own memory.
It is necessary to do variable replication before calling a parallelized module or function.
Variable replication is done through objax.VarCollection.replicate()
which is a context manager.
One could go further and creating their own replication, this is not covered here but the source of replicate
is
rather simple and a good starting point.
Here is a detailed example:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
m = objax.ModuleList([objax.TrainVar(jn.arange(5))])
# We use "repr" to see the whole type information.
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
with m.vars().replicate():
# In the scope of the with-statement, the variables are replicated to all devices.
print(repr(m[0].value))
# ShardedDeviceArray([[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]], dtype=int32)
# SharedDeviceArray is a DeviceArray across multiple devices.
# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value)) # DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
Something interesting happened: the value of m[0]
was initially of type integer but it became a float by the
end.
This is due to the reduction that follows a replication.
By default, the reduction method takes the average of the copies on each device.
And the average of multiple integer values is casted automatically to a float.
You can customize the variable reduction, this is not something one typically would need to do but it’s available for advanced users nonetheless:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
m = objax.ModuleList([objax.TrainVar(jn.arange(5), reduce=lambda x: x[0]),
objax.TrainVar(jn.arange(5), reduce=lambda x: x.sum(0)),
objax.TrainVar(jn.arange(5), reduce=lambda x: jn.stack(x))])
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[2].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
with m.vars().replicate():
pass
# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value)) # DeviceArray([ 0, 8, 16, 24, 32], dtype=int32)
print(repr(m[2].value)) # DeviceArray([[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]], dtype=int32)
Output aggregation¶
Similarly the ouptput \(y\) of parallel call is reduced using the reduce
argument.
The first dimension \(d\) of \(y\) is the device dimension and its name comes from the axis_name
argument while by default is simply "device"
.
Again, let’s look at a simple example:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_none = objax.Parallel(net, reduce=lambda x: x)
para_first = objax.Parallel(net, reduce=lambda x: x[0])
para_concat = objax.Parallel(net, reduce=lambda x: jn.concatenate(x))
para_average = objax.Parallel(net, reduce=lambda x: x.mean(0))
# A batch of mockup data
x = objax.random.normal((96, 3))
# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
print(para_none(x).shape) # (8, 12, 4)
print(para_first(x).shape) # (12, 4)
print(para_concat(x).shape) # (96, 4) - This is the default setting
print(para_average(x).shape) # (12, 4)
In the example above, the batch x (of size 96) was divided into 8 batches of size 12 by the parallel call. Each of these batches was processed on its own device. The final value was then reduced using the provided reduce method.
para_none
didn’t do any reduction, it just returned the value it was given, as expected is shape is(devices, batch // devices, ...)
.para_first
andpara_mean
took either the first entry or the average over dimension 0, resulting in a shape(batch // devices, ...)
.para_concat
concatenated all the values resulting in a shape of(batch, ...)
.
Synchronized computations¶
So far, we only considered the case where all the devices were acting on their own, unaware of others’ existence. It’s commonly desirable for devices to communicate with each other.
For example, when training a model, for efficiency one would want the optimizer to update the weights on all the devices at the same time. To achieve this, we would like the gradients to be computed for each sub-batch on the device, and then averaged across all devices.
The good news is it is very easy to do, there are a set of predefined primitives that can be found in
objax.functional.parallel
which are the direct equivalent of single device primitives:
objax.functional.parallel.pmax()
is the multi-device equivalent ofjax.numpy.max
objax.functional.parallel.pmean()
is the multi-device equivalent ofjax.numpy.mean
and so on…
Recalling the code for the parallelized train operation:
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, objax.funcational.parallel.pmean(g)) # Apply averaged gradients
return objax.funcational.parallel.pmean(v) # Return averaged loss value
The train function is called on each device in parallel.
The objax.funcational.parallel.pmean(g)
averages the gradients g
on all devices.
Then on each device, the optimizer applies the averaged gradient to the local weight copy.
Finally the average loss is returned objax.funcational.parallel.pmean(v)
.
Vectorization¶
objax.Vectorize
is the module responsible for code vectorization.
Vectorization can be seen as a parallelization without knowledge of the devices available.
On a single GPU, vectorization parallelizes the execution in concurrent threads.
It can be combined with objax.Parallel
resulting in multi-GPU multi-threading!
Vectorization can also be done on a single CPU.
A typical example of CPU vectorization could data pre-processing or augmentation.
In its simplest form vectorization applies a function to the elements of a batch concurrently.
objax.Vectorize
takes a module or a function f
and vectorizes it.
Similarly to Jit
and Parallel
you must specify the variables used by the function.
Finally batch_axis
is used to say which axis should be considered as the batch axis for each input
argument of f
.
For values with no batch axis, for example when passing a value to be shared by all the calls to
the function f
, set its batch axis to None
to broadcast it.
Let’s clarify this with a simple example:
# Randomly reverse rows in a batch.
import objax
import jax.numpy as jn
class RandomReverse(objax.Module):
"""Randomly reverse a single vector x and add a value y to it."""
def __init__(self, keygen=objax.random.DEFAULT_GENERATOR):
self.keygen = keygen
def __call__(self, x, y):
r = objax.random.randint([], 0, 2, generator=self.keygen)
return x + y + r * (x[::-1] - x), r, y
random_reverse = RandomReverse()
vector_reverse = objax.Vectorize(random_reverse, batch_axis=(0, None))
# vector_reverse takes two arguments (just like random_reverse), we're going to pass:
# - a matrix x for the first argument, interpreted as a batch of vectors (batch_axis=0).
# - a value y for the second argument, interpreted as a broadcasted value (batch_axis=None).
# Test it on some mock up data
x = jn.arange(20).reshape((5, 4))
print(x) # [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]
# [12 13 14 15]
# [16 17 18 19]]
objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, 1)
print(r) # [0 1 0 1 1] - whether a row was reversed
print(y) # [1 1 1 1 1] - the brodacasted input y
print(z) # [[ 1 2 3 4]
# [ 8 7 6 5]
# [ 9 10 11 12]
# [16 15 14 13]
# [20 19 18 17]]
# Above we added a single constant (y=1)
# We can also add a vector y=(-2, -1, 0, 1)
objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, jn.array((-2, -1, 0, 1)))
print(r) # [0 1 0 1 1] - whether a row was reversed
print(y) # [[-2 -1 0 1] - the brodacasted input y
# [-2 -1 0 1]
# [-2 -1 0 1]
# [-2 -1 0 1]
# [-2 -1 0 1]]
print(z) # [[-2 0 2 4]
# [ 5 5 5 5]
# [ 6 8 10 12]
# [13 13 13 13]
# [17 17 17 17]]
Computing weights gradients par batch entry¶
This is a more advanced example, conceptually it is similar to what’s powering differential privacy gradients:
import objax
m = objax.nn.Linear(3, 4)
def loss(x, y):
return ((m(x) - y) ** 2).mean()
gv = objax.GradValues(loss, m.vars())
single_gradients = objax.Vectorize(lambda x, y: gv(x, y)[0], # Only interested in gradient
gv.vars(), # f uses variables from gv
batch_axis=(0, 0)) # Batch is dimension of x and y
# Mock some data
x = objax.random.normal((10, 3))
y = objax.random.normal((10, 4))
# Compute standard gradients
print([v.shape for v in gv(x, y)[0]]) # [(4,), (3, 4)]
# Compute per batch entry gradients
print([v.shape for v in single_gradients(x, y)]) # [(10, 4), (10, 3, 4)]
As expected, we obtained as many gradients for each of the network’s weights as there are entries in the batch.
Loading and Saving¶
Being able to load and save the weights of a model, or a model itself (e.g. the weights and the function itself) is essential for machine learning purposes. In this section we describe how to load/save the weights and also how to save an entire model. Furthermore we discuss how to keep multiple saves, a concept known as checkpointing, which is typically used for resuming interrupted training sessions.
Saving and loading model weights¶
Loading and saving is done on objax.VarCollection
objects.
Such objects are returned by the objax.Module.vars()
method or can be constructed manually if one wishes to.
The saving method uses
numpy .npz format which in essence stores
tensors in a zip file.
Here’s a simple example:
import objax
# Let's pretend we have a neural network net and we want to save it.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])
# Saving only takes one line.
objax.io.save_var_collection('net.npz', net.vars())
# Let's modify the bias of the Linear layer
net[0].b.assign(net[0].b.value + 1)
print(net[0].b.value.sum()) # 1.0
# Loading
objax.io.load_var_collection('net.npz', net.vars())
print(net[0].b.value.sum()) # 0.0
Note that in the example above we used a filename to specify where to save the weights. These APIs also accept a file descriptor, so another way to save would be:
# Saving with file descriptor
with open('net.npz', 'wb') as f:
objax.io.save_var_collection(f, net.vars())
# Loading with file descriptor
with open('net.npz', 'rb') as f:
objax.io.load_var_collection(f, net.vars())
Note
The advantage of using a filename instead of file handle is that data will be written to a temporary file first and temporary file will be renamed to provided filename only after all data has been written. In the event of the program being killed, this prevents from having truncated files. When using a file descriptor the code does not have this protection. File descriptors are typically used for unit testing.
Custom saving and loading¶
You can make your own saving and loading functions easily.
In essence saving has to store pairs of (name, numpy array)
, loading must provide a numpy array for the
variables of the objax.VarCollection
.
The only gotcha to pay attention to is to avoid saving duplicated information such as shared weights under different
names or variable references TrainRef.
Since the code for loading and saving is very concise, simply looking at it is the best example.
Checkpointing¶
Checkpointing can be defined as saving a neural network weights during training. Often checkpointing keeps multiple saves, each from different training steps. For space reasons, it’s common to keep only the latest-k saves. Checkpointing can be used for a variety of purposes:
Resuming training after the program was interrupted.
Keeping multiple copies of the network for weight averaging strategies.
Objax provides a simple checkpointing interface called objax.io.Checkpoint
, here’s an example:
import objax
# Let's pretend we have a neural network net and we want to save it.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])
# This time we use the Checkpoint class
ckpt = objax.io.Checkpoint(logdir='save_folder', keep_ckpts=5)
# Saving
ckpt.save(net.vars(), epoch=1)
net[0].b.assign(net[0].b.value + 1)
ckpt.save(net.vars(), epoch=2)
# Restoring
ckpt.restore(net.vars(), idx=1) # net[0].b.value = (0,)
ckpt.restore(net.vars(), idx=2) # net[0].b.value = (1,)
# When no epoch is specified use latest checkpoint (e.g. 2 here)
idx, file = ckpt.restore(net.vars())
print(idx, file) # 2 save_folder/ckpt/0000000002.npz
Customized checkpointing¶
The objax.io.Checkpoint
class has some constants that allow to customize its behavior.
You can redefine them for example creating a child class that inherits from Checkpoint.
The fields are the following:
class Checkpoint:
DIR_NAME: str = 'ckpt'
FILE_MATCH: str = '*.npz'
FILE_FORMAT: str = '%010d.npz'
LOAD_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(load_var_collection)
SAVE_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(save_var_collection)
This lets you change the folder name where the checkpoints are saved, the file extension and the numbering format.
If you have your own saving and loading functions, you can also replace them.
Remember to wrap them in staticmethod
since they don’t depend on the Checkpoint class itself.
Saving a module¶
Warning
Python pickle is not *security* safe. Only use it for your own saves and loads. Any pickle coming from an external source is a potential risk.
Now that we warned you, let’s mention that Objax modules can be pickled with Python’s pickle module like many others Python objects. This can be quite convenient since you can save not only the module’s weight, but the module itself.
Let’s look at a simple example:
import pickle
import objax
# Let's pretend we have a neural network net and we want to save it as whole.
net = objax.nn.Sequential([objax.nn.Linear(768, 1), objax.functional.sigmoid])
# Pickling
pickle.dump(net, open('net.pickle', 'wb'))
# Unpickling and storing into a new network
net2 = pickle.load(open('net.pickle', 'rb'))
# Confirm the network net2 has the same function as net
x = objax.random.normal((100, 768))
print(((net(x) - net2(x)) ** 2).mean()) # 0.0
# Confirm the network net2 does not share net's weights
net[0].b.assign(net[0].b.value + 1)
print(((net(x) - net2(x)) ** 2).mean()) # 0.038710583
As the example shows, pickling is really easy to use. Be aware that Python pickling has some limitations, namely lambda functions cannot always be saved (they have to be named). Objax is not limited to pickle, since its design is pythonic it should be compatible with other python pickling systems.