Source code for jax._src.lax.control_flow.loops

"""Module for the loop primitives."""
from __future__ import annotations

from import Sequence
from functools import partial
import inspect
import itertools
import operator
from typing import Any, Callable, TypeVar

import jax
import weakref
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,
                           tree_map, tree_flatten_with_path, keystr)
from jax._src.api_util import shaped_abstractify
from jax._src.tree_util import equality_errors
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src import state
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
                           unzip2, weakref_lru_cache, merge_lists)
import numpy as np

from jax._src.lax.control_flow.common import (
    _abstractify, _avals_short, _check_tree_and_avals, _initial_style_jaxpr,
    _make_closed_jaxpr, _prune_zeros, _typecheck_param)

_map = safe_map
zip = safe_zip

T = TypeVar('T')
BooleanNumeric = Any  # A bool, or a Boolean array.

### Helper functions

def _promote_weak_typed_inputs(in_vals, in_avals, out_avals):
  """Promote weakly-typed in_vals to be compatible with out_avals.

    in_vals : flattened list of input values.
    in_avals : corresponding list of avals.
    out_avals : list of target output avals.
    in_vals_new : flattened list of modified in_vals with no weak types.
    changed : bool; true if in_vals required modification.
  if len(in_vals) != len(in_avals) or len(in_avals) != len(out_avals):
    # Calling function is responsible for catching this.
    return in_vals, False
  weak_mismatches = [i for i, (a1, a2) in enumerate(zip(in_avals, out_avals))
                    if getattr(a1, 'weak_type', False) and not core.typematch(a1, a2)]
  if not weak_mismatches:
    return in_vals, False
  for i in weak_mismatches:
    new_dtype = dtypes.result_type(in_vals[i], out_avals[i])
    in_vals[i] = lax.convert_element_type(in_vals[i], new_dtype)
  return in_vals, True

### scan

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')

[docs] @api_boundary def scan(f: Callable[[Carry, X], tuple[Carry, Y]], init: Carry, xs: X, length: int | None = None, reverse: bool = False, unroll: int | bool = 1) -> tuple[Carry, Y]: """Scan a function over leading array axes while carrying along state. The `Haskell-like type signature`_ in brief is .. code-block:: haskell scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b]) where for any array type specifier ``t``, ``[t]`` represents the type with an additional leading axis, and if ``t`` is a pytree (container) type with array leaves then ``[t]`` represents the type with the same pytree structure and corresponding leaves each with an additional leading axis. When the type of ``xs`` (denoted `a` above) is an array type or None, and the type of ``ys`` (denoted `b` above) is an array type, the semantics of :func:`~scan` are given roughly by this Python implementation:: def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys) Unlike that Python version, both ``xs`` and ``ys`` may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays. ``None`` is actually a special case of this, as it represents an empty pytree. Also unlike that Python version, :func:`~scan` is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for JIT-compiled functions, since native Python loop constructs in an :func:`~jax.jit` function are unrolled, leading to large XLA computations. Finally, the loop-carried value ``carry`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``c`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). .. note:: :py:func:`scan` compiles ``f``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by ``f``. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. length: optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in ``xs`` (but can be used to perform scans where no input ``xs`` are needed). reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. unroll: optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. .. _Haskell-like type signature: """ if not callable(f): raise TypeError("lax.scan: f argument should be a callable.") xs_flat, xs_tree = tree_flatten(xs) try: lengths = [x.shape[0] for x in xs_flat] except AttributeError as err: msg = "scan got value with no leading axis to scan over: {}." raise ValueError( msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err if length is not None: length = int(length) if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " "leading axis sizes {}.") raise ValueError(msg.format(length, [x.shape[0] for x in xs_flat])) else: unique_lengths = set(lengths) if len(unique_lengths) > 1: msg = "scan got values with different leading axis sizes: {}." raise ValueError(msg.format(', '.join(str(x.shape[0]) for x in xs_flat))) elif len(unique_lengths) == 0: msg = "scan got no values to scan over and `length` not provided." raise ValueError(msg) else: length, = unique_lengths if config.disable_jit.value: if length == 0: raise ValueError("zero-length scan is not supported in disable_jit() mode because the output type is unknown.") carry = init ys = [] maybe_reversed = reversed if reverse else lambda x: x for i in maybe_reversed(range(length)): xs_slice = [_index_array(i, core.get_aval(x), x) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) stack = lambda *ys: jax.numpy.stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y xs_avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] def _create_jaxpr(init): init_flat, init_tree = tree_flatten(init) in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(_abstractify, init_flat)) jaxpr, consts, out_tree = _initial_style_jaxpr( f, in_tree, (*carry_avals, *x_avals), "scan") out_tree_children = out_tree.children() if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] return init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children # The carry input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) new_init_flat, changed = _promote_weak_typed_inputs(init_flat, carry_avals, carry_avals_out) if changed: init = tree_unflatten(init_tree, new_init_flat) init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init) in_flat, jaxpr, consts, out_tree, out_tree_children = rest _check_scan_carry_type(f, init, out_tree_children[0], carry_avals_out) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(jaxpr.effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') if isinstance(unroll, bool): unroll = length if unroll else 1 out = scan_p.bind(*consts, *in_flat, reverse=reverse, length=length, jaxpr=jaxpr, num_consts=len(consts), num_carry=len(init_flat), linear=(False,) * (len(consts) + len(in_flat)), unroll=unroll) return tree_unflatten(out_tree, out)
def _check_scan_carry_type(body_fun, in_carry, out_carry_tree, out_avals): try: sig = inspect.signature(body_fun) except (ValueError, TypeError): sig = None carry_name = sig and list(sig.parameters)[0] if carry_name: component = lambda p: (f'the input carry component {carry_name}{keystr(p)}' if p else f'the input carry {carry_name}') else: component = lambda p: (f'the input carry at path {keystr(p)}' if p else 'the input carry') leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry) paths, in_carry_flat = unzip2(leaves_and_paths) in_avals = _map(_abstractify, in_carry_flat) if in_carry_tree != out_carry_tree: try: out_carry = tree_unflatten(out_carry_tree, out_avals) except: out_carry = None if out_carry is None: differences = [f'the input tree structure is:\n{in_carry_tree}\n', f'the output tree structure is:\n{out_carry_tree}\n'] else: differences = '\n'.join( f' * {component(path)} is a {thing1} but the corresponding component ' f'of the carry output is a {thing2}, so {explanation}\n' for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)) raise TypeError( "Scanned function carry input and carry output must have the same " "pytree structure, but they differ:\n" f"{differences}\n" "Revise the scanned function so that its output is a pair where the " "first element has the same pytree structure as the first argument." ) if not all(_map(core.typematch, in_avals, out_avals)): differences = '\n'.join( f' * {component(path)} has type {in_aval.str_short()}' ' but the corresponding output carry component has type ' f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}\n' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)) raise TypeError( "Scanned function carry input and carry output must have equal types " "(e.g. shapes and dtypes of arrays), " "but they differ:\n" f"{differences}\n" "Revise the scanned function so that all output types (e.g. shapes " "and dtypes) match the corresponding input types." ) def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str: assert not core.typematch(a1, a2) if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray): dtype_mismatch = a1.dtype != a2.dtype shape_mismatch = a1.shape != a2.shape return (', so ' * (dtype_mismatch or shape_mismatch) + 'the dtypes do not match' * dtype_mismatch + ' and also ' * (dtype_mismatch and shape_mismatch) + 'the shapes do not match' * shape_mismatch) return '' def _scan_impl_unrolled(*args, reverse, length, num_consts, num_carry, linear, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) carry = init ys = [] for i in range(length): i_ = length - i - 1 if reverse else i x = _map(partial(_index_array, i_), x_avals, xs) out = f_impl(*consts, *carry, *x) carry, y = split_list(out, [num_carry]) ys.append(y) ys = list(reversed(ys)) if reverse else ys ys = list(zip(*ys)) ys = _map(_stack, y_avals, ys) return (*carry, *ys) def _scan_impl_loop(*args, reverse, length, num_consts, num_carry, linear, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) def cond_fun(vals): i, *_ = vals return i < length def body_fun(vals): [i], carry, ys = split_list(vals, [1, num_carry]) i_ = length - i - 1 if reverse else i # TODO(jakevdp)[key-reuse]: this key reuse logic is not quite right, # because the scan body may consume any keys within it. # Import here to avoid circular imports from jax.experimental import key_reuse xs_unconsumed = _map(key_reuse.unconsumed_copy, xs) x = _map(partial(_dynamic_index_array, i_), x_avals, xs_unconsumed) out_flat = f_impl(*consts, *carry, *x) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i_), y_avals, ys, y_updates) return [i + 1] + carry_out + ys_out # TODO(jakevdp)[key-reuse]: mark xs consumed here if f_impl consumes them. ys_init = _map(partial(_empty_array, length), y_avals) if length == 0: return init + ys_init else: init_val = [lax._const(length, 0)] + init + ys_init _, *outs = while_loop(cond_fun, body_fun, init_val) return outs def _scan_impl_block_unrolled(*args, reverse, length, num_consts, num_carry, linear, block_length, f_impl, x_avals, y_avals): consts, init, xs = split_list(args, [num_consts, num_carry]) num_blocks, rem = divmod(length, block_length) assert rem == 0 partition = partial(_partition_leading, num_blocks, block_length) xs_block = _map(partition, x_avals, xs) prepend_aval = partial(_prepend_dim_to_aval, block_length) x_block_avals = _map(prepend_aval, x_avals) y_block_avals = _map(prepend_aval, y_avals) f_impl_block = partial( _scan_impl_unrolled, reverse=reverse, length=block_length, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) outs = _scan_impl_loop( *consts, *init, *xs_block, reverse=reverse, length=num_blocks, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl_block, x_avals=x_block_avals, y_avals=y_block_avals) carry, ys_blocks = split_list(outs, [num_carry]) combine = partial(_combine_leading, num_blocks, block_length) ys = _map(combine, y_avals, ys_blocks) return (*carry, *ys) def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) f_impl = core.jaxpr_as_fun(jaxpr) if unroll == 1: return _scan_impl_loop( *args, reverse=reverse, length=length, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) consts, init, xs = split_list(args, [num_consts, num_carry]) num_blocks, rem = divmod(length, unroll) length_div = num_blocks * unroll if rem > 0: if reverse: split = partial(_split_leading_dim, rem) xs_rem, xs = unzip2(_map(split, x_avals, xs)) else: split = partial(_split_leading_dim, length_div) xs, xs_rem = unzip2(_map(split, x_avals, xs)) outs = _scan_impl_block_unrolled( *consts, *init, *xs, reverse=reverse, length=length_div, num_consts=num_consts, num_carry=num_carry, linear=linear, block_length=unroll, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) carry, ys = split_list(outs, [num_carry]) if rem > 0: outs = _scan_impl_unrolled( *consts, *carry, *xs_rem, reverse=reverse, length=rem, num_consts=num_consts, num_carry=num_carry, linear=linear, f_impl=f_impl, x_avals=x_avals, y_avals=y_avals) carry, ys_rem = split_list(outs, [num_carry]) if reverse: ys = _map(_concatenate, y_avals, ys_rem, ys) else: ys = _map(_concatenate, y_avals, ys, ys_rem) return (*carry, *ys) def _stack(aval, vals): vals = [lax.expand_dims(x, (0,)) for x in vals] return lax.concatenate(vals, 0) def _concatenate(aval, x1, x2): return lax.concatenate([x1, x2], 0) def _split_leading_dim(i, aval, x): assert x.ndim >= 1 return (slicing.slice_in_dim(x, 0, i), slicing.slice_in_dim(x, i, x.shape[0])) def _dynamic_index_array(i, aval, x): return slicing.dynamic_index_in_dim(x, i, keepdims=False) def _index_array(i, aval, x): return slicing.index_in_dim(x, i, keepdims=False) def _empty_array(sz, aval): return lax.broadcast(lax.empty(aval.dtype), (sz, *aval.shape)) def _update_array(i, aval, xs, x): return slicing.dynamic_update_index_in_dim(xs, x, i, 0) def _partition_leading(sz0, sz1, aval, x): assert x.ndim >= 1 assert x.shape[0] == sz0 * sz1 return lax.reshape(x, (sz0, sz1, *x.shape[1:])) def _combine_leading(sz0, sz1, aval, x): assert x.ndim >= 2 assert x.shape[0] == sz0 assert x.shape[1] == sz1 return lax.collapse(x, 0, 2) def _prepend_dim_to_aval(sz, aval): return core.unmapped_aval(sz, core.no_axis_name, 0, aval) def _scan_abstract_eval(*args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_prepend_dim_to_aval, length), y_avals) return carry_avals + ys_avals, jaxpr.effects def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry nonzeros = [type(t) is not ad_util.Zero for t in tangents] const_nz, init_nz, xs_nz = split_list(nonzeros, [num_consts, num_carry]) # Fixpoint computation of which carry are not either # non-zero from init, or the carry out is non-zero. Each iteration promotes # at least one carry to non-zero. We need at most len(carry) iterations, # but we need one last iteration to prepare the jaxpr based on the final # carry_nz. carry_nz = init_nz for _ in range(1 + len(carry_nz)): nonzeros = const_nz + carry_nz + xs_nz jaxpr_jvp, nonzeros_out = ad.jvp_jaxpr( jaxpr, nonzeros, instantiate=carry_nz + [False] * num_ys) carry_nz_out, _ = nonzeros_out[:num_carry], nonzeros_out[num_carry:] if carry_nz_out == carry_nz: break else: carry_nz = _map(operator.or_, carry_nz, carry_nz_out) else: assert False, "Fixpoint not reached" tangents = [ad.instantiate_zeros(t) if nz else t for t, nz in zip(tangents, nonzeros)] consts, init, xs = split_list(primals, [num_consts, num_carry]) all_tangents = split_list(tangents, [num_consts, num_carry]) consts_dot, init_dot, xs_dot = _map(_prune_zeros, all_tangents) jaxpr_jvp_rearranged = ad.rearrange_binders( jaxpr_jvp, [num_consts, num_carry, num_xs], [len(consts_dot), len(init_dot), len(xs_dot)], [num_carry, num_ys], [len(init_dot), sum(nonzeros_out) - len(init_dot)]) consts_linear, init_linear, xs_linear = split_list(linear, [num_consts, num_carry]) jaxpr_jvp_linear = tuple(consts_linear + [True] * len(consts_dot) + init_linear + [True] * len(init_dot) + xs_linear + [True] * len(xs_dot)) out_flat = scan_p.bind( *(consts + consts_dot + init + init_dot + xs + xs_dot), reverse=reverse, length=length, jaxpr=jaxpr_jvp_rearranged, num_consts=num_consts + len(consts_dot), num_carry=num_carry + len(init_dot), linear=jaxpr_jvp_linear, unroll=unroll) carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out_iter = iter(carry_dot + ys_dot) tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): num_ys = len(jaxpr.out_avals) - num_carry unknowns = [not t.pval.is_known() for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) # Fixpoint computation of which carry elements are unknown. Each iteration # promotes at least one carry to unknown. We need at most len(carry) # iterations, but we need one last iteration to prepare the jaxpr based on the # final carry_uk. carry_uk = init_uk for _ in range(1 + len(carry_uk)): unknowns = const_uk + carry_uk + xs_uk jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = split_list(out_uk, [num_carry]) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" num_res = len(res_avals) del res_avals, carry_uk_out # Instantiate those inputs which must be treated as unknown from the fixpoint. tracers = [trace.instantiate_const(t) if uk else t for t, uk in zip(tracers, unknowns)] # The residual inputs and outputs of the jaxprs produced haven't yet been # adapted to the scan calling convention; in particular, jaxpr_known has its # residual outputs all at the end, meaning they're extensive outputs (which is # fully general but may be wasteful for residuals which are loop-invariant) # while jaxpr_unknown has its corresponding residual inputs at the front (just # as a convention with partial_eval_jaxpr_nounits), making them constant # inputs. To make them consistent, we move the residual inputs on # jaxpr_unknown to the end, even though we may move some back in the sequel. jaxpr_unknown = pe.move_binders_to_back( jaxpr_unknown, [True] * num_res + [False] * sum(unknowns)) # At this point, all residuals are treated as extensive outputs of jaxpr_known # (and extensive inputs to jaxpr_unknown). But residuals that are loop- # invariant can be hoisted out of the scan, rather than letting them get # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't # want to broadcast the matrix!). So, outside the loop we perform a partial # evaluation with known 'const' inputs (but all other inputs unknown). const_pvals = [pe.PartialVal.known(t.pval.get_known()) for t in tracers[:num_consts] if t.pval.is_known()] other_pvals = [pe.PartialVal.unknown(aval) for aval in jaxpr_known.in_avals[len(const_pvals):]] with source_info_util.reset_name_stack(): jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits( lu.wrap_init(core.jaxpr_as_fun(jaxpr_known)), const_pvals + other_pvals, instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res) jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ()) # The above trace_to_jaxpr_nounits call computed loop-invariant residuals # (known values in invar_pvals_out) and also computed loop-invariant values # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the # previous consts). We need to collect the computed inteisive residuals, and # move corresponding intensive residual binders in jaxpr_unknown to the front. res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:] intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()] jaxpr_unknown = pe.move_binders_to_front( jaxpr_unknown, [False] * sum(unknowns) + [pval.is_known() for pval in res_pvals]) del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown. # As another optimization, for any extensive inputs that are just forwarded to # extensive outputs, to avoid a copy (which would be looping over # dynamic-update-slice) we'd rather forward the input tracer/value. That means # pruning some outputs from jaxpr_known here, and updating `out_flat` below. fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr) # Prune fwds_known to include only extensive input to extensive output. fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and in_idx is not None and in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk) else None for out_idx, in_idx in enumerate(fwds_known)] # Drop any extensive output we can instead get by forwarding an input. # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts jaxpr_known_ = jaxpr_known_.replace( outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None]) jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ()) del jaxpr_known_ # We use `fwds_known` below when forming the output of scanning jaxpr_known. # Run the known part of the scan (if it has any outputs or effects). known_inputs = (list(jaxpr_known_consts) + [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()]) if not jaxpr_known.out_avals and not jaxpr_known.effects: out_known = [] else: linear_known = [False] * len(known_inputs) # conservative! out_known = scan_p.bind( *known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known, num_consts=len(jaxpr_known_consts), num_carry=num_carry - sum(carry_uk), linear=tuple(linear_known), unroll=unroll) del linear_known # Complete the known output by filling in forwarded values using fwds_known. out_known_iter = iter(out_known) out_known = [next(out_known_iter) if f is None else _maybe_put(known_inputs[f]) for f in fwds_known] assert next(out_known_iter, None) is None del known_inputs, out_known_iter # Split known outputs from residuals. out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)]) assert len(intensive_res) + len(extensive_res) == num_res # Create input tracers for jaxpr_unknown bind. unknown_inputs = [t for t in tracers if not t.pval.is_known()] intensive_res = _map(trace.new_instantiated_const, intensive_res) extensive_res = _map(trace.new_instantiated_const, extensive_res) # Create output tracers for jaxpr_unknown bind, adapting extensive shapes. carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)]) ys_avals = [core.unmapped_aval(length, core.no_axis_name, 0, y_aval) for y_aval in y_avals] out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in itertools.chain(carry_avals, ys_avals)] del carry_avals, y_avals # Create equation. linear_unknown = tuple([False] * len(intensive_res) + [l for l, uk in zip(linear, unknowns) if uk] + [False] * len(extensive_res)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) assert len(out_tracers) == len(jaxpr_unknown.out_avals) eqn = pe.new_eqn_recipe([*intensive_res, *unknown_inputs, *extensive_res], out_tracers, scan_p, dict(reverse=reverse, length=length, unroll=unroll, jaxpr=jaxpr_unknown, linear=linear_unknown, num_consts=len(intensive_res) + sum(const_uk), num_carry=sum(carry_uk)), jaxpr_unknown.effects, source) for t in out_tracers: t.recipe = eqn # Merge known and unknown outputs into final result. return util.merge_lists(out_uk, out_known, out_tracers) def _maybe_put(x): if isinstance(x, np.ndarray): return dispatch._put_x( x, jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]), shaped_abstractify(x), False, ) else: return x def _scan_transpose(reduce_axes, cts, *args, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): # we've only implemented transposing scans with specific lin/nonlin patterns consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_ires = len(consts_lin) - sum(consts_lin) num_eres = len(xs_lin) - sum(xs_lin) if consts_lin != [False] * num_ires + [True] * (len(consts_lin) - num_ires): raise NotImplementedError if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: raise NotImplementedError if not all(init_lin): pass # TODO(mattjj): error check consts, _, xs = split_list(args, [num_consts, num_carry]) ires, _ = split_list(consts, [num_ires]) _, eres = split_list(xs, [sum(xs_lin)]) assert not any(ad.is_undefined_primal(r) for r in ires) assert not any(ad.is_undefined_primal(r) for r in eres) carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ct_carry, ct_ys = split_list(cts, [num_carry]) ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) ct_ys_is_zeros = [type(ct_y) is ad.Zero for ct_y in ct_ys] ct_ys = [x for x in ct_ys if type(x) is not ad.Zero] ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[num_ires:num_consts]) # jaxpr :: [ires, T d] -> [T c] -> [T a, eres] -> ([T c], [T b]) # jaxpr_trans :: [ires] -> [CT d, CT c] -> [CT b, eres] -> ([CT d, CT c], [CT a]) jaxpr_trans = _transpose_scan_jaxpr( num_ires, num_consts - num_ires, num_eres, jaxpr, reduce_axes, ct_ys_is_zeros) linear_trans = ([False] * num_ires + [True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * num_eres) outs = scan_p.bind( *(ires + ct_consts + ct_carry + ct_ys + eres), reverse=not reverse, length=length, jaxpr=jaxpr_trans, num_consts=num_ires, num_carry=num_consts-num_ires+num_carry, linear=tuple(linear_trans), unroll=unroll) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts - num_ires, num_carry]) return [None] * num_ires + ct_consts + ct_init + ct_xs + [None] * num_eres # transpose_scan_jaxpr :: ([res1, c, a, res2] -> b) # -> ([res1, CT c, CT b, res2] -> [CT c, CT a]) def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr, reduce_axes, ct_ys_is_zeros): num_a = len(jaxpr.in_avals) - num_res1 - num_c - num_res2 # TODO: allow input cotangent avals to be batched relative to jaxpr.in_avals # if an axis isn't reduced res1_avals, c_avals, a_avals, res2_avals = split_list( jaxpr.in_avals, [num_res1, num_c, num_a]) num_ys = len(ct_ys_is_zeros) num_b = len(jaxpr.out_avals) - num_ys # TODO: Also propagate ad.Zero through b_carry_avals until fixed point. b_carry_avals, b_ys_avals = split_list(list(jaxpr.out_avals), [num_b]) b_ys_avals_stripped = [ aval for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) if not is_zero ] @lu.wrap_init def transposed(*res1_cbar_bbar_res2): res1, c_bar, b_bar, ys_bar_stripped, res2 = split_list( res1_cbar_bbar_res2, [num_res1, num_c, num_b, len(b_ys_avals_stripped)]) ys_bar_stripped_iter = iter(ys_bar_stripped) ys_bar = [ ad.Zero(aval) if is_zero else next(ys_bar_stripped_iter) for aval, is_zero in zip(b_ys_avals, ct_ys_is_zeros) ] primals = (res1 + [ad.UndefinedPrimal(aval) for aval in c_avals] + [ad.UndefinedPrimal(aval) for aval in a_avals] + res2) cbar_abar = ad.backward_pass( jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, b_bar + ys_bar) _, new_c_bar, a_bar, _ = split_list(cbar_abar, [num_res1, num_c, num_a]) a_bar = _map(ad.instantiate_zeros_aval, a_avals, a_bar) c_bar = _map(ad.instantiate_zeros_aval, c_avals, _map(ad.add_tangents, c_bar, new_c_bar)) return c_bar + a_bar return _make_closed_jaxpr(transposed, res1_avals + c_avals + b_carry_avals + b_ys_avals_stripped + res2_avals) def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, dims, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): num_ys = len(jaxpr.out_avals) - num_carry orig_batched = [d is not batching.not_mapped for d in dims] const_batched, init_batched, xs_batched = split_list(orig_batched, [num_consts, num_carry]) # Fixpoint computation of which carry are batched: either # batched from init, or the carry out is batched. Each iteration promotes # at least one carry to batched. We need at most len(carry) iterations, # but we need one last iteration to prepare the jaxpr based on the final # carry_batched. carry_batched = init_batched for _ in range(1 + len(carry_batched)): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( jaxpr, axis_size, batched, instantiate=carry_batched + [False] * num_ys, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: break else: carry_batched = _map(operator.or_, carry_batched, carry_batched_out) else: assert False, "Fixpoint not reached" consts, init, xs = split_list(args, [num_consts, num_carry]) consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims)] new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched)] new_xs = [batching.moveaxis(x, d, 1) if d is not batching.not_mapped and d != 1 else x for x, d in zip(xs, xs_bdims)] new_args = new_consts + new_init + new_xs outs = scan_p.bind( *new_args, reverse=reverse, length=length, jaxpr=jaxpr_batched, num_consts=num_consts, num_carry=num_carry, linear=linear, unroll=unroll) carry_bdims = [0 if b else batching.not_mapped for b in carry_batched] ys_bdims = [1 if b else batching.not_mapped for b in ys_batched] return outs, carry_bdims + ys_bdims def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): padded_jaxpr = core.ClosedJaxpr(*pe.pad_jaxpr(jaxpr.jaxpr, jaxpr.consts)) return scan_p.bind(*args, jaxpr=padded_jaxpr, **params) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn]: jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry used_carry_out, used_extensive_out = split_list(used_outputs, [num_carry]) for i in range(1 + num_carry): used_outputs = used_carry_out + used_extensive_out jaxpr_dce, used_inputs = pe.dce_jaxpr( jaxpr.jaxpr, used_outputs, instantiate=[False] * num_consts + used_carry_out + [False] * num_xs) used_consts, used_carry_in, used_extensive_in = \ split_list(used_inputs, [num_consts, num_carry]) if list(used_carry_in) == list(used_carry_out): break else: used_carry_out = _map(operator.or_, used_carry_out, used_carry_in) else: assert False, "Fixpoint not reached" if config.enable_checks.value: core.check_jaxpr(jaxpr.jaxpr) new_linear = [l for l, u in zip(eqn.params['linear'], used_inputs) if u] new_params = dict(eqn.params, num_consts=sum(used_consts), num_carry=sum(used_carry_in), linear=tuple(new_linear), jaxpr=core.ClosedJaxpr(jaxpr_dce, jaxpr.consts)) # TODO(mattjj,sharadmv): don't assume effects are never DCE'd? new_invars = [v for v, used in zip(eqn.invars, used_inputs) if used] new_outvars = [v for v, used in zip(eqn.outvars, used_outputs) if used] _, new_effects = eqn.primitive.abstract_eval(*[v.aval for v in new_invars], **new_params) new_eqn = pe.new_jaxpr_eqn( new_invars, new_outvars, eqn.primitive, new_params, new_effects, eqn.source_info) assert len(new_eqn.invars ) == len(new_params['jaxpr'].in_avals ) assert len(new_eqn.outvars) == len(new_params['jaxpr'].out_avals) return used_inputs, new_eqn # TODO(mattjj): de-duplicate code with _scan_partial_eval def _scan_partial_eval_custom(saveable, unks_in, inst_in, eqn): jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_ys = len(jaxpr.out_avals) - num_carry # Fixpoint (trivial on 'inst_in', since we might as well make all inputs # available as DCE can subsequently prune any unused ones) const_uk, carry_uk, xs_uk = split_list(unks_in, [num_consts, num_carry]) for _ in range(1 + len(carry_uk)): unks_in = const_uk + carry_uk + xs_uk jaxpr_known_, jaxpr_staged_, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr.jaxpr, in_unknowns=unks_in, in_inst=True, ensure_out_unknowns=carry_uk + [False] * num_ys, ensure_out_inst=True, saveable=saveable) carry_uk_out, ys_uk = split_list(unks_out, [num_carry]) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" jaxpr_known = core.ClosedJaxpr(jaxpr_known_ , jaxpr.consts) jaxpr_staged = core.ClosedJaxpr(jaxpr_staged_, jaxpr.consts) # Move all residual binders to the back of jaxpr_staged so they're extensive. # TODO(mattjj): make jaxpr_staged only take instantiated inputs res_avals = jaxpr_staged.in_avals[:num_res] jaxpr_staged = pe.move_binders_to_back( jaxpr_staged, [True] * num_res + [False] * len(jaxpr.in_avals)) # Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to # passing in_inst argument to partial_eval_jaxpr_custom above). new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] inst_in = [True] * len(inst_in) # As an optimization, hoist loop-invariant residuals out of the loop rather # than using extensive outputs for them. See _scan_partial_eval for comments. num_const_known = len(const_uk) - sum(const_uk) num_carry_known = len(carry_uk) - sum(carry_uk) num_xs_known = len( xs_uk) - sum( xs_uk) jaxpr_known_hoist, jaxpr_known_loop, loop_dep, consts_known_lp_avals = \ pe.partial_eval_jaxpr_nounits( jaxpr_known, [False] * num_const_known + [True] * (num_carry_known + num_xs_known), [True] * (len(unks_out) - sum(unks_out)) + [False] * num_res) # jaxpr_known_hoist produces intensive residuals followed by the constants for # jaxpr_known_loop. We adjust jaxpr_staged to accept intensive res as consts. _, loop_dep_res = split_list(loop_dep, [len(loop_dep) - num_res]) jaxpr_staged = pe.move_binders_to_front( jaxpr_staged, [False] * sum(inst_in) + _map(operator.not_, loop_dep_res)) num_intensive_res = len(loop_dep_res) - sum(loop_dep_res) del loop_dep, num_carry_known, num_xs_known, const_uk # Create residual variables. intensive_avals, ext_avals_mapped = partition_list(loop_dep_res, res_avals) ext_avals = [core.unmapped_aval(eqn.params['length'], core.no_axis_name, 0, a) for a in ext_avals_mapped] newvar = core.gensym() intensive_res = _map(newvar, intensive_avals) extensive_res = _map(newvar, ext_avals) # Create known eqn, which is a call_p combining evaluation of # jaxpr_known_hoist and a scan of jaxpr_known_loop. ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(unks_out, eqn.outvars) # jaxpr_known_loop takes as input constants output as res by jaxpr_known_hoist # (corresponding to consts_known_lp_avals) followed by known carry and xs. linear_known_ = [l for l, uk in zip(eqn.params['linear'], unks_in) if not uk] _, linear_known_ = split_list(linear_known_, [num_const_known]) linear_known = [False] * len(consts_known_lp_avals) + linear_known_ params_known = dict(eqn.params, jaxpr=jaxpr_known_loop, num_consts=len(consts_known_lp_avals), num_carry=len(carry_uk)-sum(carry_uk), linear=tuple(linear_known)) @lu.wrap_init def known(*ins_known): consts_known_hoist, ins_known_lp = split_list(ins_known, [num_const_known]) out_hoist = core.jaxpr_as_fun(jaxpr_known_hoist)(*consts_known_hoist) intensive_res, consts_known_lp = split_list(out_hoist, [num_intensive_res]) out_loop = scan_p.bind(*consts_known_lp, *ins_known_lp, **params_known) return [*intensive_res, *out_loop] call_jaxpr_, _, call_jaxpr_consts = pe.trace_to_jaxpr_dynamic( known, [v.aval for v in ins_known]) call_jaxpr = core.ClosedJaxpr(call_jaxpr_, call_jaxpr_consts) eqn_known = pe.new_jaxpr_eqn( ins_known, [*intensive_res, *out_binders_known, *extensive_res], core.closed_call_p, dict(call_jaxpr=call_jaxpr), call_jaxpr.effects, eqn.source_info) # Create the staged eqn. _, out_binders_staged = partition_list(inst_out, eqn.outvars) linear_staged = ([False] * len(intensive_res) + list(eqn.params['linear']) + [False] * len(extensive_res)) params_staged = dict(eqn.params, jaxpr=jaxpr_staged, num_consts=len(intensive_res) + eqn.params['num_consts'], linear=tuple(linear_staged)) eqn_staged = pe.new_jaxpr_eqn([*intensive_res, *eqn.invars, *extensive_res], out_binders_staged, eqn.primitive, params_staged, jaxpr_staged.effects, eqn.source_info) new_vars = [*new_inst, *intensive_res, *extensive_res] return eqn_known, eqn_staged, unks_out, inst_out, new_vars def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts, num_carry, jaxpr, linear, unroll): if not bind_time: _, *in_atoms = in_atoms avals = [x.aval for x in in_atoms] tc = partial(_typecheck_param, 'scan') tc(reverse, 'reverse', 'bool', type(reverse) is bool) tc(num_consts, 'num_consts', 'non-negative int', type(num_consts) is int and num_consts >= 0) tc(num_carry, 'num_carry', 'non-negative int', type(num_carry) is int and num_carry >= 0) tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) tc(linear, 'linear', 'tuple of bool', type(linear) is tuple and all(type(x) is bool for x in linear)) tc(unroll, 'unroll', 'positive int', type(unroll) is int and unroll > 0) tc(length, 'length', 'non-negative int', length >= 0) if len(linear) != len(avals): raise core.JaxprTypeError( f'scan param linear has length {len(linear)} for {len(avals)} operands') const_avals, init_avals, x_avals = split_list(avals, [num_consts, num_carry]) const_avals_jaxpr, init_avals_jaxpr, x_avals_jaxpr = split_list( jaxpr.in_avals, [num_consts, num_carry]) carry_avals_jaxpr, y_avals_mapped = split_list(jaxpr.out_avals, [num_carry]) x_avals_mapped = _map(partial(core.mapped_aval, length, 0), x_avals) y_avals = [core.unmapped_aval(length, core.no_axis_name, 0, a) for a in y_avals_mapped] if not all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)): raise core.JaxprTypeError( f'scan input carry input and output types mismatch: ' f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}') if not all(_map(core.typecompat, const_avals_jaxpr, const_avals)): raise core.JaxprTypeError( f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n' f'called with consts of type\n{_avals_short(const_avals)}') if not all(_map(core.typecompat, init_avals_jaxpr, init_avals)): raise core.JaxprTypeError( f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n' f'called with initial carry of type\n{_avals_short(init_avals)}') if not all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)): raise core.JaxprTypeError( f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n' f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}') return [*init_avals, *y_avals], jaxpr.effects def _scan_pp_rule(eqn, context, settings): printed_params = dict(eqn.params) del printed_params['linear'] if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars): del printed_params['length'] if printed_params['unroll'] == 1: del printed_params['unroll'] if printed_params['num_carry'] == 0: del printed_params['num_carry'] if printed_params['num_consts'] == 0: del printed_params['num_consts'] if not printed_params['reverse']: del printed_params['reverse'] return core._pp_eqn(eqn.replace(params=printed_params), context, settings) def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, num_carry, linear, unroll, reverse, length): jaxpr, consts = jaxpr.jaxpr, jaxpr.consts if consts: raise NotImplementedError consts, carry, xs = split_list(args, [num_consts, num_carry]) consts_linear, carry_linear, xs_linear = split_list( linear, [num_consts, num_carry]) consts_avals, carry_avals, xs_avals = split_list(in_avals, [num_consts, num_carry]) is_ref = [isinstance(a, state.AbstractRef) for a in consts_avals] remaining_const_avals, in_ref_avals = partition_list(is_ref, consts_avals) remaining_consts, in_refs = partition_list(is_ref, consts) remaining_consts_linear, in_refs_linear = partition_list(is_ref, consts_linear) num_refs = sum(is_ref) num_extensive_in = len(in_avals) - num_carry - num_consts num_extensive_out = len(out_avals) - num_carry num_remaining_consts = num_consts - num_refs discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if discharged_consts: raise NotImplementedError("Discharged jaxpr has consts. If you see this, " "please open an issue at " "") # The discharged jaxpr will have output refs stashed at the end def wrapped(*refs_and_args): consts, refs, carry, xs = split_list(refs_and_args, [num_remaining_consts, num_refs, num_carry]) consts_with_refs = merge_lists(is_ref, consts, refs) outs_and_refs = core.eval_jaxpr(discharged_jaxpr, (), *consts_with_refs, *carry, *xs) carry, ys, out_refs = split_list(outs_and_refs, [num_carry, num_extensive_out]) assert len(out_refs) == num_refs return [*out_refs, *carry, *ys] new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals], *carry_avals, *[core.mapped_aval(length, 0, a) for a in xs_avals]] new_jaxpr, _, () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals) new_linear = (*remaining_consts_linear, *in_refs_linear, *carry_linear, *xs_linear) all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs, jaxpr=core.ClosedJaxpr(new_jaxpr, ()), length=length, num_consts=num_remaining_consts, num_carry=num_refs + num_carry, unroll=unroll, reverse=reverse, linear=new_linear) refs_out, carry_out, ys_out = split_list(all_out, [num_refs, num_carry]) new_invals = [*merge_lists(is_ref, [None] * num_remaining_consts, refs_out), *[None] * num_carry, *[None] * num_extensive_in] assert len(new_invals) == len(in_avals) return new_invals, [*carry_out, *ys_out] def scan_bind(*args, **params): if config.enable_checks.value: avals = _map(core.get_aval, args) in_atoms = [core.Var(0, '', a) for a in avals] # dummies _scan_typecheck(True, *in_atoms, **params) core.check_jaxpr(params['jaxpr'].jaxpr) return core.AxisPrimitive.bind(scan_p, *args, **params) scan_p = core.AxisPrimitive("scan") scan_p.multiple_results = True scan_p.def_custom_bind(scan_bind) scan_p.def_impl(partial(dispatch.apply_primitive, scan_p)) scan_p.def_effectful_abstract_eval(_scan_abstract_eval) ad.primitive_jvps[scan_p] = _scan_jvp ad.reducing_transposes[scan_p] = _scan_transpose pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None) batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom pe.padding_rules[scan_p] = _scan_padding_rule pe.dce_rules[scan_p] = _scan_dce_rule state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule) # TODO(mattjj,frostig): un-comment this pp rule # core.pp_eqn_rules[scan_p] = _scan_pp_rule ### while_loop @api_boundary def while_loop(cond_fun: Callable[[T], BooleanNumeric], body_fun: Callable[[T], T], init_val: T) -> T: """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. The `Haskell-like type signature`_ in brief is .. code-block:: haskell while_loop :: (a -> Bool) -> (a -> a) -> a -> a The semantics of ``while_loop`` are given by this Python implementation:: def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. Also unlike the Python analogue, the loop-carried value ``val`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``a`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). Another difference from using Python-native loop constructs is that ``while_loop`` is not reverse-mode differentiable because XLA computations require static bounds on memory requirements. .. note:: :py:func:`while_loop` compiles ``cond_fun`` and ``body_fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: cond_fun: function of type ``a -> Bool``. body_fun: function of type ``a -> a``. init_val: value of type ``a``, a type that can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. Returns: The output from the final iteration of body_fun, of type ``a``. .. _Haskell-like type signature: """ if not (callable(body_fun) and callable(cond_fun)): raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") if config.disable_jit.value: try: val = init_val while cond_fun(val): val = body_fun(val) return val except core.ConcretizationTypeError: # Can't run this while_loop in Python (e.g. because there's a vmap # transformation on it), so we fall back to the primitive version. pass def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(_abstractify, init_vals)) cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr( cond_fun, in_tree, init_avals, "while_cond") body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( body_fun, in_tree, init_avals, "while_loop") if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) pred_aval = cond_jaxpr.out_avals[0] if (not isinstance(pred_aval, ShapedArray) or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)): msg = "cond_fun must return a boolean scalar, but got output type(s) {}." raise TypeError(msg.format(cond_jaxpr.out_avals)) return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree # The body input and output avals must match exactly. However, we want to account for # the case when init contains weakly-typed values (e.g. Python scalars), with avals that # may not match the output despite being compatible by virtue of their weak type. # To do this, we compute the jaxpr in two passes: first with the raw inputs, and if # necessary, a second time with modified init values. init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val) new_init_vals, changed = _promote_weak_typed_inputs(init_vals, init_avals, body_jaxpr.out_avals) if changed: new_init_val, = tree_unflatten(in_tree, new_init_vals) init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(new_init_val) cond_jaxpr, cond_consts, body_consts, body_tree = rest in_tree_children = in_tree.children() assert len(in_tree_children) == 1 _check_tree_and_avals("body_fun output and input", body_tree, body_jaxpr.out_avals, in_tree_children[0], init_avals) joined_effects = core.join_effects(cond_jaxpr.effects, body_jaxpr.effects) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') outs = while_p.bind(*cond_consts, *body_consts, *init_vals, cond_nconsts=len(cond_consts), cond_jaxpr=cond_jaxpr, body_nconsts=len(body_consts), body_jaxpr=body_jaxpr) return tree_unflatten(body_tree, outs) def _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts ) -> effects.Effects: joined_effects = set() for eff in cond_jaxpr.effects: if isinstance(eff, effects.JaxprInputEffect): index = eff.input_index if index >= cond_nconsts: index += body_nconsts eff = eff.replace(input_index=index) joined_effects.add(eff) for eff in body_jaxpr.effects: if isinstance(eff, effects.JaxprInputEffect): index = eff.input_index + cond_nconsts eff = eff.replace(input_index=index) joined_effects.add(eff) return joined_effects def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, cond_nconsts): del avals joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): from jax._src.callback import _IOEffect, _OrderedIOEffect if any(eff in branch.effects for eff in [_IOEffect, _OrderedIOEffect] for branch in [body_jaxpr, cond_jaxpr]): raise NotImplementedError( "IO effect not supported in vmap-of-while.") orig_batched = [d is not batching.not_mapped for d in dims] cconst_bat, bconst_bat, init_bat = split_list(orig_batched, [cond_nconsts, body_nconsts]) cconsts, bconsts, init = split_list(args, [cond_nconsts, body_nconsts]) cconst_dims, bconst_dims, init_dims = split_list(dims, [cond_nconsts, body_nconsts]) carry_bat = init_bat # Fixpoint computation of which carry are batched: either # batched from init, or the carry out is batched. Each iteration promotes # at least one carry to batched. We need at most len(carry) iterations to # reach a fixpoint. for _ in range(1 + len(carry_bat)): _, carry_bat_out = batching.batch_jaxpr( body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) if carry_bat == carry_bat_out: break carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) else: assert False, "Fixpoint not reached" # Knowing how the carry is batched now, we can determine if the predicate is # batched. _, (pred_bat,) = batching.batch_jaxpr( cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) if pred_bat: # If the predicate is batched, we have to batch *all* of the carry # regardless of if the body needs it. carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) else: # If the predicate is not batched, we can look at the `cond_jaxpr`'s out # shape to determine the rank of the predicate. From this rank we pick the # dims of the carry to be batched to ensure that the predicate shape is a # prefix of the carry in and out shapes. We can then batch the `body_jaxpr` # according to these new batch dims. cond_rank = len(cond_jaxpr.out_avals[0].shape) carry_dims = [cond_rank if b else None for b in carry_bat] body_jaxpr_batched, _ = batching.batch_jaxpr_axes( body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) # Now we need to rebatch the `cond_jaxpr` according to the new dims of the # carry. cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) # To prepare the `init` to the `while_p`, we broadcast values if they are # unbatched and need to have an out axis. If their current batch axis does not # match the one it needs to be for the translation rule to work, we move it # into place. new_init = [] for x, old_axis, new_axis in zip(init, init_dims, carry_dims): if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: new_init.append(batching.broadcast(x, axis_size, new_axis)) elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: new_init.append(x) else: assert new_axis is not batching.not_mapped new_init.append(batching.moveaxis(x, old_axis, new_axis)) outs = while_p.bind(*(cconsts + bconsts + new_init), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_batched, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched) return outs, carry_dims def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): nonzeros = [type(t) is not ad_util.Zero for t in tangents] cconst_nz, bconst_nz, init_nz = split_list(nonzeros, [cond_nconsts, body_nconsts]) carry_nz = init_nz for _ in range(1 + len(carry_nz)): body_nonzeros = bconst_nz + carry_nz body_jvp, nonzeros_out = ad.jvp_jaxpr( body_jaxpr, body_nonzeros, instantiate=carry_nz) if nonzeros_out == carry_nz: break carry_nz = _map(operator.or_, carry_nz, nonzeros_out) else: assert False, "Fixpoint not reached" nonzeros = cconst_nz + body_nonzeros tangents = [ad.instantiate_zeros(t) if nz else t for t, nz in zip(tangents, nonzeros)] cconst, bconst, init = split_list(primals, [cond_nconsts, body_nconsts]) _, bconst_dot, init_dot = split_list(tangents, [cond_nconsts, body_nconsts]) bconst_dot = _prune_zeros(bconst_dot) init_dot = _prune_zeros(init_dot) num_carry = len(primals) - cond_nconsts - body_nconsts body_jvp_rearranged = ad.rearrange_binders( body_jvp, [body_nconsts, num_carry], [len(bconst_dot), len(init_dot)], [num_carry], [len(init_dot)]) newvar = core.gensym([cond_jaxpr.jaxpr]) invars_aug = ( cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot]) cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, invars_aug, cond_jaxpr.jaxpr.outvars, cond_jaxpr.jaxpr.eqns, cond_jaxpr.jaxpr.effects) cond_jaxpr_augmented = core.ClosedJaxpr(cond_jaxpr_augmented, cond_jaxpr.consts) out = while_p.bind( *(cconst + bconst + bconst_dot + init + init_dot), cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr_augmented, body_nconsts=len(bconst) + len(bconst_dot), body_jaxpr=body_jvp_rearranged) out_carry, out_carry_dot = split_list(out, [num_carry]) out_tangents_iter = iter(out_carry_dot) out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) for p, nz in zip(out_carry, nonzeros_out)] return out_carry, out_tangents def _while_partial_eval(trace: pe.JaxprTrace, *tracers: pe.Tracer, cond_nconsts: int, cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int, body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]: # As long as some carry (and hence output) are known and the output of # `cond_jaxpr` is known, we use a portion of the loop body to compute the # known outputs of the `while_loop`. For the unknown outputs we generate a # jaxpr to run the whole while, including recomputing the known parts, # basically like building in checkpointing/rematieralization. This means that # we don't actually save any computation by partial evaluation if there are # unknown outputs. # # What this achieves is twofold: jax.linearize works, and we can give a proper # error for reverse differentiation of `while`. unknowns = [not t.pval.is_known() for t in tracers] params = dict(cond_nconsts=cond_nconsts, cond_jaxpr=cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=body_jaxpr) cond_consts_uk, body_consts_uk, carry_init_uk = \ split_list(unknowns, [cond_nconsts, body_nconsts]) # Fixpoint computation of unknown carry. Each iteration promotes at least one # carry to unknown. We need one last iteration to prepare the jaxpr. carry_uk = carry_init_uk for _ in range(1 + len(carry_uk)): body_jaxpr_known, _, carry_out_uk, body_res_avals = pe.partial_eval_jaxpr_nounits( # type: ignore body_jaxpr, body_consts_uk + carry_uk, instantiate=carry_uk) if carry_out_uk == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_out_uk) else: assert False, "Fixpoint not reached" cond_jaxpr_known, _, cond_uk, _ = pe.partial_eval_jaxpr_nounits( # type: ignore cond_jaxpr, cond_consts_uk + carry_uk, instantiate=False) if cond_uk[0] or all(not uk for uk in unknowns) or all(unknowns): # If conditional is unknown, or all inputs are known, or all are unknown, # just do the default processing. return trace.default_process_primitive(while_p, tracers, params) # Run the known part of the while. in_consts = [t.pval.get_known() for uk, t in zip(cond_consts_uk + body_consts_uk + carry_uk, tracers) if not uk] cond_nconsts_known = len(cond_consts_uk) - sum(cond_consts_uk) body_nconsts_known = len(body_consts_uk) - sum(body_consts_uk) num_known_outs = len(carry_uk) - sum(carry_uk) # TODO(mattjj): use pe.dce_jaxpr to drop res computations and not just outputs body_jaxpr_known = body_jaxpr_known.replace( jaxpr=body_jaxpr_known.jaxpr.replace( outvars=body_jaxpr_known.jaxpr.outvars[:num_known_outs])) out_known = while_p.bind( *in_consts, cond_nconsts=cond_nconsts_known, cond_jaxpr=cond_jaxpr_known, body_nconsts=body_nconsts_known, body_jaxpr=body_jaxpr_known) del body_jaxpr_known # Run the whole while_loop to get all the outputs, then merge with known ones out_tracers_ = trace.default_process_primitive(while_p, tracers, params) out_tracers = [t for t, uk in zip(out_tracers_, carry_uk) if uk] return util.merge_lists(carry_uk, out_known, out_tracers) # TODO(mattjj): de-duplicate code with _while_partial_eval def _while_partial_eval_custom(saveable, unks_in, inst_in, eqn): del saveable # We can't save any residuals anyway (w/o dynamic shapes)! cond_jaxpr = eqn.params['cond_jaxpr'] cond_nconsts = eqn.params['cond_nconsts'] body_jaxpr = eqn.params['body_jaxpr'] body_nconsts = eqn.params['body_nconsts'] cond_consts_uk, body_consts_uk, carry_init_uk = \ split_list(unks_in, [cond_nconsts, body_nconsts]) # Fixpoint to compute known part of the body (trivial on 'inst_in', since we # make all inputs available as DCE can subsequently prune any unused ones) carry_uk = carry_init_uk for _ in range(1 + len(carry_uk)): body_unks_in = body_consts_uk + carry_uk jaxpr_known_, _, carry_uk_out, _, num_res = \ pe.partial_eval_jaxpr_custom( body_jaxpr.jaxpr, in_unknowns=body_unks_in, in_inst=True, ensure_out_unknowns=carry_uk, ensure_out_inst=True, saveable=ad_checkpoint.nothing_saveable) if carry_uk_out == carry_uk: break else: carry_uk = _map(operator.or_, carry_uk, carry_uk_out) else: assert False, "Fixpoint not reached" assert not num_res body_jaxpr_known = core.ClosedJaxpr(jaxpr_known_, body_jaxpr.consts) del jaxpr_known_, carry_uk_out, num_res # Instantiate all inputs (b/c jaxpr_staged will take all inputs). new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] # Compute the known part of cond_fun (basically pruning inputs on known side). cond_unks_in = cond_consts_uk + carry_uk cond_jaxpr_known_, _, [cond_uk], _, _ = \ pe.partial_eval_jaxpr_custom( cond_jaxpr.jaxpr, cond_unks_in, in_inst=True, ensure_out_unknowns=False, ensure_out_inst=True, saveable=ad_checkpoint.nothing_saveable) # NOTE(mattjj): I think it should be impossible for the condition to be # unknown, but asserting that caused a test failure in diffrax. So # we handle it: if it is unknown, stage out the whole cond function. if cond_uk: return None, eqn, [True] * len(carry_uk), [True] * len(carry_uk), new_inst cond_jaxpr_known = core.ClosedJaxpr(cond_jaxpr_known_, cond_jaxpr.consts) del cond_uk # Build the known eqn. ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(carry_uk, eqn.outvars) params_known = dict(cond_jaxpr=cond_jaxpr_known, body_jaxpr=body_jaxpr_known, cond_nconsts=len(cond_consts_uk) - sum(cond_consts_uk), body_nconsts=len(body_consts_uk) - sum(body_consts_uk)) effects_known = core.join_effects(cond_jaxpr_known.effects, body_jaxpr_known.effects) eqn_known = pe.new_jaxpr_eqn(ins_known, out_binders_known, while_p, params_known, effects_known, eqn.source_info) # Staged eqn is same as input eqn. eqn_staged = eqn unks_out = carry_uk inst_out = [True] * len(unks_out) return eqn_known, eqn_staged, unks_out, inst_out, new_inst def _while_transpose_error(*_, **kwargs): raise ValueError("Reverse-mode differentiation does not work for " "lax.while_loop or lax.fori_loop with dynamic start/stop values. " "Try using lax.scan, or using fori_loop with static start/stop.") # For a while loop with ordered effects in the cond, we need a special # lowering. Fundamentally, we'd like to rewrite a while loop that looks like # this: # ``` # while cond(x): # x = body(x) # ``` # into something that looks like this: # ``` # while True: # token, pred = cond(token, x) # if not pred: # break # token, x = body(token, x) # ``` # Unfortunately, with a WhileOp we can't (1) return multiple values # from a `cond` and (2) can't break a while loop. We thus adopt the # following rewrite strategy: # ``` # def new_cond(pred, token, x): # return pred # token, pred = cond(token, x) # while new_cond(pred, token, x): # token, x = body(token, x) # token, pred = cond(token, x) # ``` def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): pred_aval = cond_jaxpr.out_avals[0] batched = bool(pred_aval.shape) cond_ordered_effects = effects.ordered_effects.filter_in(cond_jaxpr.effects) if cond_ordered_effects: def cond(args): # Pred can be batched pred = core.eval_jaxpr(cond_jaxpr.jaxpr, cond_jaxpr.consts, *args)[0] if batched: pred = lax._reduce_or(pred, tuple(range(len(pred_aval.shape)))) return pred def body(args): return tuple(core.eval_jaxpr(body_jaxpr.jaxpr, body_jaxpr.consts, *args)) def new_cond(pred_args): pred, _ = pred_args return pred def new_body(pred_args): _, args = pred_args args = body(args) pred = cond(args) return pred, args def fun(*args): pred = cond(args) _, out = while_loop(new_cond, new_body, (pred, args)) return out return mlir.lower_fun(fun)(ctx, *args) loop_carry_types = _map(mlir.aval_to_ir_types, ctx.avals_in) body_effects = effects.ordered_effects.filter_in(body_jaxpr.effects) num_tokens = len(body_effects) tokens = [ctx.tokens_in.get(eff) for eff in body_effects] token_types = [mlir.token_type() for _ in tokens] loop_carry_types = [*token_types, *loop_carry_types] flat_loop_carry_types = util.flatten(loop_carry_types) args = [*tokens, *args] flat_args = mlir.flatten_lowering_ir_args(args) while_op = hlo.WhileOp(flat_loop_carry_types, flat_args) # Loop condition cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types) name_stack = ctx.module_context.name_stack.extend('while') with ir.InsertionPoint(cond_block): flat_cond_args = [ cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types)) # Remove tokens from cond args cond_args = cond_args[num_tokens:] x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond')) cond_consts = [ mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts ] ((pred,),), _ = mlir.jaxpr_subcomp( cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), cond_consts, *(x + z), dim_var_values=ctx.dim_var_values, ) if batched: pred_ctx = mlir.LoweringRuleContext( module_context=ctx.module_context, primitive=None, avals_in=[pred_aval], avals_out=[pred_aval.update(shape=())], tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( hlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred, axes=tuple(range(len(pred_aval.shape)))) hlo.return_([pred]) # Loop body body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types) with ir.InsertionPoint(body_block): flat_body_args = [ body_block.arguments[i] for i in range(len(flat_loop_carry_types)) ] body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types)) # Tokens are at the front of the args list to the while loop token_args, body_args = util.split_list(body_args, [num_tokens]) tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body')) body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in body_jaxpr.consts] new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: body_pred_ctx = ctx.module_context.replace( name_stack=name_stack.extend('body_pred')) cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts] ((body_pred,),), _ = mlir.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), cond_consts, *(x + z), dim_var_values=ctx.dim_var_values) new_z = _map( partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y), *util.flatten(new_z)]) outputs = util.unflatten(while_op.results, _map(len, loop_carry_types)) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) if tokens: ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens))) return z def _while_typecheck(_, *in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(frostig,mattjj): check cond_jaxpr, body_jaxpr types joined_effects = _join_while_effects(body_jaxpr, cond_jaxpr, body_nconsts, cond_nconsts) disallowed_effects = effects.control_flow_allowed_effects.filter_not_in(joined_effects) if disallowed_effects: raise NotImplementedError( f'Effects not supported in `while`: {disallowed_effects}') return body_jaxpr.out_avals, joined_effects def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): # TODO(sharadmv): enable supporting state effects in the cond if any(isinstance(eff, state.RefEffect) for eff in cond_jaxpr.effects): raise NotImplementedError cond_consts, body_consts, carry = split_list(args, [cond_nconsts, body_nconsts]) cond_consts_avals, body_consts_avals, carry_avals = split_list(in_avals, [cond_nconsts, body_nconsts]) # There shouldn't be any `Ref`s in the `cond` (because of our check above). assert not any(isinstance(aval, state.AbstractRef) for aval in cond_consts_avals) is_ref = [isinstance(aval, state.AbstractRef) for aval in body_consts_avals] remaining_body_consts, refs = partition_list(is_ref, body_consts) remaining_body_const_avals, ref_avals = partition_list(is_ref, body_consts_avals) num_refs = sum(is_ref) num_remaining_consts = body_nconsts - num_refs num_carry = len(in_avals) - body_nconsts - cond_nconsts body_jaxpr, body_jaxpr_consts = body_jaxpr.jaxpr, body_jaxpr.consts cond_jaxpr, cond_jaxpr_consts = cond_jaxpr.jaxpr, cond_jaxpr.consts if body_jaxpr_consts: raise NotImplementedError("Body jaxpr has consts. If you see this error, " "please open an issue at " "") # body_jaxpr has the signature (*body_consts, *carry) -> carry. # Some of these body_consts are actually `Ref`s so when we discharge # them, they also turn into outputs, effectively turning those consts into # carries. However this doesn't fit the expected signature for the body_jaxpr. # Therefore we need to rewrite the jaxpr to shuffle around the `Ref`s so that # they are part of the carry. discharged_body_jaxpr, discharged_consts = state_discharge.discharge_state( body_jaxpr, ()) if discharged_consts: raise NotImplementedError def new_body(*consts_refs_carry): consts, refs, carry = split_list( consts_refs_carry, [num_remaining_consts, num_refs]) consts_and_refs = merge_lists(is_ref, consts, refs) carry_refs = core.eval_jaxpr(discharged_body_jaxpr, (), *consts_and_refs, *carry) carry, refs_out = split_list(carry_refs, [num_carry]) return [*refs_out, *carry] new_body_jaxpr, _, new_body_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(new_body), [*remaining_body_const_avals, *[a.inner_aval for a in ref_avals], *carry_avals]) if new_body_consts: raise NotImplementedError # Since some `Ref`s that were previously consts are now carries, we need to # deal with them (i.e. ignore them) in the `cond`, so we need to rewrite the # cond_jaxpr as well. def new_cond(*consts_refs_carry): consts, refs, carry = split_list( consts_refs_carry, [cond_nconsts, num_refs]) del refs # We don't use them here! return core.eval_jaxpr(cond_jaxpr, cond_jaxpr_consts, *consts, *carry) new_cond_jaxpr, _, new_cond_consts = pe.trace_to_jaxpr_dynamic( lu.wrap_init(new_cond), [*cond_consts_avals, *[a.inner_aval for a in ref_avals], *carry_avals]) if new_cond_consts: raise NotImplementedError out = while_p.bind(*cond_consts, *remaining_body_consts, *refs, *carry, body_jaxpr=core.ClosedJaxpr(new_body_jaxpr, ()), cond_jaxpr=core.ClosedJaxpr(new_cond_jaxpr, ()), body_nconsts=num_remaining_consts, cond_nconsts=cond_nconsts) refs_out, carry_out = split_list(out, [num_refs]) updated_body_consts = merge_lists(is_ref, [None] * num_remaining_consts, refs_out) invals_out = [ *[None] * cond_nconsts, *updated_body_consts, *[None] * num_carry] return invals_out, carry_out while_p = core.AxisPrimitive('while') while_p.multiple_results = True while_p.def_impl(partial(dispatch.apply_primitive, while_p)) while_p.def_effectful_abstract_eval(_while_loop_abstract_eval) ad.primitive_jvps[while_p] = _while_loop_jvp pe.custom_partial_eval_rules[while_p] = _while_partial_eval xla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None) batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) core.custom_typechecks[while_p] = _while_typecheck state_discharge.register_discharge_rule(while_p)(_while_discharge_rule) def _pred_bcast_select_hlo(ctx, pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value], ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]: if x_y_aval is core.abstract_token: x, = xs y, = ys return [hlo.AfterAllOp([x, y]).result] else: assert isinstance(x_y_aval, core.ShapedArray), x_y_aval x, = xs y, = ys assert x.type == y.type, (x.type, y.type) assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), ( pred_aval.shape, x_y_aval) x_y_aval = core.physical_aval(x_y_aval) bcast_pred = mlir.broadcast_in_dim( ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)), broadcast_dimensions=list(range(len(pred_aval.shape)))) return hlo.SelectOp(bcast_pred, x, y).results ### fori_loop def _fori_cond_fun(loop_carry): i, upper, _ = loop_carry return, upper) @weakref_lru_cache def _fori_body_fun(body_fun): body_fun = weakref.ref(body_fun) def while_body_fun(loop_carry): i, upper, x = loop_carry return lax.add(i, lax._const(i, 1)), upper, body_fun()(i, x) return while_body_fun @weakref_lru_cache def _fori_scan_body_fun(body_fun): body_fun = weakref.ref(body_fun) def scanned_fun(loop_carry, _): i, x = loop_carry return (i + 1, body_fun()(i, x)), None return scanned_fun @api_boundary def fori_loop(lower, upper, body_fun, init_val, *, unroll: int | bool | None = None): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. The `Haskell-like type signature`_ in brief is .. code-block:: haskell fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a The semantics of ``fori_loop`` are given by this Python implementation:: def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val As the Python version suggests, setting ``upper <= lower`` will produce no iterations. Negative or custom increments are not supported. Unlike that Python version, ``fori_loop`` is implemented in terms of either a call to :func:`jax.lax.while_loop` or a call to :func:`jax.lax.scan`. If the trip count is static (meaning known at tracing time, perhaps because ``lower`` and ``upper`` are Python integer literals) then the ``fori_loop`` is implemented in terms of :func:`~scan` and reverse-mode autodiff is supported; otherwise, a ``while_loop`` is used and reverse-mode autodiff is not supported. See those functions' docstrings for more information. Also unlike the Python analogue, the loop-carried value ``val`` must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type ``a`` in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves). .. note:: :py:func:`fori_loop` compiles ``body_fun``, so while it can be combined with :py:func:`jit`, it's usually unnecessary. Args: lower: an integer representing the loop index lower bound (inclusive) upper: an integer representing the loop index upper bound (exclusive) body_fun: function of type ``(int, a) -> a``. init_val: initial loop carry value of type ``a``. unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). This argument is only applicable if the loop bounds are statically known. Returns: Loop value from the final iteration, of type ``a``. .. _Haskell-like type signature: """ if not callable(body_fun): raise TypeError("lax.fori_loop: body_fun argument should be callable.") # TODO(phawkins): perhaps do more type checking here, better error messages. lower_dtype = dtypes.canonicalize_dtype(lax.dtype(lower)) upper_dtype = dtypes.canonicalize_dtype(lax.dtype(upper)) if lower_dtype == upper_dtype: dtype = lower_dtype else: # As a special case: allow promotion of weak integers (e.g., Python scalars) # This improves the ergonomics if one but not both of the loop bounds is a # scalar. dtype = None if (np.issubdtype(lower_dtype, np.signedinteger) and np.issubdtype(upper_dtype, np.signedinteger)): lower_weak = dtypes.is_weakly_typed(lower) upper_weak = dtypes.is_weakly_typed(upper) if lower_weak and not upper_weak: dtype = upper_dtype elif not lower_weak and upper_weak: dtype = lower_dtype if dtype is None: raise TypeError("lower and upper arguments to fori_loop must have equal " f"types, got {} and {}") # If we can specialize on the trip count, call scan instead of a while_loop # to enable efficient reverse-mode differentiation. if (isinstance(core.get_aval(lower), ConcreteArray) and isinstance(core.get_aval(upper), ConcreteArray)): try: lower_ = int(lower) upper_ = int(upper) except TypeError: use_scan = False else: use_scan = True else: use_scan = False if use_scan: if unroll is None: unroll = False if config.disable_jit.value and upper_ == lower_: # non-jit implementation of scan does not support length=0 return init_val (_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val), None, length=upper_ - lower_, unroll=unroll) return result if unroll is not None: raise ValueError("Can only use `unroll` in `fori_loop` if the loop bounds " "are statically known.") if lower_dtype != dtype: lower = lax.convert_element_type(lower, dtype) # type: ignore if upper_dtype != dtype: upper = lax.convert_element_type(upper, dtype) # type: ignore _, _, result = while_loop(_fori_cond_fun, _fori_body_fun(body_fun), (lower, upper, init_val)) return result ### map and miscellaneous rules @api_boundary def map(f, xs): """Map a function over leading array axes. Like Python's builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives. When ``xs`` is an array type, the semantics of :func:`~map` are given by this Python implementation:: def map(f, xs): return np.stack([f(x) for x in xs]) Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply: ``xs`` may be an arbitrary nested pytree type, and the mapped computation is compiled only once. Args: f: a Python function to apply element-wise over the first axis or axes of ``xs``. xs: values over which to map along the leading axis. Returns: Mapped values. """ g = lambda _, x: ((), f(x)) _, ys = scan(g, (), xs) return ys def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm): """Calls RBG in a loop and stacks the results.""" key, = batched_args bd, = batch_dims if bd is batching.not_mapped: return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype, algorithm=algorithm), (None, None) key = batching.moveaxis(key, bd, 0) map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm) stacked_keys, stacked_bits = map(map_body, key) return (stacked_keys, stacked_bits), (0, 0) batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule # type: ignore ### associative_scan @api_boundary def associative_scan(fn: Callable, elems, reverse: bool = False, axis: int = 0): """Performs a scan with an associative binary operation, in parallel. For an introduction to associative scans, see [BLE1990]_. Args: fn: A Python callable implementing an associative binary operation with signature ``r = fn(a, b)``. Function `fn` must be associative, i.e., it must satisfy the equation ``fn(a, fn(b, c)) == fn(fn(a, b), c)``. The inputs and result are (possibly nested Python tree structures of) array(s) matching ``elems``. Each array has a dimension in place of the ``axis`` dimension. `fn` should be applied elementwise over the ``axis`` dimension (for example, by using :func:`jax.vmap` over the elementwise function.) The result ``r`` has the same shape (and structure) as the two inputs ``a`` and ``b``. elems: A (possibly nested Python tree structure of) array(s), each with an ``axis`` dimension of size ``num_elems``. reverse: A boolean stating if the scan should be reversed with respect to the ``axis`` dimension. axis: an integer identifying the axis over which the scan should occur. Returns: A (possibly nested Python tree structure of) array(s) of the same shape and structure as ``elems``, in which the ``k``'th element of ``axis`` is the result of recursively applying ``fn`` to combine the first ``k`` elements of ``elems`` along ``axis``. For example, given ``elems = [a, b, c, ...]``, the result would be ``[a, fn(a, b), fn(fn(a, b), c), ...]``. Example 1: partial sums of an array of numbers: >>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32) Example 2: partial products of an array of matrices >>> mats = jax.random.uniform(jax.random.PRNGKey(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2) Example 3: reversed partial sums of an array of numbers >>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32) .. [BLE1990] Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University. """ if not callable(fn): raise TypeError("lax.associative_scan: fn argument should be callable.") elems_flat, tree = tree_flatten(elems) if reverse: elems_flat = [lax.rev(elem, [axis]) for elem in elems_flat] def combine(a_flat, b_flat): # Lower `fn` to operate on flattened sequences of elems. a = tree_unflatten(tree, a_flat) b = tree_unflatten(tree, b_flat) c = fn(a, b) c_flat, _ = tree_flatten(c) return c_flat # Check that all inputs have a consistent leading dimension `num_elems`. axis = util.canonicalize_axis(axis, elems_flat[0].ndim) if not core.is_constant_dim(elems_flat[0].shape[axis]): raise NotImplementedError("associative scan over axis " f"of non-constant size: {elems_flat[0].shape[axis]}. You may be " "able to avoid this on TPU. See b/274176030.") num_elems = int(elems_flat[0].shape[axis]) if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]): raise ValueError('Array inputs to associative_scan must have the same ' 'first dimension. (saw: {})' .format([elem.shape for elem in elems_flat])) # Summary of algorithm: # # Consider elements of `_scan(elems)` at odd indices. That's the same as first # summing successive pairs of elements of `elems` and performing a scan on # that half sized tensor. We perform the latter scan by recursion. # # Now consider the even elements of `_scan(elems)`. These can be computed # from the odd elements of `_scan(elems)` by adding each odd element of # `_scan(elems)` to the matching even element in the original `elems`. # # We return the odd and even elements interleaved. # # For the base case of the recursion we return the first element # of `elems` followed by the sum of the first two elements computed as # a (small two-down-to-one) reduction step. def _scan(elems): """Perform scan on `elems`.""" num_elems = elems[0].shape[axis] if num_elems < 2: return elems # Combine adjacent pairs of elements. reduced_elems = combine( [slicing.slice_in_dim(elem, 0, -1, stride=2, axis=axis) for elem in elems], [slicing.slice_in_dim(elem, 1, None, stride=2, axis=axis) for elem in elems]) # Recursively compute scan for partially reduced tensors. odd_elems = _scan(reduced_elems) if num_elems % 2 == 0: even_elems = combine( [slicing.slice_in_dim(e, 0, -1, axis=axis) for e in odd_elems], [slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) else: even_elems = combine( odd_elems, [slicing.slice_in_dim(e, 2, None, stride=2, axis=axis) for e in elems]) # The first element of a scan is the same as the first element # of the original `elems`. even_elems = [ lax.concatenate([slicing.slice_in_dim(elem, 0, 1, axis=axis), result], dimension=axis) for (elem, result) in zip(elems, even_elems)] return list(_map(partial(_interleave, axis=axis), even_elems, odd_elems)) scans = _scan(elems_flat) if reverse: scans = [lax.rev(scanned, [axis]) for scanned in scans] return tree_unflatten(tree, scans) def _interleave(a, b, axis): """Given two Tensors of static shape, interleave them along the first axis.""" assert a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1 a_pad = [(0, 0, 0)] * a.ndim b_pad = [(0, 0, 0)] * b.ndim a_pad[axis] = (0, 1 if a.shape[axis] == b.shape[axis] else 0, 1) b_pad[axis] = (1, 0 if a.shape[axis] == b.shape[axis] else 1, 1) op = lax.bitwise_or if a.dtype == np.bool_ else lax.add return op(lax.pad(a, lax._const(a, 0), a_pad), lax.pad(b, lax._const(b, 0), b_pad)) ### Cumulative reductions. def cumsum(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative sum along `axis`.""" return cumsum_p.bind(operand, axis=int(axis), reverse=bool(reverse)) def cumprod(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative product along `axis`.""" return cumprod_p.bind(operand, axis=int(axis), reverse=bool(reverse)) def cummax(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative maximum along `axis`.""" return cummax_p.bind(operand, axis=int(axis), reverse=bool(reverse)) def cummin(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative minimum along `axis`.""" return cummin_p.bind(operand, axis=int(axis), reverse=bool(reverse)) def cumlogsumexp(operand: Array, axis: int = 0, reverse: bool = False) -> Array: """Computes a cumulative logsumexp along `axis`.""" return cumlogsumexp_p.bind(operand, axis=int(axis), reverse=bool(reverse)) def _cumred_shape_rule(x, *, axis: int, reverse: bool): if axis < 0 or axis >= x.ndim: raise ValueError( f"axis {axis} is out of bounds for array of shape {x.shape}") return x.shape def _cumsum_transpose_rule(t, operand, *, axis: int, reverse: bool): return [cumsum(t, axis=axis, reverse=not reverse)] def cumred_reduce_window_impl(window_reduce: Callable, x, *, axis: int, reverse: bool): n = x.shape[axis] if n == 0: return x padding = [(0, 0)] * x.ndim padding[axis] = (0, n - 1) if reverse else (n - 1, 0) strides = [1] * x.ndim window_dims = [1] * x.ndim window_dims[axis] = n return window_reduce(x, window_dims, strides, padding) def cumred_gpu_impl(window_reduce: Callable, reduce_fn: Callable, x, *, axis: int, reverse: bool): # On GPU, reduce_window is executed in a single fusion and associative_scan # is split into multiple to materialize intermediate calculations. # On small inputs reduce_window is faster being a single fusion, # but on larger ones is slower because of O(n^2) complexity. # This conservative value of the threshold was obtained via benchmarking. if not core.is_constant_dim(x.shape[axis]): raise NotImplementedError( "associative scan reductions not implemented with shape polymorphism " "and native serialization on GPU") if x.shape[axis] > 32: return associative_scan(reduce_fn, x, reverse=reverse, axis=axis) return cumred_reduce_window_impl(window_reduce, x, axis=axis, reverse=reverse) def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int, reverse: bool): operand, = batched_args bdim, = batch_dims axis = axis if axis < bdim else axis + 1 return prim.bind(operand, axis=axis, reverse=reverse), bdim def _cumred_dtype_rule(name, operand, *args, **kw): if not dtypes.issubdtype(operand.dtype, np.number): raise TypeError("{} does not accept dtype {}. Accepted dtypes are subtypes " "of number.".format(name, np.dtype(operand.dtype).name)) return dtypes.canonicalize_dtype(operand.dtype) def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn): reducer_p = lax.standard_primitive( _cumred_shape_rule, partial(_cumred_dtype_rule, name), name) batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p) def register_lowering(fn, platform=None): mlir.register_lowering( reducer_p, mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)), platform=platform) # Default for platforms not treated specially below. register_lowering(partial(associative_scan, reduce_fn)) # On GPU, we choose between window reduction and associative scan # based on the input size. for platform in ['cuda', 'rocm']: register_lowering( partial(cumred_gpu_impl, reduce_window_fn, reduce_fn), platform) # On TPU, an implementation using reduce_window is handled specially by the # compiler and is efficient. On other backends, it is O(n^2). register_lowering(partial(cumred_reduce_window_impl, reduce_window_fn), 'tpu') return reducer_p cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum) ad.deflinear2(cumsum_p, _cumsum_transpose_rule) cumlogsumexp_p = _cumulative_reduction_primitive( "cumlogsumexp", logaddexp, windowed_reductions._reduce_window_logaddexp) cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod) cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max) cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min) def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool, combine_fn: Callable): # Irrespective of backend, we always use the parallel prefix scan # implementation when differentiating because reduce_window is not # arbitrarily differentiable. return api.jvp(partial(associative_scan, combine_fn, axis=axis, reverse=reverse), primals, tangents) ad.primitive_jvps[cumlogsumexp_p] = partial(_cumulative_jvp_rule, combine_fn=logaddexp) ad.primitive_jvps[cumprod_p] = partial(_cumulative_jvp_rule, combine_fn=lax.mul) ad.primitive_jvps[cummin_p] = partial(_cumulative_jvp_rule, combine_fn=lax.min) ad.primitive_jvps[cummax_p] = partial(_cumulative_jvp_rule, combine_fn=lax.max)