Creating Custom Networks for Multi-Class Classification¶
This tutorial demonstrates how to define, train, and use different models for multi-class classification. We will reuse most of the code from the Logistic Regression tutorial so if you haven’t gone through that, consider reviewing it first.
Note that this tutorial includes a demonstration on how to build and train a simple convolutional neural network and running this colab on CPU may take some time. Therefore, we recommend to run this colab on GPU (select GPU
on the menu Runtime
-> Change runtime type
-> Hardware accelerator
if hardware accelerator is not set to GPU).
Import Modules¶
We start by importing the modules we will use in our code.
[1]:
%pip --quiet install objax
import os
import numpy as np
import tensorflow_datasets as tfds
import objax
from objax.util import EasyDict
from objax.zoo.dnnet import DNNet
Load the data¶
Next, we will load the “MNIST” dataset from TensorFlow DataSets. This dataset contains handwritten digits (i.e., numbers between 0 and 9) and to correctly identify each handwritten digit.
The prepare
method pads 2 pixels to the left, right, top, and bottom of each image to resize into 32 x 32 pixes. While MNIST images are grayscale the prepare
method expands each image to three color channels to demonstrate the process of working with color images. The same method also rescales each pixel value to [-1, 1], and converts the image to (N, C, H, W) format.
[2]:
# Data: train has 60000 images - test has 10000 images
# Each image is resized and converted to 32 x 32 x 3
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))
def prepare(x):
"""Pads 2 pixels to the left, right, top, and bottom of each image, scales pixel value to [-1, 1], and converts to NCHW format."""
s = x.shape
x_pad = np.zeros((s[0], 32, 32, 1))
x_pad[:, 2:-2, 2:-2, :] = x
return objax.util.image.nchw(
np.concatenate([x_pad.astype('f') * (1 / 127.5) - 1] * 3, axis=-1))
train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data
Deep Neural Network Model¶
Objax offers many predefined models that we can use for classification. One example is the objax.zoo.DNNet
model comprising multiple fully connected layers with configurable size and activation functions.
[3]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
Custom Model Definition¶
Alternatively, we can define a new model customized to our machine learning task. We demonstrate this process by defining a convolutional network (ConvNet) from scratch.
We use objax.nn.Sequential
to compose multiple layers of convolution (objax.nn.Conv2D
), batch normalization (objax.nn.BatchNorm2D
), ReLU (objax.functional.relu
), Max Pooling (objax.functional.max_pool_2d
), Average Pooling (jax.mean
), and Linear (objax.nn.Linear
) layers.
Since batch normalization layer behaves differently at training and at prediction, we pass the training
flag to a __call__
function of ConvNet
class. We also use the flag to output logits at training and probability at prediction.
[4]:
class ConvNet(objax.Module):
"""ConvNet implementation."""
def __init__(self, nin, nclass):
"""Define 3 blocks of conv-bn-relu-conv-bn-relu followed by linear layer."""
self.conv_block1 = objax.nn.Sequential([objax.nn.Conv2D(nin, 16, 3, use_bias=False),
objax.nn.BatchNorm2D(16),
objax.functional.relu,
objax.nn.Conv2D(16, 16, 3, use_bias=False),
objax.nn.BatchNorm2D(16),
objax.functional.relu])
self.conv_block2 = objax.nn.Sequential([objax.nn.Conv2D(16, 32, 3, use_bias=False),
objax.nn.BatchNorm2D(32),
objax.functional.relu,
objax.nn.Conv2D(32, 32, 3, use_bias=False),
objax.nn.BatchNorm2D(32),
objax.functional.relu])
self.conv_block3 = objax.nn.Sequential([objax.nn.Conv2D(32, 64, 3, use_bias=False),
objax.nn.BatchNorm2D(64),
objax.functional.relu,
objax.nn.Conv2D(64, 64, 3, use_bias=False),
objax.nn.BatchNorm2D(64),
objax.functional.relu])
self.linear = objax.nn.Linear(64, nclass)
def __call__(self, x, training):
x = self.conv_block1(x, training=training)
x = objax.functional.max_pool_2d(x, size=2, strides=2)
x = self.conv_block2(x, training=training)
x = objax.functional.max_pool_2d(x, size=2, strides=2)
x = self.conv_block3(x, training=training)
x = x.mean((2, 3))
x = self.linear(x)
return x
[5]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
(ConvNet).conv_block1(Sequential)[0](Conv2D).w 432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w 2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w 4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w 9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w 18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w 36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b 10 (10,)
(ConvNet).linear(Linear).w 640 (64, 10)
+Total(32) 73402
Model Training and Evaluation¶
The train_model
method combines all the parts of defining the loss function, gradient descent, training loop, and evaluation. It takes the model
as a parameter so it can be reused with the two models we defined earlier.
Unlike the Logistic Regression tutorial we use the objax.functional.loss.cross_entropy_logits_sparse
because we perform multi-class classification. The optimizer, gradient descent operation, and training loop remain the same.
The DNNet
model expects flattened images whereas ConvNet
images in (C, H, W) format. The flatten_image
method prepares images before passing them to the model.
When using the model for inference we apply the objax.functional.softmax
method to compute the probability distribution from the model’s logits.
[6]:
# Settings
lr = 0.03 # learning rate
batch = 128
epochs = 100
# Train loop
def train_model(model):
def predict(model, x):
""""""
return objax.functional.softmax(model(x, training=False))
def flatten_image(x):
"""Flatten the image before passing it to the DNN."""
if isinstance(model, DNNet):
return objax.functional.flatten(x)
else:
return x
opt = objax.optimizer.Momentum(model.vars())
# Cross Entropy Loss
def loss(x, label):
return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()
gv = objax.GradValues(loss, model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
for epoch in range(epochs):
avg_loss = 0
# randomly shuffle training data
shuffle_idx = np.random.permutation(train.image.shape[0])
for it in range(0, train.image.shape[0], batch):
sel = shuffle_idx[it: it + batch]
avg_loss += float(train_op(flatten_image(train.image[sel]), train.label[sel])[0]) * len(sel)
avg_loss /= it + len(sel)
# Eval
accuracy = 0
for it in range(0, test.image.shape[0], batch):
x, y = test.image[it: it + batch], test.label[it: it + batch]
accuracy += (np.argmax(predict(model, flatten_image(x)), axis=1) == y).sum()
accuracy /= test.image.shape[0]
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Training the DNN Model¶
[7]:
train_model(dnn_model)
Epoch 0001 Loss 2.39 Accuracy 56.31
Epoch 0002 Loss 1.24 Accuracy 74.19
Epoch 0003 Loss 0.75 Accuracy 84.91
Epoch 0004 Loss 0.56 Accuracy 83.10
Epoch 0005 Loss 0.48 Accuracy 86.42
Epoch 0006 Loss 0.43 Accuracy 89.00
Epoch 0007 Loss 0.41 Accuracy 89.62
Epoch 0008 Loss 0.39 Accuracy 89.82
Epoch 0009 Loss 0.37 Accuracy 90.04
Epoch 0010 Loss 0.36 Accuracy 89.55
Epoch 0011 Loss 0.35 Accuracy 90.53
Epoch 0012 Loss 0.35 Accuracy 90.64
Epoch 0013 Loss 0.34 Accuracy 90.85
Epoch 0014 Loss 0.33 Accuracy 90.87
Epoch 0015 Loss 0.33 Accuracy 91.02
Epoch 0016 Loss 0.32 Accuracy 91.35
Epoch 0017 Loss 0.32 Accuracy 91.35
Epoch 0018 Loss 0.31 Accuracy 91.50
Epoch 0019 Loss 0.31 Accuracy 91.57
Epoch 0020 Loss 0.31 Accuracy 91.75
Epoch 0021 Loss 0.30 Accuracy 91.47
Epoch 0022 Loss 0.30 Accuracy 88.14
Epoch 0023 Loss 0.30 Accuracy 91.82
Epoch 0024 Loss 0.30 Accuracy 91.92
Epoch 0025 Loss 0.29 Accuracy 92.03
Epoch 0026 Loss 0.29 Accuracy 92.04
Epoch 0027 Loss 0.29 Accuracy 92.11
Epoch 0028 Loss 0.29 Accuracy 92.11
Epoch 0029 Loss 0.29 Accuracy 92.18
Epoch 0030 Loss 0.28 Accuracy 92.24
Epoch 0031 Loss 0.28 Accuracy 92.36
Epoch 0032 Loss 0.28 Accuracy 92.17
Epoch 0033 Loss 0.28 Accuracy 92.42
Epoch 0034 Loss 0.28 Accuracy 92.42
Epoch 0035 Loss 0.27 Accuracy 92.47
Epoch 0036 Loss 0.27 Accuracy 92.50
Epoch 0037 Loss 0.27 Accuracy 92.49
Epoch 0038 Loss 0.27 Accuracy 92.58
Epoch 0039 Loss 0.26 Accuracy 92.56
Epoch 0040 Loss 0.26 Accuracy 92.56
Epoch 0041 Loss 0.26 Accuracy 92.77
Epoch 0042 Loss 0.26 Accuracy 92.72
Epoch 0043 Loss 0.26 Accuracy 92.80
Epoch 0044 Loss 0.25 Accuracy 92.85
Epoch 0045 Loss 0.25 Accuracy 92.90
Epoch 0046 Loss 0.25 Accuracy 92.96
Epoch 0047 Loss 0.25 Accuracy 93.00
Epoch 0048 Loss 0.25 Accuracy 92.82
Epoch 0049 Loss 0.25 Accuracy 93.18
Epoch 0050 Loss 0.24 Accuracy 93.09
Epoch 0051 Loss 0.24 Accuracy 92.94
Epoch 0052 Loss 0.24 Accuracy 93.20
Epoch 0053 Loss 0.24 Accuracy 93.26
Epoch 0054 Loss 0.23 Accuracy 93.21
Epoch 0055 Loss 0.24 Accuracy 93.42
Epoch 0056 Loss 0.23 Accuracy 93.35
Epoch 0057 Loss 0.23 Accuracy 93.36
Epoch 0058 Loss 0.23 Accuracy 93.56
Epoch 0059 Loss 0.23 Accuracy 93.54
Epoch 0060 Loss 0.22 Accuracy 93.39
Epoch 0061 Loss 0.23 Accuracy 93.56
Epoch 0062 Loss 0.22 Accuracy 93.74
Epoch 0063 Loss 0.22 Accuracy 93.68
Epoch 0064 Loss 0.22 Accuracy 93.72
Epoch 0065 Loss 0.22 Accuracy 93.76
Epoch 0066 Loss 0.22 Accuracy 93.87
Epoch 0067 Loss 0.21 Accuracy 93.89
Epoch 0068 Loss 0.21 Accuracy 93.96
Epoch 0069 Loss 0.21 Accuracy 93.90
Epoch 0070 Loss 0.21 Accuracy 93.99
Epoch 0071 Loss 0.21 Accuracy 94.02
Epoch 0072 Loss 0.21 Accuracy 93.86
Epoch 0073 Loss 0.21 Accuracy 94.06
Epoch 0074 Loss 0.21 Accuracy 94.14
Epoch 0075 Loss 0.20 Accuracy 94.31
Epoch 0076 Loss 0.20 Accuracy 94.14
Epoch 0077 Loss 0.20 Accuracy 94.15
Epoch 0078 Loss 0.20 Accuracy 94.10
Epoch 0079 Loss 0.20 Accuracy 94.16
Epoch 0080 Loss 0.20 Accuracy 94.28
Epoch 0081 Loss 0.20 Accuracy 94.30
Epoch 0082 Loss 0.20 Accuracy 94.28
Epoch 0083 Loss 0.19 Accuracy 94.37
Epoch 0084 Loss 0.19 Accuracy 94.33
Epoch 0085 Loss 0.19 Accuracy 94.31
Epoch 0086 Loss 0.19 Accuracy 94.25
Epoch 0087 Loss 0.19 Accuracy 94.37
Epoch 0088 Loss 0.19 Accuracy 94.38
Epoch 0089 Loss 0.19 Accuracy 94.35
Epoch 0090 Loss 0.19 Accuracy 94.38
Epoch 0091 Loss 0.19 Accuracy 94.41
Epoch 0092 Loss 0.19 Accuracy 94.46
Epoch 0093 Loss 0.19 Accuracy 94.53
Epoch 0094 Loss 0.18 Accuracy 94.47
Epoch 0095 Loss 0.18 Accuracy 94.54
Epoch 0096 Loss 0.18 Accuracy 94.65
Epoch 0097 Loss 0.18 Accuracy 94.56
Epoch 0098 Loss 0.18 Accuracy 94.60
Epoch 0099 Loss 0.18 Accuracy 94.63
Epoch 0100 Loss 0.18 Accuracy 94.46
Training the ConvNet Model¶
[8]:
train_model(cnn_model)
Epoch 0001 Loss 0.27 Accuracy 27.08
Epoch 0002 Loss 0.05 Accuracy 41.07
Epoch 0003 Loss 0.03 Accuracy 67.77
Epoch 0004 Loss 0.03 Accuracy 73.31
Epoch 0005 Loss 0.02 Accuracy 90.30
Epoch 0006 Loss 0.02 Accuracy 93.10
Epoch 0007 Loss 0.02 Accuracy 95.98
Epoch 0008 Loss 0.01 Accuracy 98.77
Epoch 0009 Loss 0.01 Accuracy 96.58
Epoch 0010 Loss 0.01 Accuracy 99.12
Epoch 0011 Loss 0.01 Accuracy 98.88
Epoch 0012 Loss 0.01 Accuracy 98.64
Epoch 0013 Loss 0.01 Accuracy 98.66
Epoch 0014 Loss 0.00 Accuracy 98.38
Epoch 0015 Loss 0.00 Accuracy 99.15
Epoch 0016 Loss 0.00 Accuracy 97.50
Epoch 0017 Loss 0.00 Accuracy 98.98
Epoch 0018 Loss 0.00 Accuracy 98.94
Epoch 0019 Loss 0.00 Accuracy 98.56
Epoch 0020 Loss 0.00 Accuracy 99.06
Epoch 0021 Loss 0.00 Accuracy 99.26
Epoch 0022 Loss 0.00 Accuracy 99.30
Epoch 0023 Loss 0.00 Accuracy 99.18
Epoch 0024 Loss 0.00 Accuracy 99.49
Epoch 0025 Loss 0.00 Accuracy 99.34
Epoch 0026 Loss 0.00 Accuracy 99.24
Epoch 0027 Loss 0.00 Accuracy 99.38
Epoch 0028 Loss 0.00 Accuracy 99.43
Epoch 0029 Loss 0.00 Accuracy 99.40
Epoch 0030 Loss 0.00 Accuracy 99.50
Epoch 0031 Loss 0.00 Accuracy 99.44
Epoch 0032 Loss 0.00 Accuracy 99.52
Epoch 0033 Loss 0.00 Accuracy 99.46
Epoch 0034 Loss 0.00 Accuracy 99.39
Epoch 0035 Loss 0.00 Accuracy 99.22
Epoch 0036 Loss 0.00 Accuracy 99.26
Epoch 0037 Loss 0.00 Accuracy 99.47
Epoch 0038 Loss 0.00 Accuracy 99.18
Epoch 0039 Loss 0.00 Accuracy 99.39
Epoch 0040 Loss 0.00 Accuracy 99.44
Epoch 0041 Loss 0.00 Accuracy 99.43
Epoch 0042 Loss 0.00 Accuracy 99.50
Epoch 0043 Loss 0.00 Accuracy 99.50
Epoch 0044 Loss 0.00 Accuracy 99.53
Epoch 0045 Loss 0.00 Accuracy 99.51
Epoch 0046 Loss 0.00 Accuracy 99.49
Epoch 0047 Loss 0.00 Accuracy 99.46
Epoch 0048 Loss 0.00 Accuracy 99.46
Epoch 0049 Loss 0.00 Accuracy 99.35
Epoch 0050 Loss 0.00 Accuracy 99.50
Epoch 0051 Loss 0.00 Accuracy 99.48
Epoch 0052 Loss 0.00 Accuracy 99.48
Epoch 0053 Loss 0.00 Accuracy 99.48
Epoch 0054 Loss 0.00 Accuracy 99.46
Epoch 0055 Loss 0.00 Accuracy 99.48
Epoch 0056 Loss 0.00 Accuracy 99.50
Epoch 0057 Loss 0.00 Accuracy 99.41
Epoch 0058 Loss 0.00 Accuracy 99.49
Epoch 0059 Loss 0.00 Accuracy 99.48
Epoch 0060 Loss 0.00 Accuracy 99.47
Epoch 0061 Loss 0.00 Accuracy 99.52
Epoch 0062 Loss 0.00 Accuracy 99.49
Epoch 0063 Loss 0.00 Accuracy 99.48
Epoch 0064 Loss 0.00 Accuracy 99.51
Epoch 0065 Loss 0.00 Accuracy 99.46
Epoch 0066 Loss 0.00 Accuracy 99.51
Epoch 0067 Loss 0.00 Accuracy 99.49
Epoch 0068 Loss 0.00 Accuracy 99.52
Epoch 0069 Loss 0.00 Accuracy 99.49
Epoch 0070 Loss 0.00 Accuracy 99.51
Epoch 0071 Loss 0.00 Accuracy 99.51
Epoch 0072 Loss 0.00 Accuracy 99.52
Epoch 0073 Loss 0.00 Accuracy 99.43
Epoch 0074 Loss 0.00 Accuracy 99.53
Epoch 0075 Loss 0.00 Accuracy 99.47
Epoch 0076 Loss 0.00 Accuracy 99.51
Epoch 0077 Loss 0.00 Accuracy 99.55
Epoch 0078 Loss 0.00 Accuracy 99.52
Epoch 0079 Loss 0.00 Accuracy 99.52
Epoch 0080 Loss 0.00 Accuracy 98.78
Epoch 0081 Loss 0.00 Accuracy 99.16
Epoch 0082 Loss 0.00 Accuracy 99.40
Epoch 0083 Loss 0.00 Accuracy 99.35
Epoch 0084 Loss 0.00 Accuracy 99.32
Epoch 0085 Loss 0.00 Accuracy 99.49
Epoch 0086 Loss 0.00 Accuracy 99.49
Epoch 0087 Loss 0.00 Accuracy 99.56
Epoch 0088 Loss 0.00 Accuracy 99.48
Epoch 0089 Loss 0.00 Accuracy 99.48
Epoch 0090 Loss 0.00 Accuracy 99.51
Epoch 0091 Loss 0.00 Accuracy 99.45
Epoch 0092 Loss 0.00 Accuracy 99.52
Epoch 0093 Loss 0.00 Accuracy 99.52
Epoch 0094 Loss 0.00 Accuracy 99.51
Epoch 0095 Loss 0.00 Accuracy 99.51
Epoch 0096 Loss 0.00 Accuracy 99.48
Epoch 0097 Loss 0.00 Accuracy 99.51
Epoch 0098 Loss 0.00 Accuracy 99.53
Epoch 0099 Loss 0.00 Accuracy 99.50
Epoch 0100 Loss 0.00 Accuracy 99.53
Training with PyTorch data processing API¶
One of the pain points for ML researchers/practioners when building a new ML model is the data processing. Here, we demonstrate how to use data processing API of PyTorch to train a model with Objax. Different deep learning library comes with different data processing APIs, and depending on your preference, you can choose an API and easily combine with Objax.
Similarly, we prepare an MNIST
dataset and apply the same data preprocessing.
[9]:
import torch
from torchvision import datasets, transforms
transform=transforms.Compose([
transforms.Pad((2,2,2,2), 0),
transforms.ToTensor(),
transforms.Lambda(lambda x: np.concatenate([x] * 3, axis=0)),
transforms.Lambda(lambda x: x * 2 - 1)
])
train_dataset = datasets.MNIST(os.environ['HOME'], train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(os.environ['HOME'], train=False, download=True, transform=transform)
# Define data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/train-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/train-labels-idx1-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /root/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
We replace data processing pipeline of the train and test loop with train_loader
and test_loader
and that’s it!
[10]:
# Train loop
def train_model_with_torch_data_api(model):
def predict(model, x):
""""""
return objax.functional.softmax(model(x, training=False))
def flatten_image(x):
"""Flatten the image before passing it to the DNN."""
if isinstance(model, DNNet):
return objax.functional.flatten(x)
else:
return x
opt = objax.optimizer.Momentum(model.vars())
# Cross Entropy Loss
def loss(x, label):
return objax.functional.loss.cross_entropy_logits_sparse(model(x, training=True), label).mean()
gv = objax.GradValues(loss, model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
for epoch in range(epochs):
avg_loss = 0
tot_data = 0
for _, (img, label) in enumerate(train_loader):
avg_loss += float(train_op(flatten_image(img.numpy()), label.numpy())[0]) * len(img)
tot_data += len(img)
avg_loss /= tot_data
# Eval
accuracy = 0
tot_data = 0
for _, (img, label) in enumerate(test_loader):
accuracy += (np.argmax(predict(model, flatten_image(img.numpy())), axis=1) == label.numpy()).sum()
tot_data += len(img)
accuracy /= tot_data
print('Epoch %04d Loss %.2f Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Training the DNN Model with PyTorch data API¶
[11]:
dnn_layer_sizes = 3072, 128, 10
dnn_model = DNNet(dnn_layer_sizes, objax.functional.leaky_relu)
train_model_with_torch_data_api(dnn_model)
Epoch 0001 Loss 2.57 Accuracy 34.35
Epoch 0002 Loss 1.93 Accuracy 58.51
Epoch 0003 Loss 1.32 Accuracy 68.46
Epoch 0004 Loss 0.83 Accuracy 80.95
Epoch 0005 Loss 0.62 Accuracy 84.74
Epoch 0006 Loss 0.53 Accuracy 86.53
Epoch 0007 Loss 0.48 Accuracy 84.18
Epoch 0008 Loss 0.45 Accuracy 88.42
Epoch 0009 Loss 0.42 Accuracy 87.34
Epoch 0010 Loss 0.40 Accuracy 89.29
Epoch 0011 Loss 0.39 Accuracy 89.31
Epoch 0012 Loss 0.38 Accuracy 89.86
Epoch 0013 Loss 0.37 Accuracy 89.91
Epoch 0014 Loss 0.36 Accuracy 86.94
Epoch 0015 Loss 0.36 Accuracy 89.89
Epoch 0016 Loss 0.35 Accuracy 90.12
Epoch 0017 Loss 0.34 Accuracy 90.40
Epoch 0018 Loss 0.34 Accuracy 90.31
Epoch 0019 Loss 0.34 Accuracy 90.79
Epoch 0020 Loss 0.33 Accuracy 90.71
Epoch 0021 Loss 0.33 Accuracy 90.70
Epoch 0022 Loss 0.33 Accuracy 90.69
Epoch 0023 Loss 0.33 Accuracy 90.91
Epoch 0024 Loss 0.32 Accuracy 90.92
Epoch 0025 Loss 0.32 Accuracy 91.06
Epoch 0026 Loss 0.32 Accuracy 91.19
Epoch 0027 Loss 0.32 Accuracy 91.31
Epoch 0028 Loss 0.31 Accuracy 91.31
Epoch 0029 Loss 0.31 Accuracy 91.20
Epoch 0030 Loss 0.31 Accuracy 91.31
Epoch 0031 Loss 0.31 Accuracy 91.36
Epoch 0032 Loss 0.31 Accuracy 91.42
Epoch 0033 Loss 0.30 Accuracy 91.27
Epoch 0034 Loss 0.31 Accuracy 91.47
Epoch 0035 Loss 0.30 Accuracy 91.57
Epoch 0036 Loss 0.30 Accuracy 91.44
Epoch 0037 Loss 0.30 Accuracy 91.55
Epoch 0038 Loss 0.30 Accuracy 91.56
Epoch 0039 Loss 0.29 Accuracy 91.75
Epoch 0040 Loss 0.29 Accuracy 91.69
Epoch 0041 Loss 0.29 Accuracy 91.60
Epoch 0042 Loss 0.29 Accuracy 91.77
Epoch 0043 Loss 0.29 Accuracy 91.76
Epoch 0044 Loss 0.29 Accuracy 91.84
Epoch 0045 Loss 0.28 Accuracy 92.05
Epoch 0046 Loss 0.28 Accuracy 91.78
Epoch 0047 Loss 0.28 Accuracy 92.01
Epoch 0048 Loss 0.28 Accuracy 91.95
Epoch 0049 Loss 0.28 Accuracy 90.11
Epoch 0050 Loss 0.28 Accuracy 92.14
Epoch 0051 Loss 0.28 Accuracy 92.03
Epoch 0052 Loss 0.27 Accuracy 92.29
Epoch 0053 Loss 0.27 Accuracy 92.17
Epoch 0054 Loss 0.27 Accuracy 92.12
Epoch 0055 Loss 0.27 Accuracy 92.34
Epoch 0056 Loss 0.27 Accuracy 92.32
Epoch 0057 Loss 0.27 Accuracy 92.47
Epoch 0058 Loss 0.27 Accuracy 92.38
Epoch 0059 Loss 0.27 Accuracy 92.39
Epoch 0060 Loss 0.26 Accuracy 92.51
Epoch 0061 Loss 0.27 Accuracy 92.50
Epoch 0062 Loss 0.26 Accuracy 92.46
Epoch 0063 Loss 0.26 Accuracy 92.65
Epoch 0064 Loss 0.26 Accuracy 92.57
Epoch 0065 Loss 0.26 Accuracy 92.63
Epoch 0066 Loss 0.26 Accuracy 92.75
Epoch 0067 Loss 0.26 Accuracy 92.57
Epoch 0068 Loss 0.26 Accuracy 92.88
Epoch 0069 Loss 0.25 Accuracy 92.53
Epoch 0070 Loss 0.25 Accuracy 92.80
Epoch 0071 Loss 0.25 Accuracy 92.71
Epoch 0072 Loss 0.25 Accuracy 92.75
Epoch 0073 Loss 0.25 Accuracy 92.84
Epoch 0074 Loss 0.25 Accuracy 92.71
Epoch 0075 Loss 0.25 Accuracy 92.95
Epoch 0076 Loss 0.25 Accuracy 92.82
Epoch 0077 Loss 0.25 Accuracy 92.90
Epoch 0078 Loss 0.25 Accuracy 92.87
Epoch 0079 Loss 0.25 Accuracy 89.55
Epoch 0080 Loss 0.25 Accuracy 92.86
Epoch 0081 Loss 0.24 Accuracy 92.99
Epoch 0082 Loss 0.24 Accuracy 93.03
Epoch 0083 Loss 0.24 Accuracy 93.03
Epoch 0084 Loss 0.24 Accuracy 93.01
Epoch 0085 Loss 0.24 Accuracy 93.13
Epoch 0086 Loss 0.24 Accuracy 93.17
Epoch 0087 Loss 0.24 Accuracy 92.87
Epoch 0088 Loss 0.24 Accuracy 92.93
Epoch 0089 Loss 0.24 Accuracy 93.16
Epoch 0090 Loss 0.24 Accuracy 93.38
Epoch 0091 Loss 0.24 Accuracy 92.98
Epoch 0092 Loss 0.24 Accuracy 93.30
Epoch 0093 Loss 0.23 Accuracy 93.09
Epoch 0094 Loss 0.23 Accuracy 93.19
Epoch 0095 Loss 0.23 Accuracy 93.25
Epoch 0096 Loss 0.23 Accuracy 93.22
Epoch 0097 Loss 0.23 Accuracy 93.28
Epoch 0098 Loss 0.23 Accuracy 93.39
Epoch 0099 Loss 0.23 Accuracy 93.25
Epoch 0100 Loss 0.23 Accuracy 93.30
Training the ConvNet Model with PyTorch data API¶
[12]:
cnn_model = ConvNet(nin=3, nclass=10)
print(cnn_model.vars())
train_model_with_torch_data_api(cnn_model)
(ConvNet).conv_block1(Sequential)[0](Conv2D).w 432 (3, 3, 3, 16)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[1](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[3](Conv2D).w 2304 (3, 3, 16, 16)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_mean 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).running_var 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).beta 16 (1, 16, 1, 1)
(ConvNet).conv_block1(Sequential)[4](BatchNorm2D).gamma 16 (1, 16, 1, 1)
(ConvNet).conv_block2(Sequential)[0](Conv2D).w 4608 (3, 3, 16, 32)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[1](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[3](Conv2D).w 9216 (3, 3, 32, 32)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_mean 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).running_var 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).beta 32 (1, 32, 1, 1)
(ConvNet).conv_block2(Sequential)[4](BatchNorm2D).gamma 32 (1, 32, 1, 1)
(ConvNet).conv_block3(Sequential)[0](Conv2D).w 18432 (3, 3, 32, 64)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[1](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[3](Conv2D).w 36864 (3, 3, 64, 64)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_mean 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).running_var 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).beta 64 (1, 64, 1, 1)
(ConvNet).conv_block3(Sequential)[4](BatchNorm2D).gamma 64 (1, 64, 1, 1)
(ConvNet).linear(Linear).b 10 (10,)
(ConvNet).linear(Linear).w 640 (64, 10)
+Total(32) 73402
Epoch 0001 Loss 0.26 Accuracy 24.18
Epoch 0002 Loss 0.05 Accuracy 37.53
Epoch 0003 Loss 0.03 Accuracy 42.17
Epoch 0004 Loss 0.03 Accuracy 73.50
Epoch 0005 Loss 0.02 Accuracy 80.33
Epoch 0006 Loss 0.02 Accuracy 83.28
Epoch 0007 Loss 0.02 Accuracy 90.87
Epoch 0008 Loss 0.01 Accuracy 98.77
Epoch 0009 Loss 0.01 Accuracy 98.42
Epoch 0010 Loss 0.01 Accuracy 98.16
Epoch 0011 Loss 0.01 Accuracy 98.74
Epoch 0012 Loss 0.01 Accuracy 95.05
Epoch 0013 Loss 0.01 Accuracy 98.89
Epoch 0014 Loss 0.00 Accuracy 98.70
Epoch 0015 Loss 0.00 Accuracy 99.01
Epoch 0016 Loss 0.00 Accuracy 98.97
Epoch 0017 Loss 0.00 Accuracy 98.79
Epoch 0018 Loss 0.00 Accuracy 98.37
Epoch 0019 Loss 0.00 Accuracy 99.19
Epoch 0020 Loss 0.00 Accuracy 99.22
Epoch 0021 Loss 0.00 Accuracy 98.43
Epoch 0022 Loss 0.00 Accuracy 99.02
Epoch 0023 Loss 0.00 Accuracy 99.38
Epoch 0024 Loss 0.00 Accuracy 99.42
Epoch 0025 Loss 0.00 Accuracy 99.45
Epoch 0026 Loss 0.00 Accuracy 99.35
Epoch 0027 Loss 0.00 Accuracy 99.42
Epoch 0028 Loss 0.00 Accuracy 99.42
Epoch 0029 Loss 0.00 Accuracy 99.14
Epoch 0030 Loss 0.00 Accuracy 99.33
Epoch 0031 Loss 0.00 Accuracy 99.36
Epoch 0032 Loss 0.00 Accuracy 99.18
Epoch 0033 Loss 0.00 Accuracy 99.43
Epoch 0034 Loss 0.00 Accuracy 99.47
Epoch 0035 Loss 0.00 Accuracy 99.49
Epoch 0036 Loss 0.00 Accuracy 99.53
Epoch 0037 Loss 0.00 Accuracy 99.38
Epoch 0038 Loss 0.00 Accuracy 99.39
Epoch 0039 Loss 0.00 Accuracy 99.49
Epoch 0040 Loss 0.00 Accuracy 99.49
Epoch 0041 Loss 0.00 Accuracy 99.47
Epoch 0042 Loss 0.00 Accuracy 99.54
Epoch 0043 Loss 0.00 Accuracy 99.35
Epoch 0044 Loss 0.00 Accuracy 99.45
Epoch 0045 Loss 0.00 Accuracy 99.47
Epoch 0046 Loss 0.00 Accuracy 99.53
Epoch 0047 Loss 0.00 Accuracy 99.50
Epoch 0048 Loss 0.00 Accuracy 99.52
Epoch 0049 Loss 0.00 Accuracy 99.51
Epoch 0050 Loss 0.00 Accuracy 99.49
Epoch 0051 Loss 0.00 Accuracy 99.45
Epoch 0052 Loss 0.00 Accuracy 99.48
Epoch 0053 Loss 0.00 Accuracy 99.50
Epoch 0054 Loss 0.00 Accuracy 99.46
Epoch 0055 Loss 0.00 Accuracy 99.50
Epoch 0056 Loss 0.00 Accuracy 99.48
Epoch 0057 Loss 0.00 Accuracy 99.46
Epoch 0058 Loss 0.00 Accuracy 99.44
Epoch 0059 Loss 0.00 Accuracy 99.46
Epoch 0060 Loss 0.00 Accuracy 99.26
Epoch 0061 Loss 0.00 Accuracy 93.99
Epoch 0062 Loss 0.00 Accuracy 97.80
Epoch 0063 Loss 0.00 Accuracy 80.26
Epoch 0064 Loss 0.00 Accuracy 99.20
Epoch 0065 Loss 0.00 Accuracy 99.38
Epoch 0066 Loss 0.00 Accuracy 99.44
Epoch 0067 Loss 0.00 Accuracy 99.51
Epoch 0068 Loss 0.00 Accuracy 99.45
Epoch 0069 Loss 0.00 Accuracy 99.42
Epoch 0070 Loss 0.00 Accuracy 99.50
Epoch 0071 Loss 0.00 Accuracy 99.52
Epoch 0072 Loss 0.00 Accuracy 99.44
Epoch 0073 Loss 0.00 Accuracy 99.41
Epoch 0074 Loss 0.00 Accuracy 99.46
Epoch 0075 Loss 0.00 Accuracy 99.42
Epoch 0076 Loss 0.00 Accuracy 99.49
Epoch 0077 Loss 0.00 Accuracy 99.50
Epoch 0078 Loss 0.00 Accuracy 99.56
Epoch 0079 Loss 0.00 Accuracy 99.52
Epoch 0080 Loss 0.00 Accuracy 99.42
Epoch 0081 Loss 0.00 Accuracy 99.49
Epoch 0082 Loss 0.00 Accuracy 99.48
Epoch 0083 Loss 0.00 Accuracy 99.44
Epoch 0084 Loss 0.00 Accuracy 99.49
Epoch 0085 Loss 0.00 Accuracy 99.53
Epoch 0086 Loss 0.00 Accuracy 99.52
Epoch 0087 Loss 0.00 Accuracy 99.52
Epoch 0088 Loss 0.00 Accuracy 99.50
Epoch 0089 Loss 0.00 Accuracy 99.51
Epoch 0090 Loss 0.00 Accuracy 99.50
Epoch 0091 Loss 0.00 Accuracy 99.49
Epoch 0092 Loss 0.00 Accuracy 99.52
Epoch 0093 Loss 0.00 Accuracy 99.50
Epoch 0094 Loss 0.00 Accuracy 99.50
Epoch 0095 Loss 0.00 Accuracy 99.55
Epoch 0096 Loss 0.00 Accuracy 99.48
Epoch 0097 Loss 0.00 Accuracy 99.51
Epoch 0098 Loss 0.00 Accuracy 99.52
Epoch 0099 Loss 0.00 Accuracy 99.49
Epoch 0100 Loss 0.00 Accuracy 99.52
What’s Next¶
We have learned how to use existing models and define new models to classify MNIST. Next, you can read one or more of the in-depth topics or browse through the Objax’s APIs.