objax.functional package
objax.functional
Due to the large number of APIs in this section, we organized it into the following sub-sections:
Activation
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
Compute the log of the sum of exponentials of input elements. |
|
|
Rectified linear unit activation function. |
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
- objax.functional.celu(x, alpha=1.0)[source]
Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
alpha (Array | ndarray | bool_ | number | bool | int | float | complex) – array or scalar (default: 1.0)
- Returns:
An array.
- Return type:
Array
- objax.functional.elu(x, alpha=1.0)[source]
Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
alpha (Array | ndarray | bool_ | number | bool | int | float | complex) – scalar or array of alpha values (default: 1.0)
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.leaky_relu(x, negative_slope=0.01)[source]
Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
negative_slope (Array | ndarray | bool_ | number | bool | int | float | complex) – array or scalar specifying the negative slope (default: 0.01)
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.log_sigmoid(x)[source]
Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.log_softmax(x, axis=-1, where=None, initial=None)[source]
Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters:
x (ArrayLike) – input array
axis (int | tuple[int, ...] | None) – the axis or axes along which the
log_softmax
should be computed. Either an integer or a tuple of integers.where (ArrayLike | None) – Elements to include in the
log_softmax
.initial (ArrayLike | None) – The minimum value used to shift the input array. Must be present when
where
is not None.
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: Literal[False] = False) Array [source]
- objax.functional.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, *, return_sign: Literal[True]) tuple[Array, Array]
- objax.functional.logsumexp(a: Array | ndarray | bool_ | number | bool | int | float | complex, axis: int | Sequence[int] | None = None, b: ArrayLike | None = None, keepdims: bool = False, return_sign: bool = False) Array | tuple[Array, Array]
Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
scipy.special.logsumexp()
.Original docstring below.
- Parameters:
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Returns:
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
- objax.functional.relu(x)[source]
Rectified linear unit activation function.
- Parameters:
x (Array) – input tensor.
- Returns:
tensor with the element-wise output relu(x) = max(x, 0).
- Return type:
Array
- objax.functional.selu(x)[source]
Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.sigmoid(x)[source]
Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) – input array
- Returns:
An array.
- Return type:
Array
See also
- objax.functional.softmax(x, axis=-1, where=None, initial=None)[source]
Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters:
x (ArrayLike) – input array
axis (int | tuple[int, ...] | None) – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
where (ArrayLike | None) – Elements to include in the
softmax
.initial (ArrayLike | None) – The minimum value used to shift the input array. Must be present when
where
is not None.
- Returns:
An array.
- Return type:
Array
See also
Pooling
|
Applies average pooling using a square 2D filter. |
|
Transfer batch dimension N into spatial dimensions (H, W). |
|
Transfer channel dimension C into spatial dimensions (H, W). |
|
Applies max pooling using a square 2D filter. |
|
Transfer spatial dimensions (H, W) into batch dimension N. |
|
Transfer spatial dimensions (H, W) into channel dimension C. |
- objax.functional.average_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]
Applies average pooling using a square 2D filter.
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of pooling filter.
strides (Tuple[int, int] | int | None) – stride step, use size when stride is none (default).
padding (ConvPadding | str | Sequence[Tuple[int, int]] | Tuple[int, int] | int) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns:
output tensor of shape (N, C, H, W).
- Return type:
Array
For a definition of pooling, including examples see Pooling Layer.
- objax.functional.batch_to_space2d(x, size=2)[source]
Transfer batch dimension N into spatial dimensions (H, W).
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of spatial area.
- Returns:
output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).
- Return type:
Array
- objax.functional.channel_to_space2d(x, size=2)[source]
Transfer channel dimension C into spatial dimensions (H, W).
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of spatial area.
- Returns:
output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
- Return type:
Array
- objax.functional.max_pool_2d(x, size=2, strides=None, padding=ConvPadding.VALID)[source]
Applies max pooling using a square 2D filter.
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of pooling filter.
strides (Tuple[int, int] | int | None) – stride step, use size when stride is none (default).
padding (ConvPadding | str | Sequence[Tuple[int, int]] | Tuple[int, int] | int) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns:
output tensor of shape (N, C, H, W).
- Return type:
Array
For a definition of pooling, including examples see Pooling Layer.
- objax.functional.space_to_batch2d(x, size=2)[source]
Transfer spatial dimensions (H, W) into batch dimension N.
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of spatial area.
- Returns:
output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).
- Return type:
Array
- objax.functional.space_to_channel2d(x, size=2)[source]
Transfer spatial dimensions (H, W) into channel dimension C.
- Parameters:
x (Array) – input tensor of shape (N, C, H, W).
size (Tuple[int, int] | int) – size of spatial area.
- Returns:
output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
- Return type:
Array
Misc
|
Wraps XLA's DynamicSlice operator. |
|
Flattens input tensor to a 2D tensor. |
|
Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. ['bilinear', 'nearest']. |
|
One-hot encodes the given indices. |
|
Pad an array. |
|
Scan a function over leading array axes while carrying along state. |
Stops gradient computation. |
|
|
Returns top |
|
Elementwise reciprocal square root: \(1 \over \sqrt{x}\). |
|
Function to upscale 2D images. |
|
Nearest neighbor upscale for image batches of shape (N, C, H, W). |
- objax.functional.dynamic_slice(operand, start_indices, slice_sizes)[source]
Wraps XLA’s DynamicSlice operator.
- Parameters:
operand (Array | np.ndarray) – an array to slice.
start_indices (Array | np.ndarray | Sequence[ArrayLike]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Shape) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Returns:
An array containing the slice.
- Return type:
Array
Examples
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
See also
jax.numpy.ndarray.at
jax.lax.slice()
jax.lax.dynamic_slice_in_dim()
jax.lax.dynamic_index_in_dim()
- objax.functional.flatten(x)[source]
Flattens input tensor to a 2D tensor.
- Parameters:
x (Array) – input tensor with dimensions (n_1, n_2, …, n_k)
- Returns:
The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k.
- Return type:
Array
- objax.functional.interpolate(input, size=None, scale_factor=None, mode=Interpolate.BILINEAR)[source]
Function to interpolate JaxArrays by size or scaling factor :param input: input tensor :param size: int or tuple for output size :param scale_factor: int or tuple scaling factor for each dimention :param mode: str or Interpolate interpolation method e.g. [‘bilinear’, ‘nearest’]
- Returns:
output JaxArray after interpolation
- Return type:
output
- Parameters:
input (Array) –
size (int | Tuple[int, ...] | None) –
scale_factor (int | Tuple[int, ...] | None) –
mode (Interpolate | str) –
- objax.functional.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]
One-hot encodes the given indices.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters:
x (Any) – A tensor of indices.
num_classes (int) – Number of classes in the one-hot dimension.
dtype (Any) – optional, a float dtype for the returned values (default
jnp.float_
).axis (int | AxisName) – the axis or axes along which the function should be computed.
- Return type:
Array
- objax.functional.pad(array, pad_width, mode='constant', *, stat_length=None, constant_values=0, end_values=None, reflect_type=None)[source]
Pad an array.
LAX-backend implementation of
numpy.pad()
.Unlike numpy, JAX “function” mode’s argument (which is another function) should return the modified array. This is because Jax arrays are immutable. (In numpy, “function” mode’s argument should modify a rank 1 array in-place.)
Original docstring below.
- Parameters:
array (array_like of rank N) – The array to pad.
pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis.
((before_1, after_1), ... (before_N, after_N))
unique pad widths for each axis.(before, after)
or((before, after),)
yields same before and after pad for each axis.(pad,)
orint
is a shortcut for before = after = pad width for all axes.mode (str or function, optional) –
One of the following string values or a user supplied function.
- ’constant’ (default)
Pads with a constant value.
- ’edge’
Pads with the edge values of array.
- ’linear_ramp’
Pads with the linear ramp between end_value and the array edge value.
- ’maximum’
Pads with the maximum value of all or part of the vector along each axis.
- ’mean’
Pads with the mean value of all or part of the vector along each axis.
- ’median’
Pads with the median value of all or part of the vector along each axis.
- ’minimum’
Pads with the minimum value of all or part of the vector along each axis.
- ’reflect’
Pads with the reflection of the vector mirrored on the first and last values of the vector along each axis.
- ’symmetric’
Pads with the reflection of the vector mirrored along the edge of the array.
- ’wrap’
Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning.
- ’empty’
Pads with undefined values.
stat_length (sequence or int, optional) –
Used in ‘maximum’, ‘mean’, ‘median’, and ‘minimum’. Number of values at edge of each axis used to calculate the statistic value.
((before_1, after_1), ... (before_N, after_N))
unique statistic lengths for each axis.(before, after)
or((before, after),)
yields same before and after statistic lengths for each axis.(stat_length,)
orint
is a shortcut forbefore = after = statistic
length for all axes.Default is
None
, to use the entire axis.constant_values (sequence or scalar, optional) –
Used in ‘constant’. The values to set the padded values for each axis.
((before_1, after_1), ... (before_N, after_N))
unique pad constants for each axis.(before, after)
or((before, after),)
yields same before and after constants for each axis.(constant,)
orconstant
is a shortcut forbefore = after = constant
for all axes.Default is 0.
end_values (sequence or scalar, optional) –
Used in ‘linear_ramp’. The values used for the ending value of the linear_ramp and that will form the edge of the padded array.
((before_1, after_1), ... (before_N, after_N))
unique end values for each axis.(before, after)
or((before, after),)
yields same before and after end values for each axis.(constant,)
orconstant
is a shortcut forbefore = after = constant
for all axes.Default is 0.
reflect_type ({'even', 'odd'}, optional) – Used in ‘reflect’, and ‘symmetric’. The ‘even’ style is the default with an unaltered reflection around the edge value. For the ‘odd’ style, the extended part of the array is created by subtracting the reflected values from two times the edge value.
- Returns:
pad – Padded array of rank equal to array with shape increased according to pad_width.
- Return type:
ndarray
- objax.functional.scan(f, init, xs, length=None, reverse=False, unroll=1)[source]
Scan a function over leading array axes while carrying along state.
The Haskell-like type signature in brief is
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where for any array type specifier
t
,[t]
represents the type with an additional leading axis, and ift
is a pytree (container) type with array leaves then[t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.When the type of
xs
(denoted a above) is an array type or None, and the type ofys
(denoted b above) is an array type, the semantics ofscan()
are given roughly by this Python implementation:def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
Unlike that Python version, both
xs
andys
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.None
is actually a special case of this, as it represents an empty pytree.Also unlike that Python version,
scan()
is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for JIT-compiled functions, since native Python loop constructs in anjit()
function are unrolled, leading to large XLA computations.Finally, the loop-carried value
carry
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typec
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Note
scan()
compilesf
, so while it can be combined withjit()
, it’s usually unnecessary.- Parameters:
f (Callable[[Carry, X], tuple[Carry, Y]]) – a Python function to be scanned of type
c -> a -> (c, b)
, meaning thatf
accepts two arguments where the first is a value of the loop carry and the second is a slice ofxs
along its leading axis, and thatf
returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.init (Carry) – an initial loop carry value of type
c
, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned byf
.xs (X) – the value of type
[a]
over which to scan along the leading axis, where[a]
can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.length (int | None) – optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in
xs
(but can be used to perform scans where no inputxs
are needed).reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both
xs
and inys
.unroll (int | bool) – optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. unroll=True) or left completely unrolled (i.e. unroll=False).
- Returns:
A pair of type
(c, [b])
where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output off
when scanned over the leading axis of the inputs.- Return type:
tuple[Carry, Y]
- objax.functional.stop_gradient(x)[source]
Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) Array(6., dtype=float32, weak_type=True) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) Array(0., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: x**2))(3.) Array(2., dtype=float32, weak_type=True) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) Array(0., dtype=float32, weak_type=True)
- Parameters:
x (T) –
- Return type:
T
- objax.functional.top_k(operand, k)[source]
Returns top
k
values and their indices along the last axis ofoperand
.- Parameters:
operand (Array | ndarray | bool_ | number | bool | int | float | complex) – N-dimensional array of non-complex type.
k (int) – integer specifying the number of top entries.
- Returns:
array containing the top k values along the last axis. indices: array containing the indices corresponding to values.
- Return type:
values
See also: -
jax.lax.approx_max_k()
-jax.lax.approx_min_k()
- objax.functional.rsqrt(x)[source]
Elementwise reciprocal square root: \(1 \over \sqrt{x}\).
- Parameters:
x (Array | ndarray | bool_ | number | bool | int | float | complex) –
- Return type:
Array
- objax.functional.upsample_2d(x, scale, method=Interpolate.BILINEAR)[source]
Function to upscale 2D images.
- Parameters:
x (Array) – input tensor.
scale (Tuple[int, int] | int) – int or tuple scaling factor
method (Interpolate | str) – str or UpSample interpolation methods e.g. [‘bilinear’, ‘nearest’].
- Returns:
upscaled 2d image tensor
- Return type:
Array
objax.functional.divergence
|
Calculates the Kullback-Leibler divergence between arrays p and q. |
- objax.functional.divergence.kl(p, q, eps=7.62939453125e-06)[source]
Calculates the Kullback-Leibler divergence between arrays p and q.
\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]The \(\epsilon\) term is added to ensure that neither
p
norq
are zero.- Parameters:
p (Array) –
q (Array) –
eps (float) –
- Return type:
Array
objax.functional.loss
|
Computes the softmax cross-entropy loss on n-dimensional data. |
|
Computes the softmax cross-entropy loss. |
|
Computes the L2 loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
|
Computes the sigmoid cross-entropy loss. |
- objax.functional.loss.cross_entropy_logits(logits, labels)[source]
Computes the softmax cross-entropy loss on n-dimensional data.
- Parameters:
logits (Array) – (batch, …, #class) tensor of logits.
labels (Array) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns:
(batch, …) tensor of the cross-entropies for each entry.
- Return type:
Array
Calculates the cross entropy loss, defined as follows:
\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]where \(o_k\) are the logits and \(y_k\) are the labels.
- objax.functional.loss.cross_entropy_logits_sparse(logits, labels)[source]
Computes the softmax cross-entropy loss.
- Parameters:
logits (Array) – (batch, …, #class) tensor of logits.
labels (Array | int) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.
- Returns:
(batch, …) tensor of the cross-entropies for each entry.
- Return type:
Array
- objax.functional.loss.l2(x)[source]
Computes the L2 loss.
- Parameters:
x (Array) – n-dimensional tensor of floats.
- Returns:
scalar tensor containing the l2 loss of x.
- Return type:
Array
Calculates the l2 loss, as:
- objax.functional.loss.mean_absolute_error(x, y, keep_axis=(0,))[source]
Computes the mean absolute error between x and y.
- Parameters:
x (Array) – a tensor of shape (d0, .. dN-1).
y (Array) – a tensor of shape (d0, .. dN-1).
keep_axis (Iterable[int] | None) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns:
tensor of shape (d_i, …, for i in keep_axis) containing the mean absolute error.
- Return type:
Array
- objax.functional.loss.mean_squared_error(x, y, keep_axis=(0,))[source]
Computes the mean squared error between x and y.
- Parameters:
x (Array) – a tensor of shape (d0, .. dN-1).
y (Array) – a tensor of shape (d0, .. dN-1).
keep_axis (Iterable[int] | None) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns:
tensor of shape (d_i, …, for i in keep_axis) containing the mean squared error.
- Return type:
Array
- objax.functional.loss.sigmoid_cross_entropy_logits(logits, labels)[source]
Computes the sigmoid cross-entropy loss.
- Parameters:
logits (Array) – (batch, …, #class) tensor of logits.
labels (Array | int) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns:
(batch, …) tensor of the cross-entropies for each entry.
- Return type:
Array
objax.functional.parallel
|
Compute a multi-device reduce max on x over the device axis axis_name. |
|
Compute a multi-device reduce mean on x over the device axis axis_name. |
|
Compute a multi-device reduce min on x over the device axis axis_name. |
|
Compute a multi-device reduce sum on x over the device axis axis_name. |
- objax.functional.parallel.pmax(x, axis_name='device')[source]
Compute a multi-device reduce max on x over the device axis axis_name.
- Parameters:
x (Array) –
axis_name (str) –
- objax.functional.parallel.pmean(x, axis_name='device')[source]
Compute a multi-device reduce mean on x over the device axis axis_name.
- Parameters:
x (Array) –
axis_name (str) –