Welcome to Objax’s documentation!¶
Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX – a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
'Hello world': optimizing the weights of classifier
net through gradient descent:
opt = objax.optimizer.Adam(net.vars()) def loss(x, y): logits = net(x) # Output of classifier on x xe = cross_entropy_logits(logits, y) return xe.mean() # Perform gradient descent wrt to net weights gv = objax.GradValues(loss, net.vars()) def train_op(x, y): g, v = gv(x, y) # returns gradients g and loss v opt(lr, g) # update weights return v train_op = objax.Jit(train_op, net.vars() + opt.vars())
Objax pursues the quest for the simplest design and code that’s as easy as possible to extend without sacrificing performance.
– Objax Devs
Researchers and students look at machine learning frameworks in their own way. Often they read the code of some technique, say an Adam optimizer, to understand how it works so they can extend it or design a new optimizer. This is how machine learning frameworks differ from standard libraries: a large class of users not only look at the APIs but also at the code behind these APIs.
Coded for simplicity¶
Source code should be understandable by everyone, including users without background in computer science. So how simple is it really? Judge for yourself with this tutorial: Logistic Regression.
It is common in machine learning to separate the inputs (\(X\)) from the parameters (\(\theta\)) of a function \(f(X; \theta)\). Math notation captures this difference by using a semi-colon to semantically separate the first group of arguments from the other.
Objax represents this semantic distinction through
the module’s parameters \(\theta\) are attributes of the form
inputs \(X\) are method arguments such as
def __call__(self, x, y, ...):
Designed for flexibility¶
Objax minimizes the number of abstractions users need to understand. There are two main ones: Modules and Variables. Everything is built out of these two basic classes. You can read more about this in Variables and Modules.
Engineered for performance¶
In machine learning, performance is essential. Every second counts. Objax makes it count by using the JAX/XLA engine that also powers TensorFlow. Read more about this in Compilation and Parallelism.
- Variables and Modules
- Understanding Gradients
- Compilation and Parallelism
- Loading and Saving