Installation and Setup

For developing or contributing to Objax, see Development setup.

User installation

Install using pip with the following command:

pip install --upgrade objax

For GPU support, we assume you have already some version of CUDA installed (jaxlib releases require CUDA 11.2 or newer). Here are the extra steps:

JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
pip uninstall -y jaxlib
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION

Useful shell configurations

Here are a few useful options:

# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)

Testing your installation

You can run the code below to test your installation:

import jax
import objax

print(f'Number of GPUs {jax.device_count()}')

x = objax.random.normal((100, 4))
m = objax.nn.Linear(4, 5)
print('Matrix product shape', m(x).shape)  # (100, 5)

x = objax.random.normal((100, 3, 32, 32))
m = objax.nn.Conv2D(3, 4, k=3)
print('Conv2D return shape', m(x).shape)  # (100, 4, 32, 32)

If you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.

Installing examples

Clone the code repository:

git clone
cd objax/examples