objax.functional package¶
objax.functional¶
Due to the large number of APIs in this section, we organized it into the following sub-sections:
Activation¶
|
Continuously-differentiable exponential linear unit activation. |
|
Exponential linear unit activation function. |
|
Leaky rectified linear unit activation function. |
|
Log-sigmoid activation function. |
|
Log-Softmax function. |
|
Compute the log of the sum of exponentials of input elements. |
|
Rectified linear unit activation function. |
|
Scaled exponential linear unit activation. |
|
Sigmoid activation function. |
|
Softmax function. |
|
Softplus activation function. |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
-
objax.functional.
celu
(x, alpha=1.0)[source]¶ Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]For more information, see Continuously Differentiable Exponential Linear Units.
-
objax.functional.
elu
(x, alpha=1.0)[source]¶ Exponential linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]
-
objax.functional.
leaky_relu
(x, negative_slope=0.01)[source]¶ Leaky rectified linear unit activation function.
Computes the element-wise function:
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]where \(\alpha\) =
negative_slope
.
-
objax.functional.
log_sigmoid
(x)[source]¶ Log-sigmoid activation function.
Computes the element-wise function:
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]
-
objax.functional.
log_softmax
(x, axis=- 1)[source]¶ Log-Softmax function.
Computes the logarithm of the
softmax
function, which rescales elements to the range \([-\infty, 0)\).\[\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- Parameters
axis – the axis or axes along which the
log_softmax
should be computed. Either an integer or a tuple of integers.
-
objax.functional.
logsumexp
(a, axis=None, b=None, keepdims=False, return_sign=False)[source]¶ Compute the log of the sum of exponentials of input elements.
LAX-backend implementation of
logsumexp()
. Original docstring below.- Parameters
a (array_like) – Input array.
axis (None or int or tuple of ints, optional) – Axis or axes over which the sum is taken. By default axis is None, and all elements are summed.
keepdims (bool, optional) – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the original array.
b (array-like, optional) – Scaling factor for exp(a) must be of the same shape as a or broadcastable to a. These values may be negative in order to implement subtraction.
return_sign (bool, optional) – If this is set to True, the result will be a pair containing sign information; if False, results that are negative will be returned as NaN. Default is False (no sign information).
- Returns
res (ndarray) – The result,
np.log(np.sum(np.exp(a)))
calculated in a numerically more stable way. If b is given thennp.log(np.sum(b*np.exp(a)))
is returned.sgn (ndarray) – If return_sign is True, this will be an array of floating-point numbers matching res and +1, 0, or -1 depending on the sign of the result. If False, only one result is returned.
See also
numpy.logaddexp
,numpy.logaddexp2
Notes
NumPy has a logaddexp function which is very similar to logsumexp, but only handles two arguments. logaddexp.reduce is similar to this function, but may be less stable.
Examples
>>> from scipy.special import logsumexp >>> a = np.arange(10) >>> np.log(np.sum(np.exp(a))) 9.4586297444267107 >>> logsumexp(a) 9.4586297444267107
With weights
>>> a = np.arange(10) >>> b = np.arange(10, 0, -1) >>> logsumexp(a, b=b) 9.9170178533034665 >>> np.log(np.sum(b*np.exp(a))) 9.9170178533034647
Returning a sign flag
>>> logsumexp([1,2],b=[1,-1],return_sign=True) (1.5413248546129181, -1.0)
Notice that logsumexp does not directly support masked arrays. To use it on a masked array, convert the mask into zero weights:
>>> a = np.ma.array([np.log(2), 2, np.log(3)], ... mask=[False, True, False]) >>> b = (~a.mask).astype(int) >>> logsumexp(a.data, b=b), np.log(5) 1.6094379124341005, 1.6094379124341005
-
objax.functional.
relu
(x)[source]¶ Rectified linear unit activation function.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor.
- Returns
tensor with the element-wise output relu(x) = max(x, 0).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
selu
(x)[source]¶ Scaled exponential linear unit activation.
Computes the element-wise function:
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]where \(\lambda = 1.0507009873554804934193349852946\) and \(\alpha = 1.6732632423543772848170429916717\).
For more information, see Self-Normalizing Neural Networks.
-
objax.functional.
sigmoid
(x)[source]¶ Sigmoid activation function.
Computes the element-wise function:
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]
-
objax.functional.
softmax
(x, axis=- 1)[source]¶ Softmax function.
Computes the function which rescales elements to the range \([0, 1]\) such that the elements along
axis
sum to \(1\).\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- Parameters
axis – the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to \(1\). Either an integer or a tuple of integers.
Pooling¶
|
Applies average pooling using a square 2D filter. |
|
Transfer batch dimension N into spatial dimensions (H, W). |
|
Transfer channel dimension C into spatial dimensions (H, W). |
|
Applies max pooling using a square 2D filter. |
|
Transfer spatial dimensions (H, W) into batch dimension N. |
|
Transfer spatial dimensions (H, W) into channel dimension C. |
-
objax.functional.
average_pool_2d
(x, size=2, strides=None, padding=<ConvPadding.VALID: 'VALID'>)[source]¶ Applies average pooling using a square 2D filter.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
For a definition of pooling, including examples see Pooling Layer.
-
objax.functional.
batch_to_space2d
(x, size=2)[source]¶ Transfer batch dimension N into spatial dimensions (H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
channel_to_space2d
(x, size=2)[source]¶ Transfer channel dimension C into spatial dimensions (H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
max_pool_2d
(x, size=2, strides=None, padding=<ConvPadding.VALID: 'VALID'>)[source]¶ Applies max pooling using a square 2D filter.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of pooling filter.
strides (Optional[Union[Tuple[int, int], int]]) – stride step, use size when stride is none (default).
padding (Union[objax.constants.ConvPadding, str, Sequence[Tuple[int, int]], Tuple[int, int], int]) – padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values.
- Returns
output tensor of shape (N, C, H, W).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
For a definition of pooling, including examples see Pooling Layer.
-
objax.functional.
space_to_batch2d
(x, size=2)[source]¶ Transfer spatial dimensions (H, W) into batch dimension N.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
space_to_channel2d
(x, size=2)[source]¶ Transfer spatial dimensions (H, W) into channel dimension C.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
size (Union[Tuple[int, int], int]) – size of spatial area.
- Returns
output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
Misc¶
|
Wraps XLA’s DynamicSlice operator. |
|
Flattens input tensor to a 2D tensor. |
|
One-hot encodes the given indicies. |
|
Pad an array. |
Stops gradient computation. |
|
|
Returns top |
|
Elementwise reciprocal square root: :math:`1 over sqrt{x}. |
|
Nearest neighbor upscale for image batches of shape (N, C, H, W). |
-
objax.functional.
dynamic_slice
(operand, start_indices, slice_sizes)[source]¶ Wraps XLA’s DynamicSlice operator.
- Parameters
operand (Any) – an array to slice.
start_indices (Sequence[Any]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Sequence[int]) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Returns
An array containing the slice.
- Return type
Any
-
objax.functional.
flatten
(x)[source]¶ Flattens input tensor to a 2D tensor.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor with dimensions (n_1, n_2, …, n_k)
- Returns
The input tensor reshaped to two dimensions (n_1, n_prod), where n_prod is equal to the product of n_2 to n_k.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.
one_hot
(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>)[source]¶ One-hot encodes the given indicies.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
- DeviceArray([[1., 0., 0.],
[0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
- DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)
- Parameters
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
-
objax.functional.
pad
(array, pad_width, mode='constant', constant_values=0)[source]¶ Pad an array.
LAX-backend implementation of
pad()
. Original docstring below.- Parameters
array (array_like of rank N) – The array to pad.
pad_width ({sequence, array_like, int}) – Number of values padded to the edges of each axis. ((before_1, after_1), … (before_N, after_N)) unique pad widths for each axis. ((before, after),) yields same before and after pad for each axis. (pad,) or int is a shortcut for before = after = pad width for all axes.
mode (str or function, optional) – One of the following string values or a user supplied function.
constant_values (sequence or scalar, optional) – Used in ‘constant’. The values to set the padded values for each axis.
- Returns
pad – Padded array of rank equal to array with shape increased according to pad_width.
- Return type
ndarray
Notes
New in version 1.7.0.
For an array with rank greater than 1, some of the padding of later axes is calculated from padding of previous axes. This is easiest to think about with a rank 2 array where the corners of the padded array are calculated by using padded values from the first axis.
The padding function, if used, should modify a rank 1 array in-place. It has the following signature:
padding_func(vector, iaxis_pad_width, iaxis, kwargs)
where
- vectorndarray
A rank 1 array already padded with zeros. Padded values are vector[:iaxis_pad_width[0]] and vector[-iaxis_pad_width[1]:].
- iaxis_pad_widthtuple
A 2-tuple of ints, iaxis_pad_width[0] represents the number of values padded at the beginning of vector where iaxis_pad_width[1] represents the number of values padded at the end of vector.
- iaxisint
The axis currently being calculated.
- kwargsdict
Any keyword arguments the function requires.
Examples
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'constant', constant_values=(4, 6)) array([4, 4, 1, ..., 6, 6, 6])
>>> np.pad(a, (2, 3), 'edge') array([1, 1, 1, ..., 5, 5, 5])
>>> np.pad(a, (2, 3), 'linear_ramp', end_values=(5, -4)) array([ 5, 3, 1, 2, 3, 4, 5, 2, -1, -4])
>>> np.pad(a, (2,), 'maximum') array([5, 5, 1, 2, 3, 4, 5, 5, 5])
>>> np.pad(a, (2,), 'mean') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> np.pad(a, (2,), 'median') array([3, 3, 1, 2, 3, 4, 5, 3, 3])
>>> a = [[1, 2], [3, 4]] >>> np.pad(a, ((3, 2), (2, 3)), 'minimum') array([[1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1], [3, 3, 3, 4, 3, 3, 3], [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1]])
>>> a = [1, 2, 3, 4, 5] >>> np.pad(a, (2, 3), 'reflect') array([3, 2, 1, 2, 3, 4, 5, 4, 3, 2])
>>> np.pad(a, (2, 3), 'reflect', reflect_type='odd') array([-1, 0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> np.pad(a, (2, 3), 'symmetric') array([2, 1, 1, 2, 3, 4, 5, 5, 4, 3])
>>> np.pad(a, (2, 3), 'symmetric', reflect_type='odd') array([0, 1, 1, 2, 3, 4, 5, 5, 6, 7])
>>> np.pad(a, (2, 3), 'wrap') array([4, 5, 1, 2, 3, 4, 5, 1, 2, 3])
>>> def pad_with(vector, pad_width, iaxis, kwargs): ... pad_value = kwargs.get('padder', 10) ... vector[:pad_width[0]] = pad_value ... vector[-pad_width[1]:] = pad_value >>> a = np.arange(6) >>> a = a.reshape((2, 3)) >>> np.pad(a, 2, pad_with) array([[10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 0, 1, 2, 10, 10], [10, 10, 3, 4, 5, 10, 10], [10, 10, 10, 10, 10, 10, 10], [10, 10, 10, 10, 10, 10, 10]]) >>> np.pad(a, 2, pad_with, padder=100) array([[100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 0, 1, 2, 100, 100], [100, 100, 3, 4, 5, 100, 100], [100, 100, 100, 100, 100, 100, 100], [100, 100, 100, 100, 100, 100, 100]])
-
objax.functional.
stop_gradient
(x)[source]¶ Stops gradient computation.
Operationally
stop_gradient
is the identity function, that is, it returns argument x unchanged. However,stop_gradient
prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations,stop_gradient
stops gradients for all of them.For example:
>>> jax.grad(lambda x: x**2)(3.) array(6., dtype=float32) >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) array(0., dtype=float32) >>> jax.grad(jax.grad(lambda x: x**2))(3.) array(2., dtype=float32) >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) array(0., dtype=float32)
-
objax.functional.
top_k
(operand, k)[source]¶ Returns top
k
values and their indices along the last axis ofoperand
.- Parameters
operand (Any) –
k (int) –
- Return type
Tuple[Any, Any]
-
objax.functional.
rsqrt
(x)[source]¶ Elementwise reciprocal square root: :math:`1 over sqrt{x}.
- Parameters
x (Any) –
- Return type
Any
-
objax.functional.
upscale_nn
(x, scale=2)[source]¶ Nearest neighbor upscale for image batches of shape (N, C, H, W).
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – input tensor of shape (N, C, H, W).
scale (int) – integer scaling factor.
- Returns
Output tensor of shape (N, C, H * scale, W * scale).
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.divergence¶
|
Calculates the Kullback-Leibler divergence between arrays p and q. |
-
objax.functional.divergence.
kl
(p, q, eps=7.62939453125e-06)[source]¶ Calculates the Kullback-Leibler divergence between arrays p and q.
\[kl(p,q) = p \cdot \log{\frac{p + \epsilon}{q + \epsilon}}\]The \(\epsilon\) term is added to ensure that neither
p
norq
are zero.- Parameters
p (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
q (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) –
eps (float) –
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.loss¶
|
Computes the softmax cross-entropy loss on n-dimensional data. |
|
Computes the softmax cross-entropy loss. |
|
Computes the L2 loss. |
|
Computes the mean absolute error between x and y. |
|
Computes the mean squared error between x and y. |
|
Computes the mean squared logarithmic error between y_true and y_pred. |
|
Computes the sigmoid cross-entropy loss. |
-
objax.functional.loss.
cross_entropy_logits
(logits, labels)[source]¶ Computes the softmax cross-entropy loss on n-dimensional data.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
Calculates the cross entropy loss, defined as follows:
\[\begin{split}\begin{eqnarray} l(y,\hat{y}) & = & - \sum_{j=1}^{q} y_j \log \frac{e^{o_j}}{\sum_{k=1}^{q} e^{o_k}} \nonumber \\ & = & \log \sum_{k=1}^{q} e^{o_k} - \sum_{j=1}^{q} y_j o_j \nonumber \end{eqnarray}\end{split}\]where \(o_k\) are the logits and \(y_k\) are the labels.
-
objax.functional.loss.
cross_entropy_logits_sparse
(logits, labels)[source]¶ Computes the softmax cross-entropy loss.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …) integer tensor of label indexes in {0, …,#nclass-1} or just a single integer.
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.loss.
l2
(x)[source]¶ Computes the L2 loss.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – n-dimensional tensor of floats.
- Returns
scalar tensor containing the l2 loss of x.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
Calculates the l2 loss, as:
-
objax.functional.loss.
mean_absolute_error
(x, y, keep_axis=(0))[source]¶ Computes the mean absolute error between x and y.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).
y (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).
keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns
tensor of shape (d_i, …, for i in keep_axis) containing the mean absolute error.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.loss.
mean_squared_error
(x, y, keep_axis=(0))[source]¶ Computes the mean squared error between x and y.
- Parameters
x (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).
y (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – a tensor of shape (d0, .. dN-1).
keep_axis (Optional[Iterable[int]]) – a sequence of the dimensions to keep, use None to return a scalar value.
- Returns
tensor of shape (d_i, …, for i in keep_axis) containing the mean squared error.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
-
objax.functional.loss.
sigmoid_cross_entropy_logits
(logits, labels)[source]¶ Computes the sigmoid cross-entropy loss.
- Parameters
logits (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]) – (batch, …, #class) tensor of logits.
labels (Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray, int]) – (batch, …, #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1)
- Returns
(batch, …) tensor of the cross-entropies for each entry.
- Return type
Union[jax._src.numpy.lax_numpy.ndarray, jax.interpreters.xla.DeviceArray, jax.interpreters.pxla.ShardedDeviceArray]
objax.functional.parallel¶
|
Compute a multi-device reduce max on x over the device axis axis_name. |
|
Compute a multi-device reduce mean on x over the device axis axis_name. |
|
Compute a multi-device reduce min on x over the device axis axis_name. |
|
Compute a multi-device reduce sum on x over the device axis axis_name. |
-
objax.functional.parallel.
pmax
(x, axis_name='device')[source]¶ Compute a multi-device reduce max on x over the device axis axis_name.
- Parameters
x (jax.interpreters.pxla.ShardedDeviceArray) –
axis_name (str) –
-
objax.functional.parallel.
pmean
(x, axis_name='device')[source]¶ Compute a multi-device reduce mean on x over the device axis axis_name.
- Parameters
x (jax.interpreters.pxla.ShardedDeviceArray) –
axis_name (str) –