1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Watchdog that monitors activity of ClusterCoordinator.""" 16 17import faulthandler 18import os 19import sys 20import threading 21import time 22from absl import logging 23 24 25class WatchDog(object): 26 """A class to dump stack traces if no activity happens in ClusterCoordinator.""" 27 28 def __init__(self, timeout=-1, traceback_file=sys.stdout, on_triggered=None): 29 if os.environ.get("TF_CLUSTER_COORDINATOR_WATCH_DOG_TIMEOUT", 30 "").isnumeric(): 31 timeout = int(os.environ["TF_CLUSTER_COORDINATOR_WATCH_DOG_TIMEOUT"]) 32 self._timeout = timeout 33 self._last_activity_time = time.time() 34 self._traceback_file = traceback_file 35 self._on_triggered = on_triggered 36 self._stopped = False 37 if timeout > 0: 38 self._watchdog_thread = threading.Thread( 39 target=self._watchdog_function, name="WatchDog", daemon=True) 40 self._watchdog_thread.start() 41 42 def stop(self): 43 self._stopped = True 44 45 def _watchdog_function(self): 46 """The watchdog thread.""" 47 logging.info("Starting watchdog thread with timeout %r", self._timeout) 48 while not self._stopped: 49 time.sleep(self._timeout / 10.0) 50 current_time = time.time() 51 if current_time - self._last_activity_time >= self._timeout: 52 logging.warning( 53 "No activity for ClusterCoordinator for %r seconds. " 54 "Dumping stack traces.", self._timeout) 55 if self._on_triggered: 56 self._on_triggered() 57 faulthandler.dump_traceback(file=self._traceback_file) 58 self._traceback_file.write("==== End of stack traces ====\n") 59 self._last_activity_time = current_time 60 61 def report_closure_done(self): 62 if self._timeout > 0: 63 self._last_activity_time = time.time() 64