Code Examples

This section describes the code examples found in objax/examples

Classification

Image

Example code available at examples/classify.

Logistic Regression

Train and evaluate a logistic regression model for binary classification on horses or humans dataset.

# Run command
python3 examples/classify/img/logistic.py

Code

examples/classify/img/logistic.py

Data

horses_or_humans from tensorflow_datasets

Network

Custom single layer

Loss

objax.functional.loss.sigmoid_cross_entropy_logits()

Optimizer

objax.optimizer.SGD

Accuracy

~77%

Hardware

CPU or GPU or TPU

Digit Classification with Deep Neural Network (DNN)

Train and evaluate a DNNet model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/classify/img/mnist_dnn.py

Code

examples/classify/img/mnist_dnn.py

Data

MNIST from tensorflow_datasets

Network

Deep Neural Net objax.zoo.DNNet

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Adam

Accuracy

~98%

Hardware

CPU or GPU or TPU

Techniques

Model weight averaging for improved accuracy using objax.optimizer.ExponentialMovingAverage.

Digit Classification with Convolutional Neural Network (CNN)

Train and evaluate a simple custom CNN model for multiclass classification on the MNIST dataset.

# Run command
python3 examples/classify/img/mnist_cnn.py

Code

examples/classify/img/mnist_cnn.py

Data

MNIST from tensorflow_datasets

Network

Custom Convolution Neural Net using objax.nn.Sequential

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Adam

Accuracy

~99.5%

Hardware

CPU or GPU or TPU

Techniques

Digit Classification using Differential Privacy

Train and evaluate a convNet model for MNIST dataset with differential privacy.

# Run command
python3 examples/classify/img/mnist_dp.py
# See available options with
python3 examples/classify/img/mnist_dp.py --help

Code

examples/classify/img/mnist_dp.py

Data

MNIST from tensorflow_datasets

Network

Custom Convolution Neural Net using objax.nn.Sequential

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.SGD

Accuracy

Hardware

GPU

Techniques

Image Classification on CIFAR-10 (Simple)

Train and evaluate a wide resnet model for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/classify/img/cifar10_simple.py

Code

examples/classify/img/cifar10_simple.py

Data

CIFAR10 from tf.keras.datasets

Network

Wide ResNet using objax.zoo.wide_resnet.WideResNet

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

~91%

Hardware

GPU or TPU

Techniques

  • Learning rate schedule.

  • Data augmentation (mirror / pixel shifts) in Numpy.

  • Regularization using extra weight decay term in loss.

Image Classification on CIFAR-10 (Advanced)

Train and evaluate convNet models for multiclass classification on the CIFAR10 dataset.

# Run command
python3 examples/classify/img/cifar10_advanced.py
# Run with custom settings
python3 examples/classify/img/cifar10_advanced.py --weight_decay=0.0001 --batch=64 --lr=0.03 --epochs=256
# See available options with
python3 examples/classify/img/cifar10_advanced.py --help

Code

examples/classify/img/cifar10_advanced.py

Data

CIFAR10 from tensorflow_datasets

Network

Configurable with --arch="network" * wrn28-1, wrn28-2 using objax.zoo.wide_resnet.WideResNet * cnn32-3-max, cnn32-3-mean, cnn64-3-max, cnn64-3-mean using objax.zoo.convnet.ConvNet

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Momentum

Accuracy

~94%

Hardware

GPU, Multi-GPU or TPU

Techniques

  • Model weight averaging for improved accuracy using objax.optimizer.ExponentialMovingAverage.

  • Parallelized on multiple GPUs using objax.Parallel.

  • Data augmentation (mirror / pixel shifts) in TensorFlow.

  • Cosine learning rate decay.

  • Regularization using extra weight decay term in loss.

  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using objax.io.Checkpoint.

  • Saving of tensorboard visualization files using objax.jaxboard.SummaryWriter.

  • Multi-loss reporting (cross-entropy, L2).

  • Reusable training loop example.

Image Classification on ImageNet

Train and evaluate a ResNet50 model on the ImageNet dataset. See README for additional information.

Code

examples/classify/img/imagenet/imagenet_train.py

Data

ImageNet from tensorflow_datasets

Network

ResNet50

Loss

objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

Hardware

GPU, Multi-GPU or TPU

Techniques

  • Parallelized on multiple GPUs using objax.Parallel.

  • Data augmentation (distorted bounding box crop) in TensorFlow.

  • Linear warmup followed by multi-step learning rate decay.

  • Regularization using extra weight decay term in loss.

  • Checkpointing, automatic resuming from latest checkpoint if training is interrupted using objax.io.Checkpoint.

  • Saving of tensorboard visualization files using objax.jaxboard.SummaryWriter.

Image Classification using Pretrained VGG Network

Image classification using an ImageNet-pretrained VGG19 model. See README for additional information.

Code

examples/classify/img/pretrained_vgg.py

Techniques

Load VGG-19 model with pretrained weights and run 1000-way image classification.

Semi-Supervised Learning

Example code available at examples/semi_supervised.

Semi-Supervised Learning with FixMatch

Semi-supervised learning of image classification models with FixMatch.

# Run command
python3 examples/classify/semi_supervised/img/fixmatch.py
# Run with custom settings
python3 examples/classify/semi_supervised/img/fixmatch.py --dataset=cifar10.3@1000-0
# See available options with
python3 examples/classify/semi_supervised/img/fixmatch.py --help

Code

examples/classify/semi_supervised/img/fixmatch.py

Data

CIFAR10, CIFAR100, SVHN, STL10

Network

Custom implementation of Wide ResNet.

Loss

objax.functional.loss.cross_entropy_logits() and objax.functional.loss.cross_entropy_logits_sparse()

Optimizer

objax.optimizer.Momentum

Accuracy

See paper

Hardware

GPU, Multi-GPU, TPU

Techniques

GPT-2

Example code is available at examples/gpt-2.

Generating a Text Sequence using GPT-2

Load pretrained GPT-2 model (124M parameter) and demonstrate how to use the model to generate a text sequence. See README for additional information.

Code

examples/gpt-2/gpt2.py

Hardware

GPU or TPU

Techniques

  • Define Transformer model.

  • Load GPT-2 model with pretrained weights and generate a sequence.

RNN

Example code is available at examples/rnn.

Train a Vanilla RNN to Predict Characters

Train and evaluate a vanilla RNN model on the Shakespeare corpus dataset. See README for additional information.

# Run command
python3 examples/rnn/shakespeare.py

Code

examples/rnn/shakespeare.py

Data

Shakespeare corpus from tensorflow_datasets

Network

Custom implementation of vanilla RNN.

Loss

objax.functional.loss.cross_entropy_logits()

Optimizer

objax.optimizer.Adam

Hardware

GPU or TPU

Techniques

Optimization

Example codes available at examples/optimization.

Model Agnostic Meta-Learning (MAML)

Meta-learning method MAML implementation to demonstrate computing the gradient of a gradient.

# Run command
python3 examples/optimization/maml.py

Code

examples/optimization/maml.py

Data

Synthetic data

Network

3-layer DNNet

Hardware

CPU or GPU or TPU

Techniques

Gradient of gradient.

Jaxboard

Example code available at examples/jaxboard.

How to Use Jaxboard

Sample usage of jaxboard. See README for additional information.

# Run command
python3 examples/jaxboard/summary.py

Code

examples/jaxboard/summary.py

Hardware

CPU

Usages

  • summary scalar

  • summary text

  • summary image