xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/coordinator/watchdog.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Watchdog that monitors activity of ClusterCoordinator."""
16
17import faulthandler
18import os
19import sys
20import threading
21import time
22from absl import logging
23
24
25class WatchDog(object):
26  """A class to dump stack traces if no activity happens in ClusterCoordinator."""
27
28  def __init__(self, timeout=-1, traceback_file=sys.stdout, on_triggered=None):
29    if os.environ.get("TF_CLUSTER_COORDINATOR_WATCH_DOG_TIMEOUT",
30                      "").isnumeric():
31      timeout = int(os.environ["TF_CLUSTER_COORDINATOR_WATCH_DOG_TIMEOUT"])
32    self._timeout = timeout
33    self._last_activity_time = time.time()
34    self._traceback_file = traceback_file
35    self._on_triggered = on_triggered
36    self._stopped = False
37    if timeout > 0:
38      self._watchdog_thread = threading.Thread(
39          target=self._watchdog_function, name="WatchDog", daemon=True)
40      self._watchdog_thread.start()
41
42  def stop(self):
43    self._stopped = True
44
45  def _watchdog_function(self):
46    """The watchdog thread."""
47    logging.info("Starting watchdog thread with timeout %r", self._timeout)
48    while not self._stopped:
49      time.sleep(self._timeout / 10.0)
50      current_time = time.time()
51      if current_time - self._last_activity_time >= self._timeout:
52        logging.warning(
53            "No activity for ClusterCoordinator for %r seconds. "
54            "Dumping stack traces.", self._timeout)
55        if self._on_triggered:
56          self._on_triggered()
57        faulthandler.dump_traceback(file=self._traceback_file)
58        self._traceback_file.write("==== End of stack traces ====\n")
59        self._last_activity_time = current_time
60
61  def report_closure_done(self):
62    if self._timeout > 0:
63      self._last_activity_time = time.time()
64