blob: 4d01f8e9cab010ca6bf37325d06b47728baeda12 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executor for eager execution."""
from tensorflow.python import pywrap_tfe
class Executor(object):
"""A class for handling eager execution.
The default behavior for asynchronous execution is to serialize all ops on
a single thread. Having different `Executor` objects in different threads
enables executing ops asynchronously in parallel:
```python
def thread_function():
executor = executor.Executor(enable_async=True):
context.set_executor(executor)
a = threading.Thread(target=thread_function)
a.start()
b = threading.Thread(target=thread_function)
b.start()
```
"""
__slots__ = ["_handle"]
def __init__(self, handle):
self._handle = handle
def __del__(self):
try:
self.wait()
pywrap_tfe.TFE_DeleteExecutor(self._handle)
except TypeError:
# Suppress some exceptions, mainly for the case when we're running on
# module deletion. Things that can go wrong include the pywrap module
# already being unloaded, self._handle. no longer being
# valid, and so on. Printing warnings in these cases is silly
# (exceptions raised from __del__ are printed as warnings to stderr).
pass # 'NoneType' object is not callable when the handle has been
# partially unloaded.
def is_async(self):
return pywrap_tfe.TFE_ExecutorIsAsync(self._handle)
def handle(self):
return self._handle
def wait(self):
"""Waits for ops dispatched in this executor to finish."""
pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
def clear_error(self):
"""Clears errors raised in this executor during execution."""
pywrap_tfe.TFE_ExecutorClearError(self._handle)
def new_executor(enable_async,
enable_streaming_enqueue=True,
in_flight_nodes_limit=0):
handle = pywrap_tfe.TFE_NewExecutor(enable_async, enable_streaming_enqueue,
in_flight_nodes_limit)
return Executor(handle)