xref: /aosp_15_r20/external/executorch/exir/common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import logging
9import re
10import sys
11from contextlib import contextmanager
12from typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
13
14import torch
15
16
17@contextmanager
18def no_dispatch() -> Generator[None, None, None]:
19    guard = torch._C._DisableTorchDispatch()
20    try:
21        yield
22    finally:
23        del guard
24
25
26def get_schema_for_operators(ops: List[str]) -> Dict[str, str]:
27    r"""
28    Accept a list of operator names fetched from the Graph Module (these are of
29    the form torch.ops.aten.cat.default, and return a dict of operator name (in
30    the form namespace::op_name.overload_name) to operator schema string.
31
32    Note: This method should only be used for debugging errors in export, and
33    not in a production context.
34    """
35    d = {}
36    pat = re.compile(r"^torch.ops.([^\.]+)\.(.*)")
37
38    aten_ops = []
39    for op in ops:
40        aten_ops.append(re.sub(pat, r"\1::\2", op))
41
42    all_schemas = torch._C._jit_get_all_schemas()
43
44    schema_dict = {}
45    for s in all_schemas:
46        n = s.name
47        if s.overload_name != "":
48            n = n + "." + s.overload_name
49        else:
50            n = n + ".default"
51        schema_dict[n] = str(s)
52
53    for op in aten_ops:
54        d[op] = "<No Schema Found>"
55        if op in schema_dict:
56            d[op] = schema_dict[op]
57
58    return d
59
60
61T = TypeVar("T")  # Declare type variable
62
63
64def extract_out_arguments(
65    schema: torch._C.FunctionSchema, keyword_args: Dict[str, T]
66) -> Union[Tuple[str, T], List[Tuple[str, T]]]:
67    # Given a possible out schema, find all out arguments and return them as tuple of
68    # the arg name and the actual value.
69    out_args = []
70    for arg in schema.arguments:
71        name = arg.name
72        if arg.is_out and name in keyword_args:
73            out_args.append((name, keyword_args[name]))
74
75    # TODO (tmanlaibaatar) There are 3 ops with TensorList as the storage for aliased tensor
76    # which was added after is_out logic. Until we fix that implementation,
77    # hack to manually add out args
78    if len(out_args) == 0:
79        if "out" in keyword_args:
80            out_args.append(("out", keyword_args["out"]))
81
82    if len(out_args) == 1:
83        return out_args[0]
84
85    return out_args
86
87
88def format_schema_name(schema: torch._C.FunctionSchema) -> str:
89    if schema.overload_name != "":
90        return schema.name + "." + schema.overload_name
91    return schema.name
92
93
94@contextmanager
95def override_logger(
96    newLevel: int = logging.DEBUG,
97    fmtstr: str = "%(message)s",
98    filename: Optional[str] = None,
99) -> Generator[None, None, None]:
100    """
101    If an nonempty filename string is provided, the log wil also be written to
102    that file besides stderr.
103    """
104    try:
105        oldLevel = logging.root.level
106        logging.root.setLevel(newLevel)
107        if fmtstr:
108            newformatter = logging.Formatter(fmtstr, None, "%")
109            oldFormatters = []
110            for handler in logging.root.handlers:
111                oldFormatters.append(handler.formatter)
112                handler.formatter = newformatter
113        filehandler = None
114        if filename:
115            filehandler = logging.FileHandler(filename, mode="w")
116            logging.root.addHandler(filehandler)
117        yield
118    finally:
119        logging.root.setLevel(oldLevel)
120        if fmtstr:
121            # pyre-fixme[61]: `oldFormatters` is undefined, or not always defined.
122            for handler, formatter in zip(logging.root.handlers, oldFormatters):
123                handler.formatter = formatter
124        if filehandler:
125            logging.root.removeHandler(filehandler)
126
127
128@contextmanager
129def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]:
130    """
131    Temporarily increase the python interpreter stack recursion limit.
132    This is mostly used for pickling large scale modules.
133    """
134    default = sys.getrecursionlimit()
135    if limit > default:
136        sys.setrecursionlimit(limit)
137    try:
138        yield
139    finally:
140        sys.setrecursionlimit(default)
141