1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Workerimport re 10*523fa7a6SAndroid Build Coastguard Workerimport sys 11*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker@contextmanager 18*523fa7a6SAndroid Build Coastguard Workerdef no_dispatch() -> Generator[None, None, None]: 19*523fa7a6SAndroid Build Coastguard Worker guard = torch._C._DisableTorchDispatch() 20*523fa7a6SAndroid Build Coastguard Worker try: 21*523fa7a6SAndroid Build Coastguard Worker yield 22*523fa7a6SAndroid Build Coastguard Worker finally: 23*523fa7a6SAndroid Build Coastguard Worker del guard 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerdef get_schema_for_operators(ops: List[str]) -> Dict[str, str]: 27*523fa7a6SAndroid Build Coastguard Worker r""" 28*523fa7a6SAndroid Build Coastguard Worker Accept a list of operator names fetched from the Graph Module (these are of 29*523fa7a6SAndroid Build Coastguard Worker the form torch.ops.aten.cat.default, and return a dict of operator name (in 30*523fa7a6SAndroid Build Coastguard Worker the form namespace::op_name.overload_name) to operator schema string. 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker Note: This method should only be used for debugging errors in export, and 33*523fa7a6SAndroid Build Coastguard Worker not in a production context. 34*523fa7a6SAndroid Build Coastguard Worker """ 35*523fa7a6SAndroid Build Coastguard Worker d = {} 36*523fa7a6SAndroid Build Coastguard Worker pat = re.compile(r"^torch.ops.([^\.]+)\.(.*)") 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Worker aten_ops = [] 39*523fa7a6SAndroid Build Coastguard Worker for op in ops: 40*523fa7a6SAndroid Build Coastguard Worker aten_ops.append(re.sub(pat, r"\1::\2", op)) 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker all_schemas = torch._C._jit_get_all_schemas() 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker schema_dict = {} 45*523fa7a6SAndroid Build Coastguard Worker for s in all_schemas: 46*523fa7a6SAndroid Build Coastguard Worker n = s.name 47*523fa7a6SAndroid Build Coastguard Worker if s.overload_name != "": 48*523fa7a6SAndroid Build Coastguard Worker n = n + "." + s.overload_name 49*523fa7a6SAndroid Build Coastguard Worker else: 50*523fa7a6SAndroid Build Coastguard Worker n = n + ".default" 51*523fa7a6SAndroid Build Coastguard Worker schema_dict[n] = str(s) 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker for op in aten_ops: 54*523fa7a6SAndroid Build Coastguard Worker d[op] = "<No Schema Found>" 55*523fa7a6SAndroid Build Coastguard Worker if op in schema_dict: 56*523fa7a6SAndroid Build Coastguard Worker d[op] = schema_dict[op] 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker return d 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard WorkerT = TypeVar("T") # Declare type variable 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Workerdef extract_out_arguments( 65*523fa7a6SAndroid Build Coastguard Worker schema: torch._C.FunctionSchema, keyword_args: Dict[str, T] 66*523fa7a6SAndroid Build Coastguard Worker) -> Union[Tuple[str, T], List[Tuple[str, T]]]: 67*523fa7a6SAndroid Build Coastguard Worker # Given a possible out schema, find all out arguments and return them as tuple of 68*523fa7a6SAndroid Build Coastguard Worker # the arg name and the actual value. 69*523fa7a6SAndroid Build Coastguard Worker out_args = [] 70*523fa7a6SAndroid Build Coastguard Worker for arg in schema.arguments: 71*523fa7a6SAndroid Build Coastguard Worker name = arg.name 72*523fa7a6SAndroid Build Coastguard Worker if arg.is_out and name in keyword_args: 73*523fa7a6SAndroid Build Coastguard Worker out_args.append((name, keyword_args[name])) 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker # TODO (tmanlaibaatar) There are 3 ops with TensorList as the storage for aliased tensor 76*523fa7a6SAndroid Build Coastguard Worker # which was added after is_out logic. Until we fix that implementation, 77*523fa7a6SAndroid Build Coastguard Worker # hack to manually add out args 78*523fa7a6SAndroid Build Coastguard Worker if len(out_args) == 0: 79*523fa7a6SAndroid Build Coastguard Worker if "out" in keyword_args: 80*523fa7a6SAndroid Build Coastguard Worker out_args.append(("out", keyword_args["out"])) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker if len(out_args) == 1: 83*523fa7a6SAndroid Build Coastguard Worker return out_args[0] 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker return out_args 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Workerdef format_schema_name(schema: torch._C.FunctionSchema) -> str: 89*523fa7a6SAndroid Build Coastguard Worker if schema.overload_name != "": 90*523fa7a6SAndroid Build Coastguard Worker return schema.name + "." + schema.overload_name 91*523fa7a6SAndroid Build Coastguard Worker return schema.name 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker@contextmanager 95*523fa7a6SAndroid Build Coastguard Workerdef override_logger( 96*523fa7a6SAndroid Build Coastguard Worker newLevel: int = logging.DEBUG, 97*523fa7a6SAndroid Build Coastguard Worker fmtstr: str = "%(message)s", 98*523fa7a6SAndroid Build Coastguard Worker filename: Optional[str] = None, 99*523fa7a6SAndroid Build Coastguard Worker) -> Generator[None, None, None]: 100*523fa7a6SAndroid Build Coastguard Worker """ 101*523fa7a6SAndroid Build Coastguard Worker If an nonempty filename string is provided, the log wil also be written to 102*523fa7a6SAndroid Build Coastguard Worker that file besides stderr. 103*523fa7a6SAndroid Build Coastguard Worker """ 104*523fa7a6SAndroid Build Coastguard Worker try: 105*523fa7a6SAndroid Build Coastguard Worker oldLevel = logging.root.level 106*523fa7a6SAndroid Build Coastguard Worker logging.root.setLevel(newLevel) 107*523fa7a6SAndroid Build Coastguard Worker if fmtstr: 108*523fa7a6SAndroid Build Coastguard Worker newformatter = logging.Formatter(fmtstr, None, "%") 109*523fa7a6SAndroid Build Coastguard Worker oldFormatters = [] 110*523fa7a6SAndroid Build Coastguard Worker for handler in logging.root.handlers: 111*523fa7a6SAndroid Build Coastguard Worker oldFormatters.append(handler.formatter) 112*523fa7a6SAndroid Build Coastguard Worker handler.formatter = newformatter 113*523fa7a6SAndroid Build Coastguard Worker filehandler = None 114*523fa7a6SAndroid Build Coastguard Worker if filename: 115*523fa7a6SAndroid Build Coastguard Worker filehandler = logging.FileHandler(filename, mode="w") 116*523fa7a6SAndroid Build Coastguard Worker logging.root.addHandler(filehandler) 117*523fa7a6SAndroid Build Coastguard Worker yield 118*523fa7a6SAndroid Build Coastguard Worker finally: 119*523fa7a6SAndroid Build Coastguard Worker logging.root.setLevel(oldLevel) 120*523fa7a6SAndroid Build Coastguard Worker if fmtstr: 121*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[61]: `oldFormatters` is undefined, or not always defined. 122*523fa7a6SAndroid Build Coastguard Worker for handler, formatter in zip(logging.root.handlers, oldFormatters): 123*523fa7a6SAndroid Build Coastguard Worker handler.formatter = formatter 124*523fa7a6SAndroid Build Coastguard Worker if filehandler: 125*523fa7a6SAndroid Build Coastguard Worker logging.root.removeHandler(filehandler) 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker@contextmanager 129*523fa7a6SAndroid Build Coastguard Workerdef setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]: 130*523fa7a6SAndroid Build Coastguard Worker """ 131*523fa7a6SAndroid Build Coastguard Worker Temporarily increase the python interpreter stack recursion limit. 132*523fa7a6SAndroid Build Coastguard Worker This is mostly used for pickling large scale modules. 133*523fa7a6SAndroid Build Coastguard Worker """ 134*523fa7a6SAndroid Build Coastguard Worker default = sys.getrecursionlimit() 135*523fa7a6SAndroid Build Coastguard Worker if limit > default: 136*523fa7a6SAndroid Build Coastguard Worker sys.setrecursionlimit(limit) 137*523fa7a6SAndroid Build Coastguard Worker try: 138*523fa7a6SAndroid Build Coastguard Worker yield 139*523fa7a6SAndroid Build Coastguard Worker finally: 140*523fa7a6SAndroid Build Coastguard Worker sys.setrecursionlimit(default) 141