# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['Jit', 'Module', 'ModuleList', 'ModuleWrapper', 'Parallel', 'Vectorize']
from types import MethodType
from typing import Optional, List, Union, Callable, Tuple
import jax
import jax.numpy as jn
from jax.interpreters.pxla import ShardedDeviceArray
from objax.typing import JaxArray
from objax.util import positional_args_names
from objax.variable import BaseState, BaseVar, RandomState, VarCollection
[docs]class Module:
"""A module is a container to associate variables and functions."""
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Collect all the variables (and their names) contained in the module and its submodules.
Important: Variables and modules stored Python structures such as dict or list are not collected. See ModuleList
if you need such a feature.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables.
"""
vc = VarCollection()
scope += f'({self.__class__.__name__}).'
for k, v in self.__dict__.items():
if isinstance(v, BaseVar):
vc[scope + k] = v
elif isinstance(v, Module):
vc.update(v.vars(scope=scope + k))
return vc
[docs] def __call__(self, *args, **kwargs):
"""Optional module __call__ method, typically a forward pass computation for standard primitives."""
raise NotImplementedError
[docs]class ModuleList(Module, list):
"""This is a replacement for Python's list that provides a vars() method to return all the variables that it
contains, including the ones contained in the modules and sub-modules in it."""
[docs] def vars(self, scope: str = '') -> VarCollection:
"""Collect all the variables (and their names) contained in the list and its submodules.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables.
"""
vc = VarCollection()
scope += f'({self.__class__.__name__})'
for p, v in enumerate(self):
if isinstance(v, BaseVar):
vc[f'{scope}[{p}]'] = v
elif isinstance(v, Module):
vc.update(v.vars(scope=f'{scope}[{p}]'))
return vc
class ModuleWrapper(Module):
"""Module whose sole purpose is to store a collectable VarCollection. This classs is exclusively
used internally by Objax, for example in Jit, Vectorize and Parallel."""
def __init__(self, vc: VarCollection):
super().__init__()
self.vc = VarCollection((f'({self.__class__.__name__}){k}', v) for k, v in vc.items())
def vars(self, scope: str = '') -> VarCollection:
"""Collect all the variables (and their names) contained in the VarCollection.
Args:
scope: string to prefix to the variable names.
Returns:
A VarCollection of all the variables.
"""
return VarCollection((scope + k, v) for k, v in self.vc.items())
[docs]class Jit(ModuleWrapper):
"""JIT (Just-In-Time) module takes a function or a module and compiles it for faster execution."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
static_argnums: Optional[Tuple[int, ...]] = None):
"""Jit constructor.
Args:
f: the function or the module to compile.
vc: the VarCollection of variables used by the function or module. This argument is equired for functions.
static_argnums: tuple of indexes of f's input arguments to treat as static (constants)).
A new graph is compiled for each different combination of values for such inputs.
"""
if vc is None:
if not isinstance(f, Module):
raise ValueError('You must supply the VarCollection used by the function f.')
vc = f.vars()
super().__init__(vc)
self._call = self.jit_local_method(f, sorted(static_argnums or ()))
def jit_local_method(self, f, static_argnums):
"""Compiles a function or module and returns method that can be attached to self instance.
Args:
f: function or module to compile.
static_argnums: indexes of the arguments to be treated as static.
Returns:
A method containing the compiled version of f.
"""
def jit(tensor_list: List[JaxArray], *args):
original_values = self.vc.tensors()
self.vc.assign(tensor_list)
output = f(*args), self.vc.tensors(BaseState)
self.vc.assign(original_values)
return output
jitf = jax.jit(jit, static_argnums=tuple(x + 1 for x in static_argnums))
def local_method(self, *args):
output, changes = jitf(self.vc.tensors(), *args)
self.vc.subset(BaseState).assign(changes)
return output
return MethodType(local_method, self)
[docs] def __call__(self, *args):
"""Call the compiled version of the function or module."""
return self._call(*args)
[docs]class Parallel(ModuleWrapper):
"""Parallel module takes a function or a module and compiles it for running on multiple devices in parallel."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
reduce: Callable[[JaxArray], JaxArray] = jn.concatenate,
axis_name: str = 'device',
static_argnums: Optional[Tuple[int, ...]] = None):
"""Parallel constructor.
Args:
f: the function or the module to compile for parallelism.
vc: the VarCollection of variables used by the function or module. This argument is required for functions.
reduce: the function used reduce the outputs from many devices to a single device value.
axis_name: what name to give to the device dimension, used in conjunction with objax.functional.parallel.
static_argnums: tuple of indexes of f's input arguments to treat as static (constants)).
A new graph is compiled for each different combination of values for such inputs.
"""
if vc is None:
if not isinstance(f, Module):
raise ValueError('You must supply the VarCollection used by the function f.')
vc = f.vars()
super().__init__(vc)
static_argnums = sorted(static_argnums or ())
self.reduce = reduce
self.ndevices = jax.device_count()
self.static_argnums = frozenset(static_argnums)
def pmap(tensor_list: List[ShardedDeviceArray], random_list: List[ShardedDeviceArray], *args):
original_values = self.vc.tensors()
self.vc.assign(tensor_list)
self.vc.subset(RandomState).assign(random_list)
output = f(*args), self.vc.tensors(BaseState)
self.vc.assign(original_values)
return output
self._call = jax.pmap(pmap, axis_name=axis_name, static_broadcasted_argnums=[x + 2 for x in static_argnums])
def device_reshape(self, x: JaxArray) -> JaxArray:
"""Utility to reshape an input array in order to broadcast to multiple devices."""
return x.reshape((self.ndevices, x.shape[0] // self.ndevices) + x.shape[1:])
[docs] def __call__(self, *args):
"""Call the compiled function or module on multiple devices in parallel.
Important: Make sure you call this function within the scope of VarCollection.replicate() statement.
"""
args = [x if i in self.static_argnums else self.device_reshape(x) for i, x in enumerate(args)]
output, changes = self._call(self.vc.tensors(), self.vc.subset(RandomState).tensors(), *args)
self.vc.subset(BaseState).assign(changes)
return jax.tree_map(self.reduce, output)
[docs]class Vectorize(ModuleWrapper):
"""Vectorize module takes a function or a module and compiles it for running in parallel on a single device."""
[docs] def __init__(self,
f: Union[Module, Callable],
vc: Optional[VarCollection] = None,
batch_axis: Tuple[Optional[int], ...] = (0,)):
"""Vectorize constructor.
Args:
f: the function or the module to compile for vectorization.
vc: the VarCollection of variables used by the function or module. This argument is required for functions.
batch_axis: tuple of int or None for each of f's input arguments: the axis to use as batch during
vectorization. Use None to automatically broadcast.
"""
if vc is None:
if not isinstance(f, Module):
raise ValueError('You must supply the VarCollection used by the function f.')
vc = f.vars()
super().__init__(vc)
fargs = positional_args_names(f)
assert len(batch_axis) >= len(fargs), f'The batched argument must be specified for all of {f} arguments {fargs}'
self.batch_axis = batch_axis
self.batch_axis_argnums = [(x, v) for x, v in enumerate(batch_axis) if v is not None]
assert self.batch_axis_argnums, f'No arguments to function {f} are vectorizable'
def vmap(tensor_list: List[JaxArray], random_list: List[JaxArray], *args):
original_values = self.vc.tensors()
self.vc.assign(tensor_list)
self.vc.subset(RandomState).assign(random_list)
output = f(*args), self.vc.tensors(BaseState)
self.vc.assign(original_values)
return output
self._call = jax.vmap(vmap, (None, 0) + batch_axis)
[docs] def __call__(self, *args):
"""Call the vectorized version of the function or module."""
assert len(args) == len(self.batch_axis), f'Number of arguments passed {len(args)} must match ' \
f'batched {len(self.batch_axis)}'
nsplits = args[self.batch_axis_argnums[0][0]].shape[self.batch_axis_argnums[0][1]]
output, changes = self._call(self.vc.tensors(), [v.split(nsplits) for v in self.vc.subset(RandomState)], *args)
for v, u in zip(self.vc.subset(BaseState), changes):
v.reduce(u)
return output