Source code for objax.gradient

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['GradValues']

from typing import List, Optional, Callable, Tuple

import jax

from objax.module import ModuleWrapper
from objax.typing import JaxArray
from objax.variable import BaseState, TrainVar, VarCollection


[docs]class GradValues(ModuleWrapper): """The GradValues module is used to compute the gradients of a function."""
[docs] def __init__(self, f: Callable, variables: Optional[VarCollection], input_argnums: Optional[Tuple[int, ...]] = None): """Constructs an instance to compute the gradient of f w.r.t. variables. Args: f: the function for which to compute gradients. variables: the variables for which to compute gradients. input_argnums: input indexes, if any, on which to compute gradients. """ variables = variables or VarCollection() super().__init__(variables) self.input_argnums = input_argnums or tuple() def f_func(inputs_and_train_tensors: List[JaxArray], state_tensors: List[JaxArray], list_args: List): inputs = inputs_and_train_tensors[:len(self.input_argnums)] train_tensors = inputs_and_train_tensors[len(self.input_argnums):] original_vc = self.vc.tensors() self.vc.subset(TrainVar).assign(train_tensors) self.vc.subset(BaseState).assign(state_tensors) for i, arg in zip(self.input_argnums, inputs): list_args[i] = arg outputs = f(*list_args) if not isinstance(outputs, (list, tuple)): outputs = [outputs] return_value = outputs[0], (outputs, variables.tensors(BaseState)) self.vc.assign(original_vc) return return_value self.f = jax.grad(f_func, has_aux=True)
[docs] def __call__(self, *args): """Returns the computed gradients for the first value returned by `f` and the values returned by `f`. Returns: A tuple (gradients , values of f]), where gradients is a list containing the input gradients, if any, followed by the variable gradients.""" inputs = [args[i] for i in self.input_argnums] g, (outputs, changes) = self.f(inputs + self.vc.subset(TrainVar).tensors(), self.vc.subset(BaseState).tensors(), list(args)) self.vc.subset(BaseState).assign(changes) return g, outputs