Interactive online version: Open In Colab

Logistic Regression

In this example we will see how to classify images as horses or people using logistic regression. The tutorial builds upon the concepts introduced in the Objax basics tutorial. Consider reading that tutorial first, or you can always go back.


First, we import the modules we will use in our code.

%pip --quiet install objax

import matplotlib.pyplot as plt
import os

import numpy as np
import tensorflow_datasets as tfds

import objax
from objax.util import EasyDict

Loading the Dataset

Next, we will load the “horses_or_humans” dataset from TensorFlow DataSets.

The prepare method downscales the image by 3x to reduce training time, flattens each image to a vector, and rescales each pixel value to [-1,1].

# Data: train has 1027 images - test has 256 images
# Each image is 300 x 300 x 3 bytes
DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')
data = tfds.as_numpy(tfds.load(name='horses_or_humans', batch_size=-1, data_dir=DATA_DIR))

def prepare(x, downscale=3):
  """Normalize images to [-1, 1] and downscale them to 100x100x3 (for faster training) and flatten them."""
  s = x.shape
  x = x.astype('f').reshape((s[0], s[1] // downscale, downscale, s[2] // downscale, downscale, s[3]))
  return x.mean((2, 4)).reshape((s[0], -1)) * (1 / 127.5) - 1

train = EasyDict(image=prepare(data['train']['image']), label=data['train']['label'])
test = EasyDict(image=prepare(data['test']['image']), label=data['test']['label'])
ndim = train.image.shape[-1]
del data

Visualizing the Data

Let’s see a couple of the images in the dataset and their corresponding labels. Note that label 0 corresponds to a horse, while label 1 corresponds to a human.

#sample image of a horse.
horse_image = np.reshape(train.image[0], [100,100,3])
print("label for horse_image:", train.label[0])
label for horse_image: 0
#sample image of a human.
human_image = np.reshape(train.image[9], [100, 100, 3])
print("label for human_image:", train.label[9])

label for human_image: 1

Model Definition

objax.nn.Linear(ndim, 1) is a linear neural unit with ndim inputs and a single output. Given input \(\mathbf{X}\), the output is equal to \(\mathbf{W}\mathbf{X} + \mathbf{b}\) where \(\mathbf{W}, \mathbf{b}\) are the model’s parameters. These parameters are available through model.vars()

# Settings
lr = 0.0001  # learning rate
batch = 256
epochs = 20

model = objax.nn.Linear(ndim, 1)
(Linear).b                  1 (1,)
(Linear).w              30000 (30000, 1)
+Total(2)               30001

Model Inference

Now that we have defined the model we can use to classify images. To do so, we call the model with an image from the train dataset we previously prepared. Notice that we use the image of a human we previously visualized.

We get the output of the model by calling model(). We then apply the sigmoid activation function and round the output. Activation outputs lower than or equal to 0.5 are rounded to zero (i.e., horses) whereas outputs larger than 0.5 are rounded to one (i.e., humans).

# This is an image of a human.

Considering that we initialized the model with random weights, it should not come as a surprise that the model may misclassify a human as a horse.

Optimizer and Loss Function

In this example we use the objax.optimizer.SGD optimizer. Next, we define the loss function we will use to optimize the network. In this case we use the cross entropy loss function. Note that we use objax.functional.loss.sigmoid_cross_entropy_logits because we perform binary classification.

opt = objax.optimizer.SGD(model.vars())

# Cross Entropy Loss
def loss(x, label):
  return objax.functional.loss.sigmoid_cross_entropy_logits(model(x)[:, 0], label).mean()

Back Propagation and Gradient Descent

objax.GradValues calculates the gradient of loss wrt model.vars(). If you want to learn more about gradients read the Understanding Gradients in-depth topic.

The train_op function implements the core of backward propagation and gradient descent. First, we calculate the gradient g and then pass it to the optimizer which updates the model’s weights.

gv = objax.GradValues(loss, model.vars())

def train_op(x, label):
  g, v = gv(x, label)  # returns gradients, loss
  opt(lr, g)
  return v

# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())

Training and Evaluation Loop

For each of the training epochs we process all the training data, contained in the train dictionary, in batches of batch size. At the end of each epoch we compute the classification accuracy by comparing the model’s predictions over the test data to the ground truth labels.

for epoch in range(epochs):
  # Train
  avg_loss = 0
  # randomly shuffle training data
  shuffle_idx = np.random.permutation(train.image.shape[0])
  for it in range(0, train.image.shape[0], batch):
    sel = shuffle_idx[it: it + batch]
    avg_loss += float(train_op(train.image[sel], train.label[sel])[0]) * len(sel)
  avg_loss /= it + batch

  # Eval
  accuracy = 0
  for it in range(0, test.image.shape[0], batch):
    x, y = test.image[it: it + batch], test.label[it: it + batch]
    accuracy += (np.round(objax.functional.sigmoid(model(x)))[:, 0] == y).sum()
  accuracy /= test.image.shape[0]
  print('Epoch %04d  Loss %.2f  Accuracy %.2f' % (epoch + 1, avg_loss, 100 * accuracy))
Epoch 0001  Loss 0.25  Accuracy 82.81
Epoch 0002  Loss 0.25  Accuracy 82.81
Epoch 0003  Loss 0.25  Accuracy 83.59
Epoch 0004  Loss 0.25  Accuracy 82.81
Epoch 0005  Loss 0.25  Accuracy 82.03
Epoch 0006  Loss 0.25  Accuracy 80.86
Epoch 0007  Loss 0.25  Accuracy 80.86
Epoch 0008  Loss 0.24  Accuracy 80.86
Epoch 0009  Loss 0.24  Accuracy 82.03
Epoch 0010  Loss 0.24  Accuracy 83.20
Epoch 0011  Loss 0.24  Accuracy 82.03
Epoch 0012  Loss 0.24  Accuracy 80.86
Epoch 0013  Loss 0.24  Accuracy 81.25
Epoch 0014  Loss 0.24  Accuracy 83.59
Epoch 0015  Loss 0.24  Accuracy 84.38
Epoch 0016  Loss 0.24  Accuracy 85.16
Epoch 0017  Loss 0.24  Accuracy 84.77
Epoch 0018  Loss 0.24  Accuracy 83.98
Epoch 0019  Loss 0.24  Accuracy 83.20
Epoch 0020  Loss 0.24  Accuracy 83.59

Model Inference After Training

Now that the network is trained we can retry classification example above:


Next Steps

We saw how we can define, use, and train a very simple model in Objax to classify images of humans and horses. Next, we will learn how to create custom models to classify handwritten digits.