Source code for objax.util.objax2tf

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from objax.module import Module
from objax.typing import JaxArray

try:
    # Only import tensorflow if available.
    import tensorflow as tf

    tf.config.experimental.set_visible_devices([], 'GPU')
except ImportError:
    # Make fake tf, so code in this file will be successfully imported even when Tensorflow is not installed.
    tf = type('tf', (), {})
    setattr(tf, 'Module', object)

    def _fake_tf_function(func=None, **kwargs):
        del kwargs
        if func is not None:
            return func
        else:
            return lambda x: x

    setattr(tf, 'function', _fake_tf_function)


[docs] class Objax2Tf(tf.Module): """Objax to Tensorflow converter, which converts Objax module to tf.Module."""
[docs] def __init__(self, module: Module): """Create a Tensorflow module from Objax module. Args: module: Objax module to be converted to Tensorflow tf.Module. """ from jax.experimental import jax2tf assert hasattr(tf, '__version__'), 'Tensorflow must be installed for Objax2Tf to work.' assert tf.__version__ >= '2.0', 'Objax2Tf works only with Tensorflow 2.' assert isinstance(module, Module), 'Input argument to Objax2Tf must be an Objax module.' super().__init__() module_vars = module.vars() def wrapped_op(tensor_list: List[JaxArray], kwargs, *args): original_values = module_vars.tensors() try: module_vars.assign(tensor_list) return module(*args, **kwargs) finally: module_vars.assign(original_values) tf_function = jax2tf.convert(wrapped_op) self._tf_vars = [tf.Variable(v) for v in module_vars.tensors()] self._tf_call = tf_function
[docs] @tf.function(autograph=False) def __call__(self, *args, **kwargs): """Calls Tensorflow function which was generated from Objax module.""" return self._tf_call(self._tf_vars, kwargs, *args)