xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/multiprocessing/tail_log.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import logging
11import os
12import time
13from concurrent.futures.thread import ThreadPoolExecutor
14from threading import Event
15from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
16
17
18if TYPE_CHECKING:
19    from concurrent.futures._base import Future
20
21__all__ = ["tail_logfile", "TailLog"]
22
23logger = logging.getLogger(__name__)
24
25
26def tail_logfile(
27    header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
28):
29    while not os.path.exists(file):
30        if finished.is_set():
31            return
32        time.sleep(interval_sec)
33
34    with open(file, errors="replace") as fp:
35        while True:
36            line = fp.readline()
37
38            if line:
39                dst.write(f"{header}{line}")
40            else:  # reached EOF
41                if finished.is_set():
42                    # log line producer is finished
43                    break
44                else:
45                    # log line producer is still going
46                    # wait for a bit before looping again
47                    time.sleep(interval_sec)
48
49
50class TailLog:
51    """
52    Tail the given log files.
53
54    The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until
55    the log files are created by the producer and will tail the contents of the
56    log files until the ``stop()`` method is called.
57
58    .. warning:: ``TailLog`` will wait indefinitely for the log file to be created!
59
60    Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``,
61    where the ``name`` is user-provided and ``idx`` is the index of the log file
62    in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the
63    header for each log file.
64
65    Usage:
66
67    ::
68
69     log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"}
70     tailer = TailLog("trainer", log_files, sys.stdout).start()
71     # actually run the trainers to produce 0_stdout.log and 1_stdout.log
72     run_trainers()
73     tailer.stop()
74
75     # once run_trainers() start writing the ##_stdout.log files
76     # the tailer will print to sys.stdout:
77     # >>> [trainer0]:log_line1
78     # >>> [trainer1]:log_line1
79     # >>> [trainer0]:log_line2
80     # >>> [trainer0]:log_line3
81     # >>> [trainer1]:log_line2
82
83    .. note:: Due to buffering log lines between files may not necessarily
84              be printed out in order. You should configure your application's
85              logger to suffix each log line with a proper timestamp.
86
87    """
88
89    def __init__(
90        self,
91        name: str,
92        log_files: Dict[int, str],
93        dst: TextIO,
94        log_line_prefixes: Optional[Dict[int, str]] = None,
95        interval_sec: float = 0.1,
96    ):
97        n = len(log_files)
98        self._threadpool = None
99        if n > 0:
100            self._threadpool = ThreadPoolExecutor(
101                max_workers=n,
102                thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
103            )
104
105        self._name = name
106        self._dst = dst
107        self._log_files = log_files
108        self._log_line_prefixes = log_line_prefixes
109        self._finished_events: Dict[int, Event] = {
110            local_rank: Event() for local_rank in log_files.keys()
111        }
112        self._futs: List[Future] = []
113        self._interval_sec = interval_sec
114        self._stopped = False
115
116    def start(self) -> "TailLog":
117        if not self._threadpool:
118            return self
119
120        for local_rank, file in self._log_files.items():
121            header = f"[{self._name}{local_rank}]:"
122            if self._log_line_prefixes and local_rank in self._log_line_prefixes:
123                header = self._log_line_prefixes[local_rank]
124            self._futs.append(
125                self._threadpool.submit(
126                    tail_logfile,
127                    header=header,
128                    file=file,
129                    dst=self._dst,
130                    finished=self._finished_events[local_rank],
131                    interval_sec=self._interval_sec,
132                )
133            )
134        return self
135
136    def stop(self) -> None:
137        for finished in self._finished_events.values():
138            finished.set()
139
140        for local_rank, f in enumerate(self._futs):
141            try:
142                f.result()
143            except Exception as e:
144                logger.error(
145                    "error in log tailor for %s%s. %s: %s",
146                    self._name,
147                    local_rank,
148                    e.__class__.__qualname__,
149                    e,
150                )
151
152        if self._threadpool:
153            self._threadpool.shutdown(wait=True)
154
155        self._stopped = True
156
157    def stopped(self) -> bool:
158        return self._stopped
159