# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['ForceArgs', 'Function', 'Jit', 'Module', 'ModuleList', 'Parallel', 'Vectorize']
from collections import namedtuple
from typing import Optional, List, Union, Callable, Tuple
import jax
import jax.numpy as jn
from jax.interpreters.pxla import ShardedDeviceArray
from objax.typing import JaxArray
from objax.util import override_args_kwargs, positional_args_names
from objax.variable import BaseState, BaseVar, RandomState, VarCollection
[docs]class Module:
"""A module is a container to associate variables and functions."""
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Collect all the variables (and their names) contained in the module and its submodules.
Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList
if you need such a feature.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables.
"""
vc = VarCollection()
scope += f'({self.__class__.__name__}).'
for k, v in self.__dict__.items():
if isinstance(v, BaseVar):
vc[scope + k] = v
elif isinstance(v, Module):
vc.update(v.vars(scope=scope + k))
return vc
[docs] def __call__(self, *args, **kwargs):
"""Optional module __call__ method, typically a forward pass computation for standard primitives."""
raise NotImplementedError
[docs]class ForceArgs(Module):
"""Forces override of arguments of given module."""
ANY = namedtuple('ANY', ())
"""Token used in `ForceArgs.undo` to indicate undo of all values of specific argument."""
[docs] @staticmethod
def undo(module: Module, **kwargs):
"""Undo ForceArgs on each submodule of the module. Modifications are done in-place.
Args:
module: module for which to undo ForceArgs.
**kwargs: dictionary of argument overrides to undo.
`name=val` remove override for value `val` of argument `name`.
`name=ForceArgs.ANY` remove all overrides of argument `name`.
If `**kwargs` is empty then all overrides will be undone.
"""
if isinstance(module, ForceArgs):
if not kwargs:
module.forced_kwargs = {}
else:
module.forced_kwargs = {k: v for k, v in module.forced_kwargs.items()
if (k not in kwargs) or (kwargs[k] not in (v, ForceArgs.ANY))}
ForceArgs.undo(module.__wrapped__, **kwargs)
elif isinstance(module, ModuleList):
for idx, v in enumerate(module):
if isinstance(v, Module):
ForceArgs.undo(v, **kwargs)
if isinstance(v, ForceArgs) and not v.forced_kwargs:
module[idx] = v.__wrapped__
else:
for k, v in module.__dict__.items():
if isinstance(v, Module):
ForceArgs.undo(v, **kwargs)
if isinstance(v, ForceArgs) and not v.forced_kwargs:
setattr(module, k, v.__wrapped__)
[docs] def __init__(self, module: Module, **kwargs):
"""Initializes ForceArgs by wrapping another module.
Args:
module: module which argument will be overridden.
kwargs: values of keyword arguments which will be forced to use.
"""
self.__wrapped__ = module
self.forced_kwargs = kwargs
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Returns the VarCollection of the wrapped module.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables of wrapped module.
"""
return self.__wrapped__.vars(scope=scope)
[docs] def __call__(self, *args, **kwargs):
"""Calls wrapped module using forced args to override wrapped module arguments."""
args, kwargs = override_args_kwargs(self.__wrapped__, args, kwargs, self.forced_kwargs)
return self.__wrapped__(*args, **kwargs)
[docs]class ModuleList(Module, list):
"""This is a replacement for Python's list that provides a vars() method to return all the variables that it
contains, including the ones contained in the modules and sub-modules in it."""
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Collect all the variables (and their names) contained in the list and its submodules.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables.
"""
vc = VarCollection()
scope += f'({self.__class__.__name__})'
for p, v in enumerate(self):
if isinstance(v, BaseVar):
vc[f'{scope}[{p}]'] = v
elif isinstance(v, Module):
vc.update(v.vars(scope=f'{scope}[{p}]'))
return vc
def __getitem__(self, key: Union[int, slice]):
value = list.__getitem__(self, key)
if isinstance(key, slice):
return ModuleList(value)
return value
[docs]class Function(Module):
"""Turn a function into a Module by keeping the vars it uses."""
[docs] def __init__(self, f: Callable, vc: VarCollection):
"""Function constructor.
Args:
f: the function or the module to represent.
vc: the VarCollection of variables used by the function.
"""
if hasattr(f, '__name__'):
self.vc = VarCollection((f'{{{f.__name__}}}.{k}', v) for k, v in vc.items())
else:
self.vc = VarCollection(vc)
self.__wrapped__ = f
[docs] def __call__(self, *args, **kwargs):
"""Call the the function."""
return self.__wrapped__(*args, **kwargs)
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Return the VarCollection of the variables used by the function."""
if scope:
return VarCollection((scope + k, v) for k, v in self.vc.items())
return VarCollection(self.vc)
@staticmethod
def with_vars(vc: VarCollection):
"""Method to use as decorator in function definitions."""
def from_function(f: Callable):
return Function(f, vc)
return from_function
[docs]class Jit(Module):
"""JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
static_argnums: Optional[Tuple[int, ...]] = None):
"""Jit constructor.
Args:
f: the function or the module to compile.
vc: the VarCollection of variables used by the function or module. This argument is required for functions.
static_argnums: tuple of indexes of f's input arguments to treat as static (constants)).
A new graph is compiled for each different combination of values for such inputs.
"""
if not isinstance(f, Module):
if vc is None:
raise ValueError('You must supply the VarCollection used by the function f.')
f = Function(f, vc)
def jit(tensor_list: List[JaxArray], kwargs, *args):
original_values = self.vc.tensors()
try:
self.vc.assign(tensor_list)
return f(*args, **kwargs), self.vc.tensors(BaseState)
finally:
self.vc.assign(original_values)
self.vc = vc or f.vars()
self._call = jax.jit(jit, static_argnums=tuple(x + 2 for x in sorted(static_argnums or ())))
self.__wrapped__ = f
[docs] def __call__(self, *args, **kwargs):
"""Call the compiled version of the function or module."""
output, changes = self._call(self.vc.tensors(), kwargs, *args)
self.vc.subset(BaseState).assign(changes)
return output
[docs]class Parallel(Module):
"""Parallel module takes a function or a module and compiles it for running on multiple devices in parallel."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
reduce: Callable[[JaxArray], JaxArray] = jn.concatenate,
axis_name: str = 'device',
static_argnums: Optional[Tuple[int, ...]] = None):
"""Parallel constructor.
Args:
f: the function or the module to compile for parallelism.
vc: the VarCollection of variables used by the function or module. This argument is required for functions.
reduce: the function used reduce the outputs from many devices to a single device value.
axis_name: what name to give to the device dimension, used in conjunction with objax.functional.parallel.
static_argnums: tuple of indexes of f's input arguments to treat as static (constants)).
A new graph is compiled for each different combination of values for such inputs.
"""
if not isinstance(f, Module):
if vc is None:
raise ValueError('You must supply the VarCollection used by the function f.')
f = Function(f, vc)
def pmap(tensor_list: List[ShardedDeviceArray], random_list: List[ShardedDeviceArray], *args):
original_values = self.vc.tensors()
try:
self.vc.assign(tensor_list)
self.vc.subset(RandomState).assign(random_list)
return f(*args), self.vc.tensors(BaseState)
finally:
self.vc.assign(original_values)
static_argnums = sorted(static_argnums or ())
self.ndevices = jax.local_device_count()
self.reduce = reduce
self.static_argnums = frozenset(static_argnums)
self.vc = vc or f.vars()
self._call = jax.pmap(pmap, axis_name=axis_name, static_broadcasted_argnums=[x + 2 for x in static_argnums])
self.__wrapped__ = f
def device_reshape(self, x: JaxArray) -> JaxArray:
"""Utility to reshape an input array in order to broadcast to multiple devices."""
assert x.shape[0] % self.ndevices == 0, f'Must be able to equally divide batch {x.shape} among ' \
f'{self.ndevices} devices, but does not go equally.'
return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
[docs] def __call__(self, *args):
"""Call the compiled function or module on multiple devices in parallel.
Important: Make sure you call this function within the scope of VarCollection.replicate() statement.
"""
args = [x if i in self.static_argnums else self.device_reshape(x) for i, x in enumerate(args)]
output, changes = self._call(self.vc.tensors(), self.vc.subset(RandomState).tensors(), *args)
self.vc.subset(BaseState).assign(changes)
return jax.tree_map(self.reduce, output)
[docs]class Vectorize(Module):
"""Vectorize module takes a function or a module and compiles it for running in parallel on a single device."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
batch_axis: Tuple[Optional[int], ...] = (0,)):
"""Vectorize constructor.
Args:
f: the function or the module to compile for vectorization.
vc: the VarCollection of variables used by the function or module. This argument is required for functions.
batch_axis: tuple of int or None for each of f's input arguments: the axis to use as batch during
vectorization. Use None to automatically broadcast.
"""
if not isinstance(f, Module):
if vc is None:
raise ValueError('You must supply the VarCollection used by the function f.')
f = Function(f, vc)
def vmap(tensor_list: List[JaxArray], random_list: List[JaxArray], *args):
original_values = self.vc.tensors()
try:
self.vc.assign(tensor_list)
self.vc.subset(RandomState).assign(random_list)
return f(*args), self.vc.tensors(BaseState)
finally:
self.vc.assign(original_values)
fargs = positional_args_names(f)
assert len(batch_axis) >= len(fargs), f'The batched argument must be specified for all of {f} arguments {fargs}'
self.batch_axis = batch_axis
self.batch_axis_argnums = [(x, v) for x, v in enumerate(batch_axis) if v is not None]
assert self.batch_axis_argnums, f'No arguments to function {f} are vectorizable'
self.vc = vc or f.vars()
self._call = jax.vmap(vmap, (None, 0) + batch_axis)
self.__wrapped__ = f
[docs] def __call__(self, *args):
"""Call the vectorized version of the function or module."""
assert len(args) == len(self.batch_axis), f'Number of arguments passed {len(args)} must match ' \
f'batched {len(self.batch_axis)}'
nsplits = args[self.batch_axis_argnums[0][0]].shape[self.batch_axis_argnums[0][1]]
output, changes = self._call(self.vc.tensors(), [v.split(nsplits) for v in self.vc.subset(RandomState)], *args)
for v, u in zip(self.vc.subset(BaseState), changes):
v.reduce(u)
return output