# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test utilities."""
import collections
import dataclasses
import functools
import io
import itertools
import threading
from absl import app
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.util import nest
import objgraph # pylint:disable=g-import-not-at-top
except ImportError:
objgraph = None
class TestClusterParams:
cluster: dict
max_num_worker: int
max_num_ps: int
def get_cluster_def(cluster_params, num_workers, num_ps):
if (num_workers > cluster_params.max_num_worker or
num_ps > cluster_params.max_num_ps):
raise ValueError("Requesting more servers than the maximum, adjust"
"cluster params' max_num_ps and max_num_worker")
if cluster_params.cluster is None:
cluster_params.cluster = multi_worker_test_base.create_in_process_cluster(
return {
"worker": cluster_params.cluster["worker"][:num_workers],
"ps": cluster_params.cluster["ps"][:num_ps],
def gather(strategy, value):
"""Gathers value from all workers.
This is intended for tests before we implement an official all-gather API.
strategy: a `tf.distribute.Strategy`.
value: a nested structure of n-dim `tf.distribute.DistributedValue` of
`tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
Cannot contain tf.sparse.SparseTensor.
a (n+1)-dim `tf.Tensor`.
return nest.map_structure(functools.partial(_gather, strategy), value)
def _gather(strategy, value):
"""Gathers a single value."""
# pylint: disable=protected-access
if not isinstance(value, values.DistributedValues):
value = values.PerReplica([ops.convert_to_tensor(value)])
if not isinstance(strategy.extended,
return array_ops_stack.stack(value._values)
assert len(strategy.extended.worker_devices) == len(value._values)
inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
return strategy.gather(values.PerReplica(inputs), axis=0)
# pylint: enable=protected-access
def set_logical_devices_to_at_least(device, num):
"""Create logical devices of at least a given number."""
if num < 1:
raise ValueError("`num` must be at least 1 not %r" % (num,))
physical_devices = config.list_physical_devices(device)
if not physical_devices:
raise RuntimeError("No {} found".format(device))
if len(physical_devices) >= num:
# By default each physical device corresponds to one logical device. We create
# multiple logical devices for the last physical device so that we have `num`
# logical devices.
num = num - len(physical_devices) + 1
logical_devices = []
for _ in range(num):
if device.upper() == "GPU":
# Create logical devices from the last device since sometimes the first GPU
# is the primary graphic card and may have less memory available.
config.set_logical_device_configuration(physical_devices[-1], logical_devices)
def _set_logical_devices():
if config.list_physical_devices("GPU"):
set_logical_devices_to_at_least("GPU", 2)
if config.list_physical_devices("CPU"):
set_logical_devices_to_at_least("CPU", 2)
def main(enable_v2_behavior=True, config_logical_devices=True):
"""All-in-one main function for tf.distribute tests."""
if config_logical_devices:
if enable_v2_behavior:
def _op_dependencies(op):
"""Returns the data and control dependencies of a tf.Operation combined."""
deps = []
for node in itertools.chain(op.inputs, op.control_inputs):
if isinstance(node, ops.Tensor):
node = node.op
assert isinstance(node, ops.Operation)
return deps
def topological_sort_operations(operations):
"""Topological sorts a list of operations.
This does a topological sort of the operations in a graph. The edges include
both data dependencies and control dependencies. Note that the edge goes from
an operation to its dependencies.
The sort is intentionally unstable, reversing orders of operations and
dependencies on ties.
operations: a list of tf.Operation in the same graph.
A map from a tf.Operation to its topological order.
in_degrees = collections.OrderedDict()
for op in reversed(operations):
if op not in in_degrees:
in_degrees[op] = 0
for next_op in reversed(_op_dependencies(op)):
in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
nexts = []
for op, in_degree in in_degrees.items():
if in_degree == 0:
order = {}
next_order = 0
while nexts:
op, nexts = nexts[0], nexts[1:]
order[op] = next_order
next_order += 1
for next_op in reversed(_op_dependencies(op)):
in_degrees[next_op] -= 1
if in_degrees[next_op] == 0:
assert len(order) == len(operations)
return order
def _exists_dependency(start, end):
"""Returns whether there exists a dependency chain from start to end."""
nexts = [start]
while nexts:
op, nexts = nexts[0], nexts[1:]
for next_op in _op_dependencies(op):
if next_op == end:
return True
return False
def assert_sequential_execution(order, operations):
"""Asserts there's a deterministic execution order between the operations.
order: a map from a tf.Operation to its topological order.
operations: a list of operations that should be executed sequentially. It
can be given in any order.
# Topological ordering guarantees that, if there's a dependency from N_a to
# N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
# among the operations, it always goes from a operation with a smaller
# topological order to one with a larger topological order. Therefore, we only
# need to sort the operations by their topological orders, and verify that
# there's a path of dependency between adjacent pairs.
operations = sorted(operations, key=lambda op: order[op])
for i in range(len(operations) - 1):
if not _exists_dependency(operations[i], operations[i + 1]):
raise AssertionError(
"No dependency between {} and {}. Graph is dumped to stdout.".format(
operations[i].name, operations[i + 1].name))
def get_running_threads():
"""Returns a set of all running thread names."""
running_threads = set()
for thread in threading.enumerate():
if is not None:
return running_threads
def has_thread(prefix, running_threads):
"""Returns whether any 'running_threads' is prefixed with 'prefix'.
prefix: The prefix of the expected thread name.
running_threads: A collection of the running thread names.
for thread in running_threads:
if thread.startswith(prefix):
return True
return False
def show_backref(target, max_depth=3):
"""Returns a dot graph of all the objects that are referencing the target.
A object referencing graph is useful to debug memory leak like circular
reference. objgraph provides a good visualization of the memory graph than
most python built-in utilities like gc.get_referrers(), which are not
human-readable sometimes.
The dot graph will be written to a string IO object, and can be rendered with
graphviz in operating system.
E.g. dot -Tpng {$dot_graph} -o output.png
target: The target object for the memory graph.
max_depth: The maximum depth of the graph. By default 3 layers of references
are used. Increases this a lot may result in the graph growing too big.
A string that contains the object reference graph.
NotImplementedError: if objgraph is not installed.
if objgraph is None:
raise NotImplementedError("objgraph is not installed.")
string_io = io.StringIO()
objgraph.show_backrefs(target, max_depth=max_depth, output=string_io)
graph = string_io.getvalue()
return graph
def create_per_replica(strategy, value_list):
"""Creates a PerReplica of Tensors from the value_list."""
if len(strategy.extended.worker_devices) != len(value_list):
raise ValueError(
"the length of values must be the same as the number of worker devices")
tensors = []
for device, value in zip(strategy.extended.worker_devices, value_list):
with ops.device(device):
return values.PerReplica(tensors)
def is_tpu_strategy(strategy):
"""Returns whether the strategy is a TPU strategy."""
return isinstance(strategy,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
def reset_context():
"""Resets eager context."""
context._reset_context() # pylint: disable=protected-access