Interactive online version: Open In Colab

Creating Custom Networks for Multi-Class Classification

This tutorial demonstrates how to define, train, and use different models for multi-class classification. We will reuse most of the code from the Logistic Regression tutorial so if you haven’t gone through that, consider reviewing it first.

Note that this tutorial includes a demonstration on how to build and train a simple convolutional neural network and running this colab on CPU may take some time. Therefore, we recommend to run this colab on GPU (select GPU on the menu Runtime -> Change runtime type -> Hardware accelerator if hardware accelerator is not set to GPU).

Import Modules

We start by importing the modules we will use in our code.

[1]:
%pip --quiet install objax

import os

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.util import EasyDict
from objax.zoo.dnnet import DNNet

Load the data

Next, we will load the “MNIST” dataset from TensorFlow DataSets. This dataset contains handwritten digits (i.e., numbers between 0 and 9) and to correctly identify each handwritten digit.

The prepare method pads 2 pixels to the left, right, top, and bottom of each image to resize into 32 x 32 pixes. While MNIST images are grayscale the prepare method expands each image to three color channels to demonstrate the process of working with color images. The same method also rescales each pixel value to [-1, 1], and converts the image to (N, C, H, W) format.

[2]:
# Data: train has 60000 images - test has 10000 images
# Each image is resized and converted to 32 x 32 x 3
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))

def prepare(x):
  """Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format."""
  s = x.shape
  x_pad = np.zeros((s[0], 32, 32, 1))
  x_pad[:, 2:-2, 2:-2, :] = x
  return objax.util.image.nchw(
      np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))

train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]

del data

Deep Neural Network Model

Objax offers many predefined models that we can use for classification. One example is the objax.zoo.DNNet model comprising multiple fully connected layers with configurable size and activation functions.

[3]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)

Custom Model Definition

Alternatively, we can define a new model customized to our machine learning task. We demonstrate this process by defining a convolutional network (ConvNet) from scratch.

We use objax.nn.Sequential to compose multiple layers of convolution (objax.nn.Conv2D), batch normalization (objax.nn.BatchNorm2D), ReLU (objax.functional.relu), Max Pooling (objax.functional.max_pool_2d), Average Pooling (jax.mean), and Linear (objax.nn.Linear) layers.

Since batch normalization layer behaves differently at training and at prediction, we pass the training flag to a __call__ function of ConvNet class. We also use the flag to output logits at training and probability at prediction.

[4]:
class ConvNet(objax.Module):
  """ConvNet implementation."""

  def __init__(self, nin, nclass):
    """Define 3 blocks of conv-bn-relu-conv-bn-relu followed by linear layer."""
    self.conv_block1 = objax.nn.Sequential([objax.nn.Conv2D(nin, 16, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(16),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(16, 16, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(16),
                                            objax.functional.relu])
    self.conv_block2 = objax.nn.Sequential([objax.nn.Conv2D(16, 32, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(32),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(32, 32, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(32),
                                            objax.functional.relu])
    self.conv_block3 = objax.nn.Sequential([objax.nn.Conv2D(32, 64, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(64),
                                            objax.functional.relu,
                                            objax.nn.Conv2D(64, 64, 3, use_bias=False),
                                            objax.nn.BatchNorm2D(64),
                                            objax.functional.relu])
    self.linear = objax.nn.Linear(64, nclass)

  def __call__(self, x, training):
    x = self.conv_block1(x, training=training)
    x = objax.functional.max_pool_2d(x, size=2, strides=2)
    x = self.conv_block2(x, training=training)
    x = objax.functional.max_pool_2d(x, size=2, strides=2)
    x = self.conv_block3(x, training=training)
    x = x.mean((2, 3))
    x = self.linear(x)
    return x
[5]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b                                           10 (10,)
(ConvNet).linear(Linear).w                                          640 (64, 10)
+Total(32)                                                        73402

Model Training and Evaluation

The train_model method combines all the parts of defining the loss function, gradient descent, training loop, and evaluation. It takes the model as a parameter so it can be reused with the two models we defined earlier.

Unlike the Logistic Regression tutorial we use the objax.functional.loss.cross_entropy_logits_sparse because we perform multi-class classification. The optimizer, gradient descent operation, and training loop remain the same.

The DNNet model expects flattened images whereas ConvNet images in (C, H, W) format. The flatten_image method prepares images before passing them to the model.

When using the model for inference we apply the objax.functional.softmax method to compute the probability distribution from the model’s logits.

[6]:
# Settings
lr = 0.03  # learning rate
batch = 128
epochs = 100

# Train loop

def train_model(model):

  def predict(model, x):
    """"""
    return objax.functional.softmax(model(x,  training=False))

  def flatten_image(x):
    """Flatten the image before passing it to the DNN."""
    if isinstance(model, DNNet):
      return objax.functional.flatten(x)
    else:
      return x

  opt = objax.optimizer.Momentum(model.vars())

  # Cross Entropy Loss
  def loss(x, label):
    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()

  gv = objax.GradValues(loss, model.vars())
  def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v

  train_op = objax.Jit(train_op, gv.vars() + opt.vars())

  for epoch in range(epochs):
    avg_loss = 0
    # randomly shuffle training data
    shuffle_idx = np.random.permutation(train.image.shape[0])
    for it in range(0, train.image.shape[0], batch):
      sel = shuffle_idx[it: it + batch]
      avg_loss += float(train_op(flatten_image(train.image[sel]), train.label[sel])[0]) * len(sel)
    avg_loss /= it + len(sel)

    # Eval
    accuracy = 0
    for it in range(0, test.image.shape[0], batch):
      x, y = test.image[it: it + batch], test.label[it: it + batch]
      accuracy += (np.argmax(predict(model, flatten_image(x)), axis=1) == y).sum()
    accuracy /= test.image.shape[0]
    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))

Training the DNN Model

[7]:
train_model(dnn_model)
Epoch 0001  Loss 2.39  Accuracy 56.31
Epoch 0002  Loss 1.24  Accuracy 74.19
Epoch 0003  Loss 0.75  Accuracy 84.91
Epoch 0004  Loss 0.56  Accuracy 83.10
Epoch 0005  Loss 0.48  Accuracy 86.42
Epoch 0006  Loss 0.43  Accuracy 89.00
Epoch 0007  Loss 0.41  Accuracy 89.62
Epoch 0008  Loss 0.39  Accuracy 89.82
Epoch 0009  Loss 0.37  Accuracy 90.04
Epoch 0010  Loss 0.36  Accuracy 89.55
Epoch 0011  Loss 0.35  Accuracy 90.53
Epoch 0012  Loss 0.35  Accuracy 90.64
Epoch 0013  Loss 0.34  Accuracy 90.85
Epoch 0014  Loss 0.33  Accuracy 90.87
Epoch 0015  Loss 0.33  Accuracy 91.02
Epoch 0016  Loss 0.32  Accuracy 91.35
Epoch 0017  Loss 0.32  Accuracy 91.35
Epoch 0018  Loss 0.31  Accuracy 91.50
Epoch 0019  Loss 0.31  Accuracy 91.57
Epoch 0020  Loss 0.31  Accuracy 91.75
Epoch 0021  Loss 0.30  Accuracy 91.47
Epoch 0022  Loss 0.30  Accuracy 88.14
Epoch 0023  Loss 0.30  Accuracy 91.82
Epoch 0024  Loss 0.30  Accuracy 91.92
Epoch 0025  Loss 0.29  Accuracy 92.03
Epoch 0026  Loss 0.29  Accuracy 92.04
Epoch 0027  Loss 0.29  Accuracy 92.11
Epoch 0028  Loss 0.29  Accuracy 92.11
Epoch 0029  Loss 0.29  Accuracy 92.18
Epoch 0030  Loss 0.28  Accuracy 92.24
Epoch 0031  Loss 0.28  Accuracy 92.36
Epoch 0032  Loss 0.28  Accuracy 92.17
Epoch 0033  Loss 0.28  Accuracy 92.42
Epoch 0034  Loss 0.28  Accuracy 92.42
Epoch 0035  Loss 0.27  Accuracy 92.47
Epoch 0036  Loss 0.27  Accuracy 92.50
Epoch 0037  Loss 0.27  Accuracy 92.49
Epoch 0038  Loss 0.27  Accuracy 92.58
Epoch 0039  Loss 0.26  Accuracy 92.56
Epoch 0040  Loss 0.26  Accuracy 92.56
Epoch 0041  Loss 0.26  Accuracy 92.77
Epoch 0042  Loss 0.26  Accuracy 92.72
Epoch 0043  Loss 0.26  Accuracy 92.80
Epoch 0044  Loss 0.25  Accuracy 92.85
Epoch 0045  Loss 0.25  Accuracy 92.90
Epoch 0046  Loss 0.25  Accuracy 92.96
Epoch 0047  Loss 0.25  Accuracy 93.00
Epoch 0048  Loss 0.25  Accuracy 92.82
Epoch 0049  Loss 0.25  Accuracy 93.18
Epoch 0050  Loss 0.24  Accuracy 93.09
Epoch 0051  Loss 0.24  Accuracy 92.94
Epoch 0052  Loss 0.24  Accuracy 93.20
Epoch 0053  Loss 0.24  Accuracy 93.26
Epoch 0054  Loss 0.23  Accuracy 93.21
Epoch 0055  Loss 0.24  Accuracy 93.42
Epoch 0056  Loss 0.23  Accuracy 93.35
Epoch 0057  Loss 0.23  Accuracy 93.36
Epoch 0058  Loss 0.23  Accuracy 93.56
Epoch 0059  Loss 0.23  Accuracy 93.54
Epoch 0060  Loss 0.22  Accuracy 93.39
Epoch 0061  Loss 0.23  Accuracy 93.56
Epoch 0062  Loss 0.22  Accuracy 93.74
Epoch 0063  Loss 0.22  Accuracy 93.68
Epoch 0064  Loss 0.22  Accuracy 93.72
Epoch 0065  Loss 0.22  Accuracy 93.76
Epoch 0066  Loss 0.22  Accuracy 93.87
Epoch 0067  Loss 0.21  Accuracy 93.89
Epoch 0068  Loss 0.21  Accuracy 93.96
Epoch 0069  Loss 0.21  Accuracy 93.90
Epoch 0070  Loss 0.21  Accuracy 93.99
Epoch 0071  Loss 0.21  Accuracy 94.02
Epoch 0072  Loss 0.21  Accuracy 93.86
Epoch 0073  Loss 0.21  Accuracy 94.06
Epoch 0074  Loss 0.21  Accuracy 94.14
Epoch 0075  Loss 0.20  Accuracy 94.31
Epoch 0076  Loss 0.20  Accuracy 94.14
Epoch 0077  Loss 0.20  Accuracy 94.15
Epoch 0078  Loss 0.20  Accuracy 94.10
Epoch 0079  Loss 0.20  Accuracy 94.16
Epoch 0080  Loss 0.20  Accuracy 94.28
Epoch 0081  Loss 0.20  Accuracy 94.30
Epoch 0082  Loss 0.20  Accuracy 94.28
Epoch 0083  Loss 0.19  Accuracy 94.37
Epoch 0084  Loss 0.19  Accuracy 94.33
Epoch 0085  Loss 0.19  Accuracy 94.31
Epoch 0086  Loss 0.19  Accuracy 94.25
Epoch 0087  Loss 0.19  Accuracy 94.37
Epoch 0088  Loss 0.19  Accuracy 94.38
Epoch 0089  Loss 0.19  Accuracy 94.35
Epoch 0090  Loss 0.19  Accuracy 94.38
Epoch 0091  Loss 0.19  Accuracy 94.41
Epoch 0092  Loss 0.19  Accuracy 94.46
Epoch 0093  Loss 0.19  Accuracy 94.53
Epoch 0094  Loss 0.18  Accuracy 94.47
Epoch 0095  Loss 0.18  Accuracy 94.54
Epoch 0096  Loss 0.18  Accuracy 94.65
Epoch 0097  Loss 0.18  Accuracy 94.56
Epoch 0098  Loss 0.18  Accuracy 94.60
Epoch 0099  Loss 0.18  Accuracy 94.63
Epoch 0100  Loss 0.18  Accuracy 94.46

Training the ConvNet Model

[8]:
train_model(cnn_model)
Epoch 0001  Loss 0.27  Accuracy 27.08
Epoch 0002  Loss 0.05  Accuracy 41.07
Epoch 0003  Loss 0.03  Accuracy 67.77
Epoch 0004  Loss 0.03  Accuracy 73.31
Epoch 0005  Loss 0.02  Accuracy 90.30
Epoch 0006  Loss 0.02  Accuracy 93.10
Epoch 0007  Loss 0.02  Accuracy 95.98
Epoch 0008  Loss 0.01  Accuracy 98.77
Epoch 0009  Loss 0.01  Accuracy 96.58
Epoch 0010  Loss 0.01  Accuracy 99.12
Epoch 0011  Loss 0.01  Accuracy 98.88
Epoch 0012  Loss 0.01  Accuracy 98.64
Epoch 0013  Loss 0.01  Accuracy 98.66
Epoch 0014  Loss 0.00  Accuracy 98.38
Epoch 0015  Loss 0.00  Accuracy 99.15
Epoch 0016  Loss 0.00  Accuracy 97.50
Epoch 0017  Loss 0.00  Accuracy 98.98
Epoch 0018  Loss 0.00  Accuracy 98.94
Epoch 0019  Loss 0.00  Accuracy 98.56
Epoch 0020  Loss 0.00  Accuracy 99.06
Epoch 0021  Loss 0.00  Accuracy 99.26
Epoch 0022  Loss 0.00  Accuracy 99.30
Epoch 0023  Loss 0.00  Accuracy 99.18
Epoch 0024  Loss 0.00  Accuracy 99.49
Epoch 0025  Loss 0.00  Accuracy 99.34
Epoch 0026  Loss 0.00  Accuracy 99.24
Epoch 0027  Loss 0.00  Accuracy 99.38
Epoch 0028  Loss 0.00  Accuracy 99.43
Epoch 0029  Loss 0.00  Accuracy 99.40
Epoch 0030  Loss 0.00  Accuracy 99.50
Epoch 0031  Loss 0.00  Accuracy 99.44
Epoch 0032  Loss 0.00  Accuracy 99.52
Epoch 0033  Loss 0.00  Accuracy 99.46
Epoch 0034  Loss 0.00  Accuracy 99.39
Epoch 0035  Loss 0.00  Accuracy 99.22
Epoch 0036  Loss 0.00  Accuracy 99.26
Epoch 0037  Loss 0.00  Accuracy 99.47
Epoch 0038  Loss 0.00  Accuracy 99.18
Epoch 0039  Loss 0.00  Accuracy 99.39
Epoch 0040  Loss 0.00  Accuracy 99.44
Epoch 0041  Loss 0.00  Accuracy 99.43
Epoch 0042  Loss 0.00  Accuracy 99.50
Epoch 0043  Loss 0.00  Accuracy 99.50
Epoch 0044  Loss 0.00  Accuracy 99.53
Epoch 0045  Loss 0.00  Accuracy 99.51
Epoch 0046  Loss 0.00  Accuracy 99.49
Epoch 0047  Loss 0.00  Accuracy 99.46
Epoch 0048  Loss 0.00  Accuracy 99.46
Epoch 0049  Loss 0.00  Accuracy 99.35
Epoch 0050  Loss 0.00  Accuracy 99.50
Epoch 0051  Loss 0.00  Accuracy 99.48
Epoch 0052  Loss 0.00  Accuracy 99.48
Epoch 0053  Loss 0.00  Accuracy 99.48
Epoch 0054  Loss 0.00  Accuracy 99.46
Epoch 0055  Loss 0.00  Accuracy 99.48
Epoch 0056  Loss 0.00  Accuracy 99.50
Epoch 0057  Loss 0.00  Accuracy 99.41
Epoch 0058  Loss 0.00  Accuracy 99.49
Epoch 0059  Loss 0.00  Accuracy 99.48
Epoch 0060  Loss 0.00  Accuracy 99.47
Epoch 0061  Loss 0.00  Accuracy 99.52
Epoch 0062  Loss 0.00  Accuracy 99.49
Epoch 0063  Loss 0.00  Accuracy 99.48
Epoch 0064  Loss 0.00  Accuracy 99.51
Epoch 0065  Loss 0.00  Accuracy 99.46
Epoch 0066  Loss 0.00  Accuracy 99.51
Epoch 0067  Loss 0.00  Accuracy 99.49
Epoch 0068  Loss 0.00  Accuracy 99.52
Epoch 0069  Loss 0.00  Accuracy 99.49
Epoch 0070  Loss 0.00  Accuracy 99.51
Epoch 0071  Loss 0.00  Accuracy 99.51
Epoch 0072  Loss 0.00  Accuracy 99.52
Epoch 0073  Loss 0.00  Accuracy 99.43
Epoch 0074  Loss 0.00  Accuracy 99.53
Epoch 0075  Loss 0.00  Accuracy 99.47
Epoch 0076  Loss 0.00  Accuracy 99.51
Epoch 0077  Loss 0.00  Accuracy 99.55
Epoch 0078  Loss 0.00  Accuracy 99.52
Epoch 0079  Loss 0.00  Accuracy 99.52
Epoch 0080  Loss 0.00  Accuracy 98.78
Epoch 0081  Loss 0.00  Accuracy 99.16
Epoch 0082  Loss 0.00  Accuracy 99.40
Epoch 0083  Loss 0.00  Accuracy 99.35
Epoch 0084  Loss 0.00  Accuracy 99.32
Epoch 0085  Loss 0.00  Accuracy 99.49
Epoch 0086  Loss 0.00  Accuracy 99.49
Epoch 0087  Loss 0.00  Accuracy 99.56
Epoch 0088  Loss 0.00  Accuracy 99.48
Epoch 0089  Loss 0.00  Accuracy 99.48
Epoch 0090  Loss 0.00  Accuracy 99.51
Epoch 0091  Loss 0.00  Accuracy 99.45
Epoch 0092  Loss 0.00  Accuracy 99.52
Epoch 0093  Loss 0.00  Accuracy 99.52
Epoch 0094  Loss 0.00  Accuracy 99.51
Epoch 0095  Loss 0.00  Accuracy 99.51
Epoch 0096  Loss 0.00  Accuracy 99.48
Epoch 0097  Loss 0.00  Accuracy 99.51
Epoch 0098  Loss 0.00  Accuracy 99.53
Epoch 0099  Loss 0.00  Accuracy 99.50
Epoch 0100  Loss 0.00  Accuracy 99.53

Training with PyTorch data processing API

One of the pain points for ML researchers/practitioners when building a new ML model is the data processing. Here, we demonstrate how to use data processing API of PyTorch to train a model with Objax. Different deep learning library comes with different data processing APIs, and depending on your preference, you can choose an API and easily combine with Objax.

Similarly, we prepare an MNIST dataset and apply the same data preprocessing.

[9]:
import torch
from torchvision import datasets, transforms

transform=transforms.Compose([
                              transforms.Pad((2,2,2,2), 0),
                              transforms.ToTensor(),
                              transforms.Lambda(lambda x: np.concatenate([x] * 3, axis=0)),
                              transforms.Lambda(lambda x: x * 2 - 1)
                              ])
train_dataset = datasets.MNIST(os.environ['HOME'], train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(os.environ['HOME'], train=False, download=True, transform=transform)

# Define data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/train-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/train-labels-idx1-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/MNIST/raw/t10k-images-idx3-ubyte.gz

Extracting /root/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

We replace data processing pipeline of the train and test loop with train_loader and test_loader and that’s it!

[10]:
# Train loop

def train_model_with_torch_data_api(model):

  def predict(model, x):
    """"""
    return objax.functional.softmax(model(x,  training=False))

  def flatten_image(x):
    """Flatten the image before passing it to the DNN."""
    if isinstance(model, DNNet):
      return objax.functional.flatten(x)
    else:
      return x

  opt = objax.optimizer.Momentum(model.vars())

  # Cross Entropy Loss
  def loss(x, label):
    return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()

  gv = objax.GradValues(loss, model.vars())
  def train_op(x, label):
    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v

  train_op = objax.Jit(train_op, gv.vars() + opt.vars())

  for epoch in range(epochs):
    avg_loss = 0
    tot_data = 0
    for _, (img, label) in enumerate(train_loader):
      avg_loss += float(train_op(flatten_image(img.numpy()), label.numpy())[0]) * len(img)
      tot_data += len(img)
    avg_loss /= tot_data

    # Eval
    accuracy = 0
    tot_data = 0
    for _, (img, label) in enumerate(test_loader):
      accuracy += (np.argmax(predict(model, flatten_image(img.numpy())), axis=1) == label.numpy()).sum()
      tot_data += len(img)
    accuracy /= tot_data
    print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))

Training the DNN Model with PyTorch data API

[11]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
train_model_with_torch_data_api(dnn_model)
Epoch 0001  Loss 2.57  Accuracy 34.35
Epoch 0002  Loss 1.93  Accuracy 58.51
Epoch 0003  Loss 1.32  Accuracy 68.46
Epoch 0004  Loss 0.83  Accuracy 80.95
Epoch 0005  Loss 0.62  Accuracy 84.74
Epoch 0006  Loss 0.53  Accuracy 86.53
Epoch 0007  Loss 0.48  Accuracy 84.18
Epoch 0008  Loss 0.45  Accuracy 88.42
Epoch 0009  Loss 0.42  Accuracy 87.34
Epoch 0010  Loss 0.40  Accuracy 89.29
Epoch 0011  Loss 0.39  Accuracy 89.31
Epoch 0012  Loss 0.38  Accuracy 89.86
Epoch 0013  Loss 0.37  Accuracy 89.91
Epoch 0014  Loss 0.36  Accuracy 86.94
Epoch 0015  Loss 0.36  Accuracy 89.89
Epoch 0016  Loss 0.35  Accuracy 90.12
Epoch 0017  Loss 0.34  Accuracy 90.40
Epoch 0018  Loss 0.34  Accuracy 90.31
Epoch 0019  Loss 0.34  Accuracy 90.79
Epoch 0020  Loss 0.33  Accuracy 90.71
Epoch 0021  Loss 0.33  Accuracy 90.70
Epoch 0022  Loss 0.33  Accuracy 90.69
Epoch 0023  Loss 0.33  Accuracy 90.91
Epoch 0024  Loss 0.32  Accuracy 90.92
Epoch 0025  Loss 0.32  Accuracy 91.06
Epoch 0026  Loss 0.32  Accuracy 91.19
Epoch 0027  Loss 0.32  Accuracy 91.31
Epoch 0028  Loss 0.31  Accuracy 91.31
Epoch 0029  Loss 0.31  Accuracy 91.20
Epoch 0030  Loss 0.31  Accuracy 91.31
Epoch 0031  Loss 0.31  Accuracy 91.36
Epoch 0032  Loss 0.31  Accuracy 91.42
Epoch 0033  Loss 0.30  Accuracy 91.27
Epoch 0034  Loss 0.31  Accuracy 91.47
Epoch 0035  Loss 0.30  Accuracy 91.57
Epoch 0036  Loss 0.30  Accuracy 91.44
Epoch 0037  Loss 0.30  Accuracy 91.55
Epoch 0038  Loss 0.30  Accuracy 91.56
Epoch 0039  Loss 0.29  Accuracy 91.75
Epoch 0040  Loss 0.29  Accuracy 91.69
Epoch 0041  Loss 0.29  Accuracy 91.60
Epoch 0042  Loss 0.29  Accuracy 91.77
Epoch 0043  Loss 0.29  Accuracy 91.76
Epoch 0044  Loss 0.29  Accuracy 91.84
Epoch 0045  Loss 0.28  Accuracy 92.05
Epoch 0046  Loss 0.28  Accuracy 91.78
Epoch 0047  Loss 0.28  Accuracy 92.01
Epoch 0048  Loss 0.28  Accuracy 91.95
Epoch 0049  Loss 0.28  Accuracy 90.11
Epoch 0050  Loss 0.28  Accuracy 92.14
Epoch 0051  Loss 0.28  Accuracy 92.03
Epoch 0052  Loss 0.27  Accuracy 92.29
Epoch 0053  Loss 0.27  Accuracy 92.17
Epoch 0054  Loss 0.27  Accuracy 92.12
Epoch 0055  Loss 0.27  Accuracy 92.34
Epoch 0056  Loss 0.27  Accuracy 92.32
Epoch 0057  Loss 0.27  Accuracy 92.47
Epoch 0058  Loss 0.27  Accuracy 92.38
Epoch 0059  Loss 0.27  Accuracy 92.39
Epoch 0060  Loss 0.26  Accuracy 92.51
Epoch 0061  Loss 0.27  Accuracy 92.50
Epoch 0062  Loss 0.26  Accuracy 92.46
Epoch 0063  Loss 0.26  Accuracy 92.65
Epoch 0064  Loss 0.26  Accuracy 92.57
Epoch 0065  Loss 0.26  Accuracy 92.63
Epoch 0066  Loss 0.26  Accuracy 92.75
Epoch 0067  Loss 0.26  Accuracy 92.57
Epoch 0068  Loss 0.26  Accuracy 92.88
Epoch 0069  Loss 0.25  Accuracy 92.53
Epoch 0070  Loss 0.25  Accuracy 92.80
Epoch 0071  Loss 0.25  Accuracy 92.71
Epoch 0072  Loss 0.25  Accuracy 92.75
Epoch 0073  Loss 0.25  Accuracy 92.84
Epoch 0074  Loss 0.25  Accuracy 92.71
Epoch 0075  Loss 0.25  Accuracy 92.95
Epoch 0076  Loss 0.25  Accuracy 92.82
Epoch 0077  Loss 0.25  Accuracy 92.90
Epoch 0078  Loss 0.25  Accuracy 92.87
Epoch 0079  Loss 0.25  Accuracy 89.55
Epoch 0080  Loss 0.25  Accuracy 92.86
Epoch 0081  Loss 0.24  Accuracy 92.99
Epoch 0082  Loss 0.24  Accuracy 93.03
Epoch 0083  Loss 0.24  Accuracy 93.03
Epoch 0084  Loss 0.24  Accuracy 93.01
Epoch 0085  Loss 0.24  Accuracy 93.13
Epoch 0086  Loss 0.24  Accuracy 93.17
Epoch 0087  Loss 0.24  Accuracy 92.87
Epoch 0088  Loss 0.24  Accuracy 92.93
Epoch 0089  Loss 0.24  Accuracy 93.16
Epoch 0090  Loss 0.24  Accuracy 93.38
Epoch 0091  Loss 0.24  Accuracy 92.98
Epoch 0092  Loss 0.24  Accuracy 93.30
Epoch 0093  Loss 0.23  Accuracy 93.09
Epoch 0094  Loss 0.23  Accuracy 93.19
Epoch 0095  Loss 0.23  Accuracy 93.25
Epoch 0096  Loss 0.23  Accuracy 93.22
Epoch 0097  Loss 0.23  Accuracy 93.28
Epoch 0098  Loss 0.23  Accuracy 93.39
Epoch 0099  Loss 0.23  Accuracy 93.25
Epoch 0100  Loss 0.23  Accuracy 93.30

Training the ConvNet Model with PyTorch data API

[12]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
train_model_with_torch_data_api(cnn_model)
(ConvNet).conv_block1(Sequential)[0](Conv2D).w                      432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w                     2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean       16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var        16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta               16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma              16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w                     4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w                     9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean       32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var        32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta               32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma              32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w                    18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w                    36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean       64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var        64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta               64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma              64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b                                           10 (10,)
(ConvNet).linear(Linear).w                                          640 (64, 10)
+Total(32)                                                        73402
Epoch 0001  Loss 0.26  Accuracy 24.18
Epoch 0002  Loss 0.05  Accuracy 37.53
Epoch 0003  Loss 0.03  Accuracy 42.17
Epoch 0004  Loss 0.03  Accuracy 73.50
Epoch 0005  Loss 0.02  Accuracy 80.33
Epoch 0006  Loss 0.02  Accuracy 83.28
Epoch 0007  Loss 0.02  Accuracy 90.87
Epoch 0008  Loss 0.01  Accuracy 98.77
Epoch 0009  Loss 0.01  Accuracy 98.42
Epoch 0010  Loss 0.01  Accuracy 98.16
Epoch 0011  Loss 0.01  Accuracy 98.74
Epoch 0012  Loss 0.01  Accuracy 95.05
Epoch 0013  Loss 0.01  Accuracy 98.89
Epoch 0014  Loss 0.00  Accuracy 98.70
Epoch 0015  Loss 0.00  Accuracy 99.01
Epoch 0016  Loss 0.00  Accuracy 98.97
Epoch 0017  Loss 0.00  Accuracy 98.79
Epoch 0018  Loss 0.00  Accuracy 98.37
Epoch 0019  Loss 0.00  Accuracy 99.19
Epoch 0020  Loss 0.00  Accuracy 99.22
Epoch 0021  Loss 0.00  Accuracy 98.43
Epoch 0022  Loss 0.00  Accuracy 99.02
Epoch 0023  Loss 0.00  Accuracy 99.38
Epoch 0024  Loss 0.00  Accuracy 99.42
Epoch 0025  Loss 0.00  Accuracy 99.45
Epoch 0026  Loss 0.00  Accuracy 99.35
Epoch 0027  Loss 0.00  Accuracy 99.42
Epoch 0028  Loss 0.00  Accuracy 99.42
Epoch 0029  Loss 0.00  Accuracy 99.14
Epoch 0030  Loss 0.00  Accuracy 99.33
Epoch 0031  Loss 0.00  Accuracy 99.36
Epoch 0032  Loss 0.00  Accuracy 99.18
Epoch 0033  Loss 0.00  Accuracy 99.43
Epoch 0034  Loss 0.00  Accuracy 99.47
Epoch 0035  Loss 0.00  Accuracy 99.49
Epoch 0036  Loss 0.00  Accuracy 99.53
Epoch 0037  Loss 0.00  Accuracy 99.38
Epoch 0038  Loss 0.00  Accuracy 99.39
Epoch 0039  Loss 0.00  Accuracy 99.49
Epoch 0040  Loss 0.00  Accuracy 99.49
Epoch 0041  Loss 0.00  Accuracy 99.47
Epoch 0042  Loss 0.00  Accuracy 99.54
Epoch 0043  Loss 0.00  Accuracy 99.35
Epoch 0044  Loss 0.00  Accuracy 99.45
Epoch 0045  Loss 0.00  Accuracy 99.47
Epoch 0046  Loss 0.00  Accuracy 99.53
Epoch 0047  Loss 0.00  Accuracy 99.50
Epoch 0048  Loss 0.00  Accuracy 99.52
Epoch 0049  Loss 0.00  Accuracy 99.51
Epoch 0050  Loss 0.00  Accuracy 99.49
Epoch 0051  Loss 0.00  Accuracy 99.45
Epoch 0052  Loss 0.00  Accuracy 99.48
Epoch 0053  Loss 0.00  Accuracy 99.50
Epoch 0054  Loss 0.00  Accuracy 99.46
Epoch 0055  Loss 0.00  Accuracy 99.50
Epoch 0056  Loss 0.00  Accuracy 99.48
Epoch 0057  Loss 0.00  Accuracy 99.46
Epoch 0058  Loss 0.00  Accuracy 99.44
Epoch 0059  Loss 0.00  Accuracy 99.46
Epoch 0060  Loss 0.00  Accuracy 99.26
Epoch 0061  Loss 0.00  Accuracy 93.99
Epoch 0062  Loss 0.00  Accuracy 97.80
Epoch 0063  Loss 0.00  Accuracy 80.26
Epoch 0064  Loss 0.00  Accuracy 99.20
Epoch 0065  Loss 0.00  Accuracy 99.38
Epoch 0066  Loss 0.00  Accuracy 99.44
Epoch 0067  Loss 0.00  Accuracy 99.51
Epoch 0068  Loss 0.00  Accuracy 99.45
Epoch 0069  Loss 0.00  Accuracy 99.42
Epoch 0070  Loss 0.00  Accuracy 99.50
Epoch 0071  Loss 0.00  Accuracy 99.52
Epoch 0072  Loss 0.00  Accuracy 99.44
Epoch 0073  Loss 0.00  Accuracy 99.41
Epoch 0074  Loss 0.00  Accuracy 99.46
Epoch 0075  Loss 0.00  Accuracy 99.42
Epoch 0076  Loss 0.00  Accuracy 99.49
Epoch 0077  Loss 0.00  Accuracy 99.50
Epoch 0078  Loss 0.00  Accuracy 99.56
Epoch 0079  Loss 0.00  Accuracy 99.52
Epoch 0080  Loss 0.00  Accuracy 99.42
Epoch 0081  Loss 0.00  Accuracy 99.49
Epoch 0082  Loss 0.00  Accuracy 99.48
Epoch 0083  Loss 0.00  Accuracy 99.44
Epoch 0084  Loss 0.00  Accuracy 99.49
Epoch 0085  Loss 0.00  Accuracy 99.53
Epoch 0086  Loss 0.00  Accuracy 99.52
Epoch 0087  Loss 0.00  Accuracy 99.52
Epoch 0088  Loss 0.00  Accuracy 99.50
Epoch 0089  Loss 0.00  Accuracy 99.51
Epoch 0090  Loss 0.00  Accuracy 99.50
Epoch 0091  Loss 0.00  Accuracy 99.49
Epoch 0092  Loss 0.00  Accuracy 99.52
Epoch 0093  Loss 0.00  Accuracy 99.50
Epoch 0094  Loss 0.00  Accuracy 99.50
Epoch 0095  Loss 0.00  Accuracy 99.55
Epoch 0096  Loss 0.00  Accuracy 99.48
Epoch 0097  Loss 0.00  Accuracy 99.51
Epoch 0098  Loss 0.00  Accuracy 99.52
Epoch 0099  Loss 0.00  Accuracy 99.49
Epoch 0100  Loss 0.00  Accuracy 99.52

What’s Next

We have learned how to use existing models and define new models to classify MNIST. Next, you can read one or more of the in-depth topics or browse through the Objax’s APIs.