objax.optimizer package¶
|
Adam optimizer. |
|
Creates a module that uses the moving average weights of another module. |
|
Maintains exponential moving averages for each variable from provided VarCollection. |
|
Layerwise adaptive rate scaling (LARS) optimizer. |
|
Momentum optimizer. |
|
Stochastic Gradient Descent (SGD) optimizer. |
-
class
objax.optimizer.
Adam
(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]¶ Adam optimizer.
Adam is an adaptive learning rate optimization algorithm originally presented in Adam: A Method for Stochastic Optimization. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\), the update rule is as follows:
\[\begin{split}\begin{eqnarray} v_{k} &=& \beta_1 v_{k-1} + (1 - \beta_1) \nabla f (.; w_{k-1}) \nonumber \\ s_{k} &=& \beta_2 s_{k-1} - (1 - \beta_2) (\nabla f (.; w_{k-1}))^2 \nonumber \\ \hat{v_{k}} &=& \frac{v_{k}}{(1 - \beta_{1}^{k})} \nonumber \\ \hat{s_{k}} &=& \frac{s_{k}}{(1 - \beta_{2}^{k})} \nonumber \\ w_{k} &=& w_{k-1} - \eta \frac{\hat{v_{k}}}{\sqrt{\hat{s_{k}}} + \epsilon} \nonumber \end{eqnarray}\end{split}\]Adam updates exponential moving averages of the gradient \((v_{k})\) and the squared gradient \((s_{k})\) where the hyper-parameters \(\beta_1\) and \(\beta_2 \in [0, 1)\) control the exponential decay rates of these moving averages. The \(\eta\) constant in the weight update rule is the learning rate and is passed as a parameter in the
__call__
method. Note that the implementation uses the approximation \(\sqrt{(\hat{s_{k}} + \epsilon)} \approx \sqrt{\hat{s_{k}}} + \epsilon\).-
__init__
(vc, beta1=0.9, beta2=0.999, eps=1e-08)[source]¶ Constructor for Adam optimizer class.
- Parameters
vc (objax.variable.VarCollection) – collection of variables to optimize.
beta1 (float) – value of Adam’s beta1 hyperparameter. Defaults to 0.9.
beta2 (float) – value of Adam’s beta2 hyperparameter. Defaults to 0.999.
eps (float) – value of Adam’s epsilon hyperparameter. Defaults to 1e-8.
-
__call__
(lr, grads, beta1=None, beta2=None)[source]¶ Updates variables and other state based on Adam algorithm.
- Parameters
lr (float) – the learning rate.
grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.
beta1 (Optional[float]) – optional, override the default beta1.
beta2 (Optional[float]) – optional, override the default beta2.
-
-
class
objax.optimizer.
ExponentialMovingAverageModule
(module, momentum=0.999, debias=False, eps=1e-06)[source]¶ Creates a module that uses the moving average weights of another module.
Convenience interface to apply
objax.optimizer.ExponentialMovingAverage
to a module.Usage example:
import objax m = objax.nn.Sequential([objax.nn.Linear(2, 3), objax.nn.BatchNorm0D(3)]) m_ema = objax.optimizer.ExponentialMovingAverageModule(m, momentum=0.999, debias=True) x = objax.random.uniform((16, 2)) # When the weights of m change, simply call update_ema() to update moving averages v1 = m(x, training=True) m_ema.update_ema() # You call m_ema just like you would call m v2 = m_ema(x, training=False)
-
__init__
(module, momentum=0.999, debias=False, eps=1e-06)[source]¶ Creates ExponentialMovingAverageModule instance with given hyperparameters.
- Parameters
module (objax.module.Module) – a module for which to compute the moving average.
momentum (float) – the decay factor for the moving average.
debias (bool) – bool indicating whether to use initialization bias correction.
eps (float) – small adjustment to prevent division by zero.
-
-
class
objax.optimizer.
ExponentialMovingAverage
(vc, momentum=0.999, debias=False, eps=1e-06)[source]¶ Maintains exponential moving averages for each variable from provided VarCollection.
When training a model, it is often beneficial to maintain exponential moving averages (EMA) of the trained parameters. Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values (see Acceleration of Stochastic Approximation by Averaging).
This maintains an EMA of the parameters passed in the VarCollection
vc
. The EMA update rule for weights \(w\), the EMA \(m\) at step \(t\) when using a momentum \(\mu\) is:\[m_t = \mu m_{t-1} + (1 - \mu) w_t\]The EMA weights \(\hat{w_t}\) are simply \(m_t\) when
debias=False
. Whendebias=True
, the EMA weights are defined as:\[\hat{w_t} = \frac{m_t}{1 - (1 - \epsilon)\mu^t}\]Where \(\epsilon\) is a small constant to avoid a divide-by-0.
-
__init__
(vc, momentum=0.999, debias=False, eps=1e-06)[source]¶ Creates ExponentialMovingAverage instance with given hyperparameters.
- Parameters
momentum (float) – the decay factor for the moving average.
debias (bool) – bool indicating whether to use initialization bias correction.
eps (float) – small adjustment to prevent division by zero.
vc (objax.variable.VarCollection) –
-
refs_and_values
()[source]¶ Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and their corresponding EMA values.
- Return type
Tuple[objax.variable.VarCollection, List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]]
-
-
class
objax.optimizer.
LARS
(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]¶ Layerwise adaptive rate scaling (LARS) optimizer.
See https://arxiv.org/abs/1708.03888
The Layer-Wise Rate Scaling (LARS) optimizer implements the scheme originally proposed in Large Batch Training of Convolutional Networks. The optimizer takes as input the base learning rate \(\gamma_0\), momentum \(m\), weight decay \(\beta\), and trust coefficient \(\eta\) and updates the model weights \(w\) as follows:
\[\begin{split}\begin{eqnarray} g_{t}^{l} &\leftarrow& \nabla L(w_{t}^{l}) \nonumber \\ \gamma_t &\leftarrow& \gamma_0 \ast (1 - \frac{t}{T})^{2} \nonumber \\ \lambda^{l} &\leftarrow& \frac{\| w_{t}^{l} \| }{ \| g_t^{l} \| + \beta \| w_{t}^{l} \|} \nonumber \\ v_{t+1}^{l} &\leftarrow& m v_{t}^{l} + \gamma_{t+1} \ast \lambda^{l} \ast (g_{t}^{l} + \beta w_{t}^{l}) \nonumber \\ w_{t+1}^{l} &\leftarrow& w_{t}^{l} - v_{t+1}^{l} \nonumber \\ \end{eqnarray}\end{split}\]where \(T\) is the total number of steps (epochs) that the optimizer will take, \(t\) is the current step number, and \(w_{t}^{l}\) are the weights for during step \(t\) for layer \(l\).
-
__init__
(vc, momentum=0.9, weight_decay=0.0001, tc=0.001, eps=1e-05)[source]¶ Constructor for LARS optimizer.
- Parameters
vc (objax.variable.VarCollection) – collection of variables to optimize.
momentum (float) – coefficient used for the moving average of the gradient.
weight_decay (float) – weight decay coefficient.
tc (float) – trust coefficient eta ( < 1) for trust ratio computation.
eps (float) – epsilon used for trust ratio computation.
-
__call__
(lr, grads)[source]¶ Updates variables based on LARS algorithm.
- Parameters
lr (float) – learning rate. The LARS paper suggests using lr = lr_0 * (1 -t/T)**2,
t is the current epoch number and T the maximum number of epochs. (where) –
grads (List[Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jax.interpreters.pxla.ShardedDeviceArray]]) – the gradients to apply.
-
-
class
objax.optimizer.
Momentum
(vc, momentum=0.9, nesterov=False)[source]¶ Momentum optimizer.
The momentum optimizer (expository article) introduces a tweak to the standard gradient descent. Specifically, when optimizing a loss function \(f\) parameterized by model weights \(w\) the update rule is as follows:
\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f (.; w_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]The term \(v\) is the velocity: It accumulates past gradients through a weighted moving average calculation. The parameters \(\mu, \eta\) are the momentum and the learning rate.
The momentum class also implements Nesterov’s Accelerated Gradient (NAG) (see Sutskever et. al.). Like momentum, NAG is a first-order optimization method with better convergence rate than gradient descent in certain situations. The NAG update can be written as:
\[\begin{split}\begin{eqnarray} v_{k} &=& \mu v_{k-1} + \nabla f(.; w_{k-1} + \mu v_{k-1}) \nonumber \\ w_{k} &=& w_{k-1} - \eta v_{k} \nonumber \end{eqnarray}\end{split}\]The implementation uses the simplification presented by Bengio et. al.
-
__init__
(vc, momentum=0.9, nesterov=False)[source]¶ Constructor for momentum optimizer class.
- Parameters
vc (objax.variable.VarCollection) – collection of variables to optimize.
momentum (float) – the momentum hyperparameter.
nesterov (bool) – bool indicating whether to use the Nesterov method.
-
__call__
(lr, grads, momentum=None)[source]¶ Updates variables and other state based on momentum (or Nesterov) SGD.
- Parameters
lr (float) – the learning rate.
grads (List[jax._src.numpy.lax_numpy.ndarray]) – the gradients to apply.
momentum (Optional[float]) – optional, override the default momentum.
-
-
class
objax.optimizer.
SGD
(vc)[source]¶ Stochastic Gradient Descent (SGD) optimizer.
The stochastic gradient optimizer performs Stochastic Gradient Descent (SGD). It uses the following update rule for a loss \(f\) parameterized with model weights \(w\) and a user provided learning rate \(\eta\):
\[w_k = w_{k-1} - \eta\nabla f(.; w_{k-1})\]