Source code for objax.nn.init

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['gain_leaky_relu', 'kaiming_normal', 'kaiming_normal_gain', 'kaiming_truncated_normal', 'truncated_normal',
           'xavier_normal', 'xavier_truncated_normal']

from typing import Tuple

import numpy as np
import scipy.stats

from objax import random
from objax.typing import JaxArray


# Expect format for init APIs
# Convolution: HWIO (number of output channels at the end)
# Linear: IO (number of output dimensions at the end)
# In general: the last dimensions is the number of output channels/dimensions.

[docs]def gain_leaky_relu(relu_slope: float = 0.1): """The recommended gain value for leaky_relu. Args: relu_slope: negative slope of leaky_relu. Returns: The recommended gain value for leaky_relu. """ return np.sqrt(2 / (1 + relu_slope ** 2))
[docs]def kaiming_normal(shape: Tuple[int, ...], gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Kaiming He normal initializer from `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Args: shape: shape of the output tensor. gain: optional scaling factor. Returns: Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain). """ return random.normal(shape, stddev=gain * kaiming_normal_gain(shape))
[docs]def kaiming_normal_gain(shape: Tuple[int, ...]) -> float: """Returns Kaiming He gain from `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Args: shape: shape of the output tensor. Returns: Scalar, the standard deviation gain. """ fan_in = np.prod(shape[:-1]) return np.sqrt(1 / fan_in)
[docs]def kaiming_truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Kaiming He truncated normal initializer from `Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification <https://arxiv.org/abs/1502.01852>`_. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. gain: optional scaling factor. Returns: Tensor initialized with truncated normal random variables with standard deviation (gain * kaiming_normal_gain) and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev = gain * kaiming_normal_gain(shape) / truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)
[docs]def xavier_normal(shape: Tuple[int, ...], gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Xavier Glorot normal initializer from `Understanding the difficulty of training deep feedforward neural networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. Args: shape: shape of the output tensor. gain: optional scaling factor. Returns: Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain). """ return random.normal(shape, stddev=gain * xavier_normal_gain(shape))
[docs]def truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, stddev: float = 1) -> JaxArray: """Returns a tensor with values assigned using truncated normal initialization. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. stddev: expected standard deviation. Returns: Tensor initialized with truncated normal random variables with standard deviation stddev and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev /= truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)
[docs]def xavier_normal_gain(shape: Tuple[int, ...]) -> float: """Returns Xavier Glorot gain from `Understanding the difficulty of training deep feedforward neural networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. Args: shape: shape of the output tensor. Returns: Scalar, the standard deviation gain. """ fan_in, fan_out = np.prod(shape[:-1]), shape[-1] return np.sqrt(2 / (fan_in + fan_out))
[docs]def xavier_truncated_normal(shape: Tuple[int, ...], lower: float = -2, upper: float = 2, gain: float = 1) -> JaxArray: """Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from `Understanding the difficulty of training deep feedforward neural networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. Args: shape: shape of the output tensor. lower: lower truncation of the normal. upper: upper truncation of the normal. gain: optional scaling factor. Returns: Tensor initialized with truncated normal random variables with standard deviation (gain * xavier_normal_gain) and support [lower, upper]. """ truncated_std = scipy.stats.truncnorm.std(a=lower, b=upper, loc=0., scale=1) stddev = gain * xavier_normal_gain(shape) / truncated_std return random.truncated_normal(shape, stddev=stddev, lower=lower, upper=upper)