objax.functional package

objax.functional

Due to the large number of APIs in this section, we organized it into the following sub-sections:

Activation

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis, where, initial])

Log-Softmax function.

logsumexp(a[, axis, b, keepdims, return_sign])

Compute the log of the sum of exponentials of input elements.

relu(x)

Rectified linear unit activation function.

selu(x)

Scaled exponential linear unit activation.

sigmoid(x)

Sigmoid activation function.

softmax(x[, axis, where, initial])

Softmax function.

softplus(x)

Softplus activation function.

tanh(x)

Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\).

objax.functional.celu(x, alpha=1.0)[source]

Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]

For more information, see Continuously Differentiable Exponential Linear Units.

Parameters
  • x (Any) – input array

  • alpha (Any) – array or scalar (default: 1.0)

Return type

Any

objax.functional.elu(x, alpha=1.0)[source]

Exponential linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
Parameters
  • x (Any) – input array

  • alpha (Any) – scalar or array of alpha values (default: 1.0)

Return type

Any

objax.functional.leaky_relu(x, negative_slope=0.01)[source]

Leaky rectified linear unit activation function.

Computes the element-wise function:

\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]

where \(\alpha\) = negative_slope.

Parameters
  • x (Any) – input array

  • negative_slope (Any) – array or scalar specifying the negative slope (default: 0.01)

Return type

Any

objax.functional.log_sigmoid(x)[source]

Log-sigmoid activation function.

Computes the element-wise function:

\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
Parameters

x (Any) – input array

Return type

Any

objax.functional.log_softmax(x, axis=- 1, where=None, initial=None)[source]

Log-Softmax function.

Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).

\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]
Parameters
  • x (Any) – input array

  • axis (Optional[Union[int, Tuple[int, ...]]]) – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.

  • where (Optional[Any]) – Elements to include in the log_softmax.

  • initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when where is not None.

Return type

Any

objax.functional.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]

Compute the log of the sum of exponentials of input elements.

LAX-backend implementation of logsumexp().

Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.

  • keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.

  • b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.

  • return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).

Returns

  • res (ndarray) – The result, np.log(np.sum(np.exp(a))) calculated in a numerically more stable way. If b is given then np.log(np.sum(b*np.exp(a))) is returned.

  • sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.

objax.functional.relu(x)[source]

Rectified linear unit activation function.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor.

Returns

tensor with the element-wise output relu(x) = max(x, 0).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.selu(x)[source]

Scaled exponential linear unit activation.

Computes the element-wise function:

\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]

where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).

For more information, see Self-Normalizing Neural Networks.

Parameters

x (Any) – input array

Return type

Any

objax.functional.sigmoid(x)[source]

Sigmoid activation function.

Computes the element-wise function:

\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
Parameters

x (Any) – input array

Return type

Any

objax.functional.softmax(x, axis=- 1, where=None, initial=None)[source]

Softmax function.

Computes the function which rescales elements to the range \([0, 1]\) such that the elements along axis sum to \(1\).

\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]
Parameters
  • x (Any) – input array

  • axis (Optional[Union[int, Tuple[int, ...]]]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.

  • where (Optional[Any]) – Elements to include in the softmax.

  • initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when where is not None.

Return type

Any

objax.functional.softplus(x)[source]

Softplus activation function.

Computes the element-wise function

\[\mathrm{softplus}(x) = \log(1 + e^x)\]
Parameters

x (Any) – input array

Return type

Any

objax.functional.tanh(x)[source]

Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\).

Parameters

x (Any) –

Return type

Any

Pooling

average_pool_2d(x[, size, strides, padding])

Applies average pooling using a square 2D filter.

batch_to_space2d(x[, size])

Transfer batch dimension N into spatial dimensions (H, W).

channel_to_space2d(x[, size])

Transfer channel dimension C into spatial dimensions (H, W).

max_pool_2d(x[, size, strides, padding])

Applies max pooling using a square 2D filter.

space_to_batch2d(x[, size])

Transfer spatial dimensions (H, W) into batch dimension N.

space_to_channel2d(x[, size])

Transfer spatial dimensions (H, W) into channel dimension C.

objax.functional.average_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]

Applies average pooling using a square 2D filter.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of pooling filter.

  • strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

Returns

output tensor of shape (N, C, H, W).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

For a definition of pooling, including examples see Pooling Layer.

objax.functional.batch_to_space2d(x, size=2)[source]

Transfer batch dimension N into spatial dimensions (H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.channel_to_space2d(x, size=2)[source]

Transfer channel dimension C into spatial dimensions (H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.max_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]

Applies max pooling using a square 2D filter.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of pooling filter.

  • strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).

  • padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.

Returns

output tensor of shape (N, C, H, W).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

For a definition of pooling, including examples see Pooling Layer.

objax.functional.space_to_batch2d(x, size=2)[source]

Transfer spatial dimensions (H, W) into batch dimension N.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.space_to_channel2d(x, size=2)[source]

Transfer spatial dimensions (H, W) into channel dimension C.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • size (Union[Tuple[int, int], int]) – size of spatial area.

Returns

output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

Misc

dynamic_slice(operand, start_indices, ...)

Wraps XLA's DynamicSlice operator.

flatten(x)

Flattens input tensor to a 2D tensor.

interpolate(input[, size, scale_factor, mode])

Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g.

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indicies.

pad(array, pad_width[, mode, stat_length, ...])

Pad an array.

scan(f, init, xs[, length, reverse, unroll])

Scan a function over leading array axes while carrying along state.

stop_gradient(x)

Stops gradient computation.

top_k(operand, k)

Returns top k values and their indices along the last axis of operand.

rsqrt(x)

Elementwise reciprocal square root: \(1 \over \sqrt{x}\).

upsample_2d(x, scale[, method])

Function to upscale 2D images.

upscale_nn(x[, scale])

Nearest neighbor upscale for image batches of shape (N, C, H, W).

objax.functional.dynamic_slice(operand, start_indices, slice_sizes)[source]

Wraps XLA’s DynamicSlice operator.

Parameters
  • operand (Any) – an array to slice.

  • start_indices (Sequence[Any]) – a list of scalar indices, one per dimension. These values may be dynamic.

  • slice_sizes (Sequence[Union[int, Any]]) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).

Returns

An array containing the slice.

Return type

Any

Examples

Here is a simple two-dimensional dynamic slice:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
DeviceArray([[ 5,  6,  7],
             [ 9, 10, 11]], dtype=int32)

Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:

>>> dynamic_slice(x, (1, 1), (2, 4))
DeviceArray([[ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)
objax.functional.flatten(x)[source]

Flattens input tensor to a 2D tensor.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor with dimensions (n_1, n_2, …, n_k)

Returns

The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.interpolate(input, size=None, scale_factor=None, mode=Interpolate.BILINEAR)[source]

Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. [‘bilinear’, ‘nearest’]

Returns

output JaxArray after interpolation

Return type

output

Parameters
  • input (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –

  • size (Optional[Union[int, Tuple[int, ...]]]) –

  • scale_factor (Optional[Union[int, Tuple[int, ...]]]) –

  • mode (Union[objax.constants.Interpolate, str]) –

objax.functional.one_hot(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, axis=-1)[source]

One-hot encodes the given indicies.

Each index in the input x is encoded as a vector of zeros of length num_classes with the element at index set to one:

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
              [0., 1., 0.],
              [0., 0., 1.]], dtype=float32)

Indicies outside the range [0, num_classes) will be encoded as zeros:

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
             [0., 0., 0.]], dtype=float32)
Parameters
  • x (Any) – A tensor of indices.

  • num_classes (int) – Number of classes in the one-hot dimension.

  • dtype (Any) – optional, a float dtype for the returned values (default jnp.float_).

  • axis (Union[int, Hashable]) – the axis or axes along which the function should be computed.

Return type

Any

objax.functional.pad(array, pad_width, mode='constant', *, stat_length=None, constant_values=0, end_values=None, reflect_type=None)[source]

Pad an array.

LAX-backend implementation of pad().

Unlike numpy, JAX “function” mode’s argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, “function” mode’s argument should modify a rank 1 array in-place.)

Original docstring below.

Parameters
  • array (array_like of rank N) – The array to pad.

  • pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.

  • mode (str or function, optional) –

    One of the following string values or a user supplied function.

    ’constant’ (default)

    Pads with a constant value.

    ’edge’

    Pads with the edge values of array.

    ’linear_ramp’

    Pads with the linear ramp between end_value and the array edge value.

    ’maximum’

    Pads with the maximum value of all or part of the vector along each axis.

    ’mean’

    Pads with the mean value of all or part of the vector along each axis.

    ’median’

    Pads with the median value of all or part of the vector along each axis.

    ’minimum’

    Pads with the minimum value of all or part of the vector along each axis.

    ’reflect’

    Pads with the reflection of the vector mirrored on the first and last values of the vector along each axis.

    ’symmetric’

    Pads with the reflection of the vector mirrored along the edge of the array.

    ’wrap’

    Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning.

    ’empty’

    Pads with undefined values.

  • stat_length (sequence or int, optional) –

    Used in ‘maximum’, ‘mean’, ‘median’, and ‘minimum’. Number of values at edge of each axis used to calculate the statistic value.

    ((before_1, after_1), … (before_N, after_N)) unique statistic lengths for each axis.

    ((before, after),) yields same before and after statistic lengths for each axis.

    (stat_length,) or int is a shortcut for before = after = statistic length for all axes.

    Default is None, to use the entire axis.

  • constant_values (sequence or scalar, optional) –

    Used in ‘constant’. The values to set the padded values for each axis.

    ((before_1, after_1), ... (before_N, after_N)) unique pad constants for each axis.

    ((before, after),) yields same before and after constants for each axis.

    (constant,) or constant is a shortcut for before = after = constant for all axes.

    Default is 0.

  • end_values (sequence or scalar, optional) –

    Used in ‘linear_ramp’. The values used for the ending value of the linear_ramp and that will form the edge of the padded array.

    ((before_1, after_1), ... (before_N, after_N)) unique end values for each axis.

    ((before, after),) yields same before and after end values for each axis.

    (constant,) or constant is a shortcut for before = after = constant for all axes.

    Default is 0.

  • reflect_type ({'even', 'odd'}, optional) – Used in ‘reflect’, and ‘symmetric’. The ‘even’ style is the default with an unaltered reflection around the edge value. For the ‘odd’ style, the extended part of the array is created by subtracting the reflected values from two times the edge value.

Returns

pad – Padded array of rank equal to array with shape increased according to pad_width.

Return type

ndarray

objax.functional.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]

Scan a function over leading array axes while carrying along state.

The type signature in brief is

scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.

When a is an array type or None, and b is an array type, the semantics of scan are given roughly by this Python implementation:

def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)

Unlike that Python version, both a and b may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. (None is actually an empty pytree.)

Also unlike that Python version, scan is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an @jit function are unrolled, leading to large XLA computations.

Finally, the loop-carried value carry must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the type c in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).

Parameters
  • f (Callable[[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.X], Tuple[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.Y]]) – a Python function to be scanned of type c -> a -> (c, b), meaning that f accepts two arguments where the first is a value of the loop carry and the second is a slice of xs along its leading axis, and that f returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.

  • init (jax._src.lax.control_flow.Carry) – an initial loop carry value of type c, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by f.

  • xs (jax._src.lax.control_flow.X) – the value of type [a] over which to scan along the leading axis, where [a] can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.

  • length (Optional[int]) – optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in xs (but can be used to perform scans where no input xs are needed).

  • reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both xs and in ys.

  • unroll (int) – optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.

Returns

A pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of f when scanned over the leading axis of the inputs.

Return type

Tuple[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.Y]

objax.functional.stop_gradient(x)[source]

Stops gradient computation.

Operationally stop_gradient is the identity function, that is, it returns argument x unchanged. However, stop_gradient prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, stop_gradient stops gradients for all of them.

For example:

>>> jax.grad(lambda x: x**2)(3.)
DeviceArray(6., dtype=float32, weak_type=True)
>>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.)
DeviceArray(0., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: x**2))(3.)
DeviceArray(2., dtype=float32, weak_type=True)
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
DeviceArray(0., dtype=float32, weak_type=True)
Parameters

x (jax._src.lax.lax.T) –

Return type

jax._src.lax.lax.T

objax.functional.top_k(operand, k)[source]

Returns top k values and their indices along the last axis of operand.

Parameters
  • operand (Any) –

  • k (int) –

Return type

Tuple[Any, Any]

objax.functional.rsqrt(x)[source]

Elementwise reciprocal square root: \(1 \over \sqrt{x}\).

Parameters

x (Any) –

Return type

Any

objax.functional.upsample_2d(x, scale, method=Interpolate.BILINEAR)[source]

Function to upscale 2D images.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor.

  • scale (Union[Tuple[int, int], int]) – int or tuple scaling factor

  • method (Union[objax.constants.Interpolate, str]) – str or UpSample interpolation methods e.g. [‘bilinear’, ‘nearest’].

Returns

upscaled 2d image tensor

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.upscale_nn(x, scale=2)[source]

Nearest neighbor upscale for image batches of shape (N, C, H, W).

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).

  • scale (int) – integer scaling factor.

Returns

Output tensor of shape (N, C, H * scale, W * scale).

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.divergence

kl(p, q[, eps])

Calculates the Kullback-Leibler divergence between arrays p and q.

objax.functional.divergence.kl(p, q, eps=7.62939453125e-06)[source]

Calculates the Kullback-Leibler divergence between arrays p and q.

\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]

The \(\epsilon\) term is added to ensure that neither p nor q are zero.

Parameters
  • p (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –

  • q (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –

  • eps (float) –

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.loss

cross_entropy_logits(logits, labels)

Computes the softmax cross-entropy loss on n-dimensional data.

cross_entropy_logits_sparse(logits, labels)

Computes the softmax cross-entropy loss.

l2(x)

Computes the L2 loss.

mean_absolute_error(x, y[, keep_axis])

Computes the mean absolute error between x and y.

mean_squared_error(x, y[, keep_axis])

Computes the mean squared error between x and y.

mean_squared_log_error(y_true, y_pred[, ...])

Computes the mean squared logarithmic error between y_true and y_pred.

sigmoid_cross_entropy_logits(logits, labels)

Computes the sigmoid cross-entropy loss.

objax.functional.loss.cross_entropy_logits(logits, labels)[source]

Computes the softmax cross-entropy loss on n-dimensional data.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

Calculates the cross entropy loss, defined as follows:

\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]

where \(o_k\) are the logits and \(y_k\) are the labels.

objax.functional.loss.cross_entropy_logits_sparse(logits, labels)[source]

Computes the softmax cross-entropy loss.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase, int]) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.loss.l2(x)[source]

Computes the L2 loss.

Parameters

x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – n-dimensional tensor of floats.

Returns

scalar tensor containing the l2 loss of x.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

Calculates the l2 loss, as:

\[l_2 = \frac{\sum_{i} x_{i}^2}{2}\]
objax.functional.loss.mean_absolute_error(x, y, keep_axis=(0,))[source]

Computes the mean absolute error between x and y.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).

  • y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).

  • keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.

Returns

tensor of shape (d_i, …, for i in keep_axis) containing the mean absolute error.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.loss.mean_squared_error(x, y, keep_axis=(0,))[source]

Computes the mean squared error between x and y.

Parameters
  • x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).

  • y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).

  • keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.

Returns

tensor of shape (d_i, …, for i in keep_axis) containing the mean squared error.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.loss.sigmoid_cross_entropy_logits(logits, labels)[source]

Computes the sigmoid cross-entropy loss.

Parameters
  • logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.

  • labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase, int]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)

Returns

(batch, …) tensor of the cross-entropies for each entry.

Return type

Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]

objax.functional.parallel

pmax(x[, axis_name])

Compute a multi-device reduce max on x over the device axis axis_name.

pmean(x[, axis_name])

Compute a multi-device reduce mean on x over the device axis axis_name.

pmin(x[, axis_name])

Compute a multi-device reduce min on x over the device axis axis_name.

psum(x[, axis_name])

Compute a multi-device reduce sum on x over the device axis axis_name.

objax.functional.parallel.pmax(x, axis_name='device')[source]

Compute a multi-device reduce max on x over the device axis axis_name.

Parameters
  • x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –

  • axis_name (str) –

objax.functional.parallel.pmean(x, axis_name='device')[source]

Compute a multi-device reduce mean on x over the device axis axis_name.

Parameters
  • x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –

  • axis_name (str) –

objax.functional.parallel.pmin(x, axis_name='device')[source]

Compute a multi-device reduce min on x over the device axis axis_name.

Parameters
  • x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –

  • axis_name (str) –

objax.functional.parallel.psum(x, axis_name='device')[source]

Compute a multi-device reduce sum on x over the device axis axis_name.

Parameters
  • x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –

  • axis_name (str) –