# Source code for objax.functional.core.pooling

```# Copyright 2020 Google LLC
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

__all__ = ['average_pool_2d', 'batch_to_space2d', 'channel_to_space2d', 'max_pool_2d', 'space_to_batch2d',
'space_to_channel2d']

from typing import Union, Tuple, Optional

import numpy as np
from jax import numpy as jn, lax

[docs]def average_pool_2d(x: JaxArray,
size: Union[Tuple[int, int], int] = 2,
strides: Optional[Union[Tuple[int, int], int]] = None,
"""Applies average pooling using a square 2D filter.

Args:
x: input tensor of shape (N, C, H, W).
size: size of pooling filter.
strides: stride step, use size when stride is none (default).

Returns:
output tensor of shape (N, C, H, W).
"""
size = to_tuple(size, 2)
strides = to_tuple(strides, 2) if strides else size

[docs]def batch_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray:
"""Transfer batch dimension N into spatial dimensions (H, W).

Args:
x: input tensor of shape (N, C, H, W).
size: size of spatial area.

Returns:
output tensor of shape (N // (size[0] * size[1]), C, H * size[0], W * size[1]).
"""
size = to_tuple(size, 2)
s = x.shape
y = x.reshape((-1, size[0], size[1], s[1], s[2], s[3]))
y = y.transpose((0, 3, 4, 1, 5, 2))
return y.reshape((s[0] // (size[0] * size[1]), s[1], s[2] * size[0], s[3] * size[1]))

[docs]def channel_to_space2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray:
"""Transfer channel dimension C into spatial dimensions (H, W).

Args:
x: input tensor of shape (N, C, H, W).
size: size of spatial area.

Returns:
output tensor of shape (N, C // (size[0] * size[1]), H * size[0], W * size[1]).
"""
size = to_tuple(size, 2)
s = x.shape
y = x.reshape((s[0], -1, size[0], size[1], s[2], s[3]))
y = y.transpose((0, 1, 4, 2, 5, 3))
return y.reshape((s[0], s[1] // (size[0] * size[1]), s[2] * size[0], s[3] * size[1]))

[docs]def max_pool_2d(x: JaxArray,
size: Union[Tuple[int, int], int] = 2,
strides: Optional[Union[Tuple[int, int], int]] = None,
"""Applies max pooling using a square 2D filter.

Args:
x: input tensor of shape (N, C, H, W).
size: size of pooling filter.
strides: stride step, use size when stride is none (default).

Returns:
output tensor of shape (N, C, H, W).
"""
size = to_tuple(size, 2)
strides = to_tuple(strides, 2) if strides else size
return lax.reduce_window(x, -jn.inf, lax.max, (1, 1) + size, (1, 1) + strides, padding=padding)

[docs]def space_to_batch2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray:
"""Transfer spatial dimensions (H, W) into batch dimension N.

Args:
x: input tensor of shape (N, C, H, W).
size: size of spatial area.

Returns:
output tensor of shape (N * size[0] * size[1]), C, H // size[0], W // size[1]).
"""
size = to_tuple(size, 2)
s = x.shape
y = x.reshape((s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1]))
y = y.transpose((0, 3, 5, 1, 2, 4))
return y.reshape((s[0] * size[0] * size[1], s[1], s[2] // size[0], s[3] // size[1]))

[docs]def space_to_channel2d(x: JaxArray, size: Union[Tuple[int, int], int] = 2) -> JaxArray:
"""Transfer spatial dimensions (H, W) into channel dimension C.

Args:
x: input tensor of shape (N, C, H, W).
size: size of spatial area.

Returns:
output tensor of shape (N, C * size[0] * size[1]), H // size[0], W // size[1]).
"""
size = to_tuple(size, 2)
s = x.shape
y = x.reshape((s[0], s[1], s[2] // size[0], size[0], s[3] // size[1], size[1]))
y = y.transpose((0, 1, 3, 5, 2, 4))
return y.reshape((s[0], s[1] * size[0] * size[1], s[2] // size[0], s[3] // size[1]))
```