| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Implementation of the SessionRunHook for preemptible Cloud TPUs.""" |
| |
| import logging as _logging |
| import os |
| import threading |
| import time |
| |
| from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import session_run_hook |
| |
| |
| class CloudTPUPreemptedHook(session_run_hook.SessionRunHook): |
| """The SessionRunHook for preemptible Cloud TPUs. |
| |
| This is an implementation of SessionRunHook for the pre-emptible Google Cloud |
| TPU service. It attempts to close the session if the TPU is preempted, and |
| exits the coordinator process if the session cannot be closed. |
| """ |
| |
| def __init__(self, cluster): |
| self._cluster = cluster |
| |
| def after_create_session(self, session, coord): |
| if tpu_cluster_resolver.is_running_in_gce(): |
| self._tpu_poller = _TPUPollingThread(self._cluster, session) |
| self._tpu_poller.start() |
| |
| def end(self, session): |
| self._tpu_poller.stop() |
| |
| |
| class _TPUPollingThread(threading.Thread): |
| """A thread that polls the state of a TPU node. |
| |
| When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED) |
| that's considered as not recoverable by the underlying infrastructure, |
| it attempts to close the session, and exits the entire process if the |
| session.close() stucks. |
| """ |
| |
| def __init__(self, cluster, session): |
| super(_TPUPollingThread, self).__init__() |
| |
| self.daemon = True |
| self._running = True |
| self._session_closed = False |
| self._cluster = cluster |
| self._session = session |
| self._interval = 30 |
| |
| # Some of the Google API libraries are quite chatty, so disable them. |
| for name in ['googleapiclient.discovery', 'oauth2client.client']: |
| _logging.getLogger(name).setLevel(_logging.WARNING) |
| |
| def stop(self): |
| self._running = False |
| self._session_closed = True |
| self.join() |
| |
| def run(self): |
| if not tpu_cluster_resolver.is_running_in_gce(): |
| logging.warning( |
| 'TPUPollingThread is running in a non-GCE environment, exiting...') |
| self._running = False |
| return |
| |
| while self._running: |
| recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access |
| if not recoverable: |
| logging.warning( |
| 'TPUPollingThread found TPU %s in state %s', |
| self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access |
| os._exit(1) # pylint: disable=protected-access |
| time.sleep(self._interval) |