xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/preempted_hook.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of the SessionRunHook for preemptible Cloud TPUs."""
16
17import logging as _logging
18import os
19import threading
20import time
21
22from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
23from tensorflow.python.platform import tf_logging as logging
24from tensorflow.python.training import session_run_hook
25
26
27class CloudTPUPreemptedHook(session_run_hook.SessionRunHook):
28  """The SessionRunHook for preemptible Cloud TPUs.
29
30  This is an implementation of SessionRunHook for the pre-emptible Google Cloud
31  TPU service. It attempts to close the session if the TPU is preempted, and
32  exits the coordinator process if the session cannot be closed.
33  """
34
35  def __init__(self, cluster):
36    self._cluster = cluster
37
38  def after_create_session(self, session, coord):
39    if tpu_cluster_resolver.is_running_in_gce():
40      self._tpu_poller = _TPUPollingThread(self._cluster, session)
41      self._tpu_poller.start()
42
43  def end(self, session):
44    self._tpu_poller.stop()
45
46
47class _TPUPollingThread(threading.Thread):
48  """A thread that polls the state of a TPU node.
49
50  When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED)
51  that's considered as not recoverable by the underlying infrastructure,
52  it attempts to close the session, and exits the entire process if the
53  session.close() stucks.
54  """
55
56  def __init__(self, cluster, session):
57    super(_TPUPollingThread, self).__init__()
58
59    self.daemon = True
60    self._running = True
61    self._session_closed = False
62    self._cluster = cluster
63    self._session = session
64    self._interval = 30
65
66    # Some of the Google API libraries are quite chatty, so disable them.
67    for name in ['googleapiclient.discovery', 'oauth2client.client']:
68      _logging.getLogger(name).setLevel(_logging.WARNING)
69
70  def stop(self):
71    self._running = False
72    self._session_closed = True
73    self.join()
74
75  def run(self):
76    if not tpu_cluster_resolver.is_running_in_gce():
77      logging.warning(
78          'TPUPollingThread is running in a non-GCE environment, exiting...')
79      self._running = False
80      return
81
82    while self._running:
83      recoverable = self._cluster._cloud_tpu_client.recoverable()  # pylint: disable=protected-access
84      if not recoverable:
85        logging.warning(
86            'TPUPollingThread found TPU %s in state %s',
87            self._cluster._tpu, self._cluster._cloud_tpu_client.state())  # pylint: disable=protected-access
88        os._exit(1)  # pylint: disable=protected-access
89      time.sleep(self._interval)
90