objax.functional package¶
objax.functional¶
Due to the large number of APIs in this section, we organized it into the following sub-sections:
Activation¶
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
Compute the log of the sum of exponentials of input elements. |
|
Rectified linear unit activation function. |
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
- objax.functional.celu(x, alpha=1.0)[source]¶
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- Parameters
x (Any) – input array
alpha (Any) – array or scalar (default: 1.0)
- Return type
Any
- objax.functional.elu(x, alpha=1.0)[source]¶
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters
x (Any) – input array
alpha (Any) – scalar or array of alpha values (default: 1.0)
- Return type
Any
- objax.functional.leaky_relu(x, negative_slope=0.01)[source]¶
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.- Parameters
x (Any) – input array
negative_slope (Any) – array or scalar specifying the negative slope (default: 0.01)
- Return type
Any
- objax.functional.log_sigmoid(x)[source]¶
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters
x (Any) – input array
- Return type
Any
- objax.functional.log_softmax(x, axis=- 1, where=None, initial=None)[source]¶
Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters
x (Any) – input array
axis (Optional[Union[int, Tuple[int, ...]]]) – the axis or axes along which the
log_softmax
should be computed. Either an integer or a tuple of integers.where (Optional[Any]) – Elements to include in the
log_softmax
.initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when
where
is not None.
- Return type
Any
- objax.functional.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]¶
Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
logsumexp()
.Original docstring below.
- Parameters
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Returns
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
- objax.functional.relu(x)[source]¶
Rectified linear unit activation function.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor.
- Returns
tensor with the element-wise output relu(x) = max(x, 0).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.selu(x)[source]¶
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters
x (Any) – input array
- Return type
Any
- objax.functional.sigmoid(x)[source]¶
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters
x (Any) – input array
- Return type
Any
- objax.functional.softmax(x, axis=- 1, where=None, initial=None)[source]¶
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters
x (Any) – input array
axis (Optional[Union[int, Tuple[int, ...]]]) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
where (Optional[Any]) – Elements to include in the
softmax
.initial (Optional[Any]) – The minimum value used to shift the input array. Must be present when
where
is not None.
- Return type
Any
Pooling¶
|
Applies average pooling using a square 2D filter. |
|
Transfer batch dimension N into spatial dimensions (H, W). |
|
Transfer channel dimension C into spatial dimensions (H, W). |
|
Applies max pooling using a square 2D filter. |
|
Transfer spatial dimensions (H, W) into batch dimension N. |
|
Transfer spatial dimensions (H, W) into channel dimension C. |
- objax.functional.average_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]¶
Applies average pooling using a square 2D filter.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
For a definition of pooling, including examples see Pooling Layer.
- objax.functional.batch_to_space2d(x, size=2)[source]¶
Transfer batch dimension N into spatial dimensions (H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.channel_to_space2d(x, size=2)[source]¶
Transfer channel dimension C into spatial dimensions (H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.max_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]¶
Applies max pooling using a square 2D filter.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
For a definition of pooling, including examples see Pooling Layer.
- objax.functional.space_to_batch2d(x, size=2)[source]¶
Transfer spatial dimensions (H, W) into batch dimension N.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.space_to_channel2d(x, size=2)[source]¶
Transfer spatial dimensions (H, W) into channel dimension C.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
Misc¶
|
Wraps XLA's DynamicSlice operator. |
|
Flattens input tensor to a 2D tensor. |
|
Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. |
|
One-hot encodes the given indicies. |
|
Pad an array. |
|
Scan a function over leading array axes while carrying along state. |
Stops gradient computation. |
|
|
Returns top |
|
Elementwise reciprocal square root: \(1 \over \sqrt{x}\). |
|
Function to upscale 2D images. |
|
Nearest neighbor upscale for image batches of shape (N, C, H, W). |
- objax.functional.dynamic_slice(operand, start_indices, slice_sizes)[source]¶
Wraps XLA’s DynamicSlice operator.
- Parameters
operand (Any) – an array to slice.
start_indices (Sequence[Any]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Sequence[Union[int, Any]]) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Returns
An array containing the slice.
- Return type
Any
Examples
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4) >>> x DeviceArray([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) DeviceArray([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4)) DeviceArray([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
- objax.functional.flatten(x)[source]¶
Flattens input tensor to a 2D tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor with dimensions (n_1, n_2, …, n_k)
- Returns
The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.interpolate(input, size=None, scale_factor=None, mode=Interpolate.BILINEAR)[source]¶
Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. [‘bilinear’, ‘nearest’]
- Returns
output JaxArray after interpolation
- Return type
output
- Parameters
input (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –
size (Optional[Union[int, Tuple[int, ...]]]) –
scale_factor (Optional[Union[int, Tuple[int, ...]]]) –
mode (Union[objax.constants.Interpolate, str]) –
- objax.functional.one_hot(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, axis=-1)[source]¶
One-hot encodes the given indicies.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters
x (Any) – A tensor of indices.
num_classes (int) – Number of classes in the one-hot dimension.
dtype (Any) – optional, a float dtype for the returned values (default
jnp.float_
).axis (Union[int, Hashable]) – the axis or axes along which the function should be computed.
- Return type
Any
- objax.functional.pad(array, pad_width, mode='constant', *, stat_length=None, constant_values=0, end_values=None, reflect_type=None)[source]¶
Pad an array.
LAX-backend implementation of
pad()
.Unlike numpy, JAX “function” mode’s argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, “function” mode’s argument should modify a rank 1 array in-place.)
Original docstring below.
- Parameters
array (array_like of rank N) – The array to pad.
pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.
mode (str or function, optional) –
One of the following string values or a user supplied function.
- ’constant’ (default)
Pads with a constant value.
- ’edge’
Pads with the edge values of array.
- ’linear_ramp’
Pads with the linear ramp between end_value and the array edge value.
- ’maximum’
Pads with the maximum value of all or part of the vector along each axis.
- ’mean’
Pads with the mean value of all or part of the vector along each axis.
- ’median’
Pads with the median value of all or part of the vector along each axis.
- ’minimum’
Pads with the minimum value of all or part of the vector along each axis.
- ’reflect’
Pads with the reflection of the vector mirrored on the first and last values of the vector along each axis.
- ’symmetric’
Pads with the reflection of the vector mirrored along the edge of the array.
- ’wrap’
Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning.
- ’empty’
Pads with undefined values.
stat_length (sequence or int, optional) –
Used in ‘maximum’, ‘mean’, ‘median’, and ‘minimum’. Number of values at edge of each axis used to calculate the statistic value.
((before_1, after_1), … (before_N, after_N)) unique statistic lengths for each axis.
((before, after),) yields same before and after statistic lengths for each axis.
(stat_length,) or int is a shortcut for before = after = statistic length for all axes.
Default is
None
, to use the entire axis.constant_values (sequence or scalar, optional) –
Used in ‘constant’. The values to set the padded values for each axis.
((before_1, after_1), ... (before_N, after_N))
unique pad constants for each axis.((before, after),)
yields same before and after constants for each axis.(constant,)
orconstant
is a shortcut forbefore = after = constant
for all axes.Default is 0.
end_values (sequence or scalar, optional) –
Used in ‘linear_ramp’. The values used for the ending value of the linear_ramp and that will form the edge of the padded array.
((before_1, after_1), ... (before_N, after_N))
unique end values for each axis.((before, after),)
yields same before and after end values for each axis.(constant,)
orconstant
is a shortcut forbefore = after = constant
for all axes.Default is 0.
reflect_type ({'even', 'odd'}, optional) – Used in ‘reflect’, and ‘symmetric’. The ‘even’ style is the default with an unaltered reflection around the edge value. For the ‘odd’ style, the extended part of the array is created by subtracting the reflected values from two times the edge value.
- Returns
pad – Padded array of rank equal to array with shape increased according to pad_width.
- Return type
ndarray
- objax.functional.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]¶
Scan a function over leading array axes while carrying along state.
The type signature in brief is
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where we use [t] here to denote the type t with an additional leading axis. That is, if t is an array type then [t] represents the type with an additional leading axis, and if t is a pytree (container) type with array leaves then [t] represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.
When
a
is an array type or None, andb
is an array type, the semantics ofscan
are given roughly by this Python implementation:def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
Unlike that Python version, both
a
andb
may be arbitrary pytree types, and so multiple arrays can be scanned over at once and produce multiple output arrays. (None is actually an empty pytree.)Also unlike that Python version,
scan
is a JAX primitive and is lowered to a single XLA While HLO. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an@jit
function are unrolled, leading to large XLA computations.Finally, the loop-carried value
carry
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typec
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).- Parameters
f (Callable[[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.X], Tuple[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.Y]]) – a Python function to be scanned of type
c -> a -> (c, b)
, meaning thatf
accepts two arguments where the first is a value of the loop carry and the second is a slice ofxs
along its leading axis, and thatf
returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.init (jax._src.lax.control_flow.Carry) – an initial loop carry value of type
c
, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned byf
.xs (jax._src.lax.control_flow.X) – the value of type
[a]
over which to scan along the leading axis, where[a]
can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.length (Optional[int]) – optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in
xs
(but can be used to perform scans where no inputxs
are needed).reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both
xs
and inys
.unroll (int) – optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.
- Returns
A pair of type
(c, [b])
where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output off
when scanned over the leading axis of the inputs.- Return type
Tuple[jax._src.lax.control_flow.Carry, jax._src.lax.control_flow.Y]
- objax.functional.stop_gradient(x)[source]¶
Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) DeviceArray(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) DeviceArray(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) DeviceArray(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) DeviceArray(0., dtype=float32, weak_type=True)
- Parameters
x (jax._src.lax.lax.T) –
- Return type
jax._src.lax.lax.T
- objax.functional.top_k(operand, k)[source]¶
Returns top
k
values and their indices along the last axis ofoperand
.- Parameters
operand (Any) –
k (int) –
- Return type
Tuple[Any, Any]
- objax.functional.rsqrt(x)[source]¶
Elementwise reciprocal square root: \(1 \over \sqrt{x}\).
- Parameters
x (Any) –
- Return type
Any
- objax.functional.upsample_2d(x, scale, method=Interpolate.BILINEAR)[source]¶
Function to upscale 2D images.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor.
scale (Union[Tuple[int, int], int]) – int or tuple scaling factor
method (Union[objax.constants.Interpolate, str]) – str or UpSample interpolation methods e.g. [‘bilinear’, ‘nearest’].
- Returns
upscaled 2d image tensor
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.upscale_nn(x, scale=2)[source]¶
Nearest neighbor upscale for image batches of shape (N, C, H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – input tensor of shape (N, C, H, W).
scale (int) – integer scaling factor.
- Returns
Output tensor of shape (N, C, H * scale, W * scale).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
objax.functional.divergence¶
|
Calculates the Kullback-Leibler divergence between arrays p and q. |
- objax.functional.divergence.kl(p, q, eps=7.62939453125e-06)[source]¶
Calculates the Kullback-Leibler divergence between arrays p and q.
\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]The \(\epsilon\) term is added to ensure that neither
p
norq
are zero.- Parameters
p (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –
q (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) –
eps (float) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
objax.functional.loss¶
|
Computes the softmax cross-entropy loss on n-dimensional data. |
|
Computes the softmax cross-entropy loss. |
|
Computes the L2 loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
|
Computes the sigmoid cross-entropy loss. |
- objax.functional.loss.cross_entropy_logits(logits, labels)[source]¶
Computes the softmax cross-entropy loss on n-dimensional data.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
Calculates the cross entropy loss, defined as follows:
\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]where \(o_k\) are the logits and \(y_k\) are the labels.
- objax.functional.loss.cross_entropy_logits_sparse(logits, labels)[source]¶
Computes the softmax cross-entropy loss.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase, int]) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.loss.l2(x)[source]¶
Computes the L2 loss.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – n-dimensional tensor of floats.
- Returns
scalar tensor containing the l2 loss of x.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
Calculates the l2 loss, as:
- objax.functional.loss.mean_absolute_error(x, y, keep_axis=(0,))[source]¶
Computes the mean absolute error between x and y.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).
y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).
keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns
tensor of shape (d_i, …, for i in keep_axis) containing the mean absolute error.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.loss.mean_squared_error(x, y, keep_axis=(0,))[source]¶
Computes the mean squared error between x and y.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).
y (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – a tensor of shape (d0, .. dN-1).
keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns
tensor of shape (d_i, …, for i in keep_axis) containing the mean squared error.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
- objax.functional.loss.sigmoid_cross_entropy_logits(logits, labels)[source]¶
Computes the sigmoid cross-entropy loss.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase, int]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jaxlib.xla_extension.DeviceArrayBase, jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase]
objax.functional.parallel¶
|
Compute a multi-device reduce max on x over the device axis axis_name. |
|
Compute a multi-device reduce mean on x over the device axis axis_name. |
|
Compute a multi-device reduce min on x over the device axis axis_name. |
|
Compute a multi-device reduce sum on x over the device axis axis_name. |
- objax.functional.parallel.pmax(x, axis_name='device')[source]¶
Compute a multi-device reduce max on x over the device axis axis_name.
- Parameters
x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –
axis_name (str) –
- objax.functional.parallel.pmean(x, axis_name='device')[source]¶
Compute a multi-device reduce mean on x over the device axis axis_name.
- Parameters
x (jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase) –
axis_name (str) –