1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of the SessionRunHook for preemptible Cloud TPUs.""" 16 17import logging as _logging 18import os 19import threading 20import time 21 22from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 23from tensorflow.python.platform import tf_logging as logging 24from tensorflow.python.training import session_run_hook 25 26 27class CloudTPUPreemptedHook(session_run_hook.SessionRunHook): 28 """The SessionRunHook for preemptible Cloud TPUs. 29 30 This is an implementation of SessionRunHook for the pre-emptible Google Cloud 31 TPU service. It attempts to close the session if the TPU is preempted, and 32 exits the coordinator process if the session cannot be closed. 33 """ 34 35 def __init__(self, cluster): 36 self._cluster = cluster 37 38 def after_create_session(self, session, coord): 39 if tpu_cluster_resolver.is_running_in_gce(): 40 self._tpu_poller = _TPUPollingThread(self._cluster, session) 41 self._tpu_poller.start() 42 43 def end(self, session): 44 self._tpu_poller.stop() 45 46 47class _TPUPollingThread(threading.Thread): 48 """A thread that polls the state of a TPU node. 49 50 When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED) 51 that's considered as not recoverable by the underlying infrastructure, 52 it attempts to close the session, and exits the entire process if the 53 session.close() stucks. 54 """ 55 56 def __init__(self, cluster, session): 57 super(_TPUPollingThread, self).__init__() 58 59 self.daemon = True 60 self._running = True 61 self._session_closed = False 62 self._cluster = cluster 63 self._session = session 64 self._interval = 30 65 66 # Some of the Google API libraries are quite chatty, so disable them. 67 for name in ['googleapiclient.discovery', 'oauth2client.client']: 68 _logging.getLogger(name).setLevel(_logging.WARNING) 69 70 def stop(self): 71 self._running = False 72 self._session_closed = True 73 self.join() 74 75 def run(self): 76 if not tpu_cluster_resolver.is_running_in_gce(): 77 logging.warning( 78 'TPUPollingThread is running in a non-GCE environment, exiting...') 79 self._running = False 80 return 81 82 while self._running: 83 recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access 84 if not recoverable: 85 logging.warning( 86 'TPUPollingThread found TPU %s in state %s', 87 self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access 88 os._exit(1) # pylint: disable=protected-access 89 time.sleep(self._interval) 90