# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['find_used_variables']
import inspect
import ast
import re
import tokenize
from io import StringIO
from typing import Callable
from objax.variable import BaseVar, VarCollection
from objax.module import Module
def getanno(node, key, field_name='__anno'):
"""Gets annotation of the AST node."""
return getattr(node, field_name)[key]
def hasanno(node, key, field_name='__anno'):
"""Returns whether AST node has an annotation."""
return hasattr(node, field_name) and key in getattr(node, field_name)
def setanno(node, key, value, field_name='__anno'):
"""Sets annotation on AST node."""
annotations = getattr(node, field_name, {})
setattr(node, field_name, annotations)
annotations[key] = value
class AnalyzeUserVariablesNodeTransformer(ast.NodeTransformer):
def __init__(self, closure_vars, global_vars):
self.closure_vars = closure_vars
self.global_vars = global_vars
self.vc = VarCollection()
def check_objax_var_module(self, node):
if not hasanno(node, 'value'):
return
v = getanno(node, 'value')
v_name = getanno(node, 'name')
if v is None:
return
if isinstance(v, Module):
self.vc.update(v.vars(scope=v_name + '.'))
setanno(node, 'value', None)
if isinstance(v, BaseVar):
if v_name in self.vc and self.vc[v_name] is not v:
# This generally should not happen and probably indication of a bug somewhere.
raise ValueError(
f'Variable tracing failed because two variables were found with the same name {v_name}')
else:
self.vc[v_name] = v
setanno(node, 'value', None)
def visit_Name(self, node):
node = self.generic_visit(node)
if isinstance(node.ctx, ast.Load):
if node.id in self.closure_vars:
setanno(node, 'name', node.id)
setanno(node, 'value', self.closure_vars[node.id])
self.check_objax_var_module(node)
elif node.id in self.global_vars:
setanno(node, 'name', node.id)
setanno(node, 'value', self.global_vars[node.id])
self.check_objax_var_module(node)
return node
def visit_Attribute(self, node):
node = self.generic_visit(node)
if isinstance(node.ctx, ast.Load) and hasanno(node.value, 'value'):
parent_value = getanno(node.value, 'value')
if parent_value is not None and hasattr(parent_value, node.attr):
setanno(node, 'name', getanno(node.value, 'name') + '.' + node.attr)
setanno(node, 'value', getattr(parent_value, node.attr))
self.check_objax_var_module(node)
return node
_LEADING_WHITESPACE = re.compile(r'\s*')
def dedent_block(code_string):
"""Dedents a code so that its first line starts at row zero."""
# Removes any backslash line continuations from the code
code_string = code_string.replace('\\\n', '')
token_gen = tokenize.generate_tokens(StringIO(code_string).readline)
block_indentation = None
tokens = []
try:
for tok in token_gen:
tokens.append(tok)
except tokenize.TokenError:
# Resolution of lambda functions may yield incomplete code, which can
# in turn generate this error. We silently ignore this error because the
# parser may still be able to deal with it.
pass
for tok in tokens:
tok_type, tok_string, _, _, _ = tok
if tok_type == tokenize.INDENT:
block_indentation = tok_string
break
elif tok_type not in (tokenize.NL, tokenize.NEWLINE, tokenize.STRING, tokenize.COMMENT):
block_indentation = ''
break
if not block_indentation:
return code_string
block_level = len(block_indentation)
first_indent_uses_tabs = '\t' in block_indentation
for i, tok in enumerate(tokens):
tok_type, tok_string, _, _, _ = tok
if tok_type == tokenize.INDENT:
if ((' ' in tok_string and first_indent_uses_tabs) or ('\t' in tok_string and not first_indent_uses_tabs)):
raise ValueError('Code mixing tabs and spaces for indentation is not allowed')
if len(tok_string) >= block_level:
tok_string = tok_string[block_level:]
tokens[i] = (tok_type, tok_string)
new_code = tokenize.untokenize(tokens)
# Note: untokenize respects the line structure, but not the whitespace within
# lines. For example, `def foo()` may be untokenized as `def foo ()`
# So instead of using the output of dedent, we match the leading whitespace
# on each line.
dedented_code = []
for line, new_line in zip(code_string.split('\n'), new_code.split('\n')):
original_indent = re.match(_LEADING_WHITESPACE, line).group()
new_indent = re.match(_LEADING_WHITESPACE, new_line).group()
if len(original_indent) > len(new_indent):
dedented_line = line[len(original_indent) - len(new_indent):]
else:
dedented_line = line
dedented_code.append(dedented_line)
new_code = '\n'.join(dedented_code)
return new_code
[docs]def find_used_variables(fn: Callable) -> VarCollection:
"""Finds all Objax variables which are used by a given callable.
Args:
fn: input function or callable.
Returns:
Variable collection with all variables used by input function.
"""
if not hasattr(fn, '__code__'):
raise ValueError('Can not determine variables used by a function. Function does not have __code__ attribute.')
try:
src = inspect.getsource(fn)
except OSError:
raise ValueError('Can not determine variables used by a function. Code of the function can not be retrieved.')
src = dedent_block(src)
main_node = ast.parse(src)
main_node = main_node.body[0]
if fn.__closure__:
closure_vars = {name: cell.cell_contents for name, cell in zip(fn.__code__.co_freevars, fn.__closure__)}
else:
closure_vars = {}
analyzer = AnalyzeUserVariablesNodeTransformer(closure_vars=closure_vars,
global_vars=fn.__globals__)
analyzer.visit(main_node)
return analyzer.vc