Source code for objax.functional.core.pooling

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['average_pool_2d', 'batch_to_space2d', 'channel_to_space2d', 'max_pool_2d', 'space_to_batch2d',
           'space_to_channel2d']

from typing import Union, Tuple, Optional

import numpy as np
from jax import numpy as jn, lax

from objax.constants import ConvPadding
from objax.typing import JaxArray, ConvPaddingInt
from objax.util import to_padding, to_tuple


[docs] def average_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID) -> JaxArray: """Applies average pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size padding = to_padding(padding, 2) if isinstance(padding, tuple): padding = ((0, 0), (0, 0)) + padding return lax.reduce_window(x, 0, lax.add, (1, 1) + size, (1, 1) + strides, padding=padding) / np.prod(size)
[docs] def batch_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer batch dimension N into spatial dimensions (H, W). Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((-1, size[0], size[1], s[1], s[2], s[3])) y = y.transpose((0, 3, 4, 1, 5, 2)) return y.reshape((s[0] // (size[0] * size[1]), s[1], s[2] * size[0], s[3] * size[1]))
[docs] def channel_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer channel dimension C into spatial dimensions (H, W). Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3])) y = y.transpose((0, 1, 4, 2, 5, 3)) return y.reshape((s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))
[docs] def max_pool_2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2, strides: Optional[Union[Tuple[int, int], int]] = None, padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.VALID) -> JaxArray: """Applies max pooling using a square 2D filter. Args: x: input tensor of shape (N, C, H, W). size: size of pooling filter. strides: stride step, use size when stride is none (default). padding: padding of the input tensor, either Padding.SAME or Padding.VALID or numerical values. Returns: output tensor of shape (N, C, H, W). """ size = to_tuple(size, 2) strides = to_tuple(strides, 2) if strides else size padding = to_padding(padding, 2) if isinstance(padding, tuple): padding = ((0, 0), (0, 0)) + padding return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding)
[docs] def space_to_batch2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer spatial dimensions (H, W) into batch dimension N. Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1])) y = y.transpose((0, 3, 5, 1, 2, 4)) return y.reshape((s[0] * size[0] * size[1], s[1], s[2] // size[0], s[3] // size[1]))
[docs] def space_to_channel2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray: """Transfer spatial dimensions (H, W) into channel dimension C. Args: x: input tensor of shape (N, C, H, W). size: size of spatial area. Returns: output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]). """ size = to_tuple(size, 2) s = x.shape y = x.reshape((s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1])) y = y.transpose((0, 1, 3, 5, 2, 4)) return y.reshape((s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))