1# mypy: allow-untyped-defs 2import functools 3import hashlib 4import itertools 5import json 6import logging 7import os 8import os.path 9import pathlib 10import re 11import sys 12import tempfile 13from dataclasses import dataclass, field 14from importlib import __import__ 15from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 16from weakref import WeakSet 17 18import torch._logging.structured 19from torch._utils_internal import log_trace_structured_event 20from torch.utils._traceback import CapturedTraceback 21 22 23log = logging.getLogger(__name__) 24 25# This is a synthetic logger which doesn't correspond to an actual logger, 26# but handles all of our "tracing" logging, which is structured and doesn't go 27# to stderr but always goes to a dedicated log file. We don't put these 28# loggers in the classic module hierarchy, because we don't want a suppression 29# of logs to also cause a trace to get suppressed (traces typically are not 30# collected, unless we are in prod, in which case they always are collected.) 31# 32# TODO: Maybe we should allow for some sub-hierarchy so you can control which 33# traces you want to collect, for performance reasons. 34# 35# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit 36trace_log = logging.getLogger("torch.__trace") 37 38DEFAULT_LOG_LEVEL = logging.WARNING 39LOG_ENV_VAR = "TORCH_LOGS" 40LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT" 41LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT" 42TRACE_ENV_VAR = "TORCH_TRACE" 43 44 45@dataclass 46class LogRegistry: 47 # shorthand name to log qualified name 48 # Note: this only contains loggers registered 49 # from register_log 50 # e.g. "dynamo" -> "torch._dynamo" 51 log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict) 52 53 # artifact logger qualified names, 54 # this is populated lazily, as calls to getArtifactLogger 55 # currently formatted as <module>.__<artifact_name> 56 # e.g. "torch._dynamo.convert_frame.__guards" 57 artifact_log_qnames: Set[str] = field(default_factory=set) 58 59 # child logs of registered logs if specified via open 60 # registration by the user (ie placing "torch._dynamo.output_graph" in the env var) 61 # these need to be tracked so their levels can be reset properly 62 # e.g. "torch._dynamo.output_graph" 63 child_log_qnames: Set[str] = field(default_factory=set) 64 65 # artifact names, populated by register_artifact 66 # e.g. "guards" 67 artifact_names: Set[str] = field(default_factory=set) 68 69 # Artifacts that should be visible by default in the error message 70 visible_artifacts: Set[str] = field(default_factory=set) 71 72 # A short description of each artifact 73 artifact_descriptions: Dict[str, str] = field(default_factory=dict) 74 75 # artifacts which are not displayed unless explicitly named in the 76 # settings. Ex. output_code is NOT displayed even if the inductor 77 # log level is set to DEBUG. It must be explicitly named in the settings 78 off_by_default_artifact_names: Set[str] = field(default_factory=set) 79 80 # logging format string for artifacts 81 artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict) 82 83 def is_artifact(self, name): 84 return name in self.artifact_names 85 86 def is_log(self, alias): 87 return alias in self.log_alias_to_log_qnames 88 89 # register a log with an alias 90 def register_log(self, alias, log_qnames: Union[str, List[str]]): 91 if isinstance(log_qnames, str): 92 log_qnames = [log_qnames] 93 self.log_alias_to_log_qnames[alias] = log_qnames 94 95 # register an artifact name 96 def register_artifact_name( 97 self, name, description, visible, off_by_default, log_format 98 ): 99 self.artifact_names.add(name) 100 if visible: 101 self.visible_artifacts.add(name) 102 self.artifact_descriptions[name] = description 103 104 # if off by default, don't enable it 105 # when log_name's log_level is set to DEBUG 106 if off_by_default: 107 self.off_by_default_artifact_names.add(name) 108 109 if log_format is not None: 110 self.artifact_log_formatters[name] = logging.Formatter(log_format) 111 112 # register the qualified name of an artifact log 113 # this is needed to know which logs need to be reset 114 # whenever the log_state is changed 115 def register_artifact_log(self, artifact_log_qname): 116 self.artifact_log_qnames.add(artifact_log_qname) 117 118 def register_child_log(self, log_qname): 119 self.child_log_qnames.add(log_qname) 120 121 # flattens all the qnames together (TODO: consider memoizing?) 122 def get_log_qnames(self) -> Set[str]: 123 return { 124 qname 125 for qnames in self.log_alias_to_log_qnames.values() 126 for qname in qnames 127 } 128 129 def get_artifact_log_qnames(self): 130 return set(self.artifact_log_qnames) 131 132 def get_child_log_qnames(self): 133 return set(self.child_log_qnames) 134 135 def is_off_by_default(self, artifact_qname): 136 return artifact_qname in self.off_by_default_artifact_names 137 138 139@dataclass 140class LogState: 141 # qualified log names -> currently set log level 142 log_qname_to_level: Dict[str, str] = field(default_factory=dict) 143 144 # the set of currently enabled artifacts 145 artifact_names: Set[str] = field(default_factory=set) 146 147 def enable_artifact(self, artifact_name): 148 self.artifact_names.add(artifact_name) 149 150 def is_artifact_enabled(self, name): 151 return name in self.artifact_names 152 153 def enable_log(self, log_qnames, log_level): 154 if isinstance(log_qnames, str): 155 log_qnames = [log_qnames] 156 for log_qname in log_qnames: 157 self.log_qname_to_level[log_qname] = log_level 158 159 def get_log_level_pairs(self): 160 """Returns all qualified module names for which the user requested 161 explicit logging settings. 162 163 .. warning: 164 165 This function used to return all loggers, regardless of whether 166 or not the user specified them or not; it now only returns logs 167 which were explicitly mentioned by the user (and torch, which 168 always is implicitly requested when we initialize our logging 169 subsystem.) 170 """ 171 return self.log_qname_to_level.items() 172 173 def clear(self): 174 self.log_qname_to_level.clear() 175 self.artifact_names.clear() 176 177 178log_registry = LogRegistry() 179log_state = LogState() 180 181# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING) 182DEFAULT_LOGGING = { 183 "dynamo": logging.INFO, 184 "aot": logging.INFO, 185 "inductor": logging.INFO, 186 "fsdp": logging.INFO, 187 "ddp_graphs": True, 188 "graph_breaks": True, 189 "guards": True, 190 "recompiles": True, 191 "dynamic": logging.INFO, 192} 193 194 195def set_logs( 196 *, 197 all: Optional[int] = None, 198 dynamo: Optional[int] = None, 199 aot: Optional[int] = None, 200 autograd: Optional[int] = None, 201 dynamic: Optional[int] = None, 202 inductor: Optional[int] = None, 203 distributed: Optional[int] = None, 204 c10d: Optional[int] = None, 205 ddp: Optional[int] = None, 206 fsdp: Optional[int] = None, 207 dtensor: Optional[int] = None, 208 onnx: Optional[int] = None, 209 bytecode: bool = False, 210 aot_graphs: bool = False, 211 aot_joint_graph: bool = False, 212 ddp_graphs: bool = False, 213 graph: bool = False, 214 graph_code: bool = False, 215 graph_breaks: bool = False, 216 graph_sizes: bool = False, 217 guards: bool = False, 218 recompiles: bool = False, 219 recompiles_verbose: bool = False, 220 trace_source: bool = False, 221 trace_call: bool = False, 222 trace_bytecode: bool = False, 223 output_code: bool = False, 224 kernel_code: bool = False, 225 schedule: bool = False, 226 perf_hints: bool = False, 227 post_grad_graphs: bool = False, 228 onnx_diagnostics: bool = False, 229 fusion: bool = False, 230 overlap: bool = False, 231 export: Optional[int] = None, 232 modules: Optional[Dict[str, Union[int, bool]]] = None, 233 cudagraphs: bool = False, 234 sym_node: bool = False, 235 compiled_autograd: bool = False, 236 compiled_autograd_verbose: bool = False, 237 cudagraph_static_inputs: bool = False, 238 benchmarking: bool = False, 239): 240 """ 241 Sets the log level for individual components and toggles individual log 242 artifact types. 243 244 .. warning:: This feature is a prototype and may have compatibility 245 breaking changes in the future. 246 247 .. note:: The ``TORCH_LOGS`` environment variable has complete precedence 248 over this function, so if it was set, this function does nothing. 249 250 A component is a set of related features in PyTorch. All of the log 251 messages emitted from a given component have their own log levels. If the 252 log level of a particular message has priority greater than or equal to its 253 component's log level setting, it is emitted. Otherwise, it is suppressed. 254 This allows you to, for instance, silence large groups of log messages that 255 are not relevant to you and increase verbosity of logs for components that 256 are relevant. The expected log level values, ordered from highest to lowest 257 priority, are: 258 259 * ``logging.CRITICAL`` 260 * ``logging.ERROR`` 261 * ``logging.WARNING`` 262 * ``logging.INFO`` 263 * ``logging.DEBUG`` 264 * ``logging.NOTSET`` 265 266 See documentation for the Python ``logging`` module for more information on 267 log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_ 268 269 An artifact is a particular type of log message. Each artifact is assigned 270 to a parent component. A component can emit many different kinds of 271 artifacts. In general, an artifact is emitted if either its corresponding 272 setting in the argument list below is turned on or if its parent component 273 is set to a log level less than or equal to the log level of the artifact. 274 275 Keyword args: 276 all (:class:`Optional[int]`): 277 The default log level for all components. Default: ``logging.WARN`` 278 279 dynamo (:class:`Optional[int]`): 280 The log level for the TorchDynamo component. Default: ``logging.WARN`` 281 282 aot (:class:`Optional[int]`): 283 The log level for the AOTAutograd component. Default: ``logging.WARN`` 284 285 autograd (:class:`Optional[int]`): 286 The log level for autograd. Default: ``logging.WARN`` 287 288 inductor (:class:`Optional[int]`): 289 The log level for the TorchInductor component. Default: ``logging.WARN`` 290 291 dynamic (:class:`Optional[int]`): 292 The log level for dynamic shapes. Default: ``logging.WARN`` 293 294 distributed (:class:`Optional[int]`): 295 Whether to log c10d communication operations and other debug info from PyTorch Distributed components. 296 Default: ``logging.WARN`` 297 298 c10d (:class:`Optional[int]`): 299 Whether to log c10d communication operations related debug info in PyTorch Distributed components. 300 Default: ``logging.WARN`` 301 302 ddp (:class:`Optional[int]`): 303 Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components. 304 Default: ``logging.WARN`` 305 306 fsdp (:class:`Optional[int]`): 307 Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components. 308 Default: ``logging.WARN`` 309 310 dtensor (:class:`Optional[int]`): 311 Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components. 312 Default: ``logging.WARN`` 313 314 onnx (:class:`Optional[int]`): 315 The log level for the ONNX exporter component. Default: ``logging.WARN`` 316 317 bytecode (:class:`bool`): 318 Whether to emit the original and generated bytecode from TorchDynamo. 319 Default: ``False`` 320 321 aot_graphs (:class:`bool`): 322 Whether to emit the graphs generated by AOTAutograd. Default: ``False`` 323 324 aot_joint_graph (:class:`bool`): 325 Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` 326 327 ddp_graphs (:class:`bool`): 328 Whether to emit graphs generated by DDPOptimizer. Default: ``False`` 329 330 graph (:class:`bool`): 331 Whether to emit the graph captured by TorchDynamo in tabular format. 332 Default: ``False`` 333 334 graph_code (:class:`bool`): 335 Whether to emit the python source of the graph captured by TorchDynamo. 336 Default: ``False`` 337 338 graph_breaks (:class:`bool`): 339 Whether to emit the graph breaks encountered by TorchDynamo. 340 Default: ``False`` 341 342 graph_sizes (:class:`bool`): 343 Whether to emit tensor sizes of the graph captured by TorchDynamo. 344 Default: ``False`` 345 346 guards (:class:`bool`): 347 Whether to emit the guards generated by TorchDynamo for each compiled 348 function. Default: ``False`` 349 350 recompiles (:class:`bool`): 351 Whether to emit a guard failure reason and message every time 352 TorchDynamo recompiles a function. Default: ``False`` 353 354 recompiles_verbose (:class:`bool`): 355 Whether to emit all guard failure reasons when TorchDynamo recompiles 356 a function, even those that are not actually run. Default: ``False`` 357 358 trace_source (:class:`bool`): 359 Whether to emit when TorchDynamo begins tracing a new line. Default: ``False`` 360 361 trace_call (:class:`bool`): 362 Whether to emit detailed line location when TorchDynamo creates an FX node 363 corresponding to function call. Python 3.11+ only. Default: ``False`` 364 365 trace_bytecode (:class:`bool`): 366 Whether to emit bytecode instructions and traced stack state as TorchDynamo 367 traces bytecode. Default: ``False`` 368 369 output_code (:class:`bool`): 370 Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False`` 371 372 kernel_code (:class:`bool`): 373 Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False`` 374 375 schedule (:class:`bool`): 376 Whether to emit the TorchInductor schedule. Default: ``False`` 377 378 perf_hints (:class:`bool`): 379 Whether to emit the TorchInductor perf hints. Default: ``False`` 380 381 post_grad_graphs (:class:`bool`): 382 Whether to emit the graphs generated by after post grad passes. Default: ``False`` 383 384 onnx_diagnostics (:class:`bool`): 385 Whether to emit the ONNX exporter diagnostics in logging. Default: ``False`` 386 387 fusion (:class:`bool`): 388 Whether to emit detailed Inductor fusion decisions. Default: ``False`` 389 390 overlap (:class:`bool`): 391 Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False`` 392 393 sym_node (:class:`bool`): 394 Whether to emit debug info for various SymNode opterations. Default: ``False`` 395 396 export (:class:`Optional[int]`): 397 The log level for export. Default: ``logging.WARN`` 398 399 benchmarking (:class:`bool`): 400 Whether to emit detailed Inductor benchmarking information. Default: ``False`` 401 402 modules (dict): 403 This argument provides an alternate way to specify the above log 404 component and artifact settings, in the format of a keyword args 405 dictionary given as a single argument. There are two cases 406 where this is useful (1) if a new log component or artifact has 407 been registered but a keyword argument for it has not been added 408 to this function and (2) if the log level for an unregistered module 409 needs to be set. This can be done by providing the fully-qualified module 410 name as the key, with the log level as the value. Default: ``None`` 411 412 cudagraph_static_inputs (:class:`bool`): 413 Whether to emit debug info for cudagraph static input detection. Default: ``False`` 414 415 416 Example:: 417 418 >>> # xdoctest: +SKIP 419 >>> import logging 420 421 # The following changes the "dynamo" component to emit DEBUG-level 422 # logs, and to emit "graph_code" artifacts. 423 424 >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True) 425 426 # The following enables the logs for a different module 427 428 >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG}) 429 """ 430 # ignore if env var is set 431 if LOG_ENV_VAR in os.environ: 432 log.warning( 433 "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs" 434 ) 435 return 436 437 log_state.clear() 438 439 modules = modules or {} 440 441 def _set_logs(**kwargs): 442 for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr] 443 if val is None: 444 continue 445 446 if log_registry.is_artifact(alias): 447 if not isinstance(val, bool): 448 raise ValueError( 449 f"Expected bool to enable artifact {alias}, received {val}" 450 ) 451 452 if val: 453 log_state.enable_artifact(alias) 454 elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames: 455 if val not in logging._levelToName: 456 raise ValueError( 457 f"Unrecognized log level for log {alias}: {val}, valid level values " 458 f"are: {','.join([str(k) for k in logging._levelToName.keys()])}" 459 ) 460 461 log_state.enable_log( 462 log_registry.log_alias_to_log_qnames.get(alias, alias), val 463 ) 464 else: 465 raise ValueError( 466 f"Unrecognized log or artifact name passed to set_logs: {alias}" 467 ) 468 469 _init_logs() 470 471 _set_logs( 472 torch=all, 473 dynamo=dynamo, 474 aot=aot, 475 autograd=autograd, 476 inductor=inductor, 477 dynamic=dynamic, 478 bytecode=bytecode, 479 aot_graphs=aot_graphs, 480 aot_joint_graph=aot_joint_graph, 481 ddp_graphs=ddp_graphs, 482 distributed=distributed, 483 c10d=c10d, 484 ddp=ddp, 485 fsdp=fsdp, 486 dtensor=dtensor, 487 graph=graph, 488 graph_code=graph_code, 489 graph_breaks=graph_breaks, 490 graph_sizes=graph_sizes, 491 guards=guards, 492 recompiles=recompiles, 493 recompiles_verbose=recompiles_verbose, 494 trace_source=trace_source, 495 trace_call=trace_call, 496 trace_bytecode=trace_bytecode, 497 output_code=output_code, 498 kernel_code=kernel_code, 499 schedule=schedule, 500 perf_hints=perf_hints, 501 post_grad_graphs=post_grad_graphs, 502 onnx=onnx, 503 onnx_diagnostics=onnx_diagnostics, 504 fusion=fusion, 505 overlap=overlap, 506 sym_node=sym_node, 507 export=export, 508 cudagraphs=cudagraphs, 509 compiled_autograd=compiled_autograd, 510 compiled_autograd_verbose=compiled_autograd_verbose, 511 cudagraph_static_inputs=cudagraph_static_inputs, 512 benchmarking=benchmarking, 513 ) 514 515 516def get_loggers(): 517 """ 518 Returns: a list of all registered loggers 519 """ 520 return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()] 521 522 523def register_log(setting_name, log_name): 524 """ 525 Enables a log to be controlled by the env var and user API with the setting_name 526 Args: 527 setting_name: the shorthand name used in the env var and user API 528 log_name: the log name that the setting_name is associated with 529 """ 530 log_registry.register_log(setting_name, log_name) 531 532 533def register_artifact( 534 setting_name, description, visible=False, off_by_default=False, log_format=None 535): 536 """ 537 Enables an artifact to be controlled by the env var and user API with name 538 Args: 539 setting_name: the shorthand name used in the env var and user API 540 description: A description of what this outputs 541 visible: Whether it gets suggested to users by default 542 off_by_default: whether this artifact should be logged when the ancestor loggers 543 are enabled at level DEBUG 544 """ 545 log_registry.register_artifact_name( 546 setting_name, description, visible, off_by_default, log_format 547 ) 548 549 550def getArtifactLogger(module_qname, artifact_name): 551 if artifact_name not in log_registry.artifact_names: 552 raise ValueError( 553 f"Artifact name: {repr(artifact_name)} not registered," 554 f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations." 555 ) 556 qname = module_qname + f".__{artifact_name}" 557 log = logging.getLogger(qname) 558 log.artifact_name = artifact_name # type: ignore[attr-defined] 559 log_registry.register_artifact_log(qname) 560 configure_artifact_log(log) 561 return log 562 563 564INCR_VERBOSITY_CHAR = "+" 565DECR_VERBOSITY_CHAR = "-" 566VERBOSITY_REGEX = ( 567 "(" 568 + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)]) 569 + "?)" 570) 571 572 573def configure_artifact_log(log): 574 # If the artifact is off by default, then it should only be logged when explicitly 575 # enabled; set propagate to False so that this artifact is not propagated 576 # to its ancestor logger 577 if log_registry.is_off_by_default(log.artifact_name): 578 log.propagate = False 579 580 # enable artifact logging when explicitly enabled 581 if log_state.is_artifact_enabled(log.artifact_name): 582 log.setLevel(logging.DEBUG) 583 log.propagate = True 584 585 586# match a comma separated list of loggable names (whitespace allowed after commas) 587def _gen_settings_regex(): 588 return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?") 589 590 591def _validate_settings(settings): 592 return re.fullmatch(_gen_settings_regex(), settings) is not None 593 594 595def help_message(verbose=False): 596 def pad_to(s, length=30): 597 assert len(s) <= length 598 return s + " " * (length - len(s)) 599 600 if verbose: 601 printed_artifacts = log_registry.artifact_names 602 else: 603 printed_artifacts = log_registry.visible_artifacts 604 605 if verbose: 606 heading = "All registered names" 607 else: 608 heading = "Visible registered names (use TORCH_LOGS='+help' for full list)" 609 lines = ( 610 ["all"] 611 + sorted(log_registry.log_alias_to_log_qnames.keys()) 612 + sorted( 613 [ 614 f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}" 615 for name in printed_artifacts 616 ] 617 ) 618 ) 619 setting_info = " " + "\n ".join(lines) 620 examples = """ 621Examples: 622 TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to 623 logging.DEBUG and AOT to logging.INFO 624 625 TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to 626 logging.ERROR and TorchInductor to logging.DEBUG 627 628 TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact 629 630 TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo 631 to logging.DEBUG and enable the schedule artifact 632 633 TORCH_LOGS="+some.random.module,schedule" will set the log level of 634 some.random.module to logging.DEBUG and enable the schedule artifact 635 636 TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format 637 string will set the output format 638 Valid keys are "levelname", "message", "pathname", "levelno", "lineno", 639 "filename" and "name". 640 641 TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as 642 well. This is useful when the output is long. 643""" # flake8: noqa: B950 644 msg = f""" 645TORCH_LOGS Info 646{examples} 647 648{heading} 649{setting_info} 650""" 651 return msg 652 653 654def _invalid_settings_err_msg(settings, verbose=False): 655 valid_settings = ", ".join( 656 ["all"] 657 + list(log_registry.log_alias_to_log_qnames.keys()) 658 + list(log_registry.artifact_names) 659 ) 660 msg = f""" 661Invalid log settings: {settings}, must be a comma separated list of fully 662qualified module names, registered log names or registered artifact names. 663For more info on various settings, try TORCH_LOGS="help" 664Valid settings: 665{valid_settings} 666""" 667 return msg 668 669 670@functools.lru_cache 671def _parse_log_settings(settings): 672 if settings == "": 673 return {} 674 675 if settings == "help": 676 raise ValueError(help_message(verbose=False)) 677 elif settings == "+help": 678 raise ValueError(help_message(verbose=True)) 679 if not _validate_settings(settings): 680 raise ValueError(_invalid_settings_err_msg(settings)) 681 682 settings = re.sub(r"\s+", "", settings) 683 log_names = settings.split(",") 684 685 def get_name_level_pair(name): 686 clean_name = name.replace(INCR_VERBOSITY_CHAR, "") 687 clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "") 688 689 if name[0] == INCR_VERBOSITY_CHAR: 690 level = logging.DEBUG 691 elif name[0] == DECR_VERBOSITY_CHAR: 692 level = logging.ERROR 693 else: 694 level = logging.INFO 695 696 return clean_name, level 697 698 log_state = LogState() 699 700 for name in log_names: 701 name, level = get_name_level_pair(name) 702 703 if name == "all": 704 name = "torch" 705 706 if log_registry.is_log(name): 707 assert level is not None 708 log_qnames = log_registry.log_alias_to_log_qnames[name] 709 log_state.enable_log(log_qnames, level) 710 elif log_registry.is_artifact(name): 711 log_state.enable_artifact(name) 712 elif _is_valid_module(name): 713 if not _has_registered_parent(name): 714 log_registry.register_log(name, name) 715 else: 716 log_registry.register_child_log(name) 717 log_state.enable_log(name, level) 718 else: 719 raise ValueError(_invalid_settings_err_msg(settings)) 720 721 return log_state 722 723 724def _is_valid_module(qname): 725 try: 726 __import__(qname) 727 return True 728 except ImportError: 729 return False 730 731 732def _update_log_state_from_env(): 733 global log_state 734 log_setting = os.environ.get(LOG_ENV_VAR, None) 735 if log_setting is not None: 736 log_state = _parse_log_settings(log_setting) 737 738 739def _has_registered_parent(log_qname): 740 cur_log = logging.getLogger(log_qname) 741 742 registered_log_qnames = log_registry.get_log_qnames() 743 744 while cur_log.parent: 745 if cur_log.name in registered_log_qnames: 746 return True 747 cur_log = cur_log.parent 748 749 return False 750 751 752def make_module_path_relative(abs_path): 753 """ 754 Given an absolute filepath corresponding to a Python module which was 755 loaded via normal import mechanisms using sys.path, convert it into 756 a relative path relative to one of the Python search paths. 757 """ 758 759 abs_path = pathlib.Path(abs_path).resolve() 760 761 for path in sys.path: 762 try: 763 rel_path = abs_path.relative_to(path) 764 except ValueError: 765 continue 766 else: 767 return str(rel_path) 768 769 return str(abs_path) 770 771 772# apply custom formats to artifacts when necessary 773class TorchLogsFormatter(logging.Formatter): 774 def __init__(self, *, trace: bool = False): 775 super().__init__() 776 self._is_trace = trace 777 778 def format(self, record): 779 artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None) 780 if artifact_name is not None: 781 artifact_formatter = log_registry.artifact_log_formatters.get( 782 artifact_name, None 783 ) 784 if artifact_formatter is not None: 785 return artifact_formatter.format(record) 786 787 record.message = record.getMessage() 788 record.asctime = self.formatTime(record, "%m%d %H:%M:%S") 789 790 # exception handling - copied from logging.Formatter.format 791 s = record.message 792 if record.exc_info: 793 # Cache the traceback text to avoid converting it multiple times 794 # (it's constant anyway) 795 if not record.exc_text: 796 record.exc_text = self.formatException(record.exc_info) 797 if record.exc_text: 798 if s[-1:] != "\n": 799 s = s + "\n" 800 s = s + record.exc_text 801 if record.stack_info: 802 if s[-1:] != "\n": 803 s = s + "\n" 804 s = s + self.formatStack(record.stack_info) 805 806 record.rankprefix = "" 807 if not self._is_trace and dist.is_available() and dist.is_initialized(): 808 record.rankprefix = f"[rank{dist.get_rank()}]:" 809 810 record.traceid = "" 811 if ( 812 not self._is_trace 813 and (trace_id := torch._guards.CompileContext.current_trace_id()) 814 is not None 815 ): 816 record.traceid = f" [{trace_id}]" 817 818 glog_level_to_abbr = { 819 "DEBUG": "V", # V is for VERBOSE in glog 820 "INFO": "I", 821 "WARNING": "W", 822 "ERROR": "E", 823 "CRITICAL": "C", 824 } 825 826 shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname) 827 828 record.artifactprefix = "" 829 if artifact_name is not None: 830 record.artifactprefix = f" [__{artifact_name}]" 831 832 filepath = make_module_path_relative(record.pathname) 833 834 prefix = ( 835 f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.process} " 836 f"{filepath}:" 837 f"{record.lineno}]{record.traceid}{record.artifactprefix}" 838 ) 839 if self._is_trace: 840 assert s == "" 841 try: 842 r = f"{prefix} {json.dumps(record.metadata)}" 843 except TypeError: 844 log.warning("failing metadata: %r", record.metadata) 845 raise 846 if record.payload is not None: 847 r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) 848 return r 849 else: 850 lines = s.split("\n") 851 return "\n".join(f"{prefix} {l}" for l in lines) 852 853 854def _default_formatter(): 855 fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None) 856 if fmt is None: 857 return TorchLogsFormatter() 858 else: 859 if fmt in ("short", "basic"): 860 fmt = logging.BASIC_FORMAT 861 return logging.Formatter(fmt) 862 863 864DEFAULT_FORMATTER = _default_formatter() 865 866 867def _setup_handlers(create_handler_fn, log): 868 debug_handler = _track_handler(create_handler_fn()) 869 debug_handler.setFormatter(DEFAULT_FORMATTER) 870 debug_handler.setLevel(logging.DEBUG) 871 log.addHandler(debug_handler) 872 873 874handlers = WeakSet() # type: ignore[var-annotated] 875 876 877# mark handlers that we've created 878# so we don't modify user handlers 879def _track_handler(handler): 880 handlers.add(handler) 881 return handler 882 883 884def _is_torch_handler(handler): 885 return handler in handlers 886 887 888# clears all torch handlers on specified loggers 889def _clear_handlers(log): 890 to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)] 891 for handler in to_remove: 892 log.removeHandler(handler) 893 894 895def _reset_logs(): 896 # reset all registered logs 897 for log_qname in log_registry.get_log_qnames(): 898 log = logging.getLogger(log_qname) 899 log.setLevel(logging.WARNING) 900 log.propagate = False 901 _clear_handlers(log) 902 903 # reset all artifact and child logs 904 for artifact_log_qname in itertools.chain( 905 log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames() 906 ): 907 log = logging.getLogger(artifact_log_qname) 908 log.setLevel(logging.NOTSET) 909 log.propagate = True 910 911 trace_log.propagate = False 912 _clear_handlers(trace_log) 913 914 915def _get_log_state(): 916 return log_state 917 918 919def _set_log_state(state): 920 global log_state 921 log_state = state 922 923 924def _init_logs(log_file_name=None): 925 _reset_logs() 926 _update_log_state_from_env() 927 928 out = os.environ.get(LOG_OUT_ENV_VAR, None) 929 if out is not None: 930 log_file_name = out 931 932 # First, reset all known (registered) loggers to NOTSET, so that they 933 # respect their parent log level 934 for log_qname in log_registry.get_log_qnames(): 935 # But not the top level torch level: this defaults to WARNING so 936 # that our log messages don't leak to the lower levels 937 if log_qname == "torch": 938 continue 939 log = logging.getLogger(log_qname) 940 log.setLevel(logging.NOTSET) 941 942 # Now, for all loggers which the user requested to have non-standard 943 # logging behavior, modify their log levels 944 for log_qname, level in log_state.get_log_level_pairs(): 945 log = logging.getLogger(log_qname) 946 log.setLevel(level) 947 948 # Finally, setup handlers for all registered loggers 949 for log_qname in log_registry.get_log_qnames(): 950 log = logging.getLogger(log_qname) 951 _setup_handlers( 952 logging.StreamHandler, 953 log, 954 ) 955 956 if log_file_name is not None: 957 _setup_handlers( 958 lambda: logging.FileHandler(log_file_name), 959 log, 960 ) 961 962 # configure artifact loggers, note: this must happen last 963 # since the levels of ancestor loggers are taken into account 964 for artifact_log_qname in log_registry.get_artifact_log_qnames(): 965 log = logging.getLogger(artifact_log_qname) 966 configure_artifact_log(log) 967 968 # Setup handler for the special trace_log, with different default 969 # configuration 970 trace_dir_name = os.environ.get(TRACE_ENV_VAR, None) 971 # This handler may remove itself if trace_dir_name is None and we are not 972 # actually in an FB environment. This allows us to defer actually 973 # initializing it until we actually need to log anything. This is 974 # important because JK initializes a C++ singleton, which will pork our 975 # process if we subsequently fork. 976 handler = LazyTraceHandler(trace_dir_name) 977 # This log is ALWAYS at debug level. We will additionally test if there 978 # are any handlers before deciding to actually call logging on this. Do 979 # not manually call 980 trace_log.setLevel(logging.DEBUG) 981 trace_log_handler = _track_handler(handler) 982 trace_log_handler.setFormatter(TorchLogsFormatter(trace=True)) 983 trace_log.addHandler(trace_log_handler) 984 985 986class LazyTraceHandler(logging.StreamHandler): 987 """Like FileHandler, but the file is allocated lazily only upon the first log message""" 988 989 def __init__(self, root_dir: Optional[str]): 990 # This is implemented in the same way that delay is implemented on 991 # FileHandler 992 self.root_dir = root_dir 993 logging.Handler.__init__(self) 994 self.stream = None 995 self._builtin_open = open 996 997 # cloned from FileHandler in cpython 998 def close(self): 999 self.acquire() 1000 try: 1001 try: 1002 if self.stream: 1003 try: 1004 self.flush() 1005 finally: 1006 stream = self.stream 1007 self.stream = None 1008 if hasattr(stream, "close"): 1009 stream.close() 1010 finally: 1011 # Issue #19523: call unconditionally to 1012 # prevent a handler leak when delay is set 1013 # Also see Issue #42378: we also rely on 1014 # self._closed being set to True there 1015 logging.StreamHandler.close(self) 1016 finally: 1017 self.release() 1018 1019 def emit(self, record): 1020 if self.stream is None: 1021 ok = False 1022 if self.root_dir is None: 1023 TRACE_LOG_DIR = "/logs" 1024 open_func = self._builtin_open 1025 1026 import torch.version as torch_version 1027 1028 if ( 1029 hasattr(torch_version, "git_version") 1030 and os.getenv("MAST_HPC_JOB_NAME") is None 1031 ): 1032 log.info( 1033 "LazyTraceHandler: disabled because not fbcode or conda on mast" 1034 ) 1035 elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"): 1036 log.info( 1037 "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False" 1038 ) 1039 elif not os.path.exists(TRACE_LOG_DIR): 1040 log.info( 1041 "LazyTraceHandler: disabled because %s does not exist", 1042 TRACE_LOG_DIR, 1043 ) 1044 elif not os.access(TRACE_LOG_DIR, os.W_OK): 1045 log.info( 1046 "LazyTraceHandler: disabled because %s is not writeable", 1047 TRACE_LOG_DIR, 1048 ) 1049 else: 1050 self.root_dir = TRACE_LOG_DIR 1051 1052 if self.root_dir is not None: 1053 os.makedirs(self.root_dir, exist_ok=True) 1054 ranksuffix = "" 1055 if dist.is_available() and dist.is_initialized(): 1056 ranksuffix = f"rank_{dist.get_rank()}_" 1057 self.stream = tempfile.NamedTemporaryFile( 1058 mode="w+", 1059 suffix=".log", 1060 prefix=f"dedicated_log_torch_trace_{ranksuffix}", 1061 dir=self.root_dir, 1062 delete=False, 1063 ) 1064 log.info("LazyTraceHandler: logging to %s", self.stream.name) 1065 else: 1066 # We go poof, remove and no-op 1067 trace_log.removeHandler(self) 1068 return 1069 if self.stream: 1070 super().emit(record) 1071 1072 1073@functools.lru_cache(None) 1074def warning_once(logger_obj, *args, **kwargs): 1075 """ 1076 This function is similar to `logger.warning()`, but will emit the warning with the same message only once 1077 Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. 1078 The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to 1079 another type of cache that includes the caller frame information in the hashing function. 1080 """ 1081 logger_obj.warning(*args, **kwargs) 1082 1083 1084class LazyString: 1085 def __init__(self, func, *args, **kwargs): 1086 self.func = func 1087 self.args = args 1088 self.kwargs = kwargs 1089 1090 def __str__(self): 1091 return self.func(*self.args, **self.kwargs) 1092 1093 1094def trace_structured( 1095 name: str, 1096 # NB: metadata expected to be dict so adding more info is forward compatible 1097 # Tuple[str, int] is a special case for string interning 1098 metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict, 1099 *, 1100 payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, 1101 suppress_context: bool = False, 1102 expect_trace_id: bool = True, # Whether or not we expect to have a current trace id 1103): 1104 """ 1105 metadata is an arbitrary JSON compatible struct, but it's expected to not be 1106 too long (e.g., less than 1MB) 1107 1108 payload is an arbitrary string, which can be arbitrarily long (but expected to have 1109 newlines so no lines are too long) 1110 """ 1111 assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"] 1112 assert callable( 1113 metadata_fn 1114 ), f"metadata_fn should be callable, but got {type(metadata_fn)}" 1115 assert callable( 1116 payload_fn 1117 ), f"payload_fn should be callable, but got {type(payload_fn)}" 1118 # trace_log never propagates and is ALWAYS DEBUG, so also check that there 1119 # are handlers instead of checking the log level 1120 if trace_log.handlers: 1121 record: Dict[str, object] = {} 1122 record[name] = metadata_fn() 1123 if not suppress_context: 1124 # TODO: Actually, the rank probably should just be emitted once at 1125 # the top, and not repeatedly spammed in all the logs, since it 1126 # never changes and we assume no interleaving 1127 if dist.is_available() and dist.is_initialized(): 1128 record["rank"] = dist.get_rank() 1129 if ( 1130 trace_id := torch._guards.CompileContext.current_trace_id() 1131 ) is not None: 1132 record["frame_id"] = trace_id.compile_id.frame_id 1133 record["frame_compile_id"] = trace_id.compile_id.frame_compile_id 1134 record["attempt"] = trace_id.attempt 1135 else: 1136 if expect_trace_id: 1137 # Record the stack of the log call to better diagnose why we 1138 # don't have a frame id for it 1139 record["stack"] = torch._logging.structured.from_traceback( 1140 CapturedTraceback.extract(skip=1).summary() 1141 ) 1142 payload = payload_fn() 1143 if payload is not None: 1144 if not isinstance(payload, str): 1145 if isinstance(payload, list): 1146 # special case to look better 1147 payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]" 1148 else: 1149 # force newlines so we are unlikely to overflow line limit 1150 payload = json.dumps(payload, indent=0) 1151 h = hashlib.md5() 1152 h.update(payload.encode("utf-8")) 1153 record["has_payload"] = h.hexdigest() 1154 trace_log.debug( 1155 "", extra={"metadata": record, "payload": payload}, stacklevel=2 1156 ) 1157 log_trace_structured_event(name, record) 1158 1159 1160import torch._guards 1161import torch._utils_internal 1162import torch.distributed as dist 1163