blob: c9bedb9343e76295ddf4c4db064f65768aebdd1f [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the SessionRunHook for preemptible Cloud TPUs."""
import logging as _logging
import os
import threading
import time
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
class CloudTPUPreemptedHook(session_run_hook.SessionRunHook):
"""The SessionRunHook for preemptible Cloud TPUs.
This is an implementation of SessionRunHook for the pre-emptible Google Cloud
TPU service. It attempts to close the session if the TPU is preempted, and
exits the coordinator process if the session cannot be closed.
"""
def __init__(self, cluster):
self._cluster = cluster
def after_create_session(self, session, coord):
if tpu_cluster_resolver.is_running_in_gce():
self._tpu_poller = _TPUPollingThread(self._cluster, session)
self._tpu_poller.start()
def end(self, session):
self._tpu_poller.stop()
class _TPUPollingThread(threading.Thread):
"""A thread that polls the state of a TPU node.
When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED)
that's considered as not recoverable by the underlying infrastructure,
it attempts to close the session, and exits the entire process if the
session.close() stucks.
"""
def __init__(self, cluster, session):
super(_TPUPollingThread, self).__init__()
self.daemon = True
self._running = True
self._session_closed = False
self._cluster = cluster
self._session = session
self._interval = 30
# Some of the Google API libraries are quite chatty, so disable them.
for name in ['googleapiclient.discovery', 'oauth2client.client']:
_logging.getLogger(name).setLevel(_logging.WARNING)
def stop(self):
self._running = False
self._session_closed = True
self.join()
def run(self):
if not tpu_cluster_resolver.is_running_in_gce():
logging.warning(
'TPUPollingThread is running in a non-GCE environment, exiting...')
self._running = False
return
while self._running:
recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access
if not recoverable:
logging.warning(
'TPUPollingThread found TPU %s in state %s',
self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access
os._exit(1) # pylint: disable=protected-access
time.sleep(self._interval)