# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['ExponentialMovingAverage', 'ExponentialMovingAverageModule']
from typing import Callable, Tuple, List
import jax.numpy as jn
from objax.module import Module, ModuleList
from objax.typing import JaxArray
from objax.util import class_name
from objax.variable import RandomState, TrainRef, StateVar, TrainVar, VarCollection
[docs]
class ExponentialMovingAverage(Module):
"""Maintains exponential moving averages for each variable from provided VarCollection."""
[docs]
def __init__(self, vc: VarCollection, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6):
"""Creates ExponentialMovingAverage instance with given hyperparameters.
Args:
momentum: the decay factor for the moving average.
debias: bool indicating whether to use initialization bias correction.
eps: small adjustment to prevent division by zero.
"""
self.momentum = momentum
self.debias = debias
self.eps = eps
self.step = StateVar(jn.array(0, jn.uint32), reduce=lambda x: x[0])
# Deduplicate variables and skip RandomState vars since they cannot be averaged.
trainable, non_trainable = {}, {} # Use dicts since they are ordered since python >= 3.6
for v in vc:
if isinstance(v, RandomState):
continue
if isinstance(v, TrainRef):
v = v.ref
if isinstance(v, TrainVar):
trainable[id(v)] = v
else:
non_trainable[id(v)] = v
self.refs = ModuleList(list(non_trainable.values()) + [TrainRef(v) for v in trainable.values()])
self.m = ModuleList(StateVar(jn.zeros_like(x.value)) for x in self.refs)
[docs]
def __call__(self):
"""Updates the moving average."""
self.step.value += 1
for ref, m in zip(self.refs, self.m):
m.value += (1 - self.momentum) * (ref.value - m.value)
[docs]
def refs_and_values(self) -> Tuple[VarCollection, List[JaxArray]]:
"""Returns the VarCollection of variables affected by Exponential Moving Average (EMA) and
their corresponding EMA values."""
if self.debias:
step = self.step.value
debias = 1 / (1 - (1 - self.eps) * self.momentum ** step)
tensors = [m.value * debias for ref, m in zip(self.refs, self.m)]
else:
tensors = self.m.vars().tensors()
return self.refs.vars(), tensors
[docs]
def replace_vars(self, f: Callable):
"""Returns a function that acts as f called when variables are replaced by their averages.
Args:
f: function to be called on the stored averages.
Returns:
A function that returns the output of calling f with stored variables replaced by
their moving averages.
"""
def wrap(*args, **kwargs):
refs, new_values = self.refs_and_values()
original_values = refs.tensors()
refs.assign(new_values)
try:
return f(*args, **kwargs)
finally:
refs.assign(original_values)
return wrap
def __repr__(self):
return f'{class_name(self)}(momentum={self.momentum}, debias={self.debias}, eps={self.eps})'
[docs]
class ExponentialMovingAverageModule(Module):
"""Creates a module that uses the moving average weights of another module."""
[docs]
def __init__(self, module: Module, momentum: float = 0.999, debias: bool = False, eps: float = 1e-6):
"""Creates ExponentialMovingAverageModule instance with given hyperparameters.
Args:
module: a module for which to compute the moving average.
momentum: the decay factor for the moving average.
debias: bool indicating whether to use initialization bias correction.
eps: small adjustment to prevent division by zero.
"""
self.__wrapped__ = module
self.ema = ExponentialMovingAverage(module.vars(), momentum=momentum, debias=debias, eps=eps)
[docs]
def __call__(self, *args, **kwargs):
"""Calls the original module with moving average weights."""
return self.ema.replace_vars(self.__wrapped__)(*args, **kwargs)
[docs]
def update_ema(self):
"""Updates the moving average."""
self.ema()
def __repr__(self):
return f'{class_name(self)}(momentum={self.ema.momentum}, debias={self.ema.debias}, eps={self.ema.eps})'