xref: /aosp_15_r20/external/pytorch/torch/_logging/_internal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import hashlib
4import itertools
5import json
6import logging
7import os
8import os.path
9import pathlib
10import re
11import sys
12import tempfile
13from dataclasses import dataclass, field
14from importlib import __import__
15from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
16from weakref import WeakSet
17
18import torch._logging.structured
19from torch._utils_internal import log_trace_structured_event
20from torch.utils._traceback import CapturedTraceback
21
22
23log = logging.getLogger(__name__)
24
25# This is a synthetic logger which doesn't correspond to an actual logger,
26# but handles all of our "tracing" logging, which is structured and doesn't go
27# to stderr but always goes to a dedicated log file.  We don't put these
28# loggers in the classic module hierarchy, because we don't want a suppression
29# of logs to also cause a trace to get suppressed (traces typically are not
30# collected, unless we are in prod, in which case they always are collected.)
31#
32# TODO: Maybe we should allow for some sub-hierarchy so you can control which
33# traces you want to collect, for performance reasons.
34#
35# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
36trace_log = logging.getLogger("torch.__trace")
37
38DEFAULT_LOG_LEVEL = logging.WARNING
39LOG_ENV_VAR = "TORCH_LOGS"
40LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
41LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
42TRACE_ENV_VAR = "TORCH_TRACE"
43
44
45@dataclass
46class LogRegistry:
47    # shorthand name to log qualified name
48    # Note: this only contains loggers registered
49    # from register_log
50    # e.g. "dynamo" -> "torch._dynamo"
51    log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
52
53    # artifact logger qualified names,
54    # this is populated lazily, as calls to getArtifactLogger
55    # currently formatted as <module>.__<artifact_name>
56    # e.g. "torch._dynamo.convert_frame.__guards"
57    artifact_log_qnames: Set[str] = field(default_factory=set)
58
59    # child logs of registered logs if specified via open
60    # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
61    # these need to be tracked so their levels can be reset properly
62    # e.g. "torch._dynamo.output_graph"
63    child_log_qnames: Set[str] = field(default_factory=set)
64
65    # artifact names, populated by register_artifact
66    # e.g. "guards"
67    artifact_names: Set[str] = field(default_factory=set)
68
69    # Artifacts that should be visible by default in the error message
70    visible_artifacts: Set[str] = field(default_factory=set)
71
72    # A short description of each artifact
73    artifact_descriptions: Dict[str, str] = field(default_factory=dict)
74
75    # artifacts which are not displayed unless explicitly named in the
76    # settings. Ex. output_code is NOT displayed even if the inductor
77    # log level is set to DEBUG. It must be explicitly named in the settings
78    off_by_default_artifact_names: Set[str] = field(default_factory=set)
79
80    # logging format string for artifacts
81    artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
82
83    def is_artifact(self, name):
84        return name in self.artifact_names
85
86    def is_log(self, alias):
87        return alias in self.log_alias_to_log_qnames
88
89    # register a log with an alias
90    def register_log(self, alias, log_qnames: Union[str, List[str]]):
91        if isinstance(log_qnames, str):
92            log_qnames = [log_qnames]
93        self.log_alias_to_log_qnames[alias] = log_qnames
94
95    # register an artifact name
96    def register_artifact_name(
97        self, name, description, visible, off_by_default, log_format
98    ):
99        self.artifact_names.add(name)
100        if visible:
101            self.visible_artifacts.add(name)
102        self.artifact_descriptions[name] = description
103
104        # if off by default, don't enable it
105        # when log_name's log_level is set to DEBUG
106        if off_by_default:
107            self.off_by_default_artifact_names.add(name)
108
109        if log_format is not None:
110            self.artifact_log_formatters[name] = logging.Formatter(log_format)
111
112    # register the qualified name of an artifact log
113    # this is needed to know which logs need to be reset
114    # whenever the log_state is changed
115    def register_artifact_log(self, artifact_log_qname):
116        self.artifact_log_qnames.add(artifact_log_qname)
117
118    def register_child_log(self, log_qname):
119        self.child_log_qnames.add(log_qname)
120
121    # flattens all the qnames together (TODO: consider memoizing?)
122    def get_log_qnames(self) -> Set[str]:
123        return {
124            qname
125            for qnames in self.log_alias_to_log_qnames.values()
126            for qname in qnames
127        }
128
129    def get_artifact_log_qnames(self):
130        return set(self.artifact_log_qnames)
131
132    def get_child_log_qnames(self):
133        return set(self.child_log_qnames)
134
135    def is_off_by_default(self, artifact_qname):
136        return artifact_qname in self.off_by_default_artifact_names
137
138
139@dataclass
140class LogState:
141    # qualified log names -> currently set log level
142    log_qname_to_level: Dict[str, str] = field(default_factory=dict)
143
144    # the set of currently enabled artifacts
145    artifact_names: Set[str] = field(default_factory=set)
146
147    def enable_artifact(self, artifact_name):
148        self.artifact_names.add(artifact_name)
149
150    def is_artifact_enabled(self, name):
151        return name in self.artifact_names
152
153    def enable_log(self, log_qnames, log_level):
154        if isinstance(log_qnames, str):
155            log_qnames = [log_qnames]
156        for log_qname in log_qnames:
157            self.log_qname_to_level[log_qname] = log_level
158
159    def get_log_level_pairs(self):
160        """Returns all qualified module names for which the user requested
161        explicit logging settings.
162
163        .. warning:
164
165            This function used to return all loggers, regardless of whether
166            or not the user specified them or not; it now only returns logs
167            which were explicitly mentioned by the user (and torch, which
168            always is implicitly requested when we initialize our logging
169            subsystem.)
170        """
171        return self.log_qname_to_level.items()
172
173    def clear(self):
174        self.log_qname_to_level.clear()
175        self.artifact_names.clear()
176
177
178log_registry = LogRegistry()
179log_state = LogState()
180
181# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
182DEFAULT_LOGGING = {
183    "dynamo": logging.INFO,
184    "aot": logging.INFO,
185    "inductor": logging.INFO,
186    "fsdp": logging.INFO,
187    "ddp_graphs": True,
188    "graph_breaks": True,
189    "guards": True,
190    "recompiles": True,
191    "dynamic": logging.INFO,
192}
193
194
195def set_logs(
196    *,
197    all: Optional[int] = None,
198    dynamo: Optional[int] = None,
199    aot: Optional[int] = None,
200    autograd: Optional[int] = None,
201    dynamic: Optional[int] = None,
202    inductor: Optional[int] = None,
203    distributed: Optional[int] = None,
204    c10d: Optional[int] = None,
205    ddp: Optional[int] = None,
206    fsdp: Optional[int] = None,
207    dtensor: Optional[int] = None,
208    onnx: Optional[int] = None,
209    bytecode: bool = False,
210    aot_graphs: bool = False,
211    aot_joint_graph: bool = False,
212    ddp_graphs: bool = False,
213    graph: bool = False,
214    graph_code: bool = False,
215    graph_breaks: bool = False,
216    graph_sizes: bool = False,
217    guards: bool = False,
218    recompiles: bool = False,
219    recompiles_verbose: bool = False,
220    trace_source: bool = False,
221    trace_call: bool = False,
222    trace_bytecode: bool = False,
223    output_code: bool = False,
224    kernel_code: bool = False,
225    schedule: bool = False,
226    perf_hints: bool = False,
227    post_grad_graphs: bool = False,
228    onnx_diagnostics: bool = False,
229    fusion: bool = False,
230    overlap: bool = False,
231    export: Optional[int] = None,
232    modules: Optional[Dict[str, Union[int, bool]]] = None,
233    cudagraphs: bool = False,
234    sym_node: bool = False,
235    compiled_autograd: bool = False,
236    compiled_autograd_verbose: bool = False,
237    cudagraph_static_inputs: bool = False,
238    benchmarking: bool = False,
239):
240    """
241    Sets the log level for individual components and toggles individual log
242    artifact types.
243
244    .. warning:: This feature is a prototype and may have compatibility
245        breaking changes in the future.
246
247    .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
248        over this function, so if it was set, this function does nothing.
249
250    A component is a set of related features in PyTorch. All of the log
251    messages emitted from a given component have their own log levels. If the
252    log level of a particular message has priority greater than or equal to its
253    component's log level setting, it is emitted. Otherwise, it is suppressed.
254    This allows you to, for instance, silence large groups of log messages that
255    are not relevant to you and increase verbosity of logs for components that
256    are relevant. The expected log level values, ordered from highest to lowest
257    priority, are:
258
259        * ``logging.CRITICAL``
260        * ``logging.ERROR``
261        * ``logging.WARNING``
262        * ``logging.INFO``
263        * ``logging.DEBUG``
264        * ``logging.NOTSET``
265
266    See documentation for the Python ``logging`` module for more information on
267    log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
268
269    An artifact is a particular type of log message. Each artifact is assigned
270    to a parent component. A component can emit many different kinds of
271    artifacts. In general, an artifact is emitted if either its corresponding
272    setting in the argument list below is turned on or if its parent component
273    is set to a log level less than or equal to the log level of the artifact.
274
275    Keyword args:
276        all (:class:`Optional[int]`):
277            The default log level for all components. Default: ``logging.WARN``
278
279        dynamo (:class:`Optional[int]`):
280            The log level for the TorchDynamo component. Default: ``logging.WARN``
281
282        aot (:class:`Optional[int]`):
283            The log level for the AOTAutograd component. Default: ``logging.WARN``
284
285        autograd (:class:`Optional[int]`):
286            The log level for autograd. Default: ``logging.WARN``
287
288        inductor (:class:`Optional[int]`):
289            The log level for the TorchInductor component. Default: ``logging.WARN``
290
291        dynamic (:class:`Optional[int]`):
292            The log level for dynamic shapes. Default: ``logging.WARN``
293
294        distributed (:class:`Optional[int]`):
295            Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
296            Default: ``logging.WARN``
297
298        c10d (:class:`Optional[int]`):
299            Whether to log c10d communication operations related debug info in PyTorch Distributed components.
300            Default: ``logging.WARN``
301
302        ddp (:class:`Optional[int]`):
303            Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
304            Default: ``logging.WARN``
305
306        fsdp (:class:`Optional[int]`):
307            Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
308            Default: ``logging.WARN``
309
310        dtensor (:class:`Optional[int]`):
311            Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components.
312            Default: ``logging.WARN``
313
314        onnx (:class:`Optional[int]`):
315            The log level for the ONNX exporter component. Default: ``logging.WARN``
316
317        bytecode (:class:`bool`):
318            Whether to emit the original and generated bytecode from TorchDynamo.
319            Default: ``False``
320
321        aot_graphs (:class:`bool`):
322            Whether to emit the graphs generated by AOTAutograd. Default: ``False``
323
324        aot_joint_graph (:class:`bool`):
325            Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
326
327        ddp_graphs (:class:`bool`):
328            Whether to emit graphs generated by DDPOptimizer. Default: ``False``
329
330        graph (:class:`bool`):
331            Whether to emit the graph captured by TorchDynamo in tabular format.
332            Default: ``False``
333
334        graph_code (:class:`bool`):
335            Whether to emit the python source of the graph captured by TorchDynamo.
336            Default: ``False``
337
338        graph_breaks (:class:`bool`):
339            Whether to emit the graph breaks encountered by TorchDynamo.
340            Default: ``False``
341
342        graph_sizes (:class:`bool`):
343            Whether to emit tensor sizes of the graph captured by TorchDynamo.
344            Default: ``False``
345
346        guards (:class:`bool`):
347            Whether to emit the guards generated by TorchDynamo for each compiled
348            function. Default: ``False``
349
350        recompiles (:class:`bool`):
351            Whether to emit a guard failure reason and message every time
352            TorchDynamo recompiles a function. Default: ``False``
353
354        recompiles_verbose (:class:`bool`):
355            Whether to emit all guard failure reasons when TorchDynamo recompiles
356            a function, even those that are not actually run. Default: ``False``
357
358        trace_source (:class:`bool`):
359            Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
360
361        trace_call (:class:`bool`):
362            Whether to emit detailed line location when TorchDynamo creates an FX node
363            corresponding to function call. Python 3.11+ only. Default: ``False``
364
365        trace_bytecode (:class:`bool`):
366            Whether to emit bytecode instructions and traced stack state as TorchDynamo
367            traces bytecode. Default: ``False``
368
369        output_code (:class:`bool`):
370            Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
371
372        kernel_code (:class:`bool`):
373            Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
374
375        schedule (:class:`bool`):
376            Whether to emit the TorchInductor schedule. Default: ``False``
377
378        perf_hints (:class:`bool`):
379            Whether to emit the TorchInductor perf hints. Default: ``False``
380
381        post_grad_graphs (:class:`bool`):
382            Whether to emit the graphs generated by after post grad passes. Default: ``False``
383
384        onnx_diagnostics (:class:`bool`):
385            Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
386
387        fusion (:class:`bool`):
388            Whether to emit detailed Inductor fusion decisions. Default: ``False``
389
390        overlap (:class:`bool`):
391            Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
392
393        sym_node (:class:`bool`):
394            Whether to emit debug info for various SymNode opterations. Default: ``False``
395
396        export (:class:`Optional[int]`):
397            The log level for export. Default: ``logging.WARN``
398
399        benchmarking (:class:`bool`):
400            Whether to emit detailed Inductor benchmarking information. Default: ``False``
401
402        modules (dict):
403            This argument provides an alternate way to specify the above log
404            component and artifact settings, in the format of a keyword args
405            dictionary given as a single argument. There are two cases
406            where this is useful (1) if a new log component or artifact has
407            been registered but a keyword argument for it has not been added
408            to this function and (2) if the log level for an unregistered module
409            needs to be set. This can be done by providing the fully-qualified module
410            name as the key, with the log level as the value. Default: ``None``
411
412        cudagraph_static_inputs (:class:`bool`):
413            Whether to emit debug info for cudagraph static input detection. Default: ``False``
414
415
416    Example::
417
418        >>> # xdoctest: +SKIP
419        >>> import logging
420
421        # The following changes the "dynamo" component to emit DEBUG-level
422        # logs, and to emit "graph_code" artifacts.
423
424        >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
425
426        # The following enables the logs for a different module
427
428        >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
429    """
430    # ignore if env var is set
431    if LOG_ENV_VAR in os.environ:
432        log.warning(
433            "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
434        )
435        return
436
437    log_state.clear()
438
439    modules = modules or {}
440
441    def _set_logs(**kwargs):
442        for alias, val in itertools.chain(kwargs.items(), modules.items()):  # type: ignore[union-attr]
443            if val is None:
444                continue
445
446            if log_registry.is_artifact(alias):
447                if not isinstance(val, bool):
448                    raise ValueError(
449                        f"Expected bool to enable artifact {alias}, received {val}"
450                    )
451
452                if val:
453                    log_state.enable_artifact(alias)
454            elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
455                if val not in logging._levelToName:
456                    raise ValueError(
457                        f"Unrecognized log level for log {alias}: {val}, valid level values "
458                        f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
459                    )
460
461                log_state.enable_log(
462                    log_registry.log_alias_to_log_qnames.get(alias, alias), val
463                )
464            else:
465                raise ValueError(
466                    f"Unrecognized log or artifact name passed to set_logs: {alias}"
467                )
468
469        _init_logs()
470
471    _set_logs(
472        torch=all,
473        dynamo=dynamo,
474        aot=aot,
475        autograd=autograd,
476        inductor=inductor,
477        dynamic=dynamic,
478        bytecode=bytecode,
479        aot_graphs=aot_graphs,
480        aot_joint_graph=aot_joint_graph,
481        ddp_graphs=ddp_graphs,
482        distributed=distributed,
483        c10d=c10d,
484        ddp=ddp,
485        fsdp=fsdp,
486        dtensor=dtensor,
487        graph=graph,
488        graph_code=graph_code,
489        graph_breaks=graph_breaks,
490        graph_sizes=graph_sizes,
491        guards=guards,
492        recompiles=recompiles,
493        recompiles_verbose=recompiles_verbose,
494        trace_source=trace_source,
495        trace_call=trace_call,
496        trace_bytecode=trace_bytecode,
497        output_code=output_code,
498        kernel_code=kernel_code,
499        schedule=schedule,
500        perf_hints=perf_hints,
501        post_grad_graphs=post_grad_graphs,
502        onnx=onnx,
503        onnx_diagnostics=onnx_diagnostics,
504        fusion=fusion,
505        overlap=overlap,
506        sym_node=sym_node,
507        export=export,
508        cudagraphs=cudagraphs,
509        compiled_autograd=compiled_autograd,
510        compiled_autograd_verbose=compiled_autograd_verbose,
511        cudagraph_static_inputs=cudagraph_static_inputs,
512        benchmarking=benchmarking,
513    )
514
515
516def get_loggers():
517    """
518    Returns: a list of all registered loggers
519    """
520    return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
521
522
523def register_log(setting_name, log_name):
524    """
525    Enables a log to be controlled by the env var and user API with the setting_name
526    Args:
527        setting_name:  the shorthand name used in the env var and user API
528        log_name:  the log name that the setting_name is associated with
529    """
530    log_registry.register_log(setting_name, log_name)
531
532
533def register_artifact(
534    setting_name, description, visible=False, off_by_default=False, log_format=None
535):
536    """
537    Enables an artifact to be controlled by the env var and user API with name
538    Args:
539        setting_name: the shorthand name used in the env var and user API
540        description: A description of what this outputs
541        visible: Whether it gets suggested to users by default
542        off_by_default: whether this artifact should be logged when the ancestor loggers
543            are enabled at level DEBUG
544    """
545    log_registry.register_artifact_name(
546        setting_name, description, visible, off_by_default, log_format
547    )
548
549
550def getArtifactLogger(module_qname, artifact_name):
551    if artifact_name not in log_registry.artifact_names:
552        raise ValueError(
553            f"Artifact name: {repr(artifact_name)} not registered,"
554            f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
555        )
556    qname = module_qname + f".__{artifact_name}"
557    log = logging.getLogger(qname)
558    log.artifact_name = artifact_name  # type: ignore[attr-defined]
559    log_registry.register_artifact_log(qname)
560    configure_artifact_log(log)
561    return log
562
563
564INCR_VERBOSITY_CHAR = "+"
565DECR_VERBOSITY_CHAR = "-"
566VERBOSITY_REGEX = (
567    "("
568    + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
569    + "?)"
570)
571
572
573def configure_artifact_log(log):
574    # If the artifact is off by default, then it should only be logged when explicitly
575    # enabled; set propagate to False so that this artifact is not propagated
576    # to its ancestor logger
577    if log_registry.is_off_by_default(log.artifact_name):
578        log.propagate = False
579
580    # enable artifact logging when explicitly enabled
581    if log_state.is_artifact_enabled(log.artifact_name):
582        log.setLevel(logging.DEBUG)
583        log.propagate = True
584
585
586# match a comma separated list of loggable names (whitespace allowed after commas)
587def _gen_settings_regex():
588    return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
589
590
591def _validate_settings(settings):
592    return re.fullmatch(_gen_settings_regex(), settings) is not None
593
594
595def help_message(verbose=False):
596    def pad_to(s, length=30):
597        assert len(s) <= length
598        return s + " " * (length - len(s))
599
600    if verbose:
601        printed_artifacts = log_registry.artifact_names
602    else:
603        printed_artifacts = log_registry.visible_artifacts
604
605    if verbose:
606        heading = "All registered names"
607    else:
608        heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
609    lines = (
610        ["all"]
611        + sorted(log_registry.log_alias_to_log_qnames.keys())
612        + sorted(
613            [
614                f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
615                for name in printed_artifacts
616            ]
617        )
618    )
619    setting_info = "  " + "\n  ".join(lines)
620    examples = """
621Examples:
622  TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
623  logging.DEBUG and AOT to logging.INFO
624
625  TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
626  logging.ERROR and TorchInductor to logging.DEBUG
627
628  TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
629
630  TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
631  to logging.DEBUG and enable the schedule artifact
632
633  TORCH_LOGS="+some.random.module,schedule" will set the log level of
634  some.random.module to logging.DEBUG and enable the schedule artifact
635
636  TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
637  string will set the output format
638  Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
639  "filename" and "name".
640
641  TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
642  well. This is useful when the output is long.
643"""  # flake8: noqa: B950
644    msg = f"""
645TORCH_LOGS Info
646{examples}
647
648{heading}
649{setting_info}
650"""
651    return msg
652
653
654def _invalid_settings_err_msg(settings, verbose=False):
655    valid_settings = ", ".join(
656        ["all"]
657        + list(log_registry.log_alias_to_log_qnames.keys())
658        + list(log_registry.artifact_names)
659    )
660    msg = f"""
661Invalid log settings: {settings}, must be a comma separated list of fully
662qualified module names, registered log names or registered artifact names.
663For more info on various settings, try TORCH_LOGS="help"
664Valid settings:
665{valid_settings}
666"""
667    return msg
668
669
670@functools.lru_cache
671def _parse_log_settings(settings):
672    if settings == "":
673        return {}
674
675    if settings == "help":
676        raise ValueError(help_message(verbose=False))
677    elif settings == "+help":
678        raise ValueError(help_message(verbose=True))
679    if not _validate_settings(settings):
680        raise ValueError(_invalid_settings_err_msg(settings))
681
682    settings = re.sub(r"\s+", "", settings)
683    log_names = settings.split(",")
684
685    def get_name_level_pair(name):
686        clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
687        clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
688
689        if name[0] == INCR_VERBOSITY_CHAR:
690            level = logging.DEBUG
691        elif name[0] == DECR_VERBOSITY_CHAR:
692            level = logging.ERROR
693        else:
694            level = logging.INFO
695
696        return clean_name, level
697
698    log_state = LogState()
699
700    for name in log_names:
701        name, level = get_name_level_pair(name)
702
703        if name == "all":
704            name = "torch"
705
706        if log_registry.is_log(name):
707            assert level is not None
708            log_qnames = log_registry.log_alias_to_log_qnames[name]
709            log_state.enable_log(log_qnames, level)
710        elif log_registry.is_artifact(name):
711            log_state.enable_artifact(name)
712        elif _is_valid_module(name):
713            if not _has_registered_parent(name):
714                log_registry.register_log(name, name)
715            else:
716                log_registry.register_child_log(name)
717            log_state.enable_log(name, level)
718        else:
719            raise ValueError(_invalid_settings_err_msg(settings))
720
721    return log_state
722
723
724def _is_valid_module(qname):
725    try:
726        __import__(qname)
727        return True
728    except ImportError:
729        return False
730
731
732def _update_log_state_from_env():
733    global log_state
734    log_setting = os.environ.get(LOG_ENV_VAR, None)
735    if log_setting is not None:
736        log_state = _parse_log_settings(log_setting)
737
738
739def _has_registered_parent(log_qname):
740    cur_log = logging.getLogger(log_qname)
741
742    registered_log_qnames = log_registry.get_log_qnames()
743
744    while cur_log.parent:
745        if cur_log.name in registered_log_qnames:
746            return True
747        cur_log = cur_log.parent
748
749    return False
750
751
752def make_module_path_relative(abs_path):
753    """
754    Given an absolute filepath corresponding to a Python module which was
755    loaded via normal import mechanisms using sys.path, convert it into
756    a relative path relative to one of the Python search paths.
757    """
758
759    abs_path = pathlib.Path(abs_path).resolve()
760
761    for path in sys.path:
762        try:
763            rel_path = abs_path.relative_to(path)
764        except ValueError:
765            continue
766        else:
767            return str(rel_path)
768
769    return str(abs_path)
770
771
772# apply custom formats to artifacts when necessary
773class TorchLogsFormatter(logging.Formatter):
774    def __init__(self, *, trace: bool = False):
775        super().__init__()
776        self._is_trace = trace
777
778    def format(self, record):
779        artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
780        if artifact_name is not None:
781            artifact_formatter = log_registry.artifact_log_formatters.get(
782                artifact_name, None
783            )
784            if artifact_formatter is not None:
785                return artifact_formatter.format(record)
786
787        record.message = record.getMessage()
788        record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
789
790        # exception handling - copied from logging.Formatter.format
791        s = record.message
792        if record.exc_info:
793            # Cache the traceback text to avoid converting it multiple times
794            # (it's constant anyway)
795            if not record.exc_text:
796                record.exc_text = self.formatException(record.exc_info)
797        if record.exc_text:
798            if s[-1:] != "\n":
799                s = s + "\n"
800            s = s + record.exc_text
801        if record.stack_info:
802            if s[-1:] != "\n":
803                s = s + "\n"
804            s = s + self.formatStack(record.stack_info)
805
806        record.rankprefix = ""
807        if not self._is_trace and dist.is_available() and dist.is_initialized():
808            record.rankprefix = f"[rank{dist.get_rank()}]:"
809
810        record.traceid = ""
811        if (
812            not self._is_trace
813            and (trace_id := torch._guards.CompileContext.current_trace_id())
814            is not None
815        ):
816            record.traceid = f" [{trace_id}]"
817
818        glog_level_to_abbr = {
819            "DEBUG": "V",  # V is for VERBOSE in glog
820            "INFO": "I",
821            "WARNING": "W",
822            "ERROR": "E",
823            "CRITICAL": "C",
824        }
825
826        shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
827
828        record.artifactprefix = ""
829        if artifact_name is not None:
830            record.artifactprefix = f" [__{artifact_name}]"
831
832        filepath = make_module_path_relative(record.pathname)
833
834        prefix = (
835            f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.process} "
836            f"{filepath}:"
837            f"{record.lineno}]{record.traceid}{record.artifactprefix}"
838        )
839        if self._is_trace:
840            assert s == ""
841            try:
842                r = f"{prefix} {json.dumps(record.metadata)}"
843            except TypeError:
844                log.warning("failing metadata: %r", record.metadata)
845                raise
846            if record.payload is not None:
847                r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
848            return r
849        else:
850            lines = s.split("\n")
851            return "\n".join(f"{prefix} {l}" for l in lines)
852
853
854def _default_formatter():
855    fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
856    if fmt is None:
857        return TorchLogsFormatter()
858    else:
859        if fmt in ("short", "basic"):
860            fmt = logging.BASIC_FORMAT
861        return logging.Formatter(fmt)
862
863
864DEFAULT_FORMATTER = _default_formatter()
865
866
867def _setup_handlers(create_handler_fn, log):
868    debug_handler = _track_handler(create_handler_fn())
869    debug_handler.setFormatter(DEFAULT_FORMATTER)
870    debug_handler.setLevel(logging.DEBUG)
871    log.addHandler(debug_handler)
872
873
874handlers = WeakSet()  # type: ignore[var-annotated]
875
876
877# mark handlers that we've created
878# so we don't modify user handlers
879def _track_handler(handler):
880    handlers.add(handler)
881    return handler
882
883
884def _is_torch_handler(handler):
885    return handler in handlers
886
887
888# clears all torch handlers on specified loggers
889def _clear_handlers(log):
890    to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
891    for handler in to_remove:
892        log.removeHandler(handler)
893
894
895def _reset_logs():
896    # reset all registered logs
897    for log_qname in log_registry.get_log_qnames():
898        log = logging.getLogger(log_qname)
899        log.setLevel(logging.WARNING)
900        log.propagate = False
901        _clear_handlers(log)
902
903    # reset all artifact and child logs
904    for artifact_log_qname in itertools.chain(
905        log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
906    ):
907        log = logging.getLogger(artifact_log_qname)
908        log.setLevel(logging.NOTSET)
909        log.propagate = True
910
911    trace_log.propagate = False
912    _clear_handlers(trace_log)
913
914
915def _get_log_state():
916    return log_state
917
918
919def _set_log_state(state):
920    global log_state
921    log_state = state
922
923
924def _init_logs(log_file_name=None):
925    _reset_logs()
926    _update_log_state_from_env()
927
928    out = os.environ.get(LOG_OUT_ENV_VAR, None)
929    if out is not None:
930        log_file_name = out
931
932    # First, reset all known (registered) loggers to NOTSET, so that they
933    # respect their parent log level
934    for log_qname in log_registry.get_log_qnames():
935        # But not the top level torch level: this defaults to WARNING so
936        # that our log messages don't leak to the lower levels
937        if log_qname == "torch":
938            continue
939        log = logging.getLogger(log_qname)
940        log.setLevel(logging.NOTSET)
941
942    # Now, for all loggers which the user requested to have non-standard
943    # logging behavior, modify their log levels
944    for log_qname, level in log_state.get_log_level_pairs():
945        log = logging.getLogger(log_qname)
946        log.setLevel(level)
947
948    # Finally, setup handlers for all registered loggers
949    for log_qname in log_registry.get_log_qnames():
950        log = logging.getLogger(log_qname)
951        _setup_handlers(
952            logging.StreamHandler,
953            log,
954        )
955
956        if log_file_name is not None:
957            _setup_handlers(
958                lambda: logging.FileHandler(log_file_name),
959                log,
960            )
961
962    # configure artifact loggers, note: this must happen last
963    # since the levels of ancestor loggers are taken into account
964    for artifact_log_qname in log_registry.get_artifact_log_qnames():
965        log = logging.getLogger(artifact_log_qname)
966        configure_artifact_log(log)
967
968    # Setup handler for the special trace_log, with different default
969    # configuration
970    trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
971    # This handler may remove itself if trace_dir_name is None and we are not
972    # actually in an FB environment.  This allows us to defer actually
973    # initializing it until we actually need to log anything.  This is
974    # important because JK initializes a C++ singleton, which will pork our
975    # process if we subsequently fork.
976    handler = LazyTraceHandler(trace_dir_name)
977    # This log is ALWAYS at debug level.  We will additionally test if there
978    # are any handlers before deciding to actually call logging on this.  Do
979    # not manually call
980    trace_log.setLevel(logging.DEBUG)
981    trace_log_handler = _track_handler(handler)
982    trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
983    trace_log.addHandler(trace_log_handler)
984
985
986class LazyTraceHandler(logging.StreamHandler):
987    """Like FileHandler, but the file is allocated lazily only upon the first log message"""
988
989    def __init__(self, root_dir: Optional[str]):
990        # This is implemented in the same way that delay is implemented on
991        # FileHandler
992        self.root_dir = root_dir
993        logging.Handler.__init__(self)
994        self.stream = None
995        self._builtin_open = open
996
997    # cloned from FileHandler in cpython
998    def close(self):
999        self.acquire()
1000        try:
1001            try:
1002                if self.stream:
1003                    try:
1004                        self.flush()
1005                    finally:
1006                        stream = self.stream
1007                        self.stream = None
1008                        if hasattr(stream, "close"):
1009                            stream.close()
1010            finally:
1011                # Issue #19523: call unconditionally to
1012                # prevent a handler leak when delay is set
1013                # Also see Issue #42378: we also rely on
1014                # self._closed being set to True there
1015                logging.StreamHandler.close(self)
1016        finally:
1017            self.release()
1018
1019    def emit(self, record):
1020        if self.stream is None:
1021            ok = False
1022            if self.root_dir is None:
1023                TRACE_LOG_DIR = "/logs"
1024                open_func = self._builtin_open
1025
1026                import torch.version as torch_version
1027
1028                if (
1029                    hasattr(torch_version, "git_version")
1030                    and os.getenv("MAST_HPC_JOB_NAME") is None
1031                ):
1032                    log.info(
1033                        "LazyTraceHandler: disabled because not fbcode or conda on mast"
1034                    )
1035                elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
1036                    log.info(
1037                        "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
1038                    )
1039                elif not os.path.exists(TRACE_LOG_DIR):
1040                    log.info(
1041                        "LazyTraceHandler: disabled because %s does not exist",
1042                        TRACE_LOG_DIR,
1043                    )
1044                elif not os.access(TRACE_LOG_DIR, os.W_OK):
1045                    log.info(
1046                        "LazyTraceHandler: disabled because %s is not writeable",
1047                        TRACE_LOG_DIR,
1048                    )
1049                else:
1050                    self.root_dir = TRACE_LOG_DIR
1051
1052            if self.root_dir is not None:
1053                os.makedirs(self.root_dir, exist_ok=True)
1054                ranksuffix = ""
1055                if dist.is_available() and dist.is_initialized():
1056                    ranksuffix = f"rank_{dist.get_rank()}_"
1057                self.stream = tempfile.NamedTemporaryFile(
1058                    mode="w+",
1059                    suffix=".log",
1060                    prefix=f"dedicated_log_torch_trace_{ranksuffix}",
1061                    dir=self.root_dir,
1062                    delete=False,
1063                )
1064                log.info("LazyTraceHandler: logging to %s", self.stream.name)
1065            else:
1066                # We go poof, remove and no-op
1067                trace_log.removeHandler(self)
1068                return
1069        if self.stream:
1070            super().emit(record)
1071
1072
1073@functools.lru_cache(None)
1074def warning_once(logger_obj, *args, **kwargs):
1075    """
1076    This function is similar to `logger.warning()`, but will emit the warning with the same message only once
1077    Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
1078    The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
1079    another type of cache that includes the caller frame information in the hashing function.
1080    """
1081    logger_obj.warning(*args, **kwargs)
1082
1083
1084class LazyString:
1085    def __init__(self, func, *args, **kwargs):
1086        self.func = func
1087        self.args = args
1088        self.kwargs = kwargs
1089
1090    def __str__(self):
1091        return self.func(*self.args, **self.kwargs)
1092
1093
1094def trace_structured(
1095    name: str,
1096    # NB: metadata expected to be dict so adding more info is forward compatible
1097    # Tuple[str, int] is a special case for string interning
1098    metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
1099    *,
1100    payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
1101    suppress_context: bool = False,
1102    expect_trace_id: bool = True,  # Whether or not we expect to have a current trace id
1103):
1104    """
1105    metadata is an arbitrary JSON compatible struct, but it's expected to not be
1106    too long (e.g., less than 1MB)
1107
1108    payload is an arbitrary string, which can be arbitrarily long (but expected to have
1109    newlines so no lines are too long)
1110    """
1111    assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"]
1112    assert callable(
1113        metadata_fn
1114    ), f"metadata_fn should be callable, but got {type(metadata_fn)}"
1115    assert callable(
1116        payload_fn
1117    ), f"payload_fn should be callable, but got {type(payload_fn)}"
1118    # trace_log never propagates and is ALWAYS DEBUG, so also check that there
1119    # are handlers instead of checking the log level
1120    if trace_log.handlers:
1121        record: Dict[str, object] = {}
1122        record[name] = metadata_fn()
1123        if not suppress_context:
1124            # TODO: Actually, the rank probably should just be emitted once at
1125            # the top, and not repeatedly spammed in all the logs, since it
1126            # never changes and we assume no interleaving
1127            if dist.is_available() and dist.is_initialized():
1128                record["rank"] = dist.get_rank()
1129            if (
1130                trace_id := torch._guards.CompileContext.current_trace_id()
1131            ) is not None:
1132                record["frame_id"] = trace_id.compile_id.frame_id
1133                record["frame_compile_id"] = trace_id.compile_id.frame_compile_id
1134                record["attempt"] = trace_id.attempt
1135            else:
1136                if expect_trace_id:
1137                    # Record the stack of the log call to better diagnose why we
1138                    # don't have a frame id for it
1139                    record["stack"] = torch._logging.structured.from_traceback(
1140                        CapturedTraceback.extract(skip=1).summary()
1141                    )
1142        payload = payload_fn()
1143        if payload is not None:
1144            if not isinstance(payload, str):
1145                if isinstance(payload, list):
1146                    # special case to look better
1147                    payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
1148                else:
1149                    # force newlines so we are unlikely to overflow line limit
1150                    payload = json.dumps(payload, indent=0)
1151            h = hashlib.md5()
1152            h.update(payload.encode("utf-8"))
1153            record["has_payload"] = h.hexdigest()
1154        trace_log.debug(
1155            "", extra={"metadata": record, "payload": payload}, stacklevel=2
1156        )
1157        log_trace_structured_event(name, record)
1158
1159
1160import torch._guards
1161import torch._utils_internal
1162import torch.distributed as dist
1163