objax.optimizer package

Adam(vc[, beta1, beta2, eps])

Adam optimizer.

ExponentialMovingAverageModule(module[, …])

Creates a module that uses the moving average weights of another module.

ExponentialMovingAverage(vc[, momentum, …])

Maintains exponential moving averages for each variable from provided VarCollection.

LARS(vc[, momentum, weight_decay, tc, eps])

Layerwise adaptive rate scaling (LARS) optimizer.

Momentum(vc[, momentum, nesterov])

Momentum optimizer.

SGD(vc)

Stochastic Gradient Descent (SGD) optimizer.

class objax.optimizer.Adam(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Adam optimizer.

Adam is an adaptive learning rate optimization algorithm originally presented in Adam: A Method for Stochastic Optimization. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\), the update rule is as follows:

\[\begin{split}\begin{eqnarray} v_{k} &=& \beta_1 v_{k-1} + (1 - \beta_1) \nabla f (.; w_{k-1}) \nonumber \\ s_{k} &=& \beta_2 s_{k-1} - (1 - \beta_2) (\nabla f (.; w_{k-1}))^2 \nonumber \\ \hat{v_{k}} &=& \frac{v_{k}}{(1 - \beta_{1}^{k})} \nonumber \\ \hat{s_{k}} &=& \frac{s_{k}}{(1 - \beta_{2}^{k})} \nonumber \\ w_{k} &=& w_{k-1} - \eta \frac{\hat{v_{k}}}{\sqrt{\hat{s_{k}}} + \epsilon} \nonumber \end{eqnarray}\end{split}\]

Adam updates exponential moving averages of the gradient \((v_{k})\) and the squared gradient \((s_{k})\) where the hyper-parameters \(\beta_1\) and \(\beta_2 \in [0, 1)\) control the exponential decay rates of these moving averages. The \(\eta\) constant in the weight update rule is the learning rate and is passed as a parameter in the __call__ method. Note that the implementation uses the approximation \(\sqrt{(\hat{s_{k}} + \epsilon)} \approx \sqrt{\hat{s_{k}}} + \epsilon\).

__init__(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]

Constructor for Adam optimizer class.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • beta1 (float) – value of Adam’s beta1 hyperparameter. Defaults to 0.9.

  • beta2 (float) – value of Adam’s beta2 hyperparameter. Defaults to 0.999.

  • eps (float) – value of Adam’s epsilon hyperparameter. Defaults to 1e-8.

__call__(lr, grads, beta1=None, beta2=None)[source]

Updates variables and other state based on Adam algorithm.

Parameters
  • lr (float) – the learning rate.

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.

  • beta1 (Optional[float]) – optional, override the default beta1.

  • beta2 (Optional[float]) – optional, override the default beta2.

class objax.optimizer.ExponentialMovingAverageModule(module, momentum=0.999, debias=False, eps=1e-06)[source]

Creates a module that uses the moving average weights of another module.

Convenience interface to apply objax.optimizer.ExponentialMovingAverage to a module.

Usage example:

import objax

m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)])
m_ema = objax.optimizer.ExponentialMovingAverageModule(m, momentum=0.999, debias=True)

x = objax.random.uniform((16, 2))

# When the weights of m change, simply call update_ema() to update moving averages
v1 = m(x, training=True)
m_ema.update_ema()

# You call m_ema just like you would call m
v2 = m_ema(x, training=False)
__init__(module, momentum=0.999, debias=False, eps=1e-06)[source]

Creates ExponentialMovingAverageModule instance with given hyperparameters.

Parameters
  • module (objax.module.Module) – a module for which to compute the moving average.

  • momentum (float) – the decay factor for the moving average.

  • debias (bool) – bool indicating whether to use initialization bias correction.

  • eps (float) – small adjustment to prevent division by zero.

__call__(*args, **kwargs)[source]

Calls the original module with moving average weights.

update_ema()[source]

Updates the moving average.

class objax.optimizer.ExponentialMovingAverage(vc, momentum=0.999, debias=False, eps=1e-06)[source]

Maintains exponential moving averages for each variable from provided VarCollection.

When training a model, it is often beneficial to maintain exponential moving averages (EMA) of the trained parameters. Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values (see Acceleration of Stochastic Approximation by Averaging).

This maintains an EMA of the parameters passed in the VarCollection vc. The EMA update rule for weights \(w\), the EMA \(m\) at step \(t\) when using a momentum \(\mu\) is:

\[m_t = \mu m_{t-1} + (1 - \mu) w_t\]

The EMA weights \(\hat{w_t}\) are simply \(m_t\) when debias=False. When debias=True, the EMA weights are defined as:

\[\hat{w_t} = \frac{m_t}{1 - (1 - \epsilon)\mu^t}\]

Where \(\epsilon\) is a small constant to avoid a divide-by-0.

__init__(vc, momentum=0.999, debias=False, eps=1e-06)[source]

Creates ExponentialMovingAverage instance with given hyperparameters.

Parameters
  • momentum (float) – the decay factor for the moving average.

  • debias (bool) – bool indicating whether to use initialization bias correction.

  • eps (float) – small adjustment to prevent division by zero.

  • vc (objax.variable.VarCollection) –

__call__()[source]

Updates the moving average.

refs_and_values()[source]

Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and their corresponding EMA values.

Return type

Tuple[objax.variable.VarCollection, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]

replace_vars(f)[source]

Returns a function that acts as f called when variables are replaced by their averages.

Parameters

f (Callable) – function to be called on the stored averages.

Returns

A function that returns the output of calling f with stored variables replaced by their moving averages.

class objax.optimizer.LARS(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]

Layerwise adaptive rate scaling (LARS) optimizer.

See https://arxiv.org/abs/1708.03888

The Layer-Wise Rate Scaling (LARS) optimizer implements the scheme originally proposed in Large Batch Training of Convolutional Networks. The optimizer takes as input the base learning rate \(\gamma_0\), momentum \(m\), weight decay \(\beta\), and trust coefficient \(\eta\) and updates the model weights \(w\) as follows:

\[\begin{split}\begin{eqnarray} g_{t}^{l} &\leftarrow& \nabla L(w_{t}^{l}) \nonumber \\ \gamma_t &\leftarrow& \gamma_0 \ast (1 - \frac{t}{T})^{2} \nonumber \\ \lambda^{l} &\leftarrow& \frac{\| w_{t}^{l} \| }{ \| g_t^{l} \| + \beta \| w_{t}^{l} \|} \nonumber \\ v_{t+1}^{l} &\leftarrow& m v_{t}^{l} + \gamma_{t+1} \ast \lambda^{l} \ast (g_{t}^{l} + \beta w_{t}^{l}) \nonumber \\ w_{t+1}^{l} &\leftarrow& w_{t}^{l} - v_{t+1}^{l} \nonumber \\ \end{eqnarray}\end{split}\]

where \(T\) is the total number of steps (epochs) that the optimizer will take, \(t\) is the current step number, and \(w_{t}^{l}\) are the weights for during step \(t\) for layer \(l\).

__init__(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]

Constructor for LARS optimizer.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • momentum (float) – coefficient used for the moving average of the gradient.

  • weight_decay (float) – weight decay coefficient.

  • tc (float) – trust coefficient eta ( < 1) for trust ratio computation.

  • eps (float) – epsilon used for trust ratio computation.

__call__(lr, grads)[source]

Updates variables based on LARS algorithm.

Parameters
  • lr (float) – learning rate. The LARS paper suggests using lr = lr_0 * (1 -t/T)**2,

  • t is the current epoch number and T the maximum number of epochs. (where) –

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.

class objax.optimizer.Momentum(vc, momentum=0.9, nesterov=False)[source]

Momentum optimizer.

The momentum optimizer (expository article) introduces a tweak to the standard gradient descent. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\) the update rule is as follows:

\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f (.; w_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]

The term \(v\) is the velocity: It accumulates past gradients through a weighted moving average calculation. The parameters \(\mu, \eta\) are the momentum and the learning rate.

The momentum class also implements Nesterov’s Accelerated Gradient (NAG) (see Sutskever et. al.). Like momentum, NAG is a first-order optimization method with better convergence rate than gradient descent in certain situations. The NAG update can be written as:

\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f(.; w_{k-1} + \mu v_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]

The implementation uses the simplification presented by Bengio et. al.

__init__(vc, momentum=0.9, nesterov=False)[source]

Constructor for momentum optimizer class.

Parameters
  • vc (objax.variable.VarCollection) – collection of variables to optimize.

  • momentum (float) – the momentum hyperparameter.

  • nesterov (bool) – bool indicating whether to use the Nesterov method.

__call__(lr, grads, momentum=None)[source]

Updates variables and other state based on momentum (or Nesterov) SGD.

Parameters
  • lr (float) – the learning rate.

  • grads (List[jax._src.numpy.lax_numpy.ndarray]) – the gradients to apply.

  • momentum (Optional[float]) – optional, override the default momentum.

class objax.optimizer.SGD(vc)[source]

Stochastic Gradient Descent (SGD) optimizer.

The stochastic gradient optimizer performs Stochastic Gradient Descent (SGD). It uses the following update rule for a loss \(f\) parameterized with model weights \(w\) and a user provided learning rate \(\eta\):

\[w_k = w_{k-1} - \eta\nabla f(.; w_{k-1})\]
__init__(vc)[source]

Constructor for SGD optimizer.

Parameters

vc (objax.variable.VarCollection) – collection of variables to optimize.

__call__(lr, grads)[source]

Updates variables based on SGD algorithm.

Parameters
  • lr (float) – the learning rate.

  • grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.