1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import logging 9import re 10import sys 11from contextlib import contextmanager 12from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union 13 14import torch 15 16 17@contextmanager 18def no_dispatch() -> Generator[None, None, None]: 19 guard = torch._C._DisableTorchDispatch() 20 try: 21 yield 22 finally: 23 del guard 24 25 26def get_schema_for_operators(ops: List[str]) -> Dict[str, str]: 27 r""" 28 Accept a list of operator names fetched from the Graph Module (these are of 29 the form torch.ops.aten.cat.default, and return a dict of operator name (in 30 the form namespace::op_name.overload_name) to operator schema string. 31 32 Note: This method should only be used for debugging errors in export, and 33 not in a production context. 34 """ 35 d = {} 36 pat = re.compile(r"^torch.ops.([^\.]+)\.(.*)") 37 38 aten_ops = [] 39 for op in ops: 40 aten_ops.append(re.sub(pat, r"\1::\2", op)) 41 42 all_schemas = torch._C._jit_get_all_schemas() 43 44 schema_dict = {} 45 for s in all_schemas: 46 n = s.name 47 if s.overload_name != "": 48 n = n + "." + s.overload_name 49 else: 50 n = n + ".default" 51 schema_dict[n] = str(s) 52 53 for op in aten_ops: 54 d[op] = "<No Schema Found>" 55 if op in schema_dict: 56 d[op] = schema_dict[op] 57 58 return d 59 60 61T = TypeVar("T") # Declare type variable 62 63 64def extract_out_arguments( 65 schema: torch._C.FunctionSchema, keyword_args: Dict[str, T] 66) -> Union[Tuple[str, T], List[Tuple[str, T]]]: 67 # Given a possible out schema, find all out arguments and return them as tuple of 68 # the arg name and the actual value. 69 out_args = [] 70 for arg in schema.arguments: 71 name = arg.name 72 if arg.is_out and name in keyword_args: 73 out_args.append((name, keyword_args[name])) 74 75 # TODO (tmanlaibaatar) There are 3 ops with TensorList as the storage for aliased tensor 76 # which was added after is_out logic. Until we fix that implementation, 77 # hack to manually add out args 78 if len(out_args) == 0: 79 if "out" in keyword_args: 80 out_args.append(("out", keyword_args["out"])) 81 82 if len(out_args) == 1: 83 return out_args[0] 84 85 return out_args 86 87 88def format_schema_name(schema: torch._C.FunctionSchema) -> str: 89 if schema.overload_name != "": 90 return schema.name + "." + schema.overload_name 91 return schema.name 92 93 94@contextmanager 95def override_logger( 96 newLevel: int = logging.DEBUG, 97 fmtstr: str = "%(message)s", 98 filename: Optional[str] = None, 99) -> Generator[None, None, None]: 100 """ 101 If an nonempty filename string is provided, the log wil also be written to 102 that file besides stderr. 103 """ 104 try: 105 oldLevel = logging.root.level 106 logging.root.setLevel(newLevel) 107 if fmtstr: 108 newformatter = logging.Formatter(fmtstr, None, "%") 109 oldFormatters = [] 110 for handler in logging.root.handlers: 111 oldFormatters.append(handler.formatter) 112 handler.formatter = newformatter 113 filehandler = None 114 if filename: 115 filehandler = logging.FileHandler(filename, mode="w") 116 logging.root.addHandler(filehandler) 117 yield 118 finally: 119 logging.root.setLevel(oldLevel) 120 if fmtstr: 121 # pyre-fixme[61]: `oldFormatters` is undefined, or not always defined. 122 for handler, formatter in zip(logging.root.handlers, oldFormatters): 123 handler.formatter = formatter 124 if filehandler: 125 logging.root.removeHandler(filehandler) 126 127 128@contextmanager 129def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: 130 """ 131 Temporarily increase the python interpreter stack recursion limit. 132 This is mostly used for pickling large scale modules. 133 """ 134 default = sys.getrecursionlimit() 135 if limit > default: 136 sys.setrecursionlimit(limit) 137 try: 138 yield 139 finally: 140 sys.setrecursionlimit(default) 141