blob: 6f366fd0b8aa05aec25b5cdc799f551b6cb94dfe [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Benchmarks for remote worker eager execution.
To run CPU benchmarks:
bazel run -c opt remote_benchmarks_test -- --benchmark_filter=.
To run GPU benchmarks:
bazel run --config=cuda -c opt --copt="-mavx" remote_benchmarks_test -- \
--benchmark_filter=.
"""
import gc
import time
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
def run_benchmark(func, num_iters, execution_mode=None):
ctx = context.context()
with context.execution_mode(execution_mode):
# call func to maybe warm up the GPU
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
start = time.time()
for _ in range(num_iters):
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
end = time.time()
return end - start
class Foo(object):
def __init__(self, num_vars):
self._num_vars = num_vars
self._v = []
def __call__(self, inputs):
if not self._v:
for _ in range(self._num_vars):
self._v.append(variables.Variable(
random_ops.random_uniform([]), shape=[]))
for v in self._v:
inputs = inputs * v
return inputs
class RemoteWorkerMicroBenchmarks(test.Benchmark):
def __init__(self):
# used for remote benchmarks
self._cached_server1 = server_lib.Server.create_local_server()
self._cached_server_target1 = self._cached_server1.target[len("grpc://"):]
self._cached_server2 = server_lib.Server.create_local_server()
self._cached_server_target2 = self._cached_server2.target[len("grpc://"):]
def _run(self, func, num_iters=1000, execution_mode=context.ASYNC):
total_time = run_benchmark(func, num_iters, execution_mode)
mean_us = total_time * 1e6 / num_iters
self.report_benchmark(
iters=num_iters,
wall_time=mean_us,
extras={"examples_per_sec": num_iters / total_time})
def benchmark_send(self):
remote.connect_to_remote_host(self._cached_server_target1)
x = random_ops.random_uniform((2, 2)).cpu()
@def_function.function
def remote_func(m):
return math_ops.matmul(m, m)
def func(m):
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
return remote_func(m)
self._run(lambda: func(x))
# NOTE(b/136184459): Force garbage collecting hanging resources before
# subsequent calls to set_server_def, to ensure the destroy resource ops are
# executed when their corresponding device and manager are still available.
gc.collect()
def benchmark_worker_recv(self):
remote.connect_to_remote_host(
[self._cached_server_target1, self._cached_server_target2])
with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
v = variables.Variable(1.0)
@def_function.function
def remote_func():
return 1.0 + v
def func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
return remote_func()
self._run(func)
# NOTE(b/136184459): Force garbage collecting hanging resources before
# subsequent calls to set_server_def, to ensure the destroy resource ops are
# executed when their corresponding device and manager are still available.
gc.collect()
def benchmark_create_vars_inside_function(self):
remote.connect_to_remote_host(self._cached_server_target1)
def func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
layer = Foo(50)
@def_function.function
def remote_func():
with ops.device("job:worker/replica:0/task:0/device:CPU:0"):
return layer(random_ops.random_uniform([]))
return remote_func()
self._run(func, execution_mode=context.ASYNC, num_iters=100)
# NOTE(b/136184459): Force garbage collecting hanging resources before
# subsequent calls to set_server_def, to ensure the destroy resource ops are
# executed when their corresponding device and manager are still available.
gc.collect()
if __name__ == "__main__":
test.main()