Source code for objax.util.util

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


__all__ = ['EasyDict', 'args_indexes', 'class_name', 'dummy_context_mgr', 'get_local_devices', 'ilog2', 'local_kwargs',
           'map_to_device', 'multi_host_barrier', 'override_args_kwargs', 'positional_args_names', 'Renamer',
           're_sign', 'repr_function', 'to_interpolate', 'to_padding', 'to_tuple']

import contextlib
import functools
import inspect
import itertools
import re
from numbers import Number
from typing import Callable, List, Union, Tuple, Iterable, Dict, Pattern, Optional, Sequence

import jax
import jax.numpy as jn
import numpy as np

from objax.constants import ConvPadding, Interpolate
from objax.typing import ConvPaddingInt

CLASS_MODULES = {
    'objax.dpsgd.gradient': 'objax.dpsgd',
    'objax.gradient': 'objax',
    'objax.module': 'objax',
    'objax.nn.layers': 'objax.nn',
    'objax.random.random': 'objax.random',
    'objax.variable': 'objax',
}


[docs] class EasyDict(dict): """Custom dictionary that allows to access dict values as attributes."""
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self
[docs] class Renamer: """Helper class for renaming string contents."""
[docs] def __init__(self, rules: Union[Dict[str, str], Sequence[Tuple[Pattern[str], str]], Callable[[str], str]], chain: Optional['Renamer'] = None): """Create a renamer object. Args: rules: the replacement mapping. chain: optionally, another renamer to call after this one completes. """ self.chain = chain if callable(rules): self.subfn = rules elif isinstance(rules, dict): regex = re.compile('(%s)' % '|'.join(map(re.escape, rules.keys()))) self.subfn = functools.partial(regex.sub, lambda m: rules[m.group(0)]) else: def sequence_rename(x): for regex, repl in rules: x = regex.sub(repl, x) return x self.subfn = sequence_rename
[docs] def __call__(self, s: str) -> str: """Rename input string `s` using the rules provided to the constructor.""" news = self.subfn(s) return self.chain(news) if self.chain else news
[docs] def args_indexes(f: Callable, args: Iterable[str]) -> Iterable[int]: """Returns the indexes of variable names of a function.""" d = {name: i for i, name in enumerate(positional_args_names(f))} for name in args: index = d.get(name) if index is None: raise ValueError(f'Function {f} does not have argument of name {name}', (f, name)) yield index
def class_name(x) -> str: """Returns the simplified full name of a class instance.""" m = x.__class__.__module__ m = CLASS_MODULES.get(m, m) if m.startswith('objax.optimizer'): m = 'objax.optimizer' return f'{m}.{x.__class__.__name__}'
[docs] @contextlib.contextmanager def dummy_context_mgr(): """Empty Context Manager.""" yield None
_local_devices = None def get_local_devices(): """Returns list of local devices in the same order which is used by jax.pmap.""" # Lazy initialization of _local_devices to prevent any undesirable behavior before devices are initialized. global _local_devices if _local_devices is None: x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) sharded_x = map_to_device(x) _local_devices = [b.device() for b in sharded_x.device_buffers] return _local_devices
[docs] def ilog2(x: float): """Integer log2.""" return int(np.ceil(np.log2(x)))
def local_kwargs(kwargs: dict, f: Callable) -> dict: """Return the kwargs from dict that are inputs to function f.""" s = inspect.signature(f) p = s.parameters if next(reversed(p.values())).kind == inspect.Parameter.VAR_KEYWORD: return kwargs if len(kwargs) < len(p): return {k: v for k, v in kwargs.items() if k in p} return {k: kwargs[k] for k in p.keys() if k in kwargs} map_to_device: Callable[[List[jn.ndarray]], List[jn.ndarray]] = jax.pmap(lambda x: x, axis_name='device') def multi_host_barrier(): """Barrier op for multi-host setup.""" jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() def override_args_kwargs(f: Callable, args: Iterable, kwargs: dict, new_kwargs: dict) -> Tuple[List, dict]: """Overrides positional and keyword arguments according to signature of the function using new keyword arguments. Args: f: callable, which signature is used to determine how to override arguments. args: original values of positional arguments. kwargs: original values of keyword arguments. new_kwargs: new keyword arguments, their values will override original arguments. Return: args: updated list of positional arguments. kwargs: updated dictionary of keyword arguments. """ args = list(args) new_kwargs = new_kwargs.copy() p = inspect.signature(f).parameters for idx, (k, v) in enumerate(itertools.islice(p.items(), len(args))): if v.kind not in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): break if k in new_kwargs: args[idx] = new_kwargs.pop(k) return args, {**kwargs, **new_kwargs}
[docs] def positional_args_names(f: Callable) -> List[str]: """Returns the ordered names of the positional arguments of a function.""" return list(p.name for p in inspect.signature(f).parameters.values() if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD))
def to_interpolate(interpolate: Union[Interpolate, str]) -> Union[str]: """Expand to a string method for interpolation""" if isinstance(interpolate, Interpolate): return interpolate.value if isinstance(interpolate, str): return Interpolate[interpolate.upper()].value assert isinstance(interpolate, (str, Interpolate)), f'Argument "{interpolate}" must be a string or Interpolate' def re_sign(f: Callable) -> Callable: """Decorator to replace the signature of an operation with the one from f.""" def wrap(op): op.__signature__ = inspect.signature(f) return op return wrap def repr_function(f: Callable) -> str: """Human readable function representation.""" signature = inspect.signature(f) args = [f'{k}={v.default}' for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty] args = ', '.join(args) while not hasattr(f, '__name__'): if not hasattr(f, 'func'): break f = f.func # Handle functools.partial if not hasattr(f, '__name__') and hasattr(f, '__class__'): return f.__class__.__name__ if args: return f'{f.__name__}(*, {args})' return f.__name__ def to_padding(padding: Union[ConvPadding, str, ConvPaddingInt], ndim: int) \ -> Union[str, Tuple[Tuple[int, int], ...]]: """Expand to a string or a ndim-dimensional tuple of pairs usable for padding.""" if isinstance(padding, ConvPadding): return padding.value if isinstance(padding, str): return ConvPadding[padding.upper()].value if isinstance(padding, int): return tuple([(padding, padding)] * ndim) if isinstance(padding, tuple) and list(map(type, padding)) == [int, int]: return tuple([padding] * ndim) return tuple(padding)
[docs] def to_tuple(v: Union[Tuple[Number, ...], Number, Iterable], n: int): """Converts input to tuple.""" if isinstance(v, tuple): return v elif isinstance(v, Number): return (v,) * n else: return tuple(v)