blob: 685a4c5d70a1ac02adc908a928a846097bb20cf6 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Wrapper for Python TPU tests.
The py_tpu_test macro will actually use this file as its main, building and
executing the user-provided test file as a py_binary instead. This lets us do
important work behind the scenes, without complicating the tests themselves.
The main responsibilities of this file are:
- Define standard set of model flags if test did not. This allows us to
safely set flags at the Bazel invocation level using --test_arg.
- Pick a random directory on GCS to use for each test case, and set it as the
default value of --model_dir. This is similar to how Bazel provides each
test with a fresh local directory in $TEST_TMPDIR.
"""
import ast
import importlib
import os
import sys
import uuid
from tensorflow.python.platform import flags
from tensorflow.python.util import tf_inspect
FLAGS = flags.FLAGS
flags.DEFINE_string(
'wrapped_tpu_test_module_relative', None,
'The Python-style relative path to the user-given test. If test is in same '
'directory as BUILD file as is common, then "test.py" would be ".test".')
flags.DEFINE_string('test_dir_base',
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'),
'GCS path to root directory for temporary test files.')
flags.DEFINE_string(
'bazel_repo_root', 'tensorflow/python',
'Substring of a bazel filepath beginning the python absolute import path.')
# List of flags which all TPU tests should accept.
REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir']
def maybe_define_flags():
"""Defines any required flags that are missing."""
for f in REQUIRED_FLAGS:
try:
flags.DEFINE_string(f, None, 'flag defined by test lib')
except flags.DuplicateFlagError:
pass
def set_random_test_dir():
"""Pick a random GCS directory under --test_dir_base, set as --model_dir."""
path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex)
FLAGS.set_default('model_dir', path)
def calculate_parent_python_path(test_filepath):
"""Returns the absolute import path for the containing directory.
Args:
test_filepath: The filepath which Bazel invoked
(ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test)
Returns:
Absolute import path of parent (ex: tensorflow.python.tpu).
Raises:
ValueError: if bazel_repo_root does not appear within test_filepath.
"""
# We find the last occurrence of bazel_repo_root, and drop everything before.
split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1)
if len(split_path) < 2:
raise ValueError(
f'Filepath "{test_filepath}" does not contain repo root "{FLAGS.bazel_repo_root}"'
)
path = FLAGS.bazel_repo_root + split_path[1]
# We drop the last portion of the path, which is the name of the test wrapper.
path = path.rsplit('/', 1)[0]
# We convert the directory separators into dots.
return path.replace('/', '.')
def import_user_module():
"""Imports the flag-specified user test code.
This runs all top-level statements in the user module, specifically flag
definitions.
Returns:
The user test module.
"""
return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative,
calculate_parent_python_path(sys.argv[0]))
def _is_test_class(obj):
"""Check if arbitrary object is a test class (not a test object!).
Args:
obj: An arbitrary object from within a module.
Returns:
True iff obj is a test class inheriting at some point from a module
named "TestCase". This is because we write tests using different underlying
test libraries.
"""
return (tf_inspect.isclass(obj)
and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj)))
module_variables = vars()
def move_test_classes_into_scope(wrapped_test_module):
"""Add all test classes defined in wrapped module to our module.
The test runner works by inspecting the main module for TestCase classes, so
by adding a module-level reference to the TestCase we cause it to execute the
wrapped TestCase.
Args:
wrapped_test_module: The user-provided test code to run.
"""
for name, obj in wrapped_test_module.__dict__.items():
if _is_test_class(obj):
module_variables['tpu_test_imported_%s' % name] = obj
def run_user_main(wrapped_test_module):
"""Runs the "if __name__ == '__main__'" at the bottom of a module.
TensorFlow practice is to have a main if at the bottom of the module which
might call an API compat function before calling test.main().
Since this is a statement, not a function, we can't cleanly reference it, but
we can inspect it from the user module and run it in the context of that
module so all imports and variables are available to it.
Args:
wrapped_test_module: The user-provided test code to run.
Raises:
NotImplementedError: If main block was not found in module. This should not
be caught, as it is likely an error on the user's part -- absltest is all
too happy to report a successful status (and zero tests executed) if a
user forgets to end a class with "test.main()".
"""
tree = ast.parse(tf_inspect.getsource(wrapped_test_module))
# Get string representation of just the condition `__name == "__main__"`.
target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test)
# `tree.body` is a list of top-level statements in the module, like imports
# and class definitions. We search for our main block, starting from the end.
for expr in reversed(tree.body):
if isinstance(expr, ast.If) and ast.dump(expr.test) == target:
break
else:
raise NotImplementedError(
f'Could not find `if __name__ == "main":` block in {wrapped_test_module.__name__}.'
)
# expr is defined because we would have raised an error otherwise.
new_ast = ast.Module(body=expr.body, type_ignores=[]) # pylint:disable=undefined-loop-variable
exec( # pylint:disable=exec-used
compile(new_ast, '<ast>', 'exec'),
globals(),
wrapped_test_module.__dict__,
)
if __name__ == '__main__':
# Partially parse flags, since module to import is specified by flag.
unparsed = FLAGS(sys.argv, known_only=True)
user_module = import_user_module()
maybe_define_flags()
# Parse remaining flags.
FLAGS(unparsed)
set_random_test_dir()
move_test_classes_into_scope(user_module)
run_user_main(user_module)