Source code for objax.zoo.vgg

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module with VGG-19 implementation.

See https://arxiv.org/abs/1409.1556 for detail.
"""

import functools
import os
from urllib import request

import jax.numpy as jn
import numpy as np

import objax

_VGG19_URL = 'https://github.com/machrisaa/tensorflow-vgg'
_VGG19_NPY = './objax/zoo/pretrained/vgg19.npy'
_SYNSET_URL = 'https://raw.githubusercontent.com/machrisaa/tensorflow-vgg/master/synset.txt'
_SYNSET_PATH = './objax/zoo/pretrained/synset.txt'


def preprocess(x):
    bgr_mean = [103.939, 116.779, 123.68]
    red, green, blue = [x[:, i, :, :] for i in range(3)]
    return jn.stack([blue - bgr_mean[0], green - bgr_mean[1], red - bgr_mean[2]], axis=1)


def max_pool_2d(x):
    return functools.partial(objax.functional.max_pool_2d,
                             size=2, strides=2, padding=objax.constants.ConvPadding.VALID)(x)


[docs] class VGG19(objax.nn.Sequential): """VGG19 implementation."""
[docs] def __init__(self, pretrained=False): """Creates VGG19 instance. Args: pretrained: if True load weights from ImageNet pretrained model. """ if not os.path.exists(_VGG19_NPY): raise FileNotFoundError( 'You must download vgg19.npy from %s and save it to %s' % (_VGG19_URL, _VGG19_NPY)) if not os.path.exists(_SYNSET_PATH): request.urlretrieve(_SYNSET_URL, _SYNSET_PATH) self.data_dict = np.load(_VGG19_NPY, encoding='latin1', allow_pickle=True).item() self.pretrained = pretrained self.ops = self.build() super().__init__(self.ops)
[docs] def build(self): # inputs in [0, 255] self.preprocess = preprocess self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3) self.relu1_1 = objax.functional.relu self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3) self.relu1_2 = objax.functional.relu self.pool1 = max_pool_2d self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3) self.relu2_1 = objax.functional.relu self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3) self.relu2_2 = objax.functional.relu self.pool2 = max_pool_2d self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3) self.relu3_1 = objax.functional.relu self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_2 = objax.functional.relu self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_3 = objax.functional.relu self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3) self.relu3_4 = objax.functional.relu self.pool3 = max_pool_2d self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3) self.relu4_1 = objax.functional.relu self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_2 = objax.functional.relu self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_3 = objax.functional.relu self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu4_4 = objax.functional.relu self.pool4 = max_pool_2d self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_1 = objax.functional.relu self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_2 = objax.functional.relu self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_3 = objax.functional.relu self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) self.relu5_4 = objax.functional.relu self.pool5 = max_pool_2d self.flatten = objax.functional.flatten self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096) self.relu6 = objax.functional.relu self.fc7 = objax.nn.Linear(nin=4096, nout=4096) self.relu7 = objax.functional.relu self.fc8 = objax.nn.Linear(nin=4096, nout=1000) if self.pretrained: for it in self.data_dict: if it.startswith('conv'): conv = getattr(self, it) kernel, bias = self.data_dict[it] conv.w = objax.TrainVar(jn.array(kernel)) conv.b = objax.TrainVar(jn.array(bias[:, None, None])) setattr(self, it, conv) elif it.startswith('fc'): linear = getattr(self, it) kernel, bias = self.data_dict[it] if it == 'fc6': kernel = kernel.reshape([7, 7, 512, -1]).transpose((2, 0, 1, 3)).reshape([512 * 7 * 7, -1]) linear.w = objax.TrainVar(jn.array(kernel)) linear.b = objax.TrainVar(jn.array(bias)) setattr(self, it, linear) ops = [self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1, self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2, self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2, self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3, self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2, self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4, self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2, self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5, self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8] return ops