Compilation and Parallelism¶
In this section we discuss the concepts of code compilation and parallelism typically for the purpose of accelerated performance. We’ll cover the following subtopics:
Just-In-Time (JIT) Compilation is a compilation of the code on the first time it’s executed with the goal of speeding up subsequent runs.
Parallelism runs operations on multiple devices (for example multiple GPUs).
Vectorization can be seen as batch-level parallelization, running an operation on a batch in parallel.
JIT Compilation¶
objax.Jit
is a Module that takes a module or a function and compiles it for faster performance.
As a simple starting example, let’s jit a module:
import objax
net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
jit_net = objax.Jit(net)
x = objax.random.normal((100, 2))
print(((net(x) - jit_net(x)) ** 2).sum()) # 0.0
You can also jit a function, in this case you must decorate the function with the variables it uses:
@objax.Function.with_vars(net.vars())
def my_function(x):
return objax.functional.softmax(net(x))
jit_func = objax.Jit(my_function)
print(((objax.functional.softmax(net(x)) - jit_func(x)) ** 2).sum())
In terms of performance, on this small example there’s a significant gain in speed, numbers vary depending on hardware present in your computer and what code is being jitted:
from time import time
t0 = time(); y = net(x); print(time() - t0) # 0.005...
t0 = time(); y = jit_net(x); print(time() - t0) # 0.001...
As mentioned earlier, jit_net
is a module instance, it’s sharing the variables with the module net
,
we can verify it:
print(net.vars())
# (Sequential)[0](Linear).b 3 (3,)
# (Sequential)[0](Linear).w 6 (2, 3)
# (Sequential)[2](Linear).b 4 (4,)
# (Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
print(jit_net.vars())
# (Jit)(Sequential)[0](Linear).b 3 (3,)
# (Jit)(Sequential)[0](Linear).w 6 (2, 3)
# (Jit)(Sequential)[2](Linear).b 4 (4,)
# (Jit)(Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
# We can verify that jit_func also shares the same variables
print(jit_func.vars())
# (Jit){my_function}(Sequential)[0](Linear).b 3 (3,)
# (Jit){my_function}(Sequential)[0](Linear).w 6 (2, 3)
# (Jit){my_function}(Sequential)[2](Linear).b 4 (4,)
# (Jit){my_function}(Sequential)[2](Linear).w 12 (3, 4)
# +Total(4) 25
Note that we only verified that the variables names and sizes were the same (or almost the same since the variables
in Jit are prefixed with (Jit)
).
Let’s now verify that the weights are indeed shared by modifying the weights:
net[-1].b.assign(net[-1].b.value + 1)
print(((net(x) - jit_net(x)) ** 2).sum()) # 0.0
# Both net(x) and jit_net(x) were affected in the same way by the change
# since the weights are shared.
# You can also inspect the values print(net(x)) for more insight.
A realistic case: Fully jitted training step¶
Let’s write a classifier training op, this is very similar to example shown in Loss optimization. We are going to define a model, an optimizer, a loss and a gradient:
import objax
m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())
@objax.Function.with_vars(m.vars())
def loss(x, labels):
logits = m(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, m.vars())
@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, g) # Apply SGD
return v # Return loss value
# It's better to jit the top level call to allow internal optimizations.
train_jit = objax.Jit(train)
Note that we passed to Jit all the vars that were used in train
.
We passed gradient_loss.vars() + opt.vars()
.
Why didn’t we pass m.vars() + gradient_loss.vars() + opt.vars()
?
We could and it’s perfectly fine to do so, but keep in mind that gradient_loss
is itself a module which shares
the weights of m
and consequently m.vars()
is already included in gradient_loss.vars()
.
Static arguments¶
Static arguments are arguments that are treated as static (compile-time constant) in the jitted function. Boolean arguments, numerical arguments used in comparisons (resulting in a bool), strings must be marked as static.
Calling the jitted function with different values for these constants will trigger recompilation. As a rule of thumb:
Good static arguments: training (boolean), my_mode (int that can take only a few values), …
Bad static arguments: training_step (int that can take a lot of values)
Let’s look at an example with BatchNorm which takes a training argument:
import objax
net = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)])
@objax.Function.with_vars(net.vars())
def f(x, training):
return net(x, training=training)
jit_f_static = objax.Jit(f, static_argnums=(1,))
# Note the static_argnums=(1,) which indicates that argument 1 (training) is static.
x = objax.random.normal((100, 2))
print(((net(x, training=False) - jit_f_static(x, False)) ** 2).sum()) # 0.0
What happens if you don’t use static_argnums
?
jit_f = objax.Jit(f)
y = jit_f(x, False)
# Traceback (most recent call last):
# File <...>
# <long stack trace>
# jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
# Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
# See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
# Encountered value: Traced<ShapedArray(bool[], weak_type=True):JaxprTrace(level=-1/1)>
To cut a long story short: when compiling boolean inputs must be made static.
For more information, please refer to jax.jit which is the API Objax uses under the hood.
Constant optimization¶
As seen previously, objax.Jit
takes a variables
argument to specify the variables of a function
or of a module that Jit is compiling.
If a variable is not passed to Jit it will be treated as a constant and will be optimized away.
Warning
A jitted module will not see any change made to a constant. A constant is not expected to change since it is supposed to be… constant!
A simple constant optimization example:
import objax
m = objax.nn.Linear(3, 4)
# Pass an empty VarCollection to signify to Jit that m has no variable.
jit_constant = objax.Jit(m, objax.VarCollection())
x = objax.random.normal((10, 3))
print(((m(x) - jit_constant(x)) ** 2).sum()) # 0.0
# Modify m (which was supposed to be constant!)
m.b.assign(m.b.value + 1)
print(((m(x) - jit_constant(x)) ** 2).sum()) # 40.0
# As expected jit_constant didn't see the change.
Warning
The XLA backend (the interface to the hardware) will do the constant optimization and may take a long time and a lot of memory due to compilation, often with very little gain in final performance, if any.
Parallelism¶
Note
If you don’t have multiple devices, you can simulate them on CPU by starting python with the following command:
CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=8 python
Alternatively you can do it in Python directly by inserting this snippet before importing Objax:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
objax.Parallel
provides a way to distribute computations across multi-GPU (or TPU).
It also performs JIT under the hood and its API shares a lot with objax.Jit
:
It takes a function to be compiled, a VarCollection
as well as a
static_argnums parameters which all behave the same as in Jit.
However it also takes specific arguments for the task of handling parallelism which we are going to introduce.
When running a parallelized a function f
on a batch \(x\) of shape \((n, ...)\) on \(d\) devices,
the following steps happen:
The batch \(x\) is divided into \(d\) sub-batches \(x_i\) of shape \((n/d, ...)\) for \(i\in\{0, ..., d-1\}\)
Each sub-batch \(x_i\) is passed to
f
and ran on device \(i\) in parallel.The results are collected as output sub-subatches \(y_i=f(x_i)\)
The outputs \(y_i\) are represented as a single tensor \(y\) of shape \((d, ...)\)
The final output is obtained by calling the
reduce
function on \(y\):out = reduce(y)
.
With this in mind, we can now detail the additional arguments of objax.Parallel
:
reduce
: a function that aggregates the output results from each GPU/TPU.axis_name
: is the name of the device dimension which we referred to as \(d\) earlier. By default, it is called'device'
.
Let’s illustrate this with a simple example with the parallelization of a module (para_net
) and of a function
(para_func
):
# This code was run on 8 simulated devices
import objax
net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_net = objax.Parallel(net)
para_func = objax.Parallel(objax.Function(lambda x: net(x) + 1, net.vars()))
# A batch of mockup data
x = objax.random.normal((96, 3))
# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
y = para_net(x)
z = para_func(x)
print(((net(x) - y) ** 2).sum()) # 8.90954e-14
print(((net(x) - (z - 1)) ** 2).sum()) # 4.6487814e-13
We can also show the parallel version of A realistic case: Fully jitted training step, highlighted are the changes from the jitted version:
import objax
m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.functional.relu, objax.nn.Linear(3, 4)])
opt = objax.optimizer.Momentum(m.vars())
@objax.Function.with_vars(m.vars())
def loss(x, labels):
logits = m(x)
return objax.functional.loss.cross_entropy_logits(logits, labels).mean()
gradient_loss = objax.GradValues(loss, m.vars())
@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, objax.functional.parallel.pmean(g)) # Apply averaged gradients
return objax.functional.parallel.pmean(v) # Return averaged loss value
# It's better to parallelize the top level call to allow internal optimizations.
train_para = objax.Parallel(train, reduce=lambda x:x[0])
Let’s study the concepts introduced in this example in more details.
Variable replication¶
Variable replication copies the variables to multiple devices’ own memory.
It is necessary to do variable replication before calling a parallelized module or function.
Variable replication is done through objax.VarCollection.replicate()
which is a context manager.
One could go further and creating their own replication, this is not covered here but the source of replicate
is
rather simple and a good starting point.
Here is a detailed example:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
m = objax.ModuleList([objax.TrainVar(jn.arange(5))])
# We use "repr" to see the whole type information.
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
with m.vars().replicate():
# In the scope of the with-statement, the variables are replicated to all devices.
print(repr(m[0].value))
# ShardedDeviceArray([[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]], dtype=int32)
# SharedDeviceArray is a DeviceArray across multiple devices.
# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value)) # DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
Something interesting happened: the value of m[0]
was initially of type integer but it became a float by the
end.
This is due to the reduction that follows a replication.
By default, the reduction method takes the average of the copies on each device.
And the average of multiple integer values is casted automatically to a float.
You can customize the variable reduction, this is not something one typically would need to do but it’s available for advanced users nonetheless:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
m = objax.ModuleList([objax.TrainVar(jn.arange(5), reduce=lambda x: x[0]),
objax.TrainVar(jn.arange(5), reduce=lambda x: x.sum(0)),
objax.TrainVar(jn.arange(5), reduce=lambda x: jn.stack(x))])
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[2].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
with m.vars().replicate():
pass
# When exiting the with-statement, the variables are reduced back to their original device.
print(repr(m[0].value)) # DeviceArray([0, 1, 2, 3, 4], dtype=int32)
print(repr(m[1].value)) # DeviceArray([ 0, 8, 16, 24, 32], dtype=int32)
print(repr(m[2].value)) # DeviceArray([[0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4],
# [0, 1, 2, 3, 4]], dtype=int32)
Output aggregation¶
Similarly the output \(y\) of parallel call is reduced using the reduce
argument.
The first dimension \(d\) of \(y\) is the device dimension and its name comes from the axis_name
argument while by default is simply "device"
.
Again, let’s look at a simple example:
# This code was run on 8 simulated devices
import objax
import jax.numpy as jn
net = objax.nn.Sequential([objax.nn.Linear(3, 4), objax.functional.relu])
para_none = objax.Parallel(net, reduce=lambda x: x)
para_first = objax.Parallel(net, reduce=lambda x: x[0])
para_concat = objax.Parallel(net, reduce=lambda x: jn.concatenate(x))
para_average = objax.Parallel(net, reduce=lambda x: x.mean(0))
# A batch of mockup data
x = objax.random.normal((96, 3))
# We're running on multiple devices, copy the model variables to all of them first.
with net.vars().replicate():
print(para_none(x).shape) # (8, 12, 4)
print(para_first(x).shape) # (12, 4)
print(para_concat(x).shape) # (96, 4) - This is the default setting
print(para_average(x).shape) # (12, 4)
In the example above, the batch x (of size 96) was divided into 8 batches of size 12 by the parallel call. Each of these batches was processed on its own device. The final value was then reduced using the provided reduce method.
para_none
didn’t do any reduction, it just returned the value it was given, as expected is shape is(devices, batch // devices, ...)
.para_first
andpara_mean
took either the first entry or the average over dimension 0, resulting in a shape(batch // devices, ...)
.para_concat
concatenated all the values resulting in a shape of(batch, ...)
.
Synchronized computations¶
So far, we only considered the case where all the devices were acting on their own, unaware of others’ existence. It’s commonly desirable for devices to communicate with each other.
For example, when training a model, for efficiency one would want the optimizer to update the weights on all the devices at the same time. To achieve this, we would like the gradients to be computed for each sub-batch on the device, and then averaged across all devices.
The good news is it is very easy to do, there are a set of predefined primitives that can be found in
objax.functional.parallel
which are the direct equivalent of single device primitives:
objax.functional.parallel.pmax()
is the multi-device equivalent ofjax.numpy.max
objax.functional.parallel.pmean()
is the multi-device equivalent ofjax.numpy.mean
and so on…
Recalling the code for the parallelized train operation:
@objax.Function.with_vars(m.vars() + opt.vars())
def train(x, labels, lr):
g, v = gradient_loss(x, labels) # Compute gradients and loss
opt(lr, objax.funcational.parallel.pmean(g)) # Apply averaged gradients
return objax.funcational.parallel.pmean(v) # Return averaged loss value
The train function is called on each device in parallel.
The objax.funcational.parallel.pmean(g)
averages the gradients g
on all devices.
Then on each device, the optimizer applies the averaged gradient to the local weight copy.
Finally the average loss is returned objax.funcational.parallel.pmean(v)
.
Vectorization¶
objax.Vectorize
is the module responsible for code vectorization.
Vectorization can be seen as a parallelization without knowledge of the devices available.
On a single GPU, vectorization parallelizes the execution in concurrent threads.
It can be combined with objax.Parallel
resulting in multi-GPU multi-threading!
Vectorization can also be done on a single CPU.
A typical example of CPU vectorization could data pre-processing or augmentation.
In its simplest form vectorization applies a function to the elements of a batch concurrently.
objax.Vectorize
takes a module or a function f
and vectorizes it.
Similarly to Jit
and Parallel
you must specify the variables used by the function.
Finally batch_axis
is used to say which axis should be considered as the batch axis for each input
argument of f
.
For values with no batch axis, for example when passing a value to be shared by all the calls to
the function f
, set its batch axis to None
to broadcast it.
Let’s clarify this with a simple example:
# Randomly reverse rows in a batch.
import objax
import jax.numpy as jn
class RandomReverse(objax.Module):
"""Randomly reverse a single vector x and add a value y to it."""
def __init__(self, keygen=objax.random.DEFAULT_GENERATOR):
self.keygen = keygen
def __call__(self, x, y):
r = objax.random.randint([], 0, 2, generator=self.keygen)
return x + y + r * (x[::-1] - x), r, y
random_reverse = RandomReverse()
vector_reverse = objax.Vectorize(random_reverse, batch_axis=(0, None))
# vector_reverse takes two arguments (just like random_reverse), we're going to pass:
# - a matrix x for the first argument, interpreted as a batch of vectors (batch_axis=0).
# - a value y for the second argument, interpreted as a broadcasted value (batch_axis=None).
# Test it on some mock up data
x = jn.arange(20).reshape((5, 4))
print(x) # [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]
# [12 13 14 15]
# [16 17 18 19]]
objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, 1)
print(r) # [0 1 0 1 1] - whether a row was reversed
print(y) # [1 1 1 1 1] - the brodacasted input y
print(z) # [[ 1 2 3 4]
# [ 8 7 6 5]
# [ 9 10 11 12]
# [16 15 14 13]
# [20 19 18 17]]
# Above we added a single constant (y=1)
# We can also add a vector y=(-2, -1, 0, 1)
objax.random.DEFAULT_GENERATOR.seed(1337)
z, r, y = vector_reverse(x, jn.array((-2, -1, 0, 1)))
print(r) # [0 1 0 1 1] - whether a row was reversed
print(y) # [[-2 -1 0 1] - the brodacasted input y
# [-2 -1 0 1]
# [-2 -1 0 1]
# [-2 -1 0 1]
# [-2 -1 0 1]]
print(z) # [[-2 0 2 4]
# [ 5 5 5 5]
# [ 6 8 10 12]
# [13 13 13 13]
# [17 17 17 17]]
Computing weights gradients per batch entry¶
This is a more advanced example, conceptually it is similar to what’s powering differential privacy gradients:
import objax
m = objax.nn.Linear(3, 4)
@objax.Function.with_vars(m.vars())
def loss(x, y):
return ((m(x) - y) ** 2).mean()
g = objax.Grad(loss, m.vars())
single_gradients = objax.Vectorize(g, batch_axis=(0, 0)) # Batch is dimension of x and y
# Mock some data
x = objax.random.normal((10, 3))
y = objax.random.normal((10, 4))
# Compute standard gradients
print([v.shape for v in g(x, y)]) # [(4,), (3, 4)]
# Compute per batch entry gradients
print([v.shape for v in single_gradients(x, y)]) # [(10, 4), (10, 3, 4)]
As expected, we obtained as many gradients for each of the network’s weights as there are entries in the batch.