Source code for jax._src.lax.slicing

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from functools import partial
from typing import Any, NamedTuple, Optional, Sequence, Union

import numpy as np

from jax import core
from jax._src import ad_util
from jax._src import dtypes
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import masking
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lax.utils import (
    _argnum_weak_type,
    _input_dtype,
    standard_primitive,
)
from jax._src.lax import lax
from jax._src import util
from jax._src.util import safe_zip
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client

xb = xla_bridge
xc = xla_client
xops = xla_client.ops

Array = Any
Shape = core.Shape


def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  """Wraps XLA's `Slice
  <https://www.tensorflow.org/xla/operation_semantics#slice>`_
  operator.
  """
  return slice_p.bind(operand, start_indices=tuple(start_indices),
                      limit_indices=tuple(limit_indices),
                      strides=None if strides is None else tuple(strides))

[docs]def dynamic_slice(operand: Array, start_indices: Sequence[Array], slice_sizes: Shape) -> Array: """Wraps XLA's `DynamicSlice <https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_ operator. Args: operand: an array to slice. start_indices: a list of scalar indices, one per dimension. These values may be dynamic. slice_sizes: the size of the slice. Must be a sequence of non-negative integers with length equal to `ndim(operand)`. Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size). Returns: An array containing the slice. Examples: Here is a simple two-dimensional dynamic slice: >>> x = jnp.arange(12).reshape(3, 4) >>> x DeviceArray([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32) >>> dynamic_slice(x, (1, 1), (2, 3)) DeviceArray([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32) Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size: >>> dynamic_slice(x, (1, 1), (2, 4)) DeviceArray([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32) """ start_indices = _dynamic_slice_indices(operand, start_indices) return dynamic_slice_p.bind(operand, *start_indices, slice_sizes=core.canonicalize_shape(slice_sizes))
def dynamic_update_slice(operand: Array, update: Array, start_indices: Array) -> Array: """Wraps XLA's `DynamicUpdateSlice <https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_ operator. Args: operand: an array to slice. update: an array containing the new values to write onto `operand`. start_indices: a list of scalar indices, one per dimension. Returns: An array containing the slice. Examples: Here is an example of updating a one-dimensional slice update: >>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice(x, y, (2,)) DeviceArray([0., 0., 1., 1., 1., 0.], dtype=float32) If the update slice is too large to fit in the array, the start index will be adjusted to make it fit >>> dynamic_update_slice(x, y, (3,)) DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice(x, y, (5,)) DeviceArray([0., 0., 0., 1., 1., 1.], dtype=float32) Here is an example of a two-dimensional slice update: >>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 2)) >>> dynamic_update_slice(x, y, (1, 2)) DeviceArray([[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.], [0., 0., 0., 0.]], dtype=float32) """ start_indices = _dynamic_slice_indices(operand, start_indices) return dynamic_update_slice_p.bind(operand, update, *start_indices) class GatherDimensionNumbers(NamedTuple): """ Describes the dimension number arguments to an `XLA's Gather operator <https://www.tensorflow.org/xla/operation_semantics#gather>`_. See the XLA documentation for more details of what the dimension numbers mean. Args: offset_dims: the set of dimensions in the `gather` output that offset into an array sliced from `operand`. Must be a tuple of integers in ascending order, each representing a dimension number of the output. collapsed_slice_dims: the set of dimensions `i` in `operand` that have `slice_sizes[i] == 1` and that should not have a corresponding dimension in the output of the gather. Must be a tuple of integers in ascending order. start_index_map: for each dimension in `start_indices`, gives the corresponding dimension in `operand` that is to be sliced. Must be a tuple of integers with size equal to `start_indices.shape[-1]`. Unlike XLA's `GatherDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the last dimension. To gather scalar indices, add a trailing dimension of size 1. """ offset_dims: Sequence[int] collapsed_slice_dims: Sequence[int] start_index_map: Sequence[int] class GatherScatterMode(enum.Enum): """ Describes how to handle out-of-bounds indices in a gather or scatter. Possible values are: CLIP: Indices will be clamped to the nearest in-range value, i.e., such that the entire window to be gathered is in-range. FILL_OR_DROP: If any part of a gathered window is out of bounds, the entire window that is returned, even those elements that were otherwise in-bounds, will be filled with a constant. If any part of a scattered window is out of bounds, the entire window will be discarded. PROMISE_IN_BOUNDS: The user promises that indices are in bounds. No additional checking will be performed. In practice, with the current XLA implementation this means that, out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds. """ CLIP = enum.auto() FILL_OR_DROP = enum.auto() PROMISE_IN_BOUNDS = enum.auto() @staticmethod def from_any(s: Optional[Union[str, 'GatherScatterMode']]): if isinstance(s, GatherScatterMode): return s if s == "clip": return GatherScatterMode.CLIP if s == "fill" or s == "drop": return GatherScatterMode.FILL_OR_DROP if s is None or s == "promise_in_bounds": return GatherScatterMode.PROMISE_IN_BOUNDS else: raise ValueError(f'Unknown gather mode "{s}"') def gather(operand: Array, start_indices: Array, dimension_numbers: GatherDimensionNumbers, slice_sizes: Shape, *, unique_indices: bool = False, indices_are_sorted: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None, fill_value = None) -> Array: """Gather operator. Wraps `XLA's Gather operator <https://www.tensorflow.org/xla/operation_semantics#gather>`_. The semantics of gather are complicated, and its API might change in the future. For most use cases, you should prefer `Numpy-style indexing <https://numpy.org/doc/stable/reference/arrays.indexing.html>`_ (e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly. Args: operand: an array from which slices should be taken start_indices: the indices at which slices should be taken dimension_numbers: a `lax.GatherDimensionNumbers` object that describes how dimensions of `operand`, `start_indices` and the output relate. slice_sizes: the size of each slice. Must be a sequence of non-negative integers with length equal to `ndim(operand)`. indices_are_sorted: whether `indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to ``'clip'``, indices are clamped so that the slice is within bounds, and when set to ``'fill'`` or ``'drop'`` gather returns a slice full of ``fill_value`` for the affected slice. The behavior for out-of-bounds indices when set to ``'promise_in_bounds'`` is implementation-defined. fill_value: the fill value to return for out-of-bounds slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and ``True`` for booleans. Returns: An array containing the gather output. """ parsed_mode = GatherScatterMode.from_any(mode) if parsed_mode == GatherScatterMode.FILL_OR_DROP: if fill_value is None: dtype = lax._dtype(operand) if dtypes.issubdtype(dtype, np.inexact): fill_value = np.nan elif dtypes.issubdtype(dtype, np.signedinteger): fill_value = dtypes.iinfo(dtype).min elif dtypes.issubdtype(dtype, np.unsignedinteger): fill_value = dtypes.iinfo(dtype).max elif dtype == dtypes.bool_: fill_value = True else: raise ValueError(f"Unsupported dtype for gather fill_value {dtype}") else: fill_value = None return gather_p.bind( operand, start_indices, dimension_numbers=dimension_numbers, slice_sizes=core.canonicalize_shape(slice_sizes), unique_indices=bool(unique_indices), indices_are_sorted=bool(indices_are_sorted), mode=parsed_mode, fill_value=fill_value) class ScatterDimensionNumbers(NamedTuple): """ Describes the dimension number arguments to an `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_. See the XLA documentation for more details of what the dimension numbers mean. Args: update_window_dims: the set of dimensions in the `updates` that are window dimensions. Must be a tuple of integers in ascending order, each representing a dimension number. inserted_window_dims: the set of size 1 window dimensions that must be inserted into the shape of `updates`. Must be a tuple of integers in ascending order, each representing a dimension number of the output. These are the mirror image of `collapsed_slice_dims` in the case of `gather`. scatter_dims_to_operand_dims: for each dimension in `scatter_indices`, gives the corresponding dimension in `operand`. Must be a sequence of integers with size equal to indices.shape[-1]. Unlike XLA's `ScatterDimensionNumbers` structure, `index_vector_dim` is implicit; there is always an index vector dimension and it must always be the last dimension. To scatter scalar indices, add a trailing dimension of size 1. """ update_window_dims: Sequence[int] inserted_window_dims: Sequence[int] scatter_dims_to_operand_dims: Sequence[int] def scatter_add( operand: Array, scatter_indices: Array, updates: Array, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: """Scatter-add operator. Wraps `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where addition is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses the familiar NumPy indexing syntax. Args: operand: an array to which the scatter should be applied scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how dimensions of `operand`, `start_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices to be updated in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.add, lax._abstractify(lax._const(operand, 0))) return scatter_add_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) def scatter_mul( operand: Array, scatter_indices: Array, updates: Array, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: """Scatter-multiply operator. Wraps `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where multiplication is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses the familiar NumPy indexing syntax. Args: operand: an array to which the scatter should be applied scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how dimensions of `operand`, `start_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices to be updated in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(lax._const(operand, 1))) return scatter_mul_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) def scatter_min( operand: Array, scatter_indices: Array, updates: Array, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: """Scatter-min operator. Wraps `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where the `min` function is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses the familiar NumPy indexing syntax. Args: operand: an array to which the scatter should be applied scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how dimensions of `operand`, `start_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices to be updated in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.min, lax._abstractify(lax._const(operand, 0))) return scatter_min_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) def scatter_max( operand: Array, scatter_indices: Array, updates: Array, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: """Scatter-max operator. Wraps `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where the `max` function is used to combine updates and values from `operand`. The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses the familiar NumPy indexing syntax. Args: operand: an array to which the scatter should be applied scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how dimensions of `operand`, `start_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices to be updated in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(lax.max, lax._abstractify(lax._const(operand, 0))) return scatter_max_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) # Define this outside of scatter to ensure cache hits. _scatter_reduction_computation = lambda x, y: y def scatter( operand: Array, scatter_indices: Array, updates: Array, dimension_numbers: ScatterDimensionNumbers, *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: Optional[Union[str, GatherScatterMode]] = None) -> Array: """Scatter-update operator. Wraps `XLA's Scatter operator <https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates replace values from `operand`. If multiple updates are performed to the same index of operand, they may be applied in any order. The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses the familiar NumPy indexing syntax. Args: operand: an array to which the scatter should be applied scatter_indices: an array that gives the indices in `operand` to which each update in `updates` should be applied. updates: the updates that should be scattered onto `operand`. dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes how dimensions of `operand`, `start_indices`, `updates` and the output relate. indices_are_sorted: whether `scatter_indices` is known to be sorted. If true, may improve performance on some backends. unique_indices: whether the indices to be updated in ``operand`` are guaranteed to not overlap with each other. If true, may improve performance on some backends. mode: how to handle indices that are out of bounds: when set to 'clip', indices are clamped so that the slice is within bounds, and when set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to 'promise_in_bounds' is implementation-defined. Returns: An array containing the sum of `operand` and the scattered updates. """ jaxpr, consts = lax._reduction_jaxpr(_scatter_reduction_computation, lax._abstractify(lax._const(operand, 0))) return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=jaxpr, update_consts=consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=GatherScatterMode.from_any(mode)) def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: indices = lax.concatenate([lax.expand_dims(i, (1,)) for i in idxs], 1) indices = indices % np.array([src.shape[ax] for ax in axes]) slice_sizes = list(src.shape) for ax in axes: slice_sizes[ax] = 1 offset_dims = tuple(range(1, src.ndim - indices.shape[1] + 1)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=axes, start_index_map=axes) return gather(src, indices, dimension_numbers=dnums, slice_sizes=tuple(slice_sizes)) ### convenience wrappers around traceables def slice_in_dim(operand: Array, start_index: Optional[int], limit_index: Optional[int], stride: int = 1, axis: int = 0) -> Array: """Convenience wrapper around slice applying to only one dimension.""" start_indices = [0] * operand.ndim limit_indices = list(operand.shape) strides = [1] * operand.ndim # translate `None` len_axis = operand.shape[axis] start_index_int = (core._canonicalize_dimension(start_index) if start_index is not None else 0) limit_index_int = (core._canonicalize_dimension(limit_index) if limit_index is not None else len_axis) # translate negative indices if start_index_int < 0: start_index_int = start_index_int + len_axis if limit_index_int < 0: limit_index_int = limit_index_int + len_axis axis = int(axis) start_indices[axis] = start_index_int limit_indices[axis] = limit_index_int strides[axis] = int(stride) return slice(operand, start_indices, limit_indices, strides) def index_in_dim(operand: Array, index: int, axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around slice to perform int indexing.""" index, axis = core._canonicalize_dimension(index), int(axis) axis_size = operand.shape[axis] wrapped_index = index + axis_size if index < 0 else index if not 0 <= wrapped_index < axis_size: msg = 'index {} is out of bounds for axis {} with size {}' raise IndexError(msg.format(index, axis, axis_size)) result = slice_in_dim(operand, wrapped_index, wrapped_index + 1, 1, axis) if keepdims: return result else: return lax.squeeze(result, (axis,)) def dynamic_slice_in_dim(operand: Array, start_index: Array, slice_size: int, axis: int = 0) -> Array: """Convenience wrapper around dynamic_slice applying to one dimension.""" start_indices = [lax._zero(start_index)] * operand.ndim slice_sizes = list(operand.shape) axis = int(axis) start_indices[axis] = start_index slice_sizes[axis] = core._canonicalize_dimension(slice_size) return dynamic_slice(operand, start_indices, slice_sizes) def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0, keepdims: bool = True) -> Array: """Convenience wrapper around dynamic_slice to perform int indexing.""" result = dynamic_slice_in_dim(operand, index, 1, axis) if keepdims: return result else: return lax.squeeze(result, (axis,)) def dynamic_update_slice_in_dim(operand: Array, update: Array, start_index: Array, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update a slice in a single ``axis``. """ axis = int(axis) start_indices = [lax._zero(start_index)] * lax._ndim(operand) start_indices[axis] = start_index return dynamic_update_slice(operand, update, start_indices) def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array, axis: int) -> Array: """Convenience wrapper around :func:`dynamic_update_slice` to update a slice of size 1 in a single ``axis``. """ axis = int(axis) if lax._ndim(update) != lax._ndim(operand): assert lax._ndim(update) + 1 == lax._ndim(operand) update = lax.expand_dims(update, (axis,)) return dynamic_update_slice_in_dim(operand, update, index, axis) def _slice_shape_rule(operand, *, start_indices, limit_indices, strides): lax._check_shapelike("slice", "start_indices", start_indices) lax._check_shapelike("slice", "limit_indices", limit_indices) if operand.ndim != len(start_indices): msg = ("slice start_indices must have length equal to the number of " "dimensions of the operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) if len(start_indices) != len(limit_indices): msg = ("slice limit_indices must have the same length as start_indices, " "got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) if not core.greater_equal_shape(operand.shape, limit_indices): msg = ("slice limit_indices must be less than or equal to operand shape, " "got limit_indices {} for operand shape {}.") raise TypeError(msg.format(limit_indices, operand.shape)) if not all(core.greater_equal_dim(si, 0) for si in start_indices): msg = ("slice start_indices must be greater than or equal to zero, " "got start_indices of {}.") raise TypeError(msg.format(start_indices)) if not core.greater_equal_shape(limit_indices, start_indices): msg = ("slice limit_indices must be greater than or equal to start_indices," " got start_indices {} and limit_indices {}.") raise TypeError(msg.format(start_indices, limit_indices)) if strides is None: strides = np.ones(operand.ndim, np.int32) else: lax._check_shapelike("slice", "strides", strides) if len(strides) != operand.ndim: msg = ("slice strides must have length equal to the number of dimensions " "of the operand, got strides {} for operand shape {}.") raise TypeError(msg.format(strides, operand.shape)) if not core.greater_equal_shape(strides, (0,) * len(strides)): msg = "slice strides must be positive, got {}" raise TypeError(msg.format(strides)) diff = core.diff_shape(limit_indices, start_indices) return core.stride_shape(diff, (1,) * len(diff), strides) def _slice_translation_rule(ctx, avals_in, avals_out, operand, *, start_indices, limit_indices, strides): return [xops.Slice(operand, start_indices, limit_indices, strides or [1] * len(start_indices))] def _slice_transpose_rule(t, operand, *, start_indices, limit_indices, strides): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape if strides is None or np.all(np.equal(strides, 1)): pads = zip(start_indices, np.subtract(operand_shape, limit_indices), (0,) * len(start_indices)) else: real_limits = np.add( start_indices, np.where(np.array(t.shape) == 0, 0, np.add(1, np.multiply(np.subtract(t.shape, 1), strides)))) pads = safe_zip(start_indices, np.subtract(operand_shape, real_limits), np.subtract(strides, 1)) result = lax.pad(t, lax._const(t, 0), pads) assert result.shape == operand_shape, ( f"result.shape={result.shape} operand_shape={operand_shape}") return [result] def _slice_batching_rule(batched_args, batch_dims, *, start_indices, limit_indices, strides): operand, = batched_args bdim, = batch_dims new_start_indices = list(start_indices) new_start_indices.insert(bdim, 0) new_limit_indices = list(limit_indices) new_limit_indices.insert(bdim, operand.shape[bdim]) if strides is None: new_strides = None else: new_strides = list(strides) new_strides.insert(bdim, 1) out = slice(operand, new_start_indices, new_limit_indices, new_strides) return out, bdim def _slice_masking_rule( padded_vals, logical_shapes, start_indices, limit_indices, strides): operand, = padded_vals strides = masking.padded_shape_as_value(strides) if strides else None return slice(operand, start_indices=masking.padded_shape_as_value(start_indices), limit_indices=masking.padded_shape_as_value(limit_indices), strides=strides) slice_p = standard_primitive(_slice_shape_rule, _input_dtype, 'slice', _slice_translation_rule) ad.deflinear2(slice_p, _slice_transpose_rule) batching.primitive_batchers[slice_p] = _slice_batching_rule masking.masking_rules[slice_p] = _slice_masking_rule def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): aval_out, = ctx.avals_out strides = strides or [1] * len(start_indices) return mhlo.SliceOp(x, mlir.dense_int_elements(start_indices), mlir.dense_int_elements(limit_indices), mlir.dense_int_elements(strides)).results mlir.register_lowering(slice_p, _slice_lower) def _dynamic_slice_shape_rule(operand, *start_indices, slice_sizes): if operand.ndim != len(start_indices): msg = ("dynamic_slice start_indices must have length equal to the number " "of dimensions of the operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) if len(start_indices) != len(slice_sizes): msg = ("dynamic_slice slice_sizes must have the same length as " "start_indices, got start_indices length {} and slice_sizes {}.") raise TypeError(msg.format(len(start_indices), slice_sizes)) if not core.greater_equal_shape(operand.shape, slice_sizes): msg = ("slice slice_sizes must be less than or equal to operand shape, " "got slice_sizes {} for operand shape {}.") raise TypeError(msg.format(slice_sizes, operand.shape)) if not all(core.greater_equal_dim(ssz, 0) for ssz in slice_sizes): msg = ("slice slice_sizes must be greater than or equal to zero, " "got slice_sizes of {}.") raise TypeError(msg.format(slice_sizes)) return tuple(slice_sizes) def _dynamic_slice_dtype_rule(operand, *start_indices, slice_sizes): if any(i.dtype != start_indices[0].dtype or not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): msg = ("index arguments to dynamic_slice must be integers of the same " "type, got: {}") raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) return operand.dtype def _dynamic_slice_translation_rule(ctx, avals_in, avals_out, operand, *start_indices, slice_sizes): return [xops.DynamicSlice(operand, start_indices, slice_sizes)] def _dynamic_slice_jvp(primals, tangents, *, slice_sizes): tangent_out = tangents[0] if type(tangent_out) is not ad_util.Zero: tangent_out = dynamic_slice(tangent_out, primals[1:], slice_sizes) return dynamic_slice(primals[0], primals[1:], slice_sizes), tangent_out def _dynamic_slice_transpose_rule(t, operand, *start_indices, slice_sizes): assert ad.is_undefined_primal(operand) assert all(not ad.is_undefined_primal(s) for s in start_indices) operand_shape, operand_dtype = operand.aval.shape, operand.aval.dtype if type(t) is ad_util.Zero: return [ad_util.Zero(operand.aval)] + [None] * len(start_indices) else: zeros = lax.full(operand_shape, 0, operand_dtype) return ([dynamic_update_slice(zeros, t, start_indices)] + [None] * len(start_indices)) def _batch_dynamic_slice_indices(indices, bdims): if len(indices) == 0: return np.array([], 'int32'), None empty_marker = object() size = next((x.shape[i] for x, i in zip(indices, bdims) if i is not None), empty_marker) if size is empty_marker: return lax.concatenate([lax.broadcast(i, (1,)) for i in indices], 0), None indices = lax.concatenate( [lax.broadcast_in_dim(x, (size, 1), broadcast_dimensions=((0,) if i is not None else ())) for x, i in zip(indices, bdims)], dimension=1) return indices, 0 def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes): # A dynamic slice is a special case of gather; we can delegate to the gather # batching rule. # TODO(phawkins): consider removing dynamic_slice entirely and using gather # always. operand, *start_indices = batched_args operand_bd, *start_idx_bds = batch_dims operand_shape = (operand.shape if operand_bd is batching.not_mapped else tuple(np.delete(operand.shape, operand_bd))) dims = tuple(range(len(operand_shape))) dnums = GatherDimensionNumbers(offset_dims=dims, collapsed_slice_dims=(), start_index_map=dims) index, index_bdim = _batch_dynamic_slice_indices(start_indices, start_idx_bds) return _gather_batching_rule( [operand, index], [operand_bd, index_bdim], dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=True, indices_are_sorted=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS, fill_value=None) dynamic_slice_p = standard_primitive( _dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice', _dynamic_slice_translation_rule, weak_type_rule=_argnum_weak_type(0)) ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes): aval_out, = ctx.avals_out return mhlo.DynamicSliceOp(mlir.aval_to_ir_type(aval_out), x, start_indices, mlir.dense_int_elements(slice_sizes)).results mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) def _dynamic_update_slice_shape_rule(operand, update, *start_indices): if operand.ndim != update.ndim: msg = ("dynamic_update_slice update must have the same rank as operand, " "got update shape {} for operand shape {}.") raise TypeError(msg.format(update.shape, operand.shape)) if operand.ndim != len(start_indices): msg = ("dynamic_update_slice start_indices must have length equal to the " "rank of operand, got indices {} for operand shape {}.") raise TypeError(msg.format(start_indices, operand.shape)) if not core.greater_equal_shape(operand.shape, update.shape): msg = ("dynamic_update_slice update shape must be smaller than operand " "shape, got update shape {} for operand shape {}.") raise TypeError(msg.format(update.shape, operand.shape)) return operand.shape def _dynamic_update_slice_dtype_rule(operand, update, *start_indices): lax._check_same_dtypes("dynamic_update_slice", False, operand.dtype, update.dtype) if any(i.dtype != start_indices[0].dtype or not dtypes.issubdtype(i.dtype, np.integer) for i in start_indices): msg = ("index arguments to dynamic_update_slice must be integers of the " "same type, got {}") raise TypeError(msg.format(", ".join(i.dtype.name for i in start_indices))) return operand.dtype def _dynamic_update_slice_jvp(primals, tangents): operand, update = primals[:2] start_indices = primals[2:] g_operand, g_update = tangents[:2] val_out = dynamic_update_slice(operand, update, start_indices) if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_update = ad.instantiate_zeros(g_update) tangent_out = dynamic_update_slice(g_operand, g_update, start_indices) return val_out, tangent_out def _dynamic_update_slice_transpose_rule(t, operand, update, *start_indices): assert all(not ad.is_undefined_primal(x) for x in start_indices) if ad.is_undefined_primal(update): update_shape = update.aval.shape else: update_shape = update.shape if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(update.aval) if ad.is_undefined_primal(update) else None else: dus = dynamic_update_slice ds = dynamic_slice zeros = lax._zeros(t, shape=update_shape) operand_t = dus(t, zeros, start_indices) if ad.is_undefined_primal(operand) else None update_t = ds(t, start_indices, update_shape) if ad.is_undefined_primal(update) else None return [operand_t, update_t] + [None] * len(start_indices) def _dynamic_update_slice_translation_rule(ctx, avals_in, avals_out, operand, update, *start_indices): return [xops.DynamicUpdateSlice(operand, update, start_indices)] def _dynamic_update_slice_batching_rule(batched_args, batch_dims): # A dynamic update slice is a special case of scatter; we can delegate to the # scatter batching rule. # TODO(phawkins): consider removing dynamic_update_slice entirely and using # scatter always. operand, update, *start_idx = batched_args operand_bd, update_bd, *start_idx_bd = batch_dims update_shape = (np.shape(update) if update_bd is batching.not_mapped else tuple(np.delete(np.shape(update), update_bd))) dims = tuple(range(len(update_shape))) dnums = ScatterDimensionNumbers(update_window_dims=dims, inserted_window_dims=(), scatter_dims_to_operand_dims=dims) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) return _scatter_batching_rule( scatter, (operand, index, update), (operand_bd, index_bdim, update_bd), update_jaxpr=None, update_consts=None, dimension_numbers=dnums, indices_are_sorted=True, unique_indices=True, mode=GatherScatterMode.PROMISE_IN_BOUNDS) dynamic_update_slice_p = standard_primitive( _dynamic_update_slice_shape_rule, _dynamic_update_slice_dtype_rule, 'dynamic_update_slice', _dynamic_update_slice_translation_rule) ad.primitive_jvps[dynamic_update_slice_p] = _dynamic_update_slice_jvp ad.primitive_transposes[dynamic_update_slice_p] = \ _dynamic_update_slice_transpose_rule batching.primitive_batchers[dynamic_update_slice_p] = \ _dynamic_update_slice_batching_rule def _dynamic_update_slice_lower(ctx, x, update, *start_indices): aval_out, = ctx.avals_out return mhlo.DynamicUpdateSliceOp(mlir.aval_to_ir_type(aval_out), x, update, start_indices).results mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower) def _gather_dimensions_proto( indices_shape: Sequence[int], dimension_numbers: GatherDimensionNumbers ) -> xla_client.GatherDimensionNumbers: assert type(dimension_numbers) is GatherDimensionNumbers proto = xla_client.GatherDimensionNumbers() proto.offset_dims.extend(dimension_numbers.offset_dims) proto.collapsed_slice_dims.extend(dimension_numbers.collapsed_slice_dims) proto.start_index_map.extend(dimension_numbers.start_index_map) assert len(indices_shape) > 0, indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto def _gather_dtype_rule(operand, indices, *, fill_value, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") return dtypes.canonicalize_dtype(operand.dtype) _rank = lambda arr: len(arr.shape) def _is_sorted(dims, op_name, name): for i in range(1, len(dims)): if dims[i] < dims[i - 1]: raise TypeError(f"{name} in {op_name} op must be sorted; got {dims}") def _sorted_dims_in_range(dims, rank, op_name, name): if len(dims) == 0: return invalid_dim = None if dims[0] < 0: invalid_dim = dims[0] elif dims[-1] >= rank: invalid_dim = dims[-1] if invalid_dim: raise TypeError(f"Invalid {name} set in {op_name} op; valid range is " f"[0, {rank}); got: {invalid_dim}.") def _no_duplicate_dims(dims, op_name, name): if len(set(dims)) != len(dims): raise TypeError(f"{name} in {op_name} op must not repeat; got: {dims}.") def _gather_shape_rule(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): """Validates the well-formedness of the arguments to Gather. The code implements the checks based on the detailed operation semantics of XLA's `Gather <https://www.tensorflow.org/xla/operation_semantics#gather>`_ operator and following the outline of the implementation of ShapeInference::InferGatherShape in TensorFlow. """ offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the GatherDimensionNumbers class. index_vector_dim = _rank(indices) - 1 # This case should never happen in JAX, due to the implicit construction of # index_vector_dim, but is included for completeness. if _rank(indices) < index_vector_dim or index_vector_dim < 0: raise TypeError(f"Gather index leaf dimension must be within [0, rank(" f"indices) + 1). rank(indices) is {_rank(indices)} and " f"gather index leaf dimension is {index_vector_dim}.") expanded_indices_shape = list(indices.shape) # This case should never happen in JAX, due to the implicit construction of # index_vector_dim, but is included for completeness. if len(expanded_indices_shape) == index_vector_dim: expanded_indices_shape.append(1) # Start ValidateGatherDimensions # In the error messages output by XLA, "offset_dims" is called "Output window # dimensions" in error messages. For consistency's sake, our error messages # stick to "offset_dims". _is_sorted(offset_dims, "gather", "offset_dims") _no_duplicate_dims(offset_dims, "gather", "offset_dims") output_offset_dim_count = len(offset_dims) output_shape_rank = len(offset_dims) + _rank(indices) - 1 for i in range(output_offset_dim_count): offset_dim = offset_dims[i] if offset_dim < 0 or offset_dim >= output_shape_rank: raise TypeError(f"Offset dimension {i} in gather op is out of bounds; " f"got {offset_dim}, but should have been in " f"[0, {output_shape_rank})") if len(start_index_map) != indices.shape[index_vector_dim]: raise TypeError(f"Gather op has {len(start_index_map)} elements in " f"start_index_map and the bound of dimension " f"index_vector_dim={index_vector_dim} of indices is " f"{indices.shape[index_vector_dim]}. These two " f"numbers must be equal.") for i in range(len(start_index_map)): operand_dim_for_start_index_i = start_index_map[i] if (operand_dim_for_start_index_i < 0 or operand_dim_for_start_index_i >= _rank(operand)): raise TypeError(f"Invalid start_index_map; domain is " f"[0, {_rank(operand)}), got: " f"{i}->{operand_dim_for_start_index_i}.") _no_duplicate_dims(start_index_map, "gather", "start_index_map") # _is_sorted and _sorted_dims_in_range are checked in the opposite order # compared to the XLA implementation. In cases when the input is not sorted # AND there are problematic collapsed_slice_dims, the error message will thus # be different. _is_sorted(collapsed_slice_dims, "gather", "collapsed_slice_dims") _sorted_dims_in_range(collapsed_slice_dims, _rank(operand), "gather", "collapsed_slice_dims") _no_duplicate_dims(collapsed_slice_dims, "gather", "collapsed_slice_dims") # End ValidateGatherDimensions if _rank(operand) != len(slice_sizes): raise TypeError(f"Gather op must have one slice size for every input " f"dimension; got: len(slice_sizes)={len(slice_sizes)}, " f"input_shape.rank={_rank(operand)}") if len(slice_sizes) != len(offset_dims) + len(collapsed_slice_dims): raise TypeError(f"All components of the offset index in a gather op must " f"either be a offset dimension or explicitly collapsed; " f"got len(slice_sizes)={len(slice_sizes)}, " f"output_slice_sizes={offset_dims}, collapsed_slice_dims=" f"{collapsed_slice_dims}.") for i in range(len(slice_sizes)): slice_size = slice_sizes[i] corresponding_input_size = operand.shape[i] if not (core.greater_equal_dim(slice_size, 0) and core.greater_equal_dim(corresponding_input_size, slice_size)): raise TypeError(f"Slice size at index {i} in gather op is out of range, " f"must be within [0, {corresponding_input_size} + 1), " f"got {slice_size}.") for i in range(len(collapsed_slice_dims)): bound = slice_sizes[collapsed_slice_dims[i]] if bound != 1: raise TypeError(f"Gather op can only collapse slice dims with bound 1, " f"but bound is {bound} for index " f"{collapsed_slice_dims[i]} at position {i}.") expanded_indices_shape.pop(index_vector_dim) indices_shape = iter(expanded_indices_shape) slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims)) return tuple(next(slice_sizes) if i in offset_dims else next(indices_shape) for i in range(output_shape_rank)) def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, fill_value, output_shape): """Lowers a FILL_OR_DROP gather as a PROMISE_IN_BOUNDS gather with masking.""" dnums = dimension_numbers intarray = partial(np.array, dtype=np.int64) operand_dims = lax._shape_as_value(operand.shape) indices = lax.convert_element_type(indices, np.int64) num_batch_dims = len(indices.shape) - 1 upper_bound = (operand_dims[intarray(dnums.start_index_map)] - intarray(slice_sizes)[intarray(dnums.start_index_map)]) mask = lax.bitwise_and( lax.ge(indices, np.int64(0)), lax.le(indices, lax.expand_dims(upper_bound, tuple(range(num_batch_dims))))) mask = lax._reduce_and(mask, [num_batch_dims]) # Computes the output shape and the positions of the batch dimensions in the # output output_ndims = num_batch_dims + len(dnums.offset_dims) batch_dims_in_output = np.delete(np.arange(output_ndims), dnums.offset_dims) # We don't consume unique_indices directly in gather(), only in its transpose # (scatter). gather_out = gather(operand, indices, dnums, slice_sizes, indices_are_sorted=indices_are_sorted, mode=GatherScatterMode.PROMISE_IN_BOUNDS) return lax.select( lax.broadcast_in_dim(mask, output_shape, batch_dims_in_output), gather_out, lax.full_like(gather_out, fill_value=fill_value)) def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): aval_out, = avals_out if mode == GatherScatterMode.FILL_OR_DROP: gather_fill_fn = xla.lower_fun(_gather_fill, multiple_results=False, new_style=True) return gather_fill_fn( ctx, avals_in, avals_out, operand, indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, output_shape=aval_out.shape) operand_aval, indices_aval = avals_in dimensions = _gather_dimensions_proto(indices_aval.shape, dimension_numbers) assert (mode == GatherScatterMode.CLIP or mode == GatherScatterMode.PROMISE_IN_BOUNDS), mode # XLA's Gather has clamp semantics, so we can just call it directly. return [xops.Gather(operand, indices, dimensions, slice_sizes, indices_are_sorted=indices_are_sorted)] def _gather_jvp_rule(g, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): return gather(g, indices, dimension_numbers, slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=0) def _gather_transpose_rule(t, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): assert ad.is_undefined_primal(operand) operand_shape = operand.aval.shape if type(t) is ad_util.Zero: out = ad_util.Zero(operand.aval) else: zeros = lax.full(operand_shape, lax._zero(t)) scatter_dnums = ScatterDimensionNumbers( update_window_dims=dimension_numbers.offset_dims, inserted_window_dims=dimension_numbers.collapsed_slice_dims, scatter_dims_to_operand_dims=dimension_numbers.start_index_map) out = scatter_add(zeros, indices, t, scatter_dnums, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode) return [out, None] def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): operand, indices = batched_args operand_bdim, indices_bdim = batch_dims if operand_bdim is not None and indices_bdim is None: operand = batching.moveaxis(operand, operand_bdim, 0) slice_sizes = (operand.shape[0],) + slice_sizes offset_dims = (0,) + tuple(np.add(1, dimension_numbers.offset_dims)) collapsed_slice_dims = tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) start_index_map = tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, start_index_map=start_index_map) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 elif operand_bdim is None and indices_bdim is not None: indices = batching.moveaxis(indices, indices_bdim, 0) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=dimension_numbers.collapsed_slice_dims, start_index_map=dimension_numbers.start_index_map) # If batching indexed accesses into the same array, the batched gather may # no longer have sorted or unique indices. return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=False, indices_are_sorted=False, mode=mode, fill_value=fill_value), 0 else: # move batch dimensions to the front to simplify logic operand = batching.moveaxis(operand, operand_bdim, 0) indices = batching.moveaxis(indices, indices_bdim, 0) # This slightly awkward special case is needed because the shape rule for # gather does not allow size-1 slices out of a size-0 dimension, even if # the number of slices is zero. Likely the best fix would be to change the # definition of gather() so it can be batched without the construction of # an explicit iota of size-1 slices. if core.symbolic_equal_dim(operand.shape[0], 0): output_shape = _gather_shape_rule( core.ShapedArray(operand.shape[1:], operand.dtype), core.ShapedArray(indices.shape[1:], dtypes.canonicalize_dtype(indices.dtype)), dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) return lax.full((0,) + output_shape, lax._zero(operand)), 0 # Example: user code had indices shape (3, 4, 5), and we have to deal with # indices shape (7, 3, 4, 5). We transform that to indices of shape # (7, 3, 4, 6) where we concatenated an iota that counts along our batch # dimension to the front of the ndindex. count_shape = list(indices.shape) count_shape[-1] = 1 counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) indices = lax.concatenate([counts, indices], len(count_shape) - 1) slice_sizes = (1,) + slice_sizes collapsed_slice_dims = (0,) + tuple(np.add(1, dimension_numbers.collapsed_slice_dims)) offset_dims = tuple(np.add(1, dimension_numbers.offset_dims)) start_index_map = (0,) + tuple(np.add(1, dimension_numbers.start_index_map)) dnums = GatherDimensionNumbers( offset_dims=offset_dims, collapsed_slice_dims=collapsed_slice_dims, start_index_map=start_index_map) return gather(operand, indices, dimension_numbers=dnums, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value), 0 gather_p = standard_primitive( _gather_shape_rule, _gather_dtype_rule, 'gather', _gather_translation_rule, weak_type_rule=_argnum_weak_type(0)) ad.defjvp(gather_p, _gather_jvp_rule, None) ad.primitive_transposes[gather_p] = _gather_transpose_rule batching.primitive_batchers[gather_p] = _gather_batching_rule def _gather_lower(ctx, operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): aval_out, = ctx.avals_out if mode == GatherScatterMode.FILL_OR_DROP: gather_fill_fn = mlir.lower_fun(_gather_fill, multiple_results=False) return gather_fill_fn( ctx, operand, indices, dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value, output_shape=aval_out.shape) assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS, GatherScatterMode.CLIP), mode dnums = mhlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), start_index_map=list(dimension_numbers.start_index_map)) return mhlo.GatherOp(operand, indices, dnums, mlir.dense_int_elements(slice_sizes), ir.BoolAttr.get(indices_are_sorted)).results mlir.register_lowering(gather_p, _gather_lower) def _scatter_dimensions_proto( indices_shape: Sequence[int], dimension_numbers: ScatterDimensionNumbers ) -> xla_client.ScatterDimensionNumbers: assert type(dimension_numbers) is ScatterDimensionNumbers proto = xla_client.ScatterDimensionNumbers() proto.update_window_dims.extend(dimension_numbers.update_window_dims) proto.inserted_window_dims.extend(dimension_numbers.inserted_window_dims) proto.scatter_dims_to_operand_dims.extend( dimension_numbers.scatter_dims_to_operand_dims) assert len(indices_shape) > 0, indices_shape proto.index_vector_dim = len(indices_shape) - 1 return proto def _scatter_dtype_rule(operand, indices, updates, **kwargs): if not dtypes.issubdtype(indices.dtype, np.integer): raise ValueError("indices must have an integer type") lax._check_same_dtypes("scatter", False, operand.dtype, updates.dtype) return dtypes.canonicalize_dtype(operand.dtype) def _scatter_shape_rule(operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): """Validates the well-formedness of the ``dimension_numbers`` argument to Scatter. The code implements the checks based on the detailed operation semantics of XLA's `Scatter <https://www.tensorflow.org/xla/operation_semantics#scatter>`_ operator and following the outline of the implementation of ShapeInference::InferScatterShape in TensorFlow. """ update_window_dims = dimension_numbers.update_window_dims inserted_window_dims = dimension_numbers.inserted_window_dims scatter_dims_to_operand_dims = dimension_numbers.scatter_dims_to_operand_dims # Note: in JAX, index_vector_dim is always computed as below, cf. the # documentation of the ScatterDimensionNumbers class. index_vector_dim = _rank(indices) - 1 # This case should never happen in JAX, due to the implicit construction of # index_vector_dim, but is included for completeness. if _rank(indices) < index_vector_dim or index_vector_dim < 0: raise TypeError(f"Scatter index leaf dimension must be within [0, " f"rank(indices) + 1). rank(indices) is {_rank(indices)} " f"and scatter index leaf dimension is {index_vector_dim}.") expanded_indices_shape = list(indices.shape) # This case should never happen in JAX, due to the implicit construction of # index_vector_dim, but is included for completeness. if len(expanded_indices_shape) == index_vector_dim: expanded_indices_shape.append(1) expected_updates_rank = (len(expanded_indices_shape) - 1 + len(update_window_dims)) if _rank(updates) != expected_updates_rank: raise TypeError(f"Updates tensor must be of rank {expected_updates_rank}; " f"got {_rank(updates)}.") # Validate update_window_dims _is_sorted(update_window_dims, "scatter", "update_window_dims") _no_duplicate_dims(update_window_dims, "scatter", "update_window_dims") _sorted_dims_in_range(update_window_dims, _rank(updates), "scatter", "update_window_dims") # Validate inserted_window_dims _is_sorted(inserted_window_dims, "scatter", "inserted_window_dims") _no_duplicate_dims(inserted_window_dims, "scatter", "inserted_window_dims") _sorted_dims_in_range(inserted_window_dims, _rank(operand), "scatter", "inserted_window_dims") # Validate window_size window_size = len(update_window_dims) + len(inserted_window_dims) if _rank(operand) != window_size: raise TypeError(f"Scatter op has window of size {window_size}; doesn't " f"match operand of rank {_rank(operand)}.") # Validate scatter_dims_to_operand_dims if (len(scatter_dims_to_operand_dims) != indices.shape[index_vector_dim]): raise TypeError(f"Scatter op has {len(scatter_dims_to_operand_dims)} " f"elements in scatter_dims_to_operand_dims and the bound " f"of dimension index_vector_dim={index_vector_dim} of " f"indices is {indices.shape[index_vector_dim]}. These two " f"numbers must be equal") for i in range(len(scatter_dims_to_operand_dims)): dim = scatter_dims_to_operand_dims[i] if dim < 0 or dim >= _rank(operand): raise TypeError(f"Invalid scatter_dims_to_operand_dims mapping; domain " f"is [0, {_rank(operand)}), got: {i}->{dim}.") _no_duplicate_dims(scatter_dims_to_operand_dims, "scatter", "scatter_dims_to_operand_dims") max_update_slice_sizes = [operand.shape[i] for i in range(len(operand.shape)) if not i in set(inserted_window_dims)] for i in range(len(update_window_dims)): update_window_dim = update_window_dims[i] if not core.greater_equal_dim(max_update_slice_sizes[i], updates.shape[update_window_dim]): raise TypeError(f"Bounds of the window dimensions of updates must not " f"exceed the bounds of the corresponding dimensions of " f"operand. For dimension {update_window_dim}, updates " f"bound is {updates.shape[update_window_dim]}, operand " f"bound is {max_update_slice_sizes[i]}.") update_scatter_dims = [dim for dim in range(_rank(updates)) if dim not in set(update_window_dims)] scatter_dims_seen = 0 for i in update_scatter_dims: if scatter_dims_seen == index_vector_dim: scatter_dims_seen += 1 if updates.shape[i] != expanded_indices_shape[scatter_dims_seen]: raise TypeError(f"Bounds of the scatter dimensions of updates must be " f"the same as the bounds of the corresponding dimensions " f"of scatter indices. For scatter dimension {i}, updates " f"bound is {updates.shape[i]}, indices bound is " f"{expanded_indices_shape[scatter_dims_seen]}.") scatter_dims_seen += 1 return operand.shape def _clamp_scatter_indices(operand, indices, updates, *, dnums): """Clamps `indices` to be in-range for a scatter.""" slice_sizes = [] pos = 0 for i in range(len(operand.shape)): if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) pos += 1 upper_bound = np.array([operand.shape[i] - slice_sizes[i] for i in dnums.scatter_dims_to_operand_dims], np.int64) upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, (len(indices.shape) - 1,)) return lax.clamp(np.int64(0), lax.convert_element_type(indices, np.int64), upper_bound) def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand_aval, indices_aval, updates_aval = avals_in if mode == GatherScatterMode.CLIP: clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False, new_style=True) indices, = clip_fn(ctx, avals_in, None, operand, indices, updates, dnums=dimension_numbers) c = ctx.builder init_value = xla.pyval_to_ir_constant(c, np.array(0, operand_aval.dtype)) update_computation = lax._reduction_computation( ctx, update_jaxpr, update_consts, init_value) return [xops.Scatter( operand, indices, updates, update_computation, _scatter_dimensions_proto(indices_aval.shape, dimension_numbers), indices_are_sorted, unique_indices)] def _scatter_add_translation_rule( ctx, avals_in, avals_out, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode, expand_complex128=False): operand_aval, indices_aval, updates_aval = avals_in if mode == GatherScatterMode.CLIP: clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False, new_style=True) indices, = clip_fn(ctx, avals_in, None, operand, indices, updates, dnums=dimension_numbers) dtype = operand_aval.dtype scatter_dims = _scatter_dimensions_proto( indices_aval.shape, dimension_numbers) def _make_reducer(dtype): subc = xc.XlaBuilder("scatter_add_reducer") shape = xc.Shape.array_shape(np.dtype(dtype), ()) args = [xla.parameter(subc, 0, shape), xla.parameter(subc, 1, shape)] out = xops.Add(args[0], args[1]) return subc.build(out) if expand_complex128 and dtype == np.complex128: update_computation = _make_reducer(np.float64) re = xops.Scatter(xops.Real(operand), indices, xops.Real(updates), update_computation, scatter_dims, indices_are_sorted, unique_indices) im = xops.Scatter(xops.Imag(operand), indices, xops.Imag(updates), update_computation, scatter_dims, indices_are_sorted, unique_indices) return [xops.Complex(re, im)] else: update_computation = _make_reducer(dtype) return [xops.Scatter(operand, indices, updates, update_computation, scatter_dims, indices_are_sorted, unique_indices)] def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents del g_indices # ignored val_out = scatter_add_p.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) tangent_out = scatter_add_p.bind( g_operand, indices, g_updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) return val_out, tangent_out def _scatter_add_transpose_rule(t, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape else: updates_shape = updates.shape if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None else: operand_t = update_t = None if ad.is_undefined_primal(operand): operand_t = t if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( offset_dims=dimension_numbers.update_window_dims, collapsed_slice_dims=dimension_numbers.inserted_window_dims, start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) pos += 1 update_t = gather(t, indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes, mode=mode, fill_value=0) return [operand_t, None, update_t] def _scatter_mul_transpose_rule(t, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape else: updates_shape = updates.shape if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None else: operand_t = update_t = None if ad.is_undefined_primal(operand): operand_t = scatter_mul( t, indices, updates, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( offset_dims=dimension_numbers.update_window_dims, collapsed_slice_dims=dimension_numbers.inserted_window_dims, start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) pos += 1 update_t = gather(lax.mul(t, operand), indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes, mode=mode, fill_value=0) return [operand_t, None, update_t] def _scatter_batching_rule(scatter_op, batched_args, batch_dims, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand, indices, updates = batched_args operand_bdim, indices_bdim, updates_bdim = batch_dims del update_jaxpr, update_consts # Unused. # move the operand batch dim to the front if it is not None, otherwise create # it at the front (so that we can scatter into it) size = next(x.shape[ax] for x, ax in zip(batched_args, batch_dims) if ax is not None) operand = batching.bdim_at_front(operand, operand_bdim, size) operand_bdim = 0 updates = batching.bdim_at_front(updates, updates_bdim, size) if indices_bdim is None: inserted_window_dims = tuple(np.add(1, dimension_numbers.inserted_window_dims)) update_window_dims = (0,) + tuple(np.add(1, dimension_numbers.update_window_dims)) scatter_dims_to_operand_dims = tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) return scatter_op( operand, indices, updates, dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode), 0 # see the third case in _gather_batching_rule for comparison and comments indices = batching.bdim_at_front(indices, indices_bdim, size) count_shape = list(indices.shape) count_shape[-1] = 1 counts = lax.broadcasted_iota(indices.dtype, tuple(count_shape), 0) indices = lax.concatenate([counts, indices], len(count_shape) - 1) update_window_dims = tuple(np.add(1, dimension_numbers.update_window_dims)) inserted_window_dims = (0,) + tuple(np.add(1, dimension_numbers.inserted_window_dims)) scatter_dims_to_operand_dims = (0,) + tuple(np.add(1, dimension_numbers.scatter_dims_to_operand_dims)) dnums = ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims) return scatter_op( operand, indices, updates, dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode), 0 scatter_add_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-add', _scatter_add_translation_rule, weak_type_rule=_argnum_weak_type(0)) ad.primitive_jvps[scatter_add_p] = _scatter_add_jvp ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule batching.primitive_batchers[scatter_add_p] = ( partial(_scatter_batching_rule, scatter_add)) xla.register_translation(scatter_add_p, partial(_scatter_add_translation_rule, expand_complex128=True), platform='gpu') scatter_mul_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-mul', _scatter_translation_rule, weak_type_rule=_argnum_weak_type(0)) def _scatter_mul_jvp_rhs(g, x, i, y, *, dimension_numbers, indices_are_sorted, unique_indices, mode, **kw): return lax.mul(x, scatter_add( lax.zeros_like_array(x), i, g, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) ad.defjvp(scatter_mul_p, lambda g, x, i, y, **kw: scatter_mul_p.bind(g, i, y, **kw), None, _scatter_mul_jvp_rhs) ad.primitive_transposes[scatter_mul_p] = _scatter_mul_transpose_rule batching.primitive_batchers[scatter_mul_p] = ( partial(_scatter_batching_rule, scatter_mul)) def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents scatter_dnums = dimension_numbers updates_shape = updates.shape val_out = scatter_op.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=scatter_dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: tangent_out = ad_util.Zero.from_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) # gather_dnums and slice_sizes define the gather op that is the inverse of # the scatter op specified by scatter_dnums gather_dnums = GatherDimensionNumbers( offset_dims=scatter_dnums.update_window_dims, collapsed_slice_dims=scatter_dnums.inserted_window_dims, start_index_map=scatter_dnums.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(operand.shape)): if i in scatter_dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[scatter_dnums.update_window_dims[pos]]) pos += 1 # For consistency with other max operations, if there are two or more values # in updates that are contending to replace the same index location, the # resulting tangent at that location will be the average of the associated # tangents for the values in updates. initial_vals = gather( operand, indices, gather_dnums, np.array(slice_sizes)) target_vals = gather( val_out, indices, gather_dnums, np.array(slice_sizes)) successful_updates = (updates == target_vals) retained_values = (initial_vals == target_vals) num_updates = gather( scatter_add( lax._zeros(operand), indices, lax.select(successful_updates, lax._ones(updates), lax._zeros(updates)), scatter_dnums), indices, gather_dnums, np.array(slice_sizes)) num_refs = gather( scatter_add(lax._zeros(operand), indices, lax._ones(updates), scatter_dnums), indices, gather_dnums, np.array(slice_sizes)) updates_normalizer = lax.select(retained_values, 1.0 / (num_updates + 1), 1.0 / num_updates) updates_coef = lax.select(successful_updates, updates_normalizer, lax._zeros(updates)) operand_normalizer = lax.select(retained_values, 1.0 / (num_updates + 1), lax._zeros(num_updates)) operand_coef = (-1.0 + operand_normalizer) / num_refs # This can be simplified once scatter has transpose implemented target_tangents = gather( g_operand, indices, gather_dnums, np.array(slice_sizes)) tangent_updates = (target_tangents * operand_coef + g_updates * updates_coef) tangent_out = scatter_add(g_operand, indices, tangent_updates, scatter_dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) return val_out, tangent_out scatter_min_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-min', _scatter_translation_rule, weak_type_rule=_argnum_weak_type(0)) batching.primitive_batchers[scatter_min_p] = ( partial(_scatter_batching_rule, scatter_min)) ad.primitive_jvps[scatter_min_p] = partial(_scatter_extremal_jvp, scatter_min_p) scatter_max_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter-max', _scatter_translation_rule, weak_type_rule=_argnum_weak_type(0)) batching.primitive_batchers[scatter_max_p] = ( partial(_scatter_batching_rule, scatter_max)) ad.primitive_jvps[scatter_max_p] = partial(_scatter_extremal_jvp, scatter_max_p) def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand, indices, updates = primals g_operand, g_indices, g_updates = tangents dnums = dimension_numbers if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: val_out = scatter_p.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) return val_out, ad_util.Zero.from_value(val_out) g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) if unique_indices: # If the user has promised that the updates don't overlap, we can use a much # simpler JVP. val_out = scatter_p.bind( operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode) tangent_out = scatter_p.bind( g_operand, indices, g_updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode) return val_out, tangent_out # If there are overlapping indices in the scatter, it is unspecified which # update "wins". So we use the following perhaps surprising scheme: # a) attach a positive ID to each update in updates, and perform the scatter # on the IDs # b) perform the inverse gather on the scattered IDs (similar to # _scatter_add_transpose). # c) use the gathered IDs to mask the primal and tangent values. # d) perform a scatter-add on the masked primal and tangent values. A benefit # of using scatter-add here is that we don't need a `scatter` transpose # rule. # a) attach a positive ID to each update in `updates`, and perform a scatter # on the IDs. ids_shape = np.array(updates.shape, dtype=np.int64) ids_shape[dnums.update_window_dims,] = 1 num_ids = np.prod(ids_shape) id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64 update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape), lax._ones(updates, dtype=id_dtype)) scattered_ids = scatter(lax.full(operand.shape, 0, id_dtype), indices, update_ids, dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) # b) compute the inverse gather that "undoes" the scatter on the id values. gather_dnums = GatherDimensionNumbers( offset_dims=dnums.update_window_dims, collapsed_slice_dims=dnums.inserted_window_dims, start_index_map=dnums.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(scattered_ids.shape)): if i in dnums.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) pos += 1 gathered_update_ids = gather(scattered_ids, indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes) # c) mask off input elements that do not correspond to a primal output. masked_operand = lax.select(lax.eq(scattered_ids, lax._zeros(scattered_ids)), operand, lax._zeros(operand)) masked_updates = lax.select(lax.eq(update_ids, gathered_update_ids), updates, lax._zeros(updates)) masked_g_operand = lax.select(lax.eq(scattered_ids, lax._zeros(scattered_ids)), g_operand, lax._zeros(g_operand)) masked_g_updates = lax.select(lax.eq(update_ids, gathered_update_ids), g_updates, lax._zeros(g_updates)) # d) perform scatter-adds to compute the primal and tangent outputs. val_out = scatter_add(masked_operand, indices, masked_updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) tangent_out = scatter_add(masked_g_operand, indices, masked_g_updates, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) return val_out, tangent_out def _scatter_transpose_rule(t, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if not unique_indices: raise NotImplementedError("scatter transpose is only implemented where" "unique_indices=True") assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): updates_shape = updates.aval.shape else: updates_shape = updates.shape if type(t) is ad_util.Zero: operand_t = ad_util.Zero(operand.aval) if ad.is_undefined_primal(operand) else None update_t = ad_util.Zero(updates.aval) if ad.is_undefined_primal(updates) else None else: operand_t = update_t = None if ad.is_undefined_primal(operand): # Zero out gradient entries that correspond to updated indices. mask = scatter(lax._ones(t, dtype=np.bool_), indices, lax.full(updates_shape, False), dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=True, mode=mode) operand_t = lax.select(mask, t, lax._zeros(t)) if ad.is_undefined_primal(updates): gather_dnums = GatherDimensionNumbers( offset_dims=dimension_numbers.update_window_dims, collapsed_slice_dims=dimension_numbers.inserted_window_dims, start_index_map=dimension_numbers.scatter_dims_to_operand_dims) slice_sizes = [] pos = 0 for i in range(len(t.shape)): if i in dimension_numbers.inserted_window_dims: slice_sizes.append(1) else: slice_sizes.append(updates_shape[dimension_numbers.update_window_dims[pos]]) pos += 1 update_t = gather(t, indices, dimension_numbers=gather_dnums, slice_sizes=slice_sizes, mode=mode, fill_value=0) return [operand_t, None, update_t] scatter_p = standard_primitive( _scatter_shape_rule, _scatter_dtype_rule, 'scatter', _scatter_translation_rule, weak_type_rule=_argnum_weak_type(0)) ad.primitive_jvps[scatter_p] = _scatter_jvp ad.primitive_transposes[scatter_p] = _scatter_transpose_rule batching.primitive_batchers[scatter_p] = ( partial(_scatter_batching_rule, scatter)) def _scatter_lower(ctx, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if mode == GatherScatterMode.CLIP: clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False) (indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices, updates, dnums=dimension_numbers) aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = mhlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1) op = mhlo.ScatterOp(mlir.aval_to_ir_type(aval_out), operand, indices, updates, scatter_dnums, ir.BoolAttr.get(indices_are_sorted), ir.BoolAttr.get(unique_indices)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype)) update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): update_ctx = ctx.module_context.replace(name_stack='') out_nodes = mlir.jaxpr_subcomp( update_ctx, update_jaxpr, update_consts, (update.arguments[0],), (update.arguments[1],)) mhlo.ReturnOp(util.flatten(out_nodes)) return op.results mlir.register_lowering(scatter_p, _scatter_lower) mlir.register_lowering(scatter_add_p, _scatter_lower) mlir.register_lowering(scatter_mul_p, _scatter_lower) mlir.register_lowering(scatter_min_p, _scatter_lower) mlir.register_lowering(scatter_max_p, _scatter_lower) def _real_dtype(dtype): return np.finfo(dtype).dtype def _scatter_add_lower_gpu(ctx, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): operand_aval_in, _, updates_aval_in = ctx.avals_in if operand_aval_in.dtype != np.complex128: return _scatter_lower(ctx, operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, dimension_numbers=dimension_numbers, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if mode == GatherScatterMode.CLIP: clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False) (indices,), = clip_fn(ctx, ctx.avals_in, None, operand, indices, updates, dnums=dimension_numbers) aval_out, = ctx.avals_out dnums = dimension_numbers scatter_dnums = mhlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1) real_dtype = _real_dtype(aval_out.dtype) operand_type_part = mlir.aval_to_ir_type( core.ShapedArray(aval_out.shape, real_dtype)) def _scatter(operand_part, updates_part): scatter = mhlo.ScatterOp(operand_type_part, operand_part, indices, updates_part, scatter_dnums, ir.BoolAttr.get(indices_are_sorted), ir.BoolAttr.get(unique_indices)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype)) reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): add = mhlo.AddOp(*reducer.arguments).result mhlo.ReturnOp([add]) return scatter.result real = _scatter(mhlo.RealOp(operand).result, mhlo.RealOp(updates).result) imag = _scatter(mhlo.ImagOp(operand).result, mhlo.ImagOp(updates).result) return mhlo.ComplexOp(real, imag).results mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu") def _dynamic_slice_indices(operand, start_indices: Any): # Normalize the start_indices w.r.t. operand.shape if len(start_indices) != operand.ndim: msg = ("Length of slice indices must match number of operand dimensions ({} " "vs {})") raise ValueError(msg.format(len(start_indices), operand.shape)) if not isinstance(start_indices, (tuple, list)): if start_indices.ndim != 1: raise ValueError("Slice indices must be a 1D sequence, got {}" .format(start_indices.shape)) start_indices = [i for i in start_indices] return [np.asarray(i + d if i < 0 else i, lax._dtype(i)) if isinstance(i, (int, np.integer)) and core.is_constant_dim(d) else lax.select( lax.lt(i, lax._const(i, 0)), lax.add(i, lax.convert_element_type(core.dimension_as_value(d), lax._dtype(i))), i) for i, d in zip(start_indices, operand.shape)]