# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
"""
Implements the NumPy API, using the primitives in :mod:`jax.lax`.
NumPy operations are implemented in Python in terms of the primitive operations
in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
implemented in terms of :mod:`jax.lax` operations, we do not need to define
transformation rules such as gradient or batching rules. Instead,
transformations for NumPy primitives can be derived from the transformation
rules for the underlying :code:`lax` primitives.
"""
import abc
import builtins
import collections
from functools import partial
import operator
import types
from typing import Sequence, FrozenSet, Optional, Tuple, Union, Set, Type, Callable
from textwrap import dedent as _dedent
import warnings
import numpy as np
import opt_einsum
import jax
from jax import jit, custom_jvp
from jax._src.numpy.vectorize import vectorize
from jax._src.numpy.util import _wraps
from jax import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
from jax import errors
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
from jax.config import config
from jax.interpreters import pxla
from jax import lax
from jax._src import device_array
from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
from jax.tree_util import tree_leaves, tree_flatten, tree_map
newaxis = None
# Common docstring additions:
_PRECISION_DOC = """\
In addition to the original NumPy arguments listed below, also supports
``precision`` for extra control over matrix-multiplication precision
on supported devices. ``precision`` may be set to ``None``, which means
default precision for the backend, a :class:`~jax.lax.Precision` enum value
(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument.
"""
# We replace some builtin names to follow Numpy's API, so we capture here.
_abs = builtins.abs
_all = builtins.all
_any = builtins.any
_max = builtins.max
_min = builtins.min
_sum = builtins.sum
_divmod = builtins.divmod
# NumPy constants
pi = np.pi
e = np.e
euler_gamma = np.euler_gamma
inf = np.inf
NINF = np.NINF
PZERO = np.PZERO
NZERO = np.NZERO
nan = np.nan
# NumPy utility functions
get_printoptions = np.get_printoptions
printoptions = np.printoptions
set_printoptions = np.set_printoptions
# ndarray is defined as an virtual abstract base class.
class ArrayMeta(abc.ABCMeta):
"""Metaclass for overriding ndarray isinstance checks."""
def __instancecheck__(self, instance):
# Allow tracer instances with avals that are instances of UnshapedArray.
# We could instead just declare Tracer an instance of the ndarray type, but
# there can be traced values that are not arrays. The main downside here is
# that isinstance(x, ndarray) might return true but
# issubclass(type(x), ndarray) might return false for an array tracer.
try:
return (hasattr(instance, "aval") and
isinstance(instance.aval, UnshapedArray))
except AttributeError:
super().__instancecheck__(instance)
class ndarray(metaclass=ArrayMeta):
dtype: np.dtype
ndim: int
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
@abc.abstractmethod
def __getitem__(self, key, indices_are_sorted=False,
unique_indices=False): ...
@abc.abstractmethod
def __setitem__(self, key, value): ...
@abc.abstractmethod
def __len__(self): ...
@abc.abstractmethod
def __iter__(self): ...
@abc.abstractmethod
def __reversed__(self): ...
# Comparisons
@abc.abstractmethod
def __lt__(self, other): ...
@abc.abstractmethod
def __le__(self, other): ...
@abc.abstractmethod
def __eq__(self, other): ...
@abc.abstractmethod
def __ne__(self, other): ...
@abc.abstractmethod
def __gt__(self, other): ...
@abc.abstractmethod
def __ge__(self, other): ...
# Unary arithmetic
@abc.abstractmethod
def __neg__(self): ...
@abc.abstractmethod
def __pos__(self): ...
@abc.abstractmethod
def __abs__(self): ...
@abc.abstractmethod
def __invert__(self): ...
# Binary arithmetic
@abc.abstractmethod
def __add__(self, other): ...
@abc.abstractmethod
def __sub__(self, other): ...
@abc.abstractmethod
def __mul__(self, other): ...
@abc.abstractmethod
def __matmul__(self, other): ...
@abc.abstractmethod
def __truediv__(self, other): ...
@abc.abstractmethod
def __floordiv__(self, other): ...
@abc.abstractmethod
def __mod__(self, other): ...
@abc.abstractmethod
def __divmod__(self, other): ...
@abc.abstractmethod
def __pow__(self, other): ...
@abc.abstractmethod
def __lshift__(self, other): ...
@abc.abstractmethod
def __rshift__(self, other): ...
@abc.abstractmethod
def __and__(self, other): ...
@abc.abstractmethod
def __xor__(self, other): ...
@abc.abstractmethod
def __or__(self, other): ...
@abc.abstractmethod
def __radd__(self, other): ...
@abc.abstractmethod
def __rsub__(self, other): ...
@abc.abstractmethod
def __rmul__(self, other): ...
@abc.abstractmethod
def __rmatmul__(self, other): ...
@abc.abstractmethod
def __rtruediv__(self, other): ...
@abc.abstractmethod
def __rfloordiv__(self, other): ...
@abc.abstractmethod
def __rmod__(self, other): ...
@abc.abstractmethod
def __rdivmod__(self, other): ...
@abc.abstractmethod
def __rpow__(self, other): ...
@abc.abstractmethod
def __rlshift__(self, other): ...
@abc.abstractmethod
def __rrshift__(self, other): ...
@abc.abstractmethod
def __rand__(self, other): ...
@abc.abstractmethod
def __rxor__(self, other): ...
@abc.abstractmethod
def __ror__(self, other): ...
@abc.abstractmethod
def __bool__(self): ...
@abc.abstractmethod
def __complex__(self): ...
@abc.abstractmethod
def __int__(self): ...
@abc.abstractmethod
def __float__(self): ...
@abc.abstractmethod
def __round__(self, ndigits=None): ...
@abc.abstractmethod
def __index__(self): ...
# np.ndarray methods:
@abc.abstractmethod
def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None): ...
@abc.abstractmethod
def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None): ...
@abc.abstractmethod
def argmax(self, axis: Optional[int] = None, out=None): ...
@abc.abstractmethod
def argmin(self, axis: Optional[int] = None, out=None): ...
@abc.abstractmethod
def argpartition(self, kth, axis=-1, kind='introselect', order=None): ...
@abc.abstractmethod
def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None): ...
@abc.abstractmethod
def astype(self, dtype): ...
@abc.abstractmethod
def choose(self, choices, out=None, mode='raise'): ...
@abc.abstractmethod
def clip(self, a_min=None, a_max=None, out=None): ...
@abc.abstractmethod
def compress(self, condition, axis: Optional[int] = None, out=None): ...
@abc.abstractmethod
def conj(self): ...
@abc.abstractmethod
def conjugate(self): ...
@abc.abstractmethod
def copy(self): ...
@abc.abstractmethod
def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None): ...
@abc.abstractmethod
def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None): ...
@abc.abstractmethod
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1): ...
@abc.abstractmethod
def dot(self, b, *, precision=None): ...
@abc.abstractmethod
def flatten(self): ...
@property
@abc.abstractmethod
def imag(self): ...
@abc.abstractmethod
def item(self, *args): ...
@abc.abstractmethod
def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None): ...
@abc.abstractmethod
def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None,): ...
@abc.abstractmethod
def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None): ...
@property
@abc.abstractmethod
def nbytes(self): ...
@abc.abstractmethod
def nonzero(self, *, size=None, fill_value=None): ...
@abc.abstractmethod
def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None): ...
@abc.abstractmethod
def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False,): ...
@abc.abstractmethod
def ravel(self, order='C'): ...
@property
@abc.abstractmethod
def real(self): ...
@abc.abstractmethod
def repeat(self, repeats, axis: Optional[int] = None, *,
total_repeat_length=None): ...
@abc.abstractmethod
def reshape(self, *args, order='C'): ...
@abc.abstractmethod
def round(self, decimals=0, out=None): ...
@abc.abstractmethod
def searchsorted(self, v, side='left', sorter=None): ...
@abc.abstractmethod
def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None): ...
@abc.abstractmethod
def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None): ...
@abc.abstractmethod
def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None): ...
@abc.abstractmethod
def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None): ...
@abc.abstractmethod
def swapaxes(self, axis1: int, axis2: int): ...
@abc.abstractmethod
def take(self, indices, axis: Optional[int] = None, out=None,
mode=None): ...
@abc.abstractmethod
def tobytes(self, order='C'): ...
@abc.abstractmethod
def tolist(self): ...
@abc.abstractmethod
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None): ...
@abc.abstractmethod
def transpose(self, *args): ...
@abc.abstractmethod
def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None): ...
@abc.abstractmethod
def view(self, dtype=None, type=None): ...
# Even though we don't always support the NumPy array protocol, e.g., for
# tracer types, for type checking purposes we must declare support so we
# implement the NumPy ArrayLike protocol.
def __array__(self): ...
# JAX extensions
@property
@abc.abstractmethod
def at(self): ...
@property
@abc.abstractmethod
def aval(self): ...
@property
@abc.abstractmethod
def weak_type(self) -> bool: ...
ndarray.register(device_array.DeviceArray)
for t in device_array.device_array_types:
ndarray.register(t)
ndarray.register(pxla._SDA_BASE_CLASS)
iscomplexobj = np.iscomplexobj
shape = _shape = np.shape
ndim = _ndim = np.ndim
size = np.size
_dtype = partial(dtypes.dtype, canonicalize=True)
# At present JAX doesn't have a reason to distinguish between scalars and arrays
# in its object system. Further, we want JAX scalars to have the same type
# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX
# scalar object with JAX promotion behaviors, instead we make the JAX scalar
# types return JAX arrays when instantiated.
class _ScalarMeta(type):
def __hash__(self):
return hash(self.dtype.type)
def __eq__(self, other):
return id(self) == id(other) or self.dtype.type == other
def __ne__(self, other):
return not (self == other)
def __call__(self, x):
return array(x, dtype=self.dtype)
def __instancecheck__(self, instance):
return isinstance(instance, self.dtype.type)
def _make_scalar_type(np_scalar_type):
return _ScalarMeta(np_scalar_type.__name__, (object,),
{"dtype": np.dtype(np_scalar_type)})
bool_ = _make_scalar_type(np.bool_)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)
float64 = double = _make_scalar_type(np.float64)
complex64 = csingle = _make_scalar_type(np.complex64)
complex128 = cdouble = _make_scalar_type(np.complex128)
int_ = int32 if dtypes.int_ == np.int32 else int64
uint = uint32 if dtypes.uint == np.uint32 else uint64
float_ = float32 if dtypes.float_ == np.float32 else float64
complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
number = np.number
inexact = np.inexact
complexfloating = np.complexfloating
floating = np.floating
integer = np.integer
signedinteger = np.signedinteger
unsignedinteger = np.unsignedinteger
flexible = np.flexible
character = np.character
object_ = np.object_
iinfo = dtypes.iinfo
finfo = dtypes.finfo
dtype = np.dtype
can_cast = dtypes.can_cast
issubsctype = dtypes.issubsctype
promote_types = dtypes.promote_types
ComplexWarning = np.ComplexWarning
array_str = np.array_str
array_repr = np.array_repr
save = np.save
savez = np.savez
@_wraps(np.dtype)
def _jnp_dtype(obj, align=False, copy=False):
"""Similar to np.dtype, but respects JAX dtype defaults."""
if obj is None:
obj = dtypes.float_
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
obj = _DEFAULT_TYPEMAP[np.dtype(obj, align=align, copy=copy).type]
return np.dtype(obj, align=align, copy=copy)
### utility functions
_DEFAULT_TYPEMAP = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
}
_INT_DTYPES = {
16: np.int16,
32: np.int32,
64: np.int64,
}
def _promote_shapes(fun_name, *args):
"""Prepend implicit leading singleton dimensions for Numpy broadcasting."""
if len(args) < 2:
return args
else:
shapes = [shape(arg) for arg in args]
nonscalar_ranks = [len(shp) for shp in shapes if shp]
if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
return args
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
result_rank = len(lax.broadcast_shapes(*shapes))
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
for arg, shp in zip(args, shapes)]
def _rank_promotion_warning_or_error(fun_name, shapes):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
"disable this warning; for more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
elif config.jax_numpy_rank_promotion == "raise":
msg = ("Operands could not be broadcast together for {} on shapes {} "
"and with the config option jax_numpy_rank_promotion='raise'. "
"For more information, see "
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax._convert_element_type(x, to_dtype, weak_type) for x in args]
def _promote_dtypes_inexact(*args):
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
to_dtype_inexact = _to_inexact_dtype(to_dtype)
weak_type = (weak_type and to_dtype == to_dtype_inexact)
return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args]
def _to_inexact_dtype(dtype):
"""Promotes a dtype into an inexact dtype, if it is not already one."""
return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
def _complex_elem_type(dtype):
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype
def _result_dtype(op, *args):
"""Compute result dtype of applying op to arguments with given dtypes."""
args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args]
return _dtype(op(*args))
def _arraylike(x):
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or isscalar(x))
def _stackable(*args):
return _all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add
def _check_arraylike(fun_name, *args):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if _any(not _arraylike(arg) for arg in args):
pos, arg = next((i, arg) for i, arg in enumerate(args)
if not _arraylike(arg))
msg = "{} requires ndarray or scalar arguments, got {} at position {}."
raise TypeError(msg.format(fun_name, type(arg), pos))
def _check_no_float0s(fun_name, *args):
"""Check if none of the args have dtype float0."""
if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
raise TypeError(
f"Called {fun_name} with a float0 array. "
"float0s do not support any operations by design because they "
"are not compatible with non-trivial vector spaces. No implicit dtype "
"conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
"to cast a float0 array to a regular zeros array. \n"
"If you didn't expect to get a float0 you might have accidentally "
"taken a gradient with respect to an integer argument.")
def _promote_args(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))
def _promote_args_inexact(fun_name, *args):
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
def _convert_and_clip_integer(val, dtype):
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
range rather than wrapping.
Args:
val: value to be converted
dtype: dtype of output
Returns:
equivalent of val in new dtype
Examples
--------
Normal integer type conversion will wrap:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
DeviceArray(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
DeviceArray(2147483647, dtype=int32)
"""
val = val if isinstance(val, ndarray) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")
val_dtype = dtypes.canonicalize_dtype(val.dtype)
if val_dtype != val.dtype:
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
# This happens in X32 mode and can either come from a jax value created in another
# context, or a Python integer converted to int64.
pass
min_val = _constant_like(val, _max(iinfo(dtype).min, iinfo(val_dtype).min))
max_val = _constant_like(val, _min(iinfo(dtype).max, iinfo(val_dtype).max))
return clip(val, min_val, max_val).astype(dtype)
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))
@_wraps(np.load, update_doc=False)
def load(*args, **kwargs):
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
out = asarray(out)
return out
### implementations of numpy functions in terms of lax
@_wraps(np.fmin)
@jit
def fmin(x1, x2):
return where((x1 < x2) | isnan(x2), x1, x2)
@_wraps(np.fmax)
@jit
def fmax(x1, x2):
return where((x1 > x2) | isnan(x2), x1, x2)
@_wraps(np.issubdtype)
def issubdtype(arg1, arg2):
return dtypes.issubdtype(arg1, arg2)
@_wraps(np.isscalar)
def isscalar(element):
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
iterable = np.iterable
@_wraps(np.result_type)
def result_type(*args):
return dtypes.result_type(*args)
def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
else:
fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
fn = jit(fn, inline=True)
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
if promote_to_inexact:
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
fn = jit(fn, inline=True)
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
fn = jit(fn, inline=True)
if lax_doc:
doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
return _wraps(numpy_fn, lax_description=doc)(fn)
else:
return _wraps(numpy_fn)(fn)
fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
invert = _one_to_one_unop(np.invert, lax.bitwise_not)
negative = _one_to_one_unop(np.negative, lax.neg)
positive = _one_to_one_unop(np.positive, lambda x: x)
floor = _one_to_one_unop(np.floor, lax.floor, True)
ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
exp = _one_to_one_unop(np.exp, lax.exp, True)
log = _one_to_one_unop(np.log, lax.log, True)
expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
sin = _one_to_one_unop(np.sin, lax.sin, True)
cos = _one_to_one_unop(np.cos, lax.cos, True)
tan = _one_to_one_unop(np.tan, lax.tan, True)
arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
arccos = _one_to_one_unop(np.arccos, lax.acos, True)
arctan = _one_to_one_unop(np.arctan, lax.atan, True)
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True)
add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
equal = _one_to_one_binop(np.equal, lax.eq)
multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
not_equal = _one_to_one_binop(np.not_equal, lax.ne)
subtract = _one_to_one_binop(np.subtract, lax.sub)
arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
minimum = _one_to_one_binop(np.minimum, lax.min)
maximum = _one_to_one_binop(np.maximum, lax.max)
float_power = _one_to_one_binop(np.float_power, lax.pow, True)
nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
@_wraps(np.arccosh)
@jit
def arccosh(x):
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*_promote_args_inexact("arccosh", x))
if issubdtype(out.dtype, np.complexfloating):
out = where(real(out) < 0, lax.neg(out), out)
return out
def _comparison_op(numpy_fn, lax_fn):
# TODO(https://github.com/google/jax/issues/6713): decorate this function with
# jit, after fixing a surprising interaction with remat(..., concrete=True).
def fn(x1, x2):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
if issubdtype(_dtype(x1), complexfloating):
rx = lax.real(x1)
ry = lax.real(x2)
return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
lax_fn(rx, ry))
return lax_fn(x1, x2)
return _wraps(numpy_fn)(fn)
greater_equal = _comparison_op(np.greater_equal, lax.ge)
greater = _comparison_op(np.greater, lax.gt)
less_equal = _comparison_op(np.less_equal, lax.le)
less = _comparison_op(np.less, lax.lt)
def _logical_op(np_op, bitwise_op):
@_wraps(np_op, update_doc=False)
@partial(jit, inline=True)
def op(*args):
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x))
for x in args)
return bitwise_op(*_promote_args(np_op.__name__, *args))
return op
logical_and = _logical_op(np.logical_and, lax.bitwise_and)
logical_not = _logical_op(np.logical_not, lax.bitwise_not)
logical_or = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.right_shift)
@partial(jit, inline=True)
def right_shift(x1, x2):
x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
return lax_fn(x1, x2)
@_wraps(np.absolute)
@partial(jit, inline=True)
def absolute(x):
_check_arraylike('absolute', x)
dt = _dtype(x)
return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x)
abs = _wraps(np.abs)(absolute)
@_wraps(np.rint)
@jit
def rint(x):
_check_arraylike('rint', x)
dtype = _dtype(x)
if issubdtype(dtype, integer):
return lax.convert_element_type(x, float_)
if issubdtype(dtype, complexfloating):
return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
@_wraps(np.sign)
@jit
def sign(x):
_check_arraylike('sign', x)
dtype = _dtype(x)
if issubdtype(dtype, complexfloating):
re = lax.real(x)
return lax.complex(
lax.sign(where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
return lax.sign(x)
@_wraps(np.copysign)
@jit
def copysign(x1, x2):
x1, x2 = _promote_args_inexact("copysign", x1, x2)
if issubdtype(_dtype(x1), complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
return where(signbit(x2), -lax.abs(x1), lax.abs(x1))
@_wraps(np.true_divide)
@partial(jit, inline=True)
def true_divide(x1, x2):
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
divide = true_divide
@_wraps(np.floor_divide)
@jit
def floor_divide(x1, x2):
x1, x2 = _promote_args("floor_divide", x1, x2)
dtype = _dtype(x1)
if issubdtype(dtype, integer):
quotient = lax.div(x1, x2)
select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
# TODO(mattjj): investigate why subtracting a scalar was causing promotion
return where(select, quotient - np.array(1, _dtype(quotient)), quotient)
elif issubdtype(dtype, complexfloating):
x1r = lax.real(x1)
x1i = lax.imag(x1)
x2r = lax.real(x2)
x2i = lax.imag(x2)
which = lax.ge(lax.abs(x2r), lax.abs(x2i))
rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i))
rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1))
out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
return lax.convert_element_type(out, dtype)
else:
return _float_divmod(x1, x2)[0]
@_wraps(np.divmod)
@jit
def divmod(x1, x2):
x1, x2 = _promote_args("divmod", x1, x2)
if issubdtype(_dtype(x1), integer):
return floor_divide(x1, x2), remainder(x1, x2)
else:
return _float_divmod(x1, x2)
def _float_divmod(x1, x2):
# see float_divmod in floatobject.c of CPython
mod = lax.rem(x1, x2)
div = lax.div(lax.sub(x1, mod), x2)
ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
mod = lax.select(ind, mod + x2, mod)
div = lax.select(ind, div - _constant_like(div, 1), div)
return lax.round(div), mod
@partial(jit, inline=True)
def _power(x1, x2):
x1, x2 = _promote_args("power", x1, x2)
dtype = _dtype(x1)
if not issubdtype(dtype, integer):
return lax.pow(x1, x2)
# Integer power => use binary exponentiation.
# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
for _ in range(bits):
acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, one)
return acc
@_wraps(np.power)
def power(x1, x2):
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
# code path is to make sure we end up with a precise output for the common
# pattern ``x ** 2`` or similar.
if isinstance(core.get_aval(x2), ConcreteArray):
try:
x2 = operator.index(x2)
except TypeError:
pass
else:
return lax.integer_pow(x1, x2)
return _power(x1, x2)
@custom_jvp
@_wraps(np.logaddexp)
@jit
def logaddexp(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.log1p(lax.exp(delta)))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = _constant_like(x, _a)
two_a = _constant_like(x, 2 * _a)
zero = _constant_like(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)
@logaddexp.defjvp
def _logaddexp_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2)
primal_out = logaddexp(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
def _replace_inf(x):
return lax.select(isposinf(real(x)), zeros_like(x), x)
@custom_jvp
@_wraps(np.logaddexp2)
@jit
def logaddexp2(x1, x2):
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
if issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
@_wraps(np.log2)
@partial(jit, inline=True)
def log2(x):
x, = _promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10)
@partial(jit, inline=True)
def log10(x):
x, = _promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2)
@partial(jit, inline=True)
def exp2(x):
x, = _promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit)
@jit
def signbit(x):
x, = _promote_args("signbit", x)
dtype = _dtype(x)
if issubdtype(dtype, integer):
return lax.lt(x, _constant_like(x, 0))
elif issubdtype(dtype, bool_):
return full_like(x, False, dtype=bool_)
elif not issubdtype(dtype, floating):
raise ValueError(
"jax.numpy.signbit is not well defined for %s" % dtype)
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
# F32.
if dtype == bfloat16:
dtype = float32
x = lax.convert_element_type(x, float32)
info = finfo(dtype)
if info.bits not in _INT_DTYPES:
raise NotImplementedError(
"jax.numpy.signbit only supports 16, 32, and 64-bit types.")
int_type = _INT_DTYPES[info.bits]
x = lax.bitcast_convert_type(x, int_type)
return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
@_wraps(np.trapz)
@partial(jit, static_argnames=('axis',))
def trapz(y, x=None, dx=1.0, axis: int = -1):
_check_arraylike('trapz', y)
y = moveaxis(y, axis, -1)
if x is not None:
if ndim(x) == 1:
dx = diff(x)
else:
dx = moveaxis(diff(x, axis=axis), axis, -1)
return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1)
@_wraps(np.trunc)
@jit
def trunc(x):
_check_arraylike('trunc', x)
return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
@partial(jit, static_argnums=(2, 3, 4))
def _conv(x, y, mode, op, precision):
if ndim(x) != 1 or ndim(y) != 1:
raise ValueError(f"{op}() only support 1-dimensional inputs.")
x, y = _promote_dtypes_inexact(x, y)
if len(x) == 0 or len(y) == 0:
raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
out_order = slice(None)
if op == 'correlate':
y = conj(y)
if len(x) < len(y):
x, y = y, x
out_order = slice(None, None, -1)
elif op == 'convolve':
if len(x) < len(y):
x, y = y, x
y = flip(y)
if mode == 'valid':
padding = [(0, 0)]
elif mode == 'same':
padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)]
elif mode == 'full':
padding = [(y.shape[0] - 1, y.shape[0] - 1)]
else:
raise ValueError("mode must be one of ['full', 'same', 'valid']")
result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
padding, precision=precision)
return result[0, 0, out_order]
@_wraps(np.convolve, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def convolve(a, v, mode='full', *, precision=None):
_check_arraylike("convolve", a, v)
return _conv(a, v, mode, 'convolve', precision)
@_wraps(np.correlate, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('mode', 'precision'))
def correlate(a, v, mode='valid', *, precision=None):
_check_arraylike("correlate", a, v)
return _conv(a, v, mode, 'correlate', precision)
def _normalize_float(x):
info = finfo(_dtype(x))
cond = lax.abs(x) < info.tiny
x1 = where(cond, x * lax._const(x, 1 << info.nmant), x)
x2 = where(cond, lax._const(np.int32, -info.nmant), lax._const(np.int32, 0))
int_type = _INT_DTYPES[info.bits]
return lax.bitcast_convert_type(x1, int_type), x2
@_wraps(np.ldexp)
@jit
def ldexp(x1, x2):
_check_arraylike("ldexp", x1, x2)
dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
x1, x2 = _promote_shapes("ldexp", x1, x2)
x1 = lax.convert_element_type(x1, dtype)
info = finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
int_type = _INT_DTYPES[info.bits]
x, e = _normalize_float(x1)
x2 += e + ((x >> info.nmant) & mask) - bias
# find underflow/overflow before denormalization
underflow_cond = x2 < -(bias + info.nmant)
overflow_cond = x2 > bias
m = ones_like(x, dtype=dtype)
# denormals
cond = x2 < -bias + 1
x2 = where(cond, x2 + info.nmant, x2)
m = where(cond, m / (1 << info.nmant), m)
x2 = lax.convert_element_type(x2, np.int32)
x &= ~(mask << info.nmant)
x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
# underflow
x = where(underflow_cond, zeros_like(x, dtype=dtype), x)
# overflow
x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x)
# ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
@_wraps(np.frexp)
@jit
def frexp(x):
_check_arraylike("frexp", x)
x = asarray(x)
if issubdtype(x.dtype, complexfloating):
raise TypeError("frexp does not support complex-valued inputs")
elif not issubdtype(x.dtype, floating):
x = lax.convert_element_type(x, float_)
dtype = _dtype(x)
info = finfo(dtype)
mask = (1 << info.nexp) - 1
bias = ((1 << info.nexp) - 1) >> 1
x1, x2 = _normalize_float(x)
x2 += ((x1 >> info.nmant) & mask) - bias + 1
x1 &= ~(mask << info.nmant)
x1 |= (bias - 1) << info.nmant
x1 = lax.bitcast_convert_type(x1, dtype)
cond = isinf(x) | isnan(x) | (x == 0)
x2 = where(cond, zeros_like(x2), x2)
return where(cond, x, x1), lax.convert_element_type(x2, int32)
@_wraps(np.remainder)
@jit
def remainder(x1, x2):
x1, x2 = _promote_args("remainder", x1, x2)
zero = _constant_like(x1, 0)
trunc_mod = lax.rem(x1, x2)
trunc_mod_not_zero = lax.ne(trunc_mod, zero)
do_plus = lax.bitwise_and(
lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
mod = _wraps(np.mod)(remainder)
@_wraps(np.fmod)
@jit
def fmod(x1, x2):
_check_arraylike("fmod", x1, x2)
if issubdtype(result_type(x1, x2), integer):
x2 = where(x2 == 0, 1, x2)
return lax.rem(*_promote_args("fmod", x1, x2))
@_wraps(np.square)
@partial(jit, inline=True)
def square(x):
_check_arraylike("square", x)
return lax.integer_pow(x, 2)
@_wraps(np.deg2rad)
@partial(jit, inline=True)
def deg2rad(x):
x, = _promote_args_inexact("deg2rad", x)
return lax.mul(x, lax._const(x, pi / 180))
@_wraps(np.rad2deg)
@partial(jit, inline=True)
def rad2deg(x):
x, = _promote_args_inexact("rad2deg", x)
return lax.mul(x, lax._const(x, 180 / pi))
degrees = rad2deg
radians = deg2rad
@_wraps(np.histogram_bin_edges)
def histogram_bin_edges(a, bins=10, range=None, weights=None):
if isinstance(bins, str):
raise NotImplementedError("string values for `bins` not implemented.")
_check_arraylike("histogram_bin_edges", a, bins)
a = ravel(a)
b = asarray(bins)
if b.ndim == 1:
return b
if range is None:
range = [a.min(), a.max()]
assert len(range) == 2
range = asarray(range)
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
where(ptp(range) == 0, range[1] + 0.5, range[1]))
dtype = _dtype(a)
if issubdtype(dtype, integer):
dtype = promote_types(dtype, float32)
return linspace(range[0], range[1], bins + 1, dtype=dtype)
@_wraps(np.histogram)
def histogram(a, bins=10, range=None, weights=None, density=None):
_check_arraylike("histogram", a, bins)
if weights is not None and a.shape != weights.shape:
raise ValueError("weights should have the same shape as a.")
a = ravel(a)
if weights is not None:
weights = ravel(weights)
else:
weights = ones_like(a)
bin_edges = histogram_bin_edges(a, bins, range, weights)
bin_idx = searchsorted(bin_edges, a, side='right')
bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
counts = bincount(bin_idx, weights, length=len(bin_edges))[1:]
if density:
bin_widths = diff(bin_edges)
counts = counts / bin_widths / counts.sum()
return counts, bin_edges
@_wraps(np.histogram2d)
def histogram2d(x, y, bins=10, range=None, weights=None, density=None):
_check_arraylike("histogram2d", x, y)
try:
N = len(bins)
except TypeError:
N = 1
if N != 1 and N != 2:
x_edges = y_edges = asarray(bins)
bins = [x_edges, y_edges]
sample = transpose(asarray([x, y]))
hist, edges = histogramdd(sample, bins, range, weights, density)
return hist, edges[0], edges[1]
@_wraps(np.histogramdd)
def histogramdd(sample, bins=10, range=None, weights=None, density=None):
_check_arraylike("histogramdd", sample)
N, D = shape(sample)
if weights is not None and weights.shape != (N,):
raise ValueError("should have one weight for each sample.")
if range is not None and (
len(range) != D or _any(r is not None and len(r) != 2 for r in range)):
raise ValueError(f"For sample.shape={(N, D)}, range must be a sequence "
f"of {D} pairs or Nones; got range={range}")
try:
num_bins = len(bins)
if num_bins != D:
raise ValueError("should be a bin for each dimension.")
except TypeError:
# when bin_size is integer, the same bin is used for each dimension
bins = D * [bins]
bin_idx_by_dim = D*[None]
nbins = np.empty(D, int)
bin_edges_by_dim = D*[None]
dedges = D*[None]
for i in builtins.range(D):
range_i = None if range is None else range[i]
bin_edges = histogram_bin_edges(sample[:, i], bins[i], range_i, weights)
bin_idx = searchsorted(bin_edges, sample[:, i], side='right')
bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx)
bin_idx_by_dim[i] = bin_idx
nbins[i] = len(bin_edges) + 1
bin_edges_by_dim[i] = bin_edges
dedges[i] = diff(bin_edges_by_dim[i])
xy = ravel_multi_index(bin_idx_by_dim, nbins, mode='clip')
hist = bincount(xy, weights, length=nbins.prod())
hist = reshape(hist, nbins)
core = D*(slice(1, -1),)
hist = hist[core]
if density:
hist /= hist.sum()
for norm in ix_(*dedges):
hist /= norm
return hist, bin_edges_by_dim
@_wraps(np.heaviside)
@jit
def heaviside(x1, x2):
_check_arraylike("heaviside", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
zero = lax._const(x1, 0)
return where(lax.lt(x1, zero), zero,
where(lax.gt(x1, zero), lax._const(x1, 1), x2))
@_wraps(np.hypot)
@jit
def hypot(x1, x2):
_check_arraylike("hypot", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
x2 = lax.abs(x2)
x1, x2 = maximum(x1, x2), minimum(x1, x2)
return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, ones_like(x1), x1)))))
@_wraps(np.reciprocal)
@partial(jit, inline=True)
def reciprocal(x):
_check_arraylike("reciprocal", x)
x, = _promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)
@_wraps(np.sinc, update_doc=False)
@jit
def sinc(x):
_check_arraylike("sinc", x)
x, = _promote_dtypes_inexact(x)
eq_zero = lax.eq(x, lax._const(x, 0))
pi_x = lax.mul(lax._const(x, pi), x)
safe_pi_x = where(eq_zero, lax._const(x, 1), pi_x)
return where(eq_zero, _sinc_maclaurin(0, pi_x),
lax.div(lax.sin(safe_pi_x), safe_pi_x))
@partial(custom_jvp, nondiff_argnums=(0,))
def _sinc_maclaurin(k, x):
# compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
# compute the monomial term in the jvp rule)
if k % 2:
return lax.full_like(x, 0)
else:
return lax.full_like(x, (-1) ** (k // 2) / (k + 1))
@_sinc_maclaurin.defjvp
def _sinc_maclaurin_jvp(k, primals, tangents):
(x,), (t,) = primals, tangents
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
_ARRAY_VIEW_DOC = """
The JAX version of this function may in some cases return a copy rather than a
view of the input.
"""
@_wraps(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a, axes=None):
_check_arraylike("transpose", a)
axes = np.arange(ndim(a))[::-1] if axes is None else axes
return lax.transpose(a, axes)
@_wraps(np.rot90, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('k', 'axes'))
def rot90(m, k=1, axes=(0, 1)):
_check_arraylike("rot90", m)
ax1, ax2 = axes
ax1 = _canonicalize_axis(ax1, ndim(m))
ax2 = _canonicalize_axis(ax2, ndim(m))
if ax1 == ax2:
raise ValueError("Axes must be different") # same as numpy error
k = k % 4
if k == 0:
return m
elif k == 2:
return flip(flip(m, ax1), ax2)
else:
perm = list(range(m.ndim))
perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
if k == 1:
return transpose(flip(m, ax2), perm)
else:
return flip(transpose(m, perm), ax2)
@_wraps(np.flip, lax_description=_ARRAY_VIEW_DOC)
def flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None):
return _flip(m, _ensure_optional_axes(axis))
@partial(jit, static_argnames=('axis',))
def _flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None):
_check_arraylike("flip", m)
if axis is None:
return lax.rev(m, list(range(len(shape(m)))))
axis = _ensure_index_tuple(axis)
return lax.rev(m, [_canonicalize_axis(ax, ndim(m)) for ax in axis])
@_wraps(np.fliplr, lax_description=_ARRAY_VIEW_DOC)
def fliplr(m):
return _flip(m, 1)
@_wraps(np.flipud, lax_description=_ARRAY_VIEW_DOC)
def flipud(m):
return _flip(m, 0)
@_wraps(np.conjugate)
@partial(jit, inline=True)
def conjugate(x):
_check_arraylike("conjugate", x)
return lax.conj(x) if iscomplexobj(x) else x
conj = conjugate
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val):
_check_arraylike("imag", val)
return lax.imag(val) if iscomplexobj(val) else zeros_like(val)
@_wraps(np.real)
@partial(jit, inline=True)
def real(val):
_check_arraylike("real", val)
return lax.real(val) if iscomplexobj(val) else val
@_wraps(np.iscomplex)
@jit
def iscomplex(x):
i = imag(x)
return lax.ne(i, lax._const(i, 0))
@_wraps(np.isreal)
@jit
def isreal(x):
i = imag(x)
return lax.eq(i, lax._const(i, 0))
@_wraps(np.angle)
@jit
def angle(z):
re = real(z)
im = imag(z)
dtype = _dtype(re)
if not issubdtype(dtype, inexact) or (
issubdtype(_dtype(z), floating) and ndim(z) == 0):
dtype = dtypes.canonicalize_dtype(float_)
re = lax.convert_element_type(re, dtype)
im = lax.convert_element_type(im, dtype)
return lax.atan2(im, re)
@_wraps(np.diff)
@partial(jit, static_argnames=('n', 'axis'))
def diff(a, n=1, axis: int = -1, prepend=None, append=None):
_check_arraylike("diff", a)
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.diff")
if n == 0:
return a
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")
if ndim(a) == 0:
raise ValueError(f"diff requires input that is at least one dimensional; got {a}")
nd = a.ndim
axis = _canonicalize_axis(axis, nd)
combined = []
if prepend is not None:
_check_arraylike("diff", prepend)
if isscalar(prepend):
shape = list(a.shape)
shape[axis] = 1
prepend = broadcast_to(prepend, tuple(shape))
combined.append(prepend)
combined.append(a)
if append is not None:
_check_arraylike("diff", append)
if isscalar(append):
shape = list(a.shape)
shape[axis] = 1
append = broadcast_to(append, tuple(shape))
combined.append(append)
if len(combined) > 1:
a = concatenate(combined, axis)
slice1 = [slice(None)] * nd
slice2 = [slice(None)] * nd
slice1[axis] = slice(1, None)
slice2[axis] = slice(None, -1)
slice1_tuple = tuple(slice1)
slice2_tuple = tuple(slice2)
op = not_equal if a.dtype == np.bool_ else subtract
for _ in range(n):
a = op(a[slice1_tuple], a[slice2_tuple])
return a
_EDIFF1D_DOC = """\
Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
loses precision.
"""
@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC)
@jit
def ediff1d(ary, to_end=None, to_begin=None):
_check_arraylike("ediff1d", ary)
ary = ravel(ary)
result = lax.sub(ary[1:], ary[:-1])
if to_begin is not None:
_check_arraylike("ediff1d", to_begin)
result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result))
if to_end is not None:
_check_arraylike("ediff1d", to_end)
result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype))))
return result
@_wraps(np.gradient, skip_params=['edge_order'])
@partial(jit, static_argnames=('axis', 'edge_order'))
def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
edge_order=None):
if edge_order is not None:
raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
def gradient_along_axis(a, h, axis):
sliced = partial(lax.slice_in_dim, a, axis=axis)
a_grad = concatenate((
(sliced(1, 2) - sliced(0, 1)), # upper edge
(sliced(2, None) - sliced(None, -2)) * 0.5, # inner
(sliced(-1, None) - sliced(-2, -1)), # lower edge
), axis)
return a_grad / h
a = f
axis_tuple: Tuple[int, ...]
if axis is None:
axis_tuple = tuple(range(a.ndim))
else:
if isinstance(axis, int):
axis = (axis,)
elif not isinstance(axis, tuple) and not isinstance(axis, list):
raise ValueError("Give `axis` either as int or iterable")
elif len(axis) == 0:
return []
axis_tuple = tuple(_canonicalize_axis(i, a.ndim) for i in axis)
if _min([s for i, s in enumerate(a.shape) if i in axis_tuple]) < 2:
raise ValueError("Shape of array too small to calculate "
"a numerical gradient, "
"at least 2 elements are required.")
len_axes = len(axis_tuple)
n = len(varargs)
if n == 0 or varargs is None:
# no spacing
dx = [1.0] * len_axes
elif n == 1:
# single value for all axes
dx = list(varargs) * len_axes
elif n == len_axes:
dx = list(varargs)
else:
TypeError("Invalid number of spacing arguments %d" % n)
if ndim(dx[0]) != 0:
raise NotImplementedError("Non-constant spacing not implemented")
# TODO: use jax.lax loop tools if possible
a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis_tuple, dx)]
if len(axis_tuple) == 1:
a_grad = a_grad[0]
return a_grad
@_wraps(np.isrealobj)
def isrealobj(x):
return not iscomplexobj(x)
_POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
_check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
if deg < 0:
raise ValueError("expected deg >= 0")
if x.ndim != 1:
raise TypeError("expected 1D vector for x")
if x.size == 0:
raise TypeError("expected non-empty vector for x")
if y.ndim < 1 or y.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x.shape[0] != y.shape[0]:
raise TypeError("expected x and y to have same length")
# set rcond
if rcond is None:
rcond = len(x)*finfo(x.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x, order)
rhs = y
# apply weighting
if w is not None:
_check_arraylike("polyfit", w)
w, = _promote_dtypes_inexact(w)
if w.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w[:, newaxis]
if rhs.ndim == 2:
rhs *= w[:, newaxis]
else:
rhs *= w
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[newaxis,:]
from jax._src.numpy import linalg
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients
if full:
return c, resids, rank, s, rcond
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
if cov == "unscaled":
fac = 1
else:
if len(x) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:,:, newaxis] * fac
else:
return c
@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a, newshape, order="C"):
_stackable(a) or _check_arraylike("reshape", a)
try:
return a.reshape(newshape, order=order) # forward to method for ndarrays
except AttributeError:
return _reshape(a, newshape, order=order)
def _compute_newshape(a, newshape):
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream, in
# reshape_shape_rule.
try: iter(newshape)
except: iterable = False
else: iterable = True
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
if core.symbolic_equal_dim(d, -1) else d
for d in newshape)
def _reshape(a, *args, order="C"):
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
if order == "C":
return lax.reshape(a, newshape, None)
elif order == "F":
dims = np.arange(ndim(a))[::-1]
return lax.reshape(a, newshape[::-1], dims).T
elif order == "A":
raise NotImplementedError("np.reshape order=A is not implemented.")
else:
raise ValueError("Unexpected value for 'order' argument: {}.".format(order))
def _transpose(a, *args):
if not args:
axis = None
elif len(args) == 1:
axis = args[0] if args[0] is None else _ensure_index_tuple(args[0])
else:
axis = _ensure_index_tuple(args)
return transpose(a, axis)
@_wraps(np.ravel, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('order',), inline=True)
def ravel(a, order="C"):
_stackable(a) or _check_arraylike("ravel", a)
if order == "K":
raise NotImplementedError("Ravel not implemented for order='K'.")
return reshape(a, (size(a),), order)
@_wraps(np.ravel_multi_index)
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
_check_arraylike("ravel_multi_index", *multi_index)
for index in multi_index:
if mode == 'raise':
core.concrete_or_error(array, index,
"The error occurred because ravel_multi_index was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if not issubdtype(_dtype(index), integer):
raise TypeError("only int indices permitted")
if mode == "raise":
if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)):
raise ValueError("invalid entry in coordinates array")
elif mode == "clip":
multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)]
elif mode == "wrap":
multi_index = [i % d for i, d in zip(multi_index, dims)]
else:
raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")
if order == "F":
strides = np.cumprod((1,) + dims[:-1])
elif order == "C":
strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
else:
raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
result = array(0, dtype=dtypes.canonicalize_dtype(int_))
for i, s in zip(multi_index, strides):
result = result + i * s
return result
_UNRAVEL_INDEX_DOC = """\
Unlike numpy's implementation of unravel_index, negative indices are accepted
and out-of-bounds indices are clipped.
"""
@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
def unravel_index(indices, shape):
_check_arraylike("unravel_index", indices)
sizes = append(array(shape), 1)
cumulative_sizes = cumprod(sizes[::-1])[::-1]
total_size = cumulative_sizes[0]
# Clip so raveling and unraveling an oob index will not change the behavior
clipped_indices = clip(indices, -total_size, total_size - 1)
# Add enough trailing dims to avoid conflict with clipped_indices
cumulative_sizes = expand_dims(cumulative_sizes, range(1, 1 + _ndim(indices)))
clipped_indices = expand_dims(clipped_indices, axis=0)
idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
# TODO(jakevdp): return tuple(idx) once it behaves properly (#3821)
return tuple(lax.index_in_dim(idx, i, keepdims=False) for i in range(idx.shape[0]))
@_wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
def resize(a, new_shape):
_check_arraylike("resize", a)
new_shape = _ensure_index_tuple(new_shape)
if _any(dim_length < 0 for dim_length in new_shape):
raise ValueError("all elements of `new_shape` must be non-negative")
a = ravel(a)
new_size = _prod(new_shape)
if a.size == 0 or new_size == 0:
return zeros_like(a, shape=new_shape)
repeats = ceil_of_ratio(new_size, a.size)
a = tile(a, repeats)[:new_size]
return reshape(a, new_shape)
@_wraps(np.squeeze, lax_description=_ARRAY_VIEW_DOC)
def squeeze(a, axis: Optional[Union[int, Tuple[int, ...]]] = None):
return _squeeze(a, _ensure_index_tuple(axis) if axis is not None else None)
@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a, axis):
_check_arraylike("squeeze", a)
if axis is None:
a_shape = shape(a)
axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
return lax.squeeze(a, axis)
@_wraps(np.expand_dims)
def expand_dims(a, axis: Union[int, Sequence[int]]):
_check_arraylike("expand_dims", a)
return lax.expand_dims(a, _ensure_index_tuple(axis))
@_wraps(np.swapaxes, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('axis1', 'axis2'), inline=True)
def swapaxes(a, axis1: int, axis2: int):
_check_arraylike("swapaxes", a)
perm = np.arange(ndim(a))
perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
return lax.transpose(a, perm)
@_wraps(np.moveaxis, lax_description=_ARRAY_VIEW_DOC)
def moveaxis(a, source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]]):
return _moveaxis(a, _ensure_index_tuple(source),
_ensure_index_tuple(destination))
@partial(jit, static_argnames=('source', 'destination'), inline=True)
def _moveaxis(a, source: Tuple[int, ...], destination: Tuple[int, ...]):
_check_arraylike("moveaxis", a)
source = tuple(_canonicalize_axis(i, ndim(a)) for i in source)
destination = tuple(_canonicalize_axis(i, ndim(a)) for i in destination)
if len(source) != len(destination):
raise ValueError("Inconsistent number of elements: {} vs {}"
.format(len(source), len(destination)))
perm = [i for i in range(ndim(a)) if i not in source]
for dest, src in sorted(zip(destination, source)):
perm.insert(dest, src)
return lax.transpose(a, perm)
@_wraps(np.isclose)
@partial(jit, static_argnames=('equal_nan',))
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
a, b = _promote_args("isclose", a, b)
dtype = _dtype(a)
if issubdtype(dtype, inexact):
if issubdtype(dtype, complexfloating):
dtype = _complex_elem_type(dtype)
rtol = lax.convert_element_type(rtol, dtype)
atol = lax.convert_element_type(atol, dtype)
out = lax.le(
lax.abs(lax.sub(a, b)),
lax.add(atol, lax.mul(rtol, lax.abs(b))))
# This corrects the comparisons for infinite and nan values
a_inf = isinf(a)
b_inf = isinf(b)
any_inf = logical_or(a_inf, b_inf)
both_inf = logical_and(a_inf, b_inf)
# Make all elements where either a or b are infinite to False
out = logical_and(out, logical_not(any_inf))
# Make all elements where both a or b are the same inf to True
same_value = lax.eq(a, b)
same_inf = logical_and(both_inf, same_value)
out = logical_or(out, same_inf)
# Make all elements where either a or b is NaN to False
a_nan = isnan(a)
b_nan = isnan(b)
any_nan = logical_or(a_nan, b_nan)
out = logical_and(out, logical_not(any_nan))
if equal_nan:
# Make all elements where both a and b is NaN to True
both_nan = logical_and(a_nan, b_nan)
out = logical_or(out, both_nan)
return out
else:
return lax.eq(a, b)
@_wraps(np.interp)
@partial(jit, static_argnames=('period',))
def interp(x, xp, fp, left=None, right=None, period=None):
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x, xp, fp = _promote_dtypes_inexact(x, xp, fp)
if period is not None:
if period == 0:
raise ValueError(f"period must be a non-zero value; got {period}")
period = abs(period)
x = x % period
xp = xp % period
xp, fp = lax.sort_key_val(xp, fp)
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = concatenate([fp[-1:], fp, fp[:1]])
i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
df = fp[i] - fp[i - 1]
dx = xp[i] - xp[i - 1]
delta = x - xp[i - 1]
f = where((dx == 0), fp[i], fp[i - 1] + (delta / dx) * df)
if period is None:
f = where(x < xp[0], fp[0] if left is None else left, f)
f = where(x > xp[-1], fp[-1] if right is None else right, f)
return f
@_wraps(np.in1d, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
@partial(jit, static_argnames=('assume_unique', 'invert',))
def in1d(ar1, ar2, assume_unique=False, invert=False):
_check_arraylike("in1d", ar1, ar2)
ar1 = ravel(ar1)
ar2 = ravel(ar2)
# Note: an algorithm based on searchsorted has better scaling, but in practice
# is very slow on accelerators because it relies on lax control flow. If XLA
# ever supports binary search natively, we should switch to this:
# ar2 = jnp.sort(ar2)
# ind = jnp.searchsorted(ar2, ar1)
# if invert:
# return ar1 != ar2[ind]
# else:
# return ar1 == ar2[ind]
if invert:
return (ar1[:, None] != ar2[None, :]).all(-1)
else:
return (ar1[:, None] == ar2[None, :]).any(-1)
_SETDIFF1D_DOC = """\
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output array: it must be specified statically for ``jnp.setdiff1d``
to be compiled with non-static operands. If specified, the first `size` unique elements will
be returned; if there are fewer unique elements than `size` indicates, the return value will
be padded with the `fill_value`, which defaults to zero."""
@_wraps(np.setdiff1d, lax_description=_SETDIFF1D_DOC)
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
_check_arraylike("setdiff1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()")
ar1 = asarray(ar1)
fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype)
if ar1.size == 0:
return full_like(ar1, fill_value, shape=size or 0)
if not assume_unique:
ar1 = unique(ar1, size=size and ar1.size)
mask = in1d(ar1, ar2, invert=True)
if size is None:
return ar1[mask]
else:
if not (assume_unique or size is None):
# Set mask to zero at locations corresponding to unique() padding.
n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum()
mask = where(arange(ar1.size) < n_unique, mask, False)
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
_UNION1D_DOC = """\
Because the size of the output of ``union1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output array: it must be specified statically for ``jnp.union1d``
to be compiled with non-static operands. If specified, the first `size` unique elements
will be returned; if there are fewer unique elements than `size` indicates, the return
value will be padded with `fill_value`, which defaults to the minimum value of the union."""
@_wraps(np.union1d, lax_description=_UNION1D_DOC)
def union1d(ar1, ar2, *, size=None, fill_value=None):
_check_arraylike("union1d", ar1, ar2)
if size is None:
ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()")
else:
size = core.concrete_or_error(operator.index, size, "The error arose in union1d()")
return unique(concatenate((ar1, ar2), axis=None), size=size, fill_value=fill_value)
@_wraps(np.setxor1d, lax_description="""
In the JAX version, the input arrays are explicitly flattened regardless
of assume_unique value.
""")
def setxor1d(ar1, ar2, assume_unique=False):
_check_arraylike("setxor1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()")
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if not assume_unique:
ar1 = unique(ar1)
ar2 = unique(ar2)
aux = concatenate((ar1, ar2))
if aux.size == 0:
return aux
aux = sort(aux)
flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True])))
return aux[flag[1:] & flag[:-1]]
@partial(jit, static_argnums=2)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
"""
Helper function for intersect1d which is jit-able
"""
ar = concatenate((ar1, ar2))
if return_indices:
iota = lax.broadcasted_iota(np.int64, shape(ar), dimension=0)
aux, indices = lax.sort_key_val(ar, iota)
else:
aux = sort(ar)
mask = aux[1:] == aux[:-1]
if return_indices:
return aux, mask, indices
else:
return aux, mask
@_wraps(np.intersect1d)
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
_check_arraylike("intersect1d", ar1, ar2)
ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()")
ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()")
if not assume_unique:
if return_indices:
ar1, ind1 = unique(ar1, return_index=True)
ar2, ind2 = unique(ar2, return_index=True)
else:
ar1 = unique(ar1)
ar2 = unique(ar2)
else:
ar1 = ravel(ar1)
ar2 = ravel(ar2)
if return_indices:
aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
else:
aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
int1d = aux[:-1][mask]
if return_indices:
ar1_indices = aux_sort_indices[:-1][mask]
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
if not assume_unique:
ar1_indices = ind1[ar1_indices]
ar2_indices = ind2[ar2_indices]
return int1d, ar1_indices, ar2_indices
else:
return int1d
@_wraps(np.isin, lax_description="""
In the JAX version, the `assume_unique` argument is not referenced.
""")
def isin(element, test_elements, assume_unique=False, invert=False):
result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
return result.reshape(shape(element))
# The `jit` on `where` exists to avoid materializing constants in cases like
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@jit
def _where(condition, x=None, y=None):
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not issubdtype(_dtype(condition), bool_):
condition = lax.ne(condition, zeros_like(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = broadcast_arrays(condition, x, y)
return lax.select(condition, x, y) if not core.is_empty_shape(np.shape(x)) else x
_WHERE_DOC = """\
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can specify the optional ``size`` keyword:
if specified, the first ``size`` True elements will be returned; if there
are fewer True elements than ``size`` indicates, the index arrays will be
padded with ``fill_value`` (default is 0.)
"""
@_wraps(np.where, update_doc=False, lax_description=_WHERE_DOC)
def where(condition, x=None, y=None, *, size=None, fill_value=None):
if x is None and y is None:
_check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
else:
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
return _where(condition, x, y)
@_wraps(np.select)
def select(condlist, choicelist, default=0):
if len(condlist) != len(choicelist):
msg = "condlist must have length equal to choicelist ({} vs {})"
raise ValueError(msg.format(len(condlist), len(choicelist)))
if len(condlist) == 0:
raise ValueError("condlist must be non-empty")
choices = _promote_dtypes(default, *choicelist)
choicelist = choices[1:]
output = choices[0]
for cond, choice in zip(condlist[::-1], choicelist[::-1]):
output = where(cond, choice, output)
return output
@_wraps(np.bincount, lax_description="""\
Jax adds the optional `length` parameter which specifies the output length, and
defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
with non-static operands. Values larger than the specified length will be discarded.
If `length` is specified, `minlength` will be ignored.
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` clips negative values to zero.
""")
def bincount(x, weights=None, minlength=0, *, length=None):
_check_arraylike("bincount", x)
if not issubdtype(_dtype(x), integer):
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
raise TypeError(msg)
if ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
minlength = core.concrete_or_error(operator.index, minlength,
"The error occurred because of argument 'minlength' of jnp.bincount.")
if length is None:
x = core.concrete_or_error(asarray, x,
"The error occured because of argument 'x' of jnp.bincount. "
"To avoid this error, pass a static `length` argument.")
length = _max(minlength, x.size and x.max() + 1)
else:
length = core.concrete_or_error(operator.index, length,
"The error occurred because of argument 'length' of jnp.bincount.")
if weights is None:
weights = np.array(1, dtype=int_)
elif shape(x) != shape(weights):
raise ValueError("shape of weights must match shape of x.")
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
@_wraps(getattr(np, "broadcast_shapes", None))
def broadcast_shapes(*shapes):
if not shapes:
return ()
shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes]
return lax.broadcast_shapes(*shapes)
@partial(jit, inline=True)
def broadcast_arrays(*args):
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [shape(arg) for arg in args]
if len(set(shapes)) == 1:
return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [broadcast_to(arg, result_shape) for arg in args]
@_wraps(np.broadcast_to, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")
def broadcast_to(arr, shape):
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
arr = arr if isinstance(arr, ndarray) else array(arr)
shape = (shape,) if ndim(shape) == 0 else shape
shape = canonicalize_shape(shape) # check that shape is concrete
arr_shape = _shape(arr)
if core.symbolic_equal_shape(arr_shape, shape):
return arr
else:
nlead = len(shape) - len(arr_shape)
shape_tail = shape[nlead:]
compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d])
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
if nlead < 0 or not compatible:
msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
raise ValueError(msg.format(arr_shape, shape))
diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d)
for arr_d, shape_d in safe_zip(arr_shape, shape_tail)))
new_dims = tuple(range(nlead)) + tuple(nlead + diff)
kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
def _split(op, ary, indices_or_sections, axis=0):
axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`")
size = ary.shape[axis]
if isinstance(indices_or_sections, (tuple, list)):
indices_or_sections = np.array(
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections], np.int64)
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
[np.int64(size)]])
elif (isinstance(indices_or_sections, (np.ndarray, ndarray)) and
indices_or_sections.ndim > 0):
indices_or_sections = np.array(
[core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
for i_s in indices_or_sections], np.int64)
split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
[np.int64(size)]])
else:
indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections,
f"in jax.numpy.{op} argument 1")
part_size, r = _divmod(size, indices_or_sections)
if r == 0:
split_indices = np.arange(indices_or_sections + 1,
dtype=np.int64) * part_size
elif op == "array_split":
split_indices = np.concatenate(
[np.arange(r + 1, dtype=np.int64) * (part_size + 1),
np.arange(indices_or_sections - r, dtype=np.int64) * part_size
+ ((r + 1) * (part_size + 1) - 1)])
else:
raise ValueError("array split does not result in an equal division")
starts, ends = [0] * ndim(ary), shape(ary)
_subval = lambda x, i, v: subvals(x, [(i, v)])
return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
for start, end in zip(split_indices[:-1], split_indices[1:])]
@_wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
def split(ary, indices_or_sections, axis: int = 0):
return _split("split", ary, indices_or_sections, axis=axis)
def _split_on_axis(np_fun, axis):
@_wraps(np_fun, update_doc=False)
def f(ary, indices_or_sections):
return split(ary, indices_or_sections, axis=axis)
return f
vsplit = _split_on_axis(np.vsplit, axis=0)
hsplit = _split_on_axis(np.hsplit, axis=1)
dsplit = _split_on_axis(np.dsplit, axis=2)
@_wraps(np.array_split)
def array_split(ary, indices_or_sections, axis: int = 0):
return _split("array_split", ary, indices_or_sections, axis=axis)
@_wraps(np.clip, skip_params=['out'])
@jit
def clip(a, a_min=None, a_max=None, out=None):
_check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = maximum(a_min, a)
if a_max is not None:
a = minimum(a_max, a)
return a
@_wraps(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
def round(a, decimals=0, out=None):
_check_arraylike("round", a)
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
dtype = _dtype(a)
if issubdtype(dtype, integer):
if decimals < 0:
raise NotImplementedError(
"integer np.round not implemented for decimals < 0")
return a # no-op on integer types
def _round_float(x):
if decimals == 0:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
# TODO(phawkins): the strategy of rescaling the value isn't necessarily a
# good one since we may be left with an incorrectly rounded value at the
# end due to precision problems. As a workaround for float16, convert to
# float32,
x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
factor = _constant_like(x, 10 ** decimals)
out = lax.div(lax.round(lax.mul(x, factor),
lax.RoundingMethod.TO_NEAREST_EVEN), factor)
return lax.convert_element_type(out, dtype) if dtype == np.float16 else out
if issubdtype(dtype, complexfloating):
return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
else:
return _round_float(a)
around = round
round_ = round
@_wraps(np.fix, skip_params=['out'])
@jit
def fix(x, out=None):
_check_arraylike("fix", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
zero = lax._const(x, 0)
return where(lax.ge(x, zero), floor(x), ceil(x))
@_wraps(np.modf, skip_params=['out'])
@jit
def modf(x, out=None):
_check_arraylike("modf", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
whole = fix(x)
return x - whole, whole
@_wraps(np.isfinite)
@jit
def isfinite(x):
_check_arraylike("isfinite", x)
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.is_finite(x)
elif issubdtype(dtype, complexfloating):
return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
else:
return full_like(x, True, dtype=bool_)
@_wraps(np.isinf)
@jit
def isinf(x):
_check_arraylike("isinf", x)
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.eq(lax.abs(x), _constant_like(x, inf))
elif issubdtype(dtype, complexfloating):
re = lax.real(x)
im = lax.imag(x)
return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, inf)),
lax.eq(lax.abs(im), _constant_like(im, inf)))
else:
return full_like(x, False, dtype=bool_)
def _isposneginf(infinity, x, out):
if out is not None:
raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
dtype = _dtype(x)
if issubdtype(dtype, floating):
return lax.eq(x, _constant_like(x, infinity))
elif issubdtype(dtype, complexfloating):
raise ValueError("isposinf/isneginf are not well defined for complex types")
else:
return full_like(x, False, dtype=bool_)
isposinf = _wraps(np.isposinf, skip_params=['out'])(
lambda x, out=None: _isposneginf(inf, x, out)
)
isneginf = _wraps(np.isneginf, skip_params=['out'])(
lambda x, out=None: _isposneginf(-inf, x, out)
)
@_wraps(np.isnan)
@jit
def isnan(x):
_check_arraylike("isnan", x)
return lax.ne(x, x)
@_wraps(np.nan_to_num)
@jit
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
del copy
_check_arraylike("nan_to_num", x)
dtype = _dtype(x)
if issubdtype(dtype, complexfloating):
return lax.complex(
nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
info = finfo(dtypes.canonicalize_dtype(dtype))
posinf = info.max if posinf is None else posinf
neginf = info.min if neginf is None else neginf
x = where(isnan(x), array(nan, dtype=x.dtype), x)
x = where(isposinf(x), array(posinf, dtype=x.dtype), x)
x = where(isneginf(x), array(neginf, dtype=x.dtype), x)
return x
### Reducers
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
preproc=None, bool_op=None, upcast_f16_for_computation=False,
axis=None, dtype=None, out=None, keepdims=False, initial=None,
where_=None, parallel_reduce=None):
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
# exists, passing along all its arguments.
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
_check_arraylike(name, a)
lax._check_user_dtype_supported(dtype, name)
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
if initial is None and not has_identity:
if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
if where_ is not None:
raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
f"where mask one has to specify 'initial'")
a = a if isinstance(a, ndarray) else asarray(a)
a = preproc(a) if preproc else a
pos_dims, dims = _reduction_dims(a, axis)
result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a)))))
if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
computation_dtype = promote_types(result_dtype, float32)
else:
computation_dtype = result_dtype
a = lax.convert_element_type(a, computation_dtype)
op = op if computation_dtype != np.bool_ else bool_op
# NB: in XLA, init_val must be an identity for the op, so the user-specified
# initial value must be applied afterward.
init_val = _reduction_init_val(a, init_val)
if where_ is not None:
a = where(where_, a, init_val)
if pos_dims is not dims:
if parallel_reduce is None:
raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()")
result = parallel_reduce(a, dims)
else:
result = lax.reduce(a, init_val, op, dims)
if initial is not None:
result = op(lax.convert_element_type(initial, a.dtype), result)
if keepdims:
result = expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)
def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)
def _reduction_dims(a, axis):
if axis is None:
return (tuple(range(ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,)
canon_axis = tuple(_canonicalize_axis_allow_named(x, ndim(a))
for x in axis)
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
if len(canon_pos_axis) != len(canon_axis):
return canon_pos_axis, canon_axis
else:
return canon_axis, canon_axis
def _reduction_init_val(a, init_val):
# This function uses np.* functions because lax pattern matches against the
# specific concrete values of the reduction inputs.
a_dtype = dtypes.canonicalize_dtype(_dtype(a))
if a_dtype == 'bool':
return np.array(init_val > 0, dtype=a_dtype)
try:
return np.array(init_val, dtype=a_dtype)
except OverflowError:
assert issubdtype(a_dtype, integer)
sign, info = np.sign(init_val), iinfo(a_dtype)
return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def _cast_to_bool(operand):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=np.ComplexWarning)
return lax.convert_element_type(operand, bool_)
def _ensure_optional_axes(x):
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "sum", np.sum, lax.add, 0,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum)
@_wraps(np.sum, skip_params=['out'])
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "prod", np.prod, lax.mul, 1,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
@_wraps(np.prod, skip_params=['out'])
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where)
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmax)
@_wraps(np.max, skip_params=['out'])
def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
axis=axis, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.pmin)
@_wraps(np.min, skip_params=['out'])
def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, initial=None, where=None):
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, initial=initial, where=where)
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.all, skip_params=['out'])
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
@partial(jit, static_argnames=('axis', 'keepdims'), inline=True)
def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims, where_=where)
@_wraps(np.any, skip_params=['out'])
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None, *, where=None):
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
product = prod
amin = min
amax = max
alltrue = all
sometrue = any
def _axis_size(a, axis):
if not isinstance(axis, (tuple, list)):
axis = (axis,)
size = 1
a_shape = shape(a)
for a in axis:
size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
return size
@_wraps(np.mean, skip_params=['out'])
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None):
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False, *, where=None):
_check_arraylike("mean", a)
lax._check_user_dtype_supported(dtype, "mean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
if where is None:
if axis is None:
normalizer = core.dimension_as_value(size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
if dtype is None:
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
dtype = float_
else:
dtype = _dtype(a)
dtype = dtypes.canonicalize_dtype(dtype)
return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, dtype))
@_wraps(np.average)
def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
returned=False):
return _average(a, _ensure_optional_axes(axis), weights, returned)
@partial(jit, static_argnames=('axis', 'returned'), inline=True)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
returned=False):
a = asarray(a)
if weights is None: # Treat all weights as 1
avg = mean(a, axis=axis)
if axis is None:
weights_sum = full((), core.dimension_as_value(size(a)), dtype=avg.dtype)
else:
weights_sum = full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype)
else:
weights = asarray(weights)
if issubdtype(a.dtype, inexact):
out_dtype = result_type(a.dtype, weights.dtype)
else:
out_dtype = result_type(a.dtype, weights.dtype, float_)
out_dtype = dtypes.canonicalize_dtype(out_dtype)
a_shape = shape(a)
a_ndim = len(a_shape)
weights_shape = shape(weights)
axis = None if axis is None else _canonicalize_axis(axis, a_ndim)
if a_shape != weights_shape:
# Make sure the dimensions work out
if axis is None:
raise ValueError("Axis must be specified when shapes of a and "
"weights differ.")
if len(weights_shape) != 1:
raise ValueError("1D weights expected when shapes of a and "
"weights differ.")
if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
raise ValueError("Length of weights not "
"compatible with specified axis.")
weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
weights = moveaxis(weights, -1, axis)
weights_sum = sum(weights, axis=axis, dtype=out_dtype)
avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum
if returned:
if avg.shape != weights_sum.shape:
weights_sum = broadcast_to(weights_sum, avg.shape)
return avg, weights_sum
return avg
@_wraps(np.var, skip_params=['out'])
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("var", a)
lax._check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where)
centered = a - a_mean
if issubdtype(centered.dtype, complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
if where is None:
if axis is None:
normalizer = core.dimension_as_value(size(a))
else:
normalizer = core.dimension_as_value(_axis_size(a, axis))
else:
normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims)
normalizer = normalizer - ddof
result = sum(centered, axis, keepdims=keepdims, where=where)
out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
return lax.convert_element_type(out, dtype)
def _var_promote_types(a_dtype, dtype):
if dtype:
if (not issubdtype(dtype, complexfloating) and
issubdtype(a_dtype, complexfloating)):
msg = ("jax.numpy.var does not yet support real dtype parameters when "
"computing the variance of an array of complex values. The "
"semantics of numpy.var seem unclear in this case. Please comment "
"on https://github.com/google/jax/issues/2283 if this behavior is "
"important to you.")
raise ValueError(msg)
a_dtype = promote_types(a_dtype, dtype)
else:
if not issubdtype(a_dtype, inexact):
dtype = a_dtype = dtypes.canonicalize_dtype(float_)
else:
dtype = _complex_elem_type(a_dtype)
a_dtype = promote_types(a_dtype, float32)
return a_dtype, dtype
@_wraps(np.std, skip_params=['out'])
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where=where)
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("std", a)
lax._check_user_dtype_supported(dtype, "std")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
@_wraps(np.ptp, skip_params=['out'])
def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False):
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
@partial(jit, static_argnames=('axis', 'keepdims'))
def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=False):
_check_arraylike("ptp", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
x = amax(a, axis=axis, keepdims=keepdims)
y = amin(a, axis=axis, keepdims=keepdims)
return lax.sub(x, y)
@_wraps(np.allclose)
@partial(jit, static_argnames=('equal_nan',))
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
_check_arraylike("allclose", a, b)
return all(isclose(a, b, rtol, atol, equal_nan))
@_wraps(np.count_nonzero)
@partial(jit, static_argnames=('axis', 'keepdims'))
def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
keepdims=False):
_check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _constant_like(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
_NONZERO_DOC = """\
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero``
to be compiled with non-static operands. If specified, the first `size` nonzero elements
will be returned; if there are fewer nonzero elements than `size` indicates, the result
will be padded with ``fill_value``, which defaults to zero. ``fill_value`` may be a scalar,
or a tuple specifying the fill value in each dimension.
"""
@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
def nonzero(a, *, size=None, fill_value=None):
a = atleast_1d(a)
mask = a != 0
if size is None:
size = mask.sum()
size = core.concrete_or_error(int, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if a.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in a.shape)
flat_indices = cumsum(bincount(cumsum(mask), length=size))
strides = (np.cumprod(a.shape[::-1])[::-1] // a.shape).astype(int_)
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))
if size is not None and fill_value is not None:
if not isinstance(fill_value, tuple):
fill_value = a.ndim * (fill_value,)
if _shape(fill_value) != (a.ndim,):
raise ValueError(f"fill_value must be a scalar or a tuple of length {a.ndim}; got {fill_value}")
fill_mask = arange(size) >= mask.sum()
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value, out))
return out
@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC)
def flatnonzero(a, *, size=None, fill_value=None):
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]
def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
axis=None, keepdims=None, **kwargs):
_check_arraylike(name, a)
if not issubdtype(_dtype(a), inexact):
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a),
axis=axis, keepdims=keepdims, **kwargs)
if nan_if_all_nan:
return where(all(isnan(a), axis=axis, keepdims=keepdims),
_constant_like(a, nan), out)
else:
return out
@_wraps(np.nanmin, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'keepdims'))
def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=True,
axis=axis, out=out, keepdims=keepdims)
@_wraps(np.nanmax, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'keepdims'))
def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=True,
axis=axis, out=out, keepdims=keepdims)
@_wraps(np.nansum, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
lax._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
# Work around a sphinx documentation warning in NumPy 1.22.
nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n")
@_wraps(np.nanprod, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None):
lax._check_user_dtype_supported(dtype, "nanprod")
return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
axis=axis, dtype=dtype, out=out, keepdims=keepdims)
@_wraps(np.nanmean, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False):
_check_arraylike("nanmean", a)
lax._check_user_dtype_supported(dtype, "nanmean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
return mean(a, axis, dtype, out, keepdims)
if dtype is None:
dtype = _dtype(a)
nan_mask = logical_not(isnan(a))
normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims)
normalizer = lax.convert_element_type(normalizer, dtype)
td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
return td
@_wraps(np.nanvar, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
_check_arraylike("nanvar", a)
lax._check_user_dtype_supported(dtype, "nanvar")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True)
centered = where(isnan(a), 0, a - a_mean) # double-where trick for gradients.
if issubdtype(centered.dtype, complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)
normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
normalizer = normalizer - ddof
normalizer_mask = lax.le(normalizer, 0)
result = sum(centered, axis, keepdims=keepdims)
result = where(normalizer_mask, nan, result)
divisor = where(normalizer_mask, 1, normalizer)
out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
return lax.convert_element_type(out, dtype)
@_wraps(np.nanstd, skip_params=['out'])
@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'))
def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
_check_arraylike("nanstd", a)
lax._check_user_dtype_supported(dtype, "nanstd")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0):
@_wraps(np_reduction, skip_params=['out'])
def cumulative_reduction(a,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None):
return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out)
@partial(jit, static_argnames=('axis', 'dtype'))
def _cumulative_reduction(a,
axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None):
_check_arraylike(np_reduction.__name__, a)
if out is not None:
raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
f"is not supported.")
lax._check_user_dtype_supported(dtype, np_reduction.__name__)
if axis is None or isscalar(a):
a = ravel(a)
axis = 0
a_shape = list(shape(a))
num_dims = len(a_shape)
axis = _canonicalize_axis(axis, num_dims)
if fill_nan:
a = where(isnan(a), _constant_like(a, fill_value), a)
if not dtype and _dtype(a) == bool_:
dtype = int_
if dtype:
a = lax.convert_element_type(a, dtype)
return reduction(a, axis)
return cumulative_reduction
cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
cumproduct = cumprod
nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
fill_nan=True, fill_value=0)
nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
fill_nan=True, fill_value=1)
@_wraps(np.unwrap)
@partial(jit, static_argnames=('axis',))
def unwrap(p, discont=pi, axis: int = -1):
_check_arraylike("unwrap", p)
dd = diff(p, axis=axis)
ddmod = mod(dd + pi, 2 * pi) - pi
ddmod = where((ddmod == -pi) & (dd > 0), pi, ddmod)
ph_correct = where(abs(dd) < discont, 0, ddmod - dd)
up = concatenate((
lax.slice_in_dim(p, 0, 1, axis=axis),
lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis)
), axis=axis)
return up
### Array-creation functions
def _check_no_padding(axis_padding, mode):
if (axis_padding[0] > 0 or axis_padding[1] > 0):
msg = "Cannot apply '{}' padding to empty axis"
raise ValueError(msg.format(mode))
def _pad_constant(array, pad_width, constant_values):
nd = ndim(array)
constant_values = broadcast_to(asarray(constant_values), (nd, 2))
constant_values = lax._convert_element_type(constant_values, array.dtype, dtypes.is_weakly_typed(array))
for i in range(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i, 0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
widths[i] = (0, pad_width[i, 1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
return array
def _pad_wrap(array, pad_width):
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "wrap")
continue
size = array.shape[i]
repeats, (left_remainder, right_remainder) = _divmod(pad_width[i], size)
total_repeats = repeats.sum() + 1
parts = []
if left_remainder:
parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
parts += total_repeats * [array]
if right_remainder:
parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
array = lax.concatenate(parts, dimension=i)
return array
def _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type):
assert mode in ("symmetric", "reflect")
assert reflect_type in ("even", "odd")
for i in range(ndim(array)):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], mode)
continue
n = array.shape[i]
offset = 1 if (mode == "reflect" and n > 1) else 0
def build_padding(array, padding, before):
if before:
edge = lax.slice_in_dim(array, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(array, -1, None, axis=i)
while padding > 0:
curr_pad = _min(padding, n - offset)
padding -= curr_pad
if before:
start = offset
stop = offset + curr_pad
else:
start = -(curr_pad + offset)
stop = None if (mode == "symmetric" or n == 1) else -1
x = lax.slice_in_dim(array, start, stop, axis=i)
x = flip(x, axis=i)
if reflect_type == 'odd':
x = 2 * edge - x
if n > 1:
if before:
edge = lax.slice_in_dim(x, 0, 1, axis=i)
else:
edge = lax.slice_in_dim(x, -1, None, axis=i)
if before:
array = lax.concatenate([x, array], dimension=i)
else:
array = lax.concatenate([array, x], dimension=i)
return array
array = build_padding(array, pad_width[i, 0], before=True)
array = build_padding(array, pad_width[i, 1], before=False)
return array
def _pad_edge(array, pad_width):
nd = ndim(array)
for i in range(nd):
if array.shape[i] == 0:
_check_no_padding(pad_width[i], "edge")
continue
n = array.shape[i]
npad_before, npad_after = pad_width[i]
edge_before = lax.slice_in_dim(array, 0, 1, axis=i)
pad_before = repeat(edge_before, npad_before, axis=i)
edge_after = lax.slice_in_dim(array, n-1, n, axis=i)
pad_after = repeat(edge_after, npad_after, axis=i)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_linear_ramp(array, pad_width, end_values):
for axis in range(ndim(array)):
edge_before = lax.slice_in_dim(array, 0, 1, axis=axis)
edge_after = lax.slice_in_dim(array, -1, None, axis=axis)
ramp_before = linspace(
start=end_values[axis][0],
stop=edge_before.squeeze(axis), # Dimension is replaced by linspace
num=pad_width[axis][0],
endpoint=False,
dtype=array.dtype,
axis=axis
)
ramp_before = lax._convert_element_type(ramp_before, weak_type=dtypes.is_weakly_typed(array))
ramp_after = linspace(
start=end_values[axis][1],
stop=edge_after.squeeze(axis), # Dimension is replaced by linspace
num=pad_width[axis][1],
endpoint=False,
dtype=array.dtype,
axis=axis
)
ramp_after = lax._convert_element_type(ramp_after, weak_type=dtypes.is_weakly_typed(array))
# Reverse linear space in appropriate dimension
ramp_after = flip(ramp_after, axis)
array = lax.concatenate([ramp_before, array, ramp_after], dimension=axis)
return array
def _pad_stats(array, pad_width, stat_length, stat_func):
nd = ndim(array)
for i in range(nd):
if stat_length is None:
stat_before = stat_func(array, axis=i, keepdims=True)
stat_after = stat_before
else:
array_length = array.shape[i]
length_before, length_after = stat_length[i]
if length_before == 0 or length_after == 0:
raise ValueError("stat_length of 0 yields no value for padding")
# Limit stat_length to length of array.
length_before = _min(length_before, array_length)
length_after = _min(length_after, array_length)
slice_before = lax.slice_in_dim(array, 0, length_before, axis=i)
slice_after = lax.slice_in_dim(array, -length_after, None, axis=i)
stat_before = stat_func(slice_before, axis=i, keepdims=True)
stat_after = stat_func(slice_after, axis=i, keepdims=True)
if np.issubdtype(array.dtype, np.integer):
stat_before = round(stat_before)
stat_after = round(stat_after)
stat_before = lax._convert_element_type(stat_before, array.dtype, dtypes.is_weakly_typed(array))
stat_after = lax._convert_element_type(stat_after, array.dtype, dtypes.is_weakly_typed(array))
npad_before, npad_after = pad_width[i]
pad_before = repeat(stat_before, npad_before, axis=i)
pad_after = repeat(stat_after, npad_after, axis=i)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_empty(array, pad_width):
# Note: jax.numpy.empty = jax.numpy.zeros
for i in range(ndim(array)):
shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:]
pad_before = empty_like(array, shape=shape_before)
shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:]
pad_after = empty_like(array, shape=shape_after)
array = lax.concatenate([pad_before, array, pad_after], dimension=i)
return array
def _pad_func(array, pad_width, func, **kwargs):
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
padded = _pad_constant(array, np.array(pad_width), 0)
for axis in range(ndim(padded)):
padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs)
return padded
def _broadcast_to_pairs(nvals, nd, name):
nvals = np.asarray(tree_map(
lambda x: core.concrete_or_error(np.array, x, context=f"{name} argument of jnp.pad"),
nvals))
if nvals.dtype.kind == 'O':
raise TypeError(f'`{name}` entries must be the same shape.')
if nvals.shape == (nd, 2):
# ((before_1, after_1), ..., (before_N, after_N))
return tuple(tuple(nval) for nval in nvals)
elif nvals.shape == (1, 2):
# ((before, after),)
return tuple(tuple(nvals[0]) for i in range(nd))
elif nvals.shape == (2,):
# (before, after) (not in the numpy docstring but works anyway)
return tuple(tuple(nvals) for i in range(nd))
elif nvals.shape == (1,):
# (pad,)
return tuple((nvals[0], nvals[0]) for i in range(nd))
elif nvals.shape == ():
# pad
return tuple((nvals.flat[0], nvals.flat[0]) for i in range(nd))
else:
raise ValueError(f"jnp.pad: {name} with nd={nd} has unsupported shape {nvals.shape}. "
f"Valid shapes are ({nd}, 2), (1, 2), (2,), (1,), or ().")
@partial(jit, static_argnums=(1, 2, 4, 5, 6))
def _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type):
array = asarray(array)
nd = ndim(array)
if nd == 0:
return array
stat_funcs = {"maximum": amax, "minimum": amin,
"mean": mean, "median": median}
pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
pad_width = np.array(pad_width)
assert pad_width.shape == (nd, 2), pad_width
if np.any(pad_width < 0):
raise ValueError("index can't contain negative values")
if mode == "constant":
return _pad_constant(array, pad_width, constant_values)
elif mode == "wrap":
return _pad_wrap(array, pad_width)
elif mode in ("symmetric", "reflect"):
return _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type)
elif mode == "edge":
return _pad_edge(array, pad_width)
elif mode == "linear_ramp":
end_values = _broadcast_to_pairs(end_values, nd, "end_values")
return _pad_linear_ramp(array, pad_width, end_values)
elif mode in stat_funcs:
if stat_length is not None:
stat_length = _broadcast_to_pairs(stat_length, nd, "stat_length")
return _pad_stats(array, pad_width, stat_length, stat_funcs[mode])
elif mode == "empty":
return _pad_empty(array, pad_width)
else:
assert False, ("Should not be reached since pad already handled unsupported and"
"not implemented modes")
[docs]@_wraps(np.pad, lax_description="""\
Unlike numpy, JAX "function" mode's argument (which is another function) should return
the modified array. This is because Jax arrays are immutable.
(In numpy, "function" mode's argument should modify a rank 1 array in-place.)
""")
def pad(array, pad_width, mode="constant", **kwargs):
_check_arraylike("pad", array)
pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
if pad_width and np.array(pad_width).dtype.kind != 'i':
raise TypeError('`pad_width` must be of integral type.')
if callable(mode):
return _pad_func(array, pad_width, mode, **kwargs)
allowed_kwargs = {
'empty': [], 'edge': [], 'wrap': [],
'constant': ['constant_values'],
'linear_ramp': ['end_values'],
'maximum': ['stat_length'],
'mean': ['stat_length'],
'median': ['stat_length'],
'minimum': ['stat_length'],
'reflect': ['reflect_type'],
'symmetric': ['reflect_type'],
}
try:
unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode])
except KeyError:
msg = "Unimplemented padding mode '{}' for np.pad."
raise NotImplementedError(msg.format(mode))
if unsupported_kwargs:
raise ValueError("unsupported keyword arguments for mode '{}': {}"
.format(mode, unsupported_kwargs))
# Set default value if not given.
constant_values = kwargs.get('constant_values', 0)
stat_length = kwargs.get('stat_length', None)
end_values = kwargs.get('end_values', 0)
reflect_type = kwargs.get('reflect_type', "even")
return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
@_wraps(np.stack, skip_params=['out'])
def stack(arrays, axis: int = 0, out=None):
if not len(arrays):
raise ValueError("Need at least one array to stack.")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
if isinstance(arrays, (np.ndarray, ndarray)):
axis = _canonicalize_axis(axis, arrays.ndim)
return concatenate(expand_dims(arrays, axis + 1), axis=axis)
else:
_check_arraylike("stack", *arrays)
shape0 = shape(arrays[0])
axis = _canonicalize_axis(axis, len(shape0) + 1)
new_arrays = []
for a in arrays:
if shape(a) != shape0:
raise ValueError("All input arrays must have the same shape.")
new_arrays.append(expand_dims(a, axis))
return concatenate(new_arrays, axis=axis)
@_wraps(np.tile)
def tile(A, reps):
_stackable(A) or _check_arraylike("tile", A)
try:
iter(reps)
except TypeError:
reps = (reps,)
reps = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps)
A_shape = (1,) * (len(reps) - ndim(A)) + shape(A)
reps = (1,) * (len(A_shape) - len(reps)) + reps
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))
def _concatenate_array(arr, axis: int):
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr)
if arr.ndim == 0 or arr.shape[0] == 0:
raise ValueError("Need at least one array to concatenate.")
if axis is None:
return lax.reshape(arr, (arr.size,))
if arr.ndim == 1:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, arr.ndim - 1)
shape = arr.shape[1:axis + 1] + (arr.shape[0] * arr.shape[axis + 1],) + arr.shape[axis + 2:]
dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)]
return lax.reshape(arr, shape, dimensions)
@_wraps(np.concatenate)
def concatenate(arrays, axis: int = 0):
if isinstance(arrays, (np.ndarray, ndarray)):
return _concatenate_array(arrays, axis)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis)
axis = _canonicalize_axis(axis, ndim(arrays[0]))
arrays = _promote_dtypes(*arrays)
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
k = 16
if len(arrays) == 1:
return asarray(arrays[0])
else:
while len(arrays) > 1:
arrays = [lax.concatenate(arrays[i:i+k], axis)
for i in range(0, len(arrays), k)]
return arrays[0]
@_wraps(np.vstack)
def vstack(tup):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_2d)(tup)
else:
arrs = [atleast_2d(m) for m in tup]
return concatenate(arrs, axis=0)
row_stack = vstack
@_wraps(np.hstack)
def hstack(tup):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
else:
arrs = [atleast_1d(m) for m in tup]
arr0_ndim = arrs[0].ndim
return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1)
@_wraps(np.dstack)
def dstack(tup):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_3d)(tup)
else:
arrs = [atleast_3d(m) for m in tup]
return concatenate(arrs, axis=2)
@_wraps(np.column_stack)
def column_stack(tup):
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
else:
arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)]
return concatenate(arrs, 1)
@_wraps(np.choose, skip_params=['out'])
def choose(a, choices, out=None, mode='raise'):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
_check_arraylike('choose', a, *choices)
if not issubdtype(_dtype(a), integer):
raise ValueError("`a` array must be integer typed")
N = len(choices)
if mode == 'raise':
a = core.concrete_or_error(asarray, a,
"The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if any((a < 0) | (a >= N)):
raise ValueError("invalid entry in choice array")
elif mode == 'wrap':
a = a % N
elif mode == 'clip':
a = clip(a, 0, N - 1)
else:
raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'")
a, *choices = broadcast_arrays(a, *choices)
return array(choices)[(a,) + indices(a.shape, sparse=True)]
def _atleast_nd(x, n):
m = ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else x
def _block(xs):
if isinstance(xs, tuple):
raise ValueError("jax.numpy.block does not allow tuples, got {}"
.format(xs))
elif isinstance(xs, list):
if len(xs) == 0:
raise ValueError("jax.numpy.block does not allow empty list arguments")
xs, depths = unzip2([_block(x) for x in xs])
if _any(d != depths[0] for d in depths[1:]):
raise ValueError("Mismatched list depths in jax.numpy.block")
rank = _max(depths[0], _max(ndim(x) for x in xs))
xs = [_atleast_nd(x, rank) for x in xs]
return concatenate(xs, axis=-depths[0]), depths[0] + 1
else:
return asarray(xs), 1
@_wraps(np.block)
@jit
def block(arrays):
out, _ = _block(arrays)
return out
@_wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_1d(*arys):
if len(arys) == 1:
arr = asarray(arys[0])
return arr if ndim(arr) >= 1 else reshape(arr, -1)
else:
return [atleast_1d(arr) for arr in arys]
@_wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_2d(*arys):
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) >= 2:
return arr
elif ndim(arr) == 1:
return expand_dims(arr, axis=0)
else:
return expand_dims(arr, axis=(0, 1))
else:
return [atleast_2d(arr) for arr in arys]
@_wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_3d(*arys):
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) == 0:
arr = expand_dims(arr, axis=(0, 1, 2))
elif ndim(arr) == 1:
arr = expand_dims(arr, axis=(0, 2))
elif ndim(arr) == 2:
arr = expand_dims(arr, axis=2)
return arr
else:
return [atleast_3d(arr) for arr in arys]
_ARRAY_DOC = """
This function will create arrays on JAX's default device. For control of the
device placement of data, see :func:`jax.device_put`. More information is
available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at
https://jax.readthedocs.io/en/latest/faq.html).
"""
@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object, dtype=None, copy=True, order="K", ndmin=0):
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")
# check if the given dtype is compatible with JAX
lax._check_user_dtype_supported(dtype, "array")
# Here we make a judgment call: we only return a weakly-typed array when the
# input object itself is weakly typed. That ensures asarray(x) is a no-op whenever
# x is weak, but avoids introducing weak types with something like array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
# For Python scalar literals, call coerce_to_array to catch any overflow errors.
# We don't use dtypes.is_python_scalar because we don't want this triggering for
# traced values. We do this here because it matters whether or not dtype is None.
# We don't assign the result because we want the raw object to be used for type
# inference below.
if isinstance(object, (bool, int, float, complex)):
_ = dtypes.coerce_to_array(object, dtype)
leaves = tree_leaves(object)
if dtype is None:
# Use lattice_result_type rather than result_type to avoid canonicalization.
# Otherwise, weakly-typed inputs would have their dtypes canonicalized.
try:
dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
except TypeError:
# This happens if, e.g. one of the entries is a memoryview object.
# This is rare, so we only handle it if the normal path fails.
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
dtype = dtypes._lattice_result_type(*leaves)[0]
if not weak_type:
dtype = dtypes.canonicalize_dtype(dtype)
# We can't use the ndarray class because we need to handle internal buffers
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer)
if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists containing
# large integers; see discussion in https://github.com/google/jax/pull/6047.
# More correct would be to call coerce_to_array on each leaf, but this may have
# performance implications.
out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)
elif isinstance(object, ndarray_types):
if object.aval is None:
# object is a raw buffer; convert to device array on its current device.
aval = ShapedArray(object.xla_shape().dimensions(), object.dtype,
weak_type=bool(getattr(object, "weak_type", False)))
object = device_array.make_device_array(aval, object.device(), object)
out = _array_copy(object) if copy else object
elif isinstance(object, (list, tuple)):
if object:
out = stack([asarray(elt, dtype=dtype) for elt in object])
else:
out = np.array([], dtype=dtype)
else:
try:
view = memoryview(object)
except TypeError:
pass # `object` does not support the buffer interface.
else:
return array(np.asarray(view), dtype, copy, ndmin=ndmin)
raise TypeError("Unexpected input type for array: {}".format(type(object)))
out = lax._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out):
out = lax.expand_dims(out, range(ndmin - ndim(out)))
return out
def _convert_to_array_if_dtype_fails(x):
try:
dtypes.dtype(x)
except TypeError:
return np.asarray(x)
else:
return x
@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a, dtype=None, order=None):
lax._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)
@_wraps(np.zeros_like)
def zeros_like(a, dtype=None, shape=None):
_check_arraylike("zeros_like", a)
lax._check_user_dtype_supported(dtype, "zeros_like")
if np.isscalar(shape):
shape = (shape,)
return lax.full_like(a, 0, dtype, shape)
@_wraps(np.ones_like)
def ones_like(a, dtype=None, shape=None):
_check_arraylike("ones_like", a)
lax._check_user_dtype_supported(dtype, "ones_like")
if np.isscalar(shape):
shape = (shape,)
return lax.full_like(a, 1, dtype, shape)
@_wraps(np.full)
def full(shape, fill_value, dtype=None):
lax._check_user_dtype_supported(dtype, "full")
_check_arraylike("full", fill_value)
if ndim(fill_value) == 0:
shape = (shape,) if ndim(shape) == 0 else shape
return lax.full(shape, fill_value, dtype)
else:
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
@_wraps(np.full_like)
def full_like(a, fill_value, dtype=None, shape=None):
lax._check_user_dtype_supported(dtype, "full_like")
_check_arraylike("full_like", a, fill_value)
if shape is not None:
shape = (shape,) if ndim(shape) == 0 else shape
if ndim(fill_value) == 0:
return lax.full_like(a, fill_value, dtype, shape)
else:
shape = np.shape(a) if shape is None else shape
dtype = result_type(a) if dtype is None else dtype
return broadcast_to(asarray(fill_value, dtype=dtype), shape)
@_wraps(np.zeros)
def zeros(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "zeros")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
return lax.full(shape, 0, _jnp_dtype(dtype))
@_wraps(np.ones)
def ones(shape, dtype=None):
if isinstance(shape, types.GeneratorType):
raise TypeError("expected sequence object with len >= 0 or a single integer")
lax._check_user_dtype_supported(dtype, "ones")
shape = canonicalize_shape((shape,) if ndim(shape) == 0 else shape)
return lax.full(shape, 1, _jnp_dtype(dtype))
@_wraps(np.array_equal)
def array_equal(a1, a2, equal_nan=False):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
if shape(a1) != shape(a2):
return False
eq = asarray(a1 == a2)
if equal_nan:
eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
return all(eq)
@_wraps(np.array_equiv)
def array_equiv(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
try:
eq = equal(a1, a2)
except ValueError:
# shapes are not broadcastable
return False
return all(eq)
# We can't create uninitialized arrays in XLA; use zeros for empty.
empty_like = zeros_like
empty = zeros
@_wraps(np.eye)
def eye(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "eye")
N = core.canonicalize_dim(N, "'N' argument of jnp.eye()")
M = N if M is None else core.canonicalize_dim(M, "'M' argument of jnp.eye()")
if N < 0 or M < 0:
raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
k = operator.index(k)
return lax._eye(_jnp_dtype(dtype), (N, M), k)
@_wraps(np.identity)
def identity(n, dtype=None):
lax._check_user_dtype_supported(dtype, "identity")
return eye(n, dtype=dtype)
@_wraps(np.arange)
def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
step: Optional[core.DimSize]=None, dtype=None):
lax._check_user_dtype_supported(dtype, "arange")
require = partial(core.concrete_or_error, None)
msg = "It arose in jax.numpy.arange argument `{}`.".format
if _any(core.is_special_dim_size(d) for d in (start, stop, step)):
if stop is not None or step is not None:
raise ValueError(
"jax.numpy.arange supports non-constant arguments only in single-argument form. "
f"Found jax.numpy.arange(start={start}, stop={stop}, step={step})")
return lax.iota(int_, start)
if dtype is None:
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
dtype = _jnp_dtype(dtype)
if stop is None and step is None:
start = require(start, msg("stop"))
start = np.ceil(start).astype(int)
return lax.iota(dtype, start)
else:
start = require(start, msg("start"))
stop = None if stop is None else require(stop, msg("stop"))
step = None if step is None else require(step, msg("step"))
return array(np.arange(start, stop=stop, step=step, dtype=dtype))
def _wrap_numpy_nullary_function(f):
"""Adapts `f` to return a DeviceArray instead of an np.ndarray.
`f` cannot have any non-static array arguments.
"""
@_wraps(f, update_doc=False)
def wrapper(*args, **kwargs):
args = [core.concrete_or_error(None, arg, f"the error occured in argument {i} jnp.{f.__name__}()")
for i, arg in enumerate(args)]
kwargs = {key: core.concrete_or_error(None, val, f"the error occured in argument '{key}' jnp.{f.__name__}()")
for key, val in kwargs.items()}
return asarray(f(*args, **kwargs))
return wrapper
@_wraps(np.linspace)
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.linspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace")
return _linspace(start, stop, int(num), endpoint, retstep, dtype,
operator.index(axis))
@partial(jit, static_argnames=('num', 'endpoint', 'retstep', 'dtype', 'axis'))
def _linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis: int = 0):
"""Implementation of linspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "linspace")
if num < 0:
raise ValueError(f"Number of samples, {num}, must be non-negative.")
_check_arraylike("linspace", start, stop)
if dtype is None:
dtype = result_type(start, stop, dtypes.canonicalize_dtype(float_))
dtype = _jnp_dtype(dtype)
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
broadcast_start = broadcast_to(start, bounds_shape)
broadcast_stop = broadcast_to(stop, bounds_shape)
axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
bounds_shape.insert(axis, 1)
div = (num - 1) if endpoint else num
if num > 1:
delta = lax.convert_element_type(stop - start, computation_dtype) / div
iota_shape = [1,] * len(bounds_shape)
iota_shape[axis] = div
# This approach recovers the endpoints with float32 arithmetic,
# but can lead to rounding errors for integer outputs.
real_dtype = finfo(computation_dtype).dtype
step = reshape(lax.iota(real_dtype, div), iota_shape) / div
out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
reshape(broadcast_stop, bounds_shape) * step)
if endpoint:
out = lax.concatenate([out, lax.expand_dims(broadcast_stop, (axis,))],
_canonicalize_axis(axis, out.ndim))
elif num == 1:
delta = nan if endpoint else stop - start
out = reshape(broadcast_start, bounds_shape)
else: # num == 0 degenerate case, match numpy behavior
empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
empty_shape.insert(axis, 0)
delta = nan
out = reshape(array([], dtype=dtype), empty_shape)
if issubdtype(dtype, integer) and not issubdtype(out.dtype, integer):
out = lax.floor(out)
if retstep:
return lax.convert_element_type(out, dtype), delta
else:
return lax.convert_element_type(out, dtype)
@_wraps(np.logspace)
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace")
return _logspace(start, stop, int(num), endpoint, base, dtype,
operator.index(axis))
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
axis: int = 0):
"""Implementation of logspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "logspace")
if dtype is None:
dtype = result_type(start, stop, dtypes.canonicalize_dtype(float_))
dtype = _jnp_dtype(dtype)
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
_check_arraylike("logspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
lin = linspace(start, stop, num,
endpoint=endpoint, retstep=False, dtype=None, axis=axis)
return lax.convert_element_type(power(base, lin), dtype)
@_wraps(np.geomspace)
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace")
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace")
return _geomspace(start, stop, int(num), endpoint, dtype,
operator.index(axis))
@partial(jit, static_argnames=('num', 'endpoint', 'dtype', 'axis'))
def _geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
"""Implementation of geomspace differentiable in start and stop args."""
lax._check_user_dtype_supported(dtype, "geomspace")
if dtype is None:
dtype = result_type(start, stop, dtypes.canonicalize_dtype(float_))
dtype = _jnp_dtype(dtype)
computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
_check_arraylike("geomspace", start, stop)
start = asarray(start, dtype=computation_dtype)
stop = asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
res = signflip * logspace(log10(signflip * start),
log10(signflip * stop), num,
endpoint=endpoint, base=10.0,
dtype=computation_dtype, axis=0)
if axis != 0:
res = moveaxis(res, 0, axis)
return lax.convert_element_type(res, dtype)
@_wraps(np.meshgrid, lax_description=_ARRAY_VIEW_DOC)
def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
_check_arraylike("meshgrid", *xi)
args = [asarray(x) for x in xi]
if not copy:
raise ValueError("jax.numpy.meshgrid only supports copy=True")
if indexing not in ["xy", "ij"]:
raise ValueError(f"Valid values for indexing are 'xy' and 'ij', got {indexing}")
if _any(a.ndim != 1 for a in args):
raise ValueError("Arguments to jax.numpy.meshgrid must be 1D, got shapes "
f"{[a.shape for a in args]}")
if indexing == "xy" and len(args) >= 2:
args[0], args[1] = args[1], args[0]
shape = [1 if sparse else a.shape[0] for a in args]
_a_shape = lambda i, a: [*shape[:i], a.shape[0], *shape[i + 1:]] if sparse else shape
output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)]
if indexing == "xy" and len(args) >= 2:
output[0], output[1] = output[1], output[0]
return output
def _make_1d_grid_from_slice(s: slice, op_name: str):
start = core.concrete_or_error(None, s.start,
f"slice start of jnp.{op_name}") or 0
stop = core.concrete_or_error(None, s.stop,
f"slice stop of jnp.{op_name}")
step = core.concrete_or_error(None, s.step,
f"slice step of jnp.{op_name}") or 1
if np.iscomplex(step):
newobj = linspace(start, stop, int(_abs(step)))
else:
newobj = arange(start, stop, step)
return newobj
class _IndexGrid:
def __getitem__(self, key):
single_slice = isinstance(key, slice)
if single_slice:
key = (key,)
output = []
for k in key:
output.append(_make_1d_grid_from_slice(k, op_name=self.op_name))
if single_slice:
return output[0]
output = meshgrid(*output, indexing='ij', sparse=self.sparse)
return output if self.sparse else stack(output, 0)
class _Mgrid(_IndexGrid):
"""Return dense multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``.
See Also:
jnp.ogrid: open/sparse version of jnp.mgrid
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.mgrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.mgrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create broadcasted grids of indices:
>>> jnp.mgrid[:2, :3]
DeviceArray([[[0, 0, 0],
[1, 1, 1]],
[[0, 1, 2],
[0, 1, 2]]], dtype=int32)
"""
sparse = False
op_name = "mgrid"
mgrid = _Mgrid()
class _Ogrid(_IndexGrid):
"""Return open multi-dimensional "meshgrid".
LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for
functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``.
See Also:
jnp.mgrid: dense version of jnp.ogrid
Examples:
Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`:
>>> jnp.ogrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`:
>>> jnp.ogrid[0:1:4j]
DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create sparse grids of indices:
>>> jnp.ogrid[:2, :3]
[DeviceArray([[0],
[1]], dtype=int32),
DeviceArray([[0, 1, 2]], dtype=int32)]
"""
sparse = True
op_name = "ogrid"
ogrid = _Ogrid()
class _AxisConcat:
"""Concatenates slices, scalars and array-like objects along a given axis."""
def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
params = [self.axis, self.ndmin, self.trans1d, -1]
if isinstance(key[0], str):
# split off the directive
directive, *key = key
# check two special cases: matrix directives
if directive == "r":
params[-1] = 0
elif directive == "c":
params[-1] = 1
else:
vec = directive.split(",")
k = len(vec)
if k < 4:
vec += params[k:]
else:
# ignore everything after the first three comma-separated ints
vec = vec[:3] + params[-1]
try:
params = list(map(int, vec))
except ValueError as err:
raise ValueError(
"could not understand directive {!r}".format(directive)
) from err
axis, ndmin, trans1d, matrix = params
output = []
for item in key:
if isinstance(item, slice):
newobj = _make_1d_grid_from_slice(item, op_name=self.op_name)
elif isinstance(item, str):
raise ValueError("string directive must be placed at the beginning")
else:
newobj = item
newobj = array(newobj, copy=False, ndmin=ndmin)
if trans1d != -1 and ndmin - ndim(item) > 0:
shape_obj = list(range(ndmin))
# Calculate number of left shifts, with overflow protection by mod
num_lshifts = ndmin - _abs(ndmin + trans1d + 1) % ndmin
shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])
newobj = transpose(newobj, shape_obj)
output.append(newobj)
res = concatenate(tuple(output), axis=axis)
if matrix != -1 and res.ndim == 1:
# insert 2nd dim at axis 0 or 1
res = expand_dims(res, matrix)
return res
def __len__(self):
return 0
class RClass(_AxisConcat):
"""Concatenate slices, scalars and array-like objects along the first axis.
LAX-backend implementation of :obj:`numpy.r_`.
See Also:
``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis.
Examples:
Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects:
>>> jnp.r_[-1:5:1, 0, 0, jnp.array([1,2,3])]
DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32)
An imaginary value for ``step`` will create a ``jnp.linspace`` object instead,
which includes the right endpoint:
>>> jnp.r_[-1:1:6j, 0, jnp.array([1,2,3])]
DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005,
0.6 , 1. , 0. , 1. ,
2. , 3. ], dtype=float32)
Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to
specify concatenation axis, minimum number of dimensions, and the position of the
upgraded array's original dimensions in the resulting array's shape tuple:
>>> jnp.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)
>>> jnp.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Negative values for ``trans1d`` offset the last axis towards the start
of the shape tuple:
>>> jnp.r_['0,2,-2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with an extra row or column axis, respectively:
>>> jnp.r_['r',[1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32)
>>> jnp.r_['c',[1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"``
give the same result.
"""
axis = 0
ndmin = 1
trans1d = -1
op_name = "r_"
r_ = RClass()
class CClass(_AxisConcat):
"""Concatenate slices, scalars and array-like objects along the last axis.
LAX-backend implementation of :obj:`numpy.c_`.
See Also:
``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis.
Examples:
>>> a = jnp.arange(6).reshape((2,3))
>>> jnp.c_[a,a]
DeviceArray([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]], dtype=int32)
Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify
concatenation axis, minimum number of dimensions, and the position of the upgraded array's
original dimensions in the resulting array's shape tuple:
>>> jnp.c_['0,2', [1,2,3], [4,5,6]]
DeviceArray([[1],
[2],
[3],
[4],
[5],
[6]], dtype=int32)
>>> jnp.c_['0,2,-1', [1,2,3], [4,5,6]]
DeviceArray([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs
to create an array with inputs stacked along the last axis:
>>> jnp.c_['r',[1,2,3], [4,5,6]]
DeviceArray([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
"""
axis = -1
ndmin = 2
trans1d = 0
op_name = "c_"
c_ = CClass()
s_ = np.s_
index_exp = np.index_exp
@_wraps(np.i0)
@jit
def i0(x):
x_orig = x
x, = _promote_args_inexact("i0", x)
if not issubdtype(x.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x_orig)}")
x = lax.abs(x)
return lax.mul(lax.exp(x), lax.bessel_i0e(x))
@_wraps(np.ix_)
def ix_(*args):
_check_arraylike("ix", *args)
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a)
if len(a.shape) != 1:
msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
raise ValueError(msg.format(a.shape))
if _dtype(a) == bool_:
raise NotImplementedError(
"Boolean arguments to jax.numpy.ix_ are not implemented")
shape = [1] * n
shape[i] = a.shape[0]
if a.size == 0:
# Numpy uses an integer index type for empty arrays.
output.append(lax.full(shape, np.zeros((), np.intp)))
else:
output.append(lax.broadcast_in_dim(a, shape, (i,)))
return tuple(output)
@_wraps(np.indices)
def indices(dimensions, dtype=int32, sparse=False):
dimensions = tuple(
core.concrete_or_error(int, d, "dimensions argument of jnp.indices")
for d in dimensions)
N = len(dimensions)
output = []
s = dimensions
for i, dim in enumerate(dimensions):
idx = lax.iota(dtype, dim)
if sparse:
s = (1,)*i + (dim,) + (1,)*(N - i - 1)
output.append(lax.broadcast_in_dim(idx, s, (i,)))
if sparse:
return tuple(output)
return stack(output, 0) if output else array([], dtype=dtype)
_TOTAL_REPEAT_LENGTH_DOC = """\
Jax adds the optional `total_repeat_length` parameter which specifies the total
number of repeat, and defaults to sum(repeats). It must be specified for repeat
to be compilable. If `sum(repeats)` is larger than the specified
`total_repeat_length` the remaining values will be discarded. In the case of
`sum(repeats)` being smaller than the specified target length, the final value
will be repeated.
"""
@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
_check_arraylike("repeat", a, repeats)
if axis is None:
a = ravel(a)
axis = 0
axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.repeat()")
assert isinstance(axis, int) # to appease mypy
# If total_repeat_length is not given, can't compile, use a default.
if total_repeat_length is None:
repeats = core.concrete_or_error(np.array, repeats,
"When jit-compiling jnp.repeat, the total number of repeats must be static. "
"To fix this, either specify a static value for `repeats`, or pass a static "
"value to `total_repeat_length`.")
# Fast path for when repeats is a scalar.
if np.ndim(repeats) == 0 and ndim(a) != 0:
input_shape = a.shape
aux_axis = axis if axis < 0 else axis + 1
a = expand_dims(a, aux_axis)
reps = [1] * len(a.shape)
reps[aux_axis] = repeats
a = tile(a, reps)
result_shape = list(input_shape)
result_shape[axis] *= repeats
return reshape(a, result_shape)
repeats = np.ravel(repeats)
if ndim(a) != 0:
repeats = np.broadcast_to(repeats, [a.shape[axis]])
total_repeat_length = np.sum(repeats)
else:
repeats = ravel(repeats)
if ndim(a) != 0:
repeats = broadcast_to(repeats, [a.shape[axis]])
# Special case when a is a scalar.
if ndim(a) == 0:
if repeats.shape == (1,):
return full([total_repeat_length], a)
else:
raise ValueError('`repeat` with a scalar parameter `a` is only '
'implemented for scalar values of the parameter `repeats`.')
# Special case if total_repeat_length is zero.
if total_repeat_length == 0:
result_shape = list(a.shape)
result_shape[axis] = 0
return reshape(array([], dtype=a.dtype), result_shape)
# If repeats is on a zero sized axis, then return the array.
if a.shape[axis] == 0:
return a
# This implementation of repeat avoid having to instantiate a large.
#Â intermediate tensor.
# Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat.
exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
# Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
scatter_indices = cumsum(exclusive_repeats)
# Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
block_split_indicators = zeros([total_repeat_length], dtype=int32)
block_split_indicators = block_split_indicators.at[scatter_indices].add(1)
# Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
gather_indices = cumsum(block_split_indicators) - 1
return take(a, gather_indices, axis=axis)
@_wraps(np.tri)
def tri(N, M=None, k=0, dtype=None):
lax._check_user_dtype_supported(dtype, "tri")
M = M if M is not None else N
dtype = dtype or float32
return lax._tri(dtype, (N, M), k)
@_wraps(np.tril)
@partial(jit, static_argnames=('k',))
def tril(m, k=0):
_check_arraylike("tril", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.tril must be at least 2D")
mask = tri(*m_shape[-2:], k=k, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
@_wraps(np.triu, update_doc=False)
@partial(jit, static_argnames=('k',))
def triu(m, k=0):
_check_arraylike("triu", m)
m_shape = shape(m)
if len(m_shape) < 2:
raise ValueError("Argument to jax.numpy.triu must be at least 2D")
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
@_wraps(np.trace, skip_params=['out'])
@partial(jit, static_argnames=('offset', 'axis1', 'axis2', 'dtype'))
def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
_check_arraylike("trace", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
lax._check_user_dtype_supported(dtype, "trace")
axis1 = _canonicalize_axis(axis1, ndim(a))
axis2 = _canonicalize_axis(axis2, ndim(a))
a_shape = shape(a)
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(np.int_)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int
# Move the axis? dimensions to the end.
perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)
# Mask out the diagonal and reduce.
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
a, zeros_like(a))
return sum(a, axis=(-2, -1), dtype=dtype)
def _wrap_indices_function(f):
@_wraps(f, update_doc=False)
def wrapper(*args, **kwargs):
args = [core.concrete_or_error(
None, arg, f"argument {i} of jnp.{f.__name__}()")
for i, arg in enumerate(args)]
kwargs = {key: core.concrete_or_error(
None, val, f"argument '{key}' of jnp.{f.__name__}()")
for key, val in kwargs.items()}
return tuple(asarray(x) for x in f(*args, **kwargs))
return wrapper
tril_indices = _wrap_indices_function(np.tril_indices)
triu_indices = _wrap_indices_function(np.triu_indices)
mask_indices = _wrap_indices_function(np.mask_indices)
@_wraps(np.triu_indices_from)
def triu_indices_from(arr, k=0):
return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
@_wraps(np.tril_indices_from)
def tril_indices_from(arr, k=0):
return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
@_wraps(np.diag_indices)
def diag_indices(n, ndim=2):
n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diag_indices()")
ndim = core.concrete_or_error(operator.index, ndim, "'ndim' argument of jnp.diag_indices()")
if n < 0:
raise ValueError("n argument to diag_indices must be nonnegative, got {}"
.format(n))
if ndim < 0:
raise ValueError("ndim argument to diag_indices must be nonnegative, got {}"
.format(ndim))
return (lax.iota(int_, n),) * ndim
@_wraps(np.diag_indices_from)
def diag_indices_from(arr):
_check_arraylike("diag_indices_from", arr)
if not arr.ndim >= 2:
raise ValueError("input array must be at least 2-d")
if len(set(arr.shape)) != 1:
raise ValueError("All dimensions of input must be of equal length")
return diag_indices(arr.shape[0], ndim=arr.ndim)
@_wraps(np.diagonal, lax_description=_ARRAY_VIEW_DOC)
@partial(jit, static_argnames=('offset', 'axis1', 'axis2'))
def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
_check_arraylike("diagonal", a)
a_shape = shape(a)
a_ndims = len(a_shape)
offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()")
# Move the two dimensions to the end.
axis1 = _canonicalize_axis(axis1, a_ndims)
axis2 = _canonicalize_axis(axis2, a_ndims)
perm = [i for i in range(a_ndims) if i != axis1 and i != axis2]
perm = perm + [axis1, axis2]
a = lax.transpose(a, perm)
# Mask out the diagonal and reduce over one of the axes
a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
a, zeros_like(a))
reduce_axis = -2 if offset < 0 else -1
d = sum(a, axis=reduce_axis, dtype=_dtype(a))
# Slice out the correct diagonal size.
diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0),
a_shape[axis2] - _max(offset, 0)))
return lax.slice_in_dim(d, 0, diag_size, axis=-1)
@_wraps(np.diag, lax_description=_ARRAY_VIEW_DOC)
def diag(v, k=0):
return _diag(v, int(k))
@partial(jit, static_argnames=('k',))
def _diag(v, k):
_check_arraylike("diag", v)
v_shape = shape(v)
if len(v_shape) == 1:
zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
n = v_shape[0] + _abs(k)
v = lax.pad(v, zero(v), ((_max(0, k), _max(0, -k), 0),))
return where(eye(n, k=k, dtype=bool), v, zeros_like(v))
elif len(v_shape) == 2:
return diagonal(v, offset=k)
else:
raise ValueError("diag input must be 1d or 2d")
_SCALAR_VALUE_DOC = """\
This differs from np.diagflat for some scalar values of v,
jax always returns a two-dimensional array, whereas numpy may
return a scalar depending on the type of v.
"""
@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC)
def diagflat(v, k=0):
_check_arraylike("diagflat", v)
v = ravel(v)
v_length = len(v)
adj_length = v_length + _abs(k)
res = zeros(adj_length*adj_length, dtype=v.dtype)
i = arange(0, adj_length-_abs(k))
if (k >= 0):
fi = i+k+i*adj_length
else:
fi = i+(i-k)*adj_length
res = res.at[fi].set(v)
res = res.reshape(adj_length, adj_length)
return res
_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""
@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros):
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
if seq_of_zeros.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')
return a
@_wraps(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
improve runtime performance on accelerators, at the cost of increased
compilation time.
""")
@partial(jax.jit, static_argnames=['unroll'])
def polyval(p, x, *, unroll=16):
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y
@_wraps(np.polyadd)
@jit
def polyadd(a1, a2):
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
return a2.at[-a1.shape[0]:].add(a1)
@_wraps(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p, m=1, k=None):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k = _promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if k.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return true_divide(concatenate((p, k)), coeff)
@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p, m=1):
_check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(len(p), m, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return p[:-m] * coeff
@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
filt = core.concrete_or_error(asarray, filt,
"Error arose in the `filt` argument of trim_zeros()")
nz = (filt == 0)
if all(nz):
return empty(0, _dtype(filt))
start = argmin(nz) if 'f' in trim.lower() else 0
end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
return filt[start:len(filt) - end]
_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""
@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
_check_arraylike("polymul", a1, a2)
a1, a2 = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val
@_wraps(np.polysub)
@jit
def polysub(a1, a2):
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
return polyadd(a1, -a2)
@_wraps(np.append)
@partial(jit, static_argnames=('axis',))
def append(arr, values, axis: Optional[int] = None):
if axis is None:
return concatenate([ravel(arr), ravel(values)], 0)
else:
return concatenate([arr, values], axis=axis)
@_wraps(np.delete)
def delete(arr, obj, axis=None):
_check_arraylike("delete", arr)
if axis is None:
arr = ravel(arr)
axis = 0
axis = _canonicalize_axis(axis, arr.ndim)
# Case 1: obj is a static integer.
try:
obj = operator.index(obj)
obj = _canonicalize_axis(obj, arr.shape[axis])
except TypeError:
pass
else:
idx = tuple(slice(None) for i in range(axis))
return concatenate([arr[idx + (slice(0, obj),)], arr[idx + (slice(obj + 1, None),)]], axis=axis)
# Case 2: obj is a static slice.
if isinstance(obj, slice):
# TODO(jakevdp): we should be able to do this dynamically with care.
indices = np.delete(np.arange(arr.shape[axis]), obj)
return take(arr, indices, axis=axis)
# Case 3: obj is an array
# NB: pass both arrays to check for appropriate error message.
_check_arraylike("delete", arr, obj)
obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()")
if issubdtype(obj.dtype, integer):
# TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates,
# but this would require the complement of lax.gather.
mask = np.ones(arr.shape[axis], dtype=bool)
mask[obj] = False
elif obj.dtype == bool:
if obj.shape != (arr.shape[axis],):
raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional "
"with length matching specified axis.")
mask = ~obj
else:
raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.")
return arr[tuple(slice(None) for i in range(axis)) + (mask,)]
@_wraps(np.insert)
def insert(arr, obj, values, axis=None):
_check_arraylike("insert", arr, 0 if isinstance(obj, slice) else obj, values)
arr = asarray(arr)
values = asarray(values)
if axis is None:
arr = ravel(arr)
axis = 0
axis = core.concrete_or_error(None, axis, "axis argument of jnp.insert()")
axis = _canonicalize_axis(axis, arr.ndim)
if isinstance(obj, slice):
indices = arange(*obj.indices(arr.shape[axis]))
else:
indices = asarray(obj)
if indices.ndim > 1:
raise ValueError("jnp.insert(): obj must be a slice, a one-dimensional "
f"array, or a scalar; got {obj}")
if not np.issubdtype(indices.dtype, np.integer):
if indices.size == 0 and not isinstance(obj, ndarray):
indices = indices.astype(int)
else:
# Note: np.insert allows boolean inputs but the behavior is deprecated.
raise ValueError("jnp.insert(): index array must be "
f"integer typed; got {obj}")
values = array(values, ndmin=arr.ndim, dtype=arr.dtype, copy=False)
if indices.size == 1:
index = ravel(indices)[0]
if indices.ndim == 0:
values = moveaxis(values, 0, axis)
indices = full(values.shape[axis], index)
n_input = arr.shape[axis]
n_insert = broadcast_shapes(indices.shape, values.shape[axis])[0]
out_shape = list(arr.shape)
out_shape[axis] += n_insert
out = zeros_like(arr, shape=tuple(out_shape))
indices = where(indices < 0, indices + n_input, indices)
indices = clip(indices, 0, n_input)
values_ind = indices.at[argsort(indices)].add(arange(n_insert))
arr_mask = ones(n_input + n_insert, dtype=bool).at[values_ind].set(False)
arr_ind = where(arr_mask, size=n_input)[0]
out = out.at[(slice(None),) * axis + (values_ind,)].set(values)
out = out.at[(slice(None),) * axis + (arr_ind,)].set(arr)
return out
@_wraps(np.apply_along_axis)
def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
num_dims = ndim(arr)
axis = _canonicalize_axis(axis, num_dims)
func = lambda arr: func1d(arr, *args, **kwargs)
for i in range(1, num_dims - axis):
func = jax.vmap(func, in_axes=i, out_axes=-1)
for i in range(axis):
func = jax.vmap(func, in_axes=0, out_axes=0)
return func(arr)
@_wraps(np.apply_over_axes)
def apply_over_axes(func, a, axes):
for axis in axes:
b = func(a, axis=axis)
if b.ndim == a.ndim:
a = b
elif b.ndim == a.ndim - 1:
a = expand_dims(b, axis)
else:
raise ValueError("function is not returning an array of the correct shape")
return a
### Tensor contraction operations
@_wraps(np.dot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def dot(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("dot", a, b)
a, b = _promote_dtypes(a, b)
a_ndim, b_ndim = ndim(a), ndim(b)
if a_ndim == 0 or b_ndim == 0:
return lax.mul(a, b)
if _max(a_ndim, b_ndim) <= 2:
return lax.dot(a, b, precision=precision)
if b_ndim == 1:
contract_dims = ((a_ndim - 1,), (0,))
else:
contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
batch_dims = ((), ())
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
@_wraps(np.matmul, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring
_check_arraylike("matmul", a, b)
for i, x in enumerate((a, b)):
if ndim(x) < 1:
msg = (f"matmul input operand {i} must have ndim at least 1, "
f"but it has ndim {ndim(x)}")
raise ValueError(msg)
a, b = _promote_dtypes(a, b)
a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
a_batch_dims = shape(a)[:-2] if a_is_mat else ()
b_batch_dims = shape(b)[:-2] if b_is_mat else ()
num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims))
a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims
b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims
# Dimensions to squeeze from the inputs.
a_squeeze = []
b_squeeze = []
# Positions of batch dimensions in squeezed inputs.
a_batch = []
b_batch = []
# Desired index in final output of each kind of dimension, in the order that
# lax.dot_general will emit them.
idx_batch = []
idx_a_other = [] # other = non-batch, non-contracting.
idx_b_other = []
for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)):
if ba is None:
idx_b_other.append(i)
elif bb is None:
idx_a_other.append(i)
elif core.symbolic_equal_dim(ba, 1):
idx_b_other.append(i)
a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze))
elif core.symbolic_equal_dim(bb, 1):
idx_a_other.append(i)
b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze))
elif core.symbolic_equal_dim(ba, bb):
a_batch.append(len(idx_batch) + len(idx_a_other))
b_batch.append(len(idx_batch) + len(idx_b_other))
idx_batch.append(i)
else:
raise ValueError("Incompatible shapes for matmul arguments: {} and {}"
.format(shape(a), shape(b)))
if a_is_mat: idx_a_other.append(num_batch_dims)
if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat)
perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other]))
a = lax.squeeze(a, tuple(a_squeeze))
b = lax.squeeze(b, tuple(b_squeeze))
out = lax.dot_general(
a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
precision=precision)
return lax.transpose(out, perm)
@_wraps(np.vdot, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def vdot(a, b, *, precision=None):
_check_arraylike("vdot", a, b)
if issubdtype(_dtype(a), complexfloating):
a = conj(a)
return dot(a.ravel(), b.ravel(), precision=precision)
@_wraps(np.tensordot, lax_description=_PRECISION_DOC)
def tensordot(a, b, axes=2, *, precision=None):
_check_arraylike("tensordot", a, b)
a_ndim = ndim(a)
b_ndim = ndim(b)
a, b = _promote_dtypes(a, b)
if type(axes) is int:
if axes > _min(a_ndim, b_ndim):
msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
raise TypeError(msg.format(axes, a.shape, b.shape))
contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
elif type(axes) in (list, tuple) and len(axes) == 2:
ax1, ax2 = axes
if type(ax1) == type(ax2) == int:
contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
(_canonicalize_axis(ax2, b_ndim),))
elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
if len(ax1) != len(ax2):
msg = "tensordot requires axes lists to have equal length, got {} and {}."
raise TypeError(msg.format(ax1, ax2))
contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
else:
msg = ("tensordot requires both axes lists to be either ints, tuples or "
"lists, got {} and {}")
raise TypeError(msg.format(ax1, ax2))
else:
msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
"of lists/tuples of ints.")
raise TypeError(msg)
return lax.dot_general(a, b, (contracting_dims, ((), ())),
precision=precision)
_EINSUM_DOC = _PRECISION_DOC + """\
A tuple ``precision`` does not necessarily map to multiple arguments of ``einsum()``;
rather, the specified ``precision`` is forwarded to each ``dot_general`` call used in
the implementation.
"""
@_wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out'])
def einsum(*operands, out=None, optimize='optimal', precision=None,
_use_xeinsum=False):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and
len(operands[1:]) == 2):
return lax.xeinsum(*operands)
optimize = 'optimal' if optimize is True else optimize
# using einsum_call=True here is an internal api for opt_einsum
# Allow handling of shape polymorphism
non_constant_dim_types = {
type(d) for op in operands if not isinstance(op, str)
for d in np.shape(op) if not core.is_constant_dim(d)
}
if not non_constant_dim_types:
einsum_contract_path_fn = opt_einsum.contract_path
else:
einsum_contract_path_fn = _polymorphic_einsum_contract_path_handlers[next(iter(non_constant_dim_types))]
operands, contractions = einsum_contract_path_fn(
*operands, einsum_call=True, use_blas=True, optimize=optimize)
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
return _einsum(operands, contractions, precision)
# Enable other modules to override einsum_contact_path.
# Indexed by the type of the non constant dimension
_polymorphic_einsum_contract_path_handlers = {} # type: ignore
@_wraps(np.einsum_path)
def einsum_path(subscripts, *operands, optimize='greedy'):
# using einsum_call=True here is an internal api for opt_einsum
return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
def _removechars(s, chars):
return s.translate(str.maketrans(dict.fromkeys(chars)))
@partial(jit, static_argnums=(1, 2))
def _einsum(operands: Sequence,
contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
precision):
operands = list(_promote_dtypes(*operands))
def sum(x, axes):
return lax.reduce(x, np.array(0, x.dtype),
lax.add if x.dtype != bool_ else lax.bitwise_or, axes)
def sum_uniques(operand, names, uniques):
if uniques:
axes = [names.index(name) for name in uniques]
operand = sum(operand, axes)
names = _removechars(names, uniques)
return operand, names
def sum_repeats(operand, names, counts, keep_names):
for name, count in counts.items():
if count > 1:
axes = [i for i, n in enumerate(names) if n == name]
eye = lax._delta(operand.dtype, operand.shape, axes)
if name not in keep_names:
operand = sum(operand * eye, axes)
names = names.replace(name, '')
else:
operand = sum(operand * eye, axes[:-1])
names = names.replace(name, '', count - 1)
return operand, names
def filter_singleton_dims(operand, names, other_shape, other_names):
s = shape(operand)
new_shape = []
new_names = []
for i, d in enumerate(names):
other_i = other_names.find(d)
if not core.symbolic_equal_dim(s[i], 1) or other_i == -1 or core.symbolic_equal_dim(other_shape[other_i], 1):
new_shape.append(s[i])
new_names.append(d)
return reshape(operand, tuple(new_shape)), "".join(new_names)
for operand_indices, contracted_names_set, einstr in contractions:
contracted_names = sorted(contracted_names_set)
input_str, result_names = einstr.split('->')
input_names = input_str.split(',')
# switch on the number of operands to be processed in this loop iteration.
# every case here sets 'operand' and 'names'.
if len(operand_indices) == 1:
operand = operands.pop(operand_indices[0])
names, = input_names
counts = collections.Counter(names)
# sum out unique contracted indices with a single reduce-sum
uniques = [name for name in contracted_names if counts[name] == 1]
operand, names = sum_uniques(operand, names, uniques)
# for every repeated index, do a contraction against an identity matrix
operand, names = sum_repeats(operand, names, counts, result_names)
elif len(operand_indices) == 2:
lhs, rhs = map(operands.pop, operand_indices)
lhs_names, rhs_names = input_names
# handle cases where one side of a contracting or batch dimension is 1
# but its counterpart is not.
lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
rhs_names)
rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
lhs_names)
lhs_counts = collections.Counter(lhs_names)
rhs_counts = collections.Counter(rhs_names)
# sum out unique contracted indices in lhs and rhs
lhs_uniques = [name for name in contracted_names
if lhs_counts[name] == 1 and rhs_counts[name] == 0]
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
rhs_uniques = [name for name in contracted_names
if rhs_counts[name] == 1 and lhs_counts[name] == 0]
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
# for every repeated index, contract against an identity matrix
lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
result_names + rhs_names)
rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
result_names + lhs_names)
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
for n in batch_names)
# NOTE(mattjj): this can fail non-deterministically in python3, maybe
# due to opt_einsum
assert _all(
name in lhs_names and name in rhs_names and
lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
for name in contracted_names)
# contract using lax.dot_general
batch_names_str = ''.join(batch_names)
lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
for n in contracted_names)
deleted_names = batch_names_str + ''.join(contracted_names)
remaining_lhs_names = _removechars(lhs_names, deleted_names)
remaining_rhs_names = _removechars(rhs_names, deleted_names)
# Try both orders of lhs and rhs, in the hope that one of them means we
# don't need an explicit transpose. opt_einsum likes to contract from
# right to left, so we expect (rhs,lhs) to have the best chance of not
# needing a transpose.
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
if names == result_names:
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
else:
raise NotImplementedError # if this is actually reachable, open an issue!
# the resulting 'operand' with axis labels 'names' should be a permutation
# of the desired result
assert len(names) == len(result_names) == len(set(names))
assert set(names) == set(result_names)
if names != result_names:
perm = tuple([names.index(name) for name in result_names])
operand = lax.transpose(operand, perm)
operands.append(operand) # used in next iteration
return operands[0]
def _movechars(s, src, dst):
"""Helper for einsum string munging, like moveaxis on identifier strings."""
chars = [c for i, c in enumerate(s) if i not in src]
for i, j in sorted(zip(dst, src)):
chars.insert(i, s[j])
return ''.join(chars)
@_wraps(np.inner, lax_description=_PRECISION_DOC)
@partial(jit, static_argnames=('precision',), inline=True)
def inner(a, b, *, precision=None):
if ndim(a) == 0 or ndim(b) == 0:
return a * b
return tensordot(a, b, (-1, -1), precision=precision)
@_wraps(np.outer, skip_params=['out'])
@partial(jit, inline=True)
def outer(a, b, out=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
a, b = _promote_dtypes(a, b)
return ravel(a)[:, None] * ravel(b)[None, :]
@_wraps(np.cross)
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
axis: Optional[int] = None):
if axis is not None:
axisa = axis
axisb = axis
axisc = axis
a = moveaxis(a, axisa, -1)
b = moveaxis(b, axisb, -1)
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError("Dimension must be either 2 or 3 for cross product")
if a.shape[-1] == 2 and b.shape[-1] == 2:
return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
a0 = a[..., 0]
a1 = a[..., 1]
a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0)
b0 = b[..., 0]
b1 = b[..., 1]
b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0)
c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0])
return moveaxis(c, 0, axisc)
@_wraps(np.kron)
@jit
def kron(a, b):
a, b = _promote_dtypes(a, b)
if ndim(a) < ndim(b):
a = expand_dims(a, range(ndim(b) - ndim(a)))
elif ndim(b) < ndim(a):
b = expand_dims(b, range(ndim(a) - ndim(b)))
a_reshaped = expand_dims(a, range(1, 2 * ndim(a), 2))
b_reshaped = expand_dims(b, range(0, 2 * ndim(b), 2))
out_shape = tuple(np.multiply(shape(a), shape(b)))
return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
@_wraps(np.vander)
@partial(jit, static_argnames=('N', 'increasing'))
def vander(x, N=None, increasing=False):
_check_arraylike("vander", x)
x = asarray(x)
if x.ndim != 1:
raise ValueError("x must be a one-dimensional array")
N = x.shape[0] if N is None else core.concrete_or_error(
operator.index, N, "'N' argument of jnp.vander()")
if N < 0:
raise ValueError("N must be nonnegative")
iota = lax.iota(x.dtype, N)
if not increasing:
iota = lax.sub(lax._const(iota, N - 1), iota)
return power(x[..., None], expand_dims(iota, tuple(range(x.ndim))))
### Misc
_ARGWHERE_DOC = """\
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument, which
specifies the size of the leading dimension of the output - it must be specified statically
for ``jnp.argwhere`` to be compiled with non-static operands. If ``size`` is specified,
the indices of the first ``size`` True elements will be returned; if there are fewer
nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""
@_wraps(np.argwhere, lax_description=_ARGWHERE_DOC)
def argwhere(a, *, size=None, fill_value=None):
result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value)))
if ndim(a) == 0:
return result[:0].reshape(result.shape[0], 0)
return result.reshape(result.shape[0], ndim(a))
@_wraps(np.argmax, skip_params=['out'])
def argmax(a, axis: Optional[int] = None, out=None):
return<