Welcome to Objax’s documentation!

Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX – a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.

Try the 5 minutes tutorial.

Machine learning’s 'Hello world': optimizing the weights of classifier net through gradient descent:

opt = objax.optimizer.Adam(net.vars())

def loss(x, y):
    logits = net(x)  # Output of classifier on x
    xe = cross_entropy_logits(logits, y)
    return xe.mean()

# Perform gradient descent wrt to net weights
gv = objax.GradValues(loss, net.vars())

def train_op(x, y):
    g, v = gv(x, y)  # returns gradients g and loss v
    opt(lr, g)  # update weights
    return v

train_op = objax.Jit(train_op, net.vars() + opt.vars())

Objax philosophy

Objax pursues the quest for the simplest design and code that’s as easy as possible to extend without sacrificing performance.

– Objax Devs


Researchers and students look at machine learning frameworks in their own way. Often they read the code of some technique, say an Adam optimizer, to understand how it works so they can extend it or design a new optimizer. This is how machine learning frameworks differ from standard libraries: a large class of users not only look at the APIs but also at the code behind these APIs.

Coded for simplicity

Source code should be understandable by everyone, including users without background in computer science. So how simple is it really? Judge for yourself with this tutorial: Logistic Regression.


It is common in machine learning to separate the inputs (\(X\)) from the parameters (\(\theta\)) of a function \(f(X; \theta)\). Math notation captures this difference by using a semi-colon to semantically separate the first group of arguments from the other.

Objax represents this semantic distinction through objax.Module:

  • the module’s parameters \(\theta\) are attributes of the form self.w, ...
  • inputs \(X\) are method arguments such as def __call__(self, x, y, ...):

Designed for flexibility

Objax minimizes the number of abstractions users need to understand. There are two main ones: Modules and Variables. Everything is built out of these two basic classes. You can read more about this in Variables and Modules.

Indices and tables