xref: /aosp_15_r20/external/executorch/exir/common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Workerimport logging
9*523fa7a6SAndroid Build Coastguard Workerimport re
10*523fa7a6SAndroid Build Coastguard Workerimport sys
11*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager
12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Generator, List, Optional, Tuple, TypeVar, Union
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker@contextmanager
18*523fa7a6SAndroid Build Coastguard Workerdef no_dispatch() -> Generator[None, None, None]:
19*523fa7a6SAndroid Build Coastguard Worker    guard = torch._C._DisableTorchDispatch()
20*523fa7a6SAndroid Build Coastguard Worker    try:
21*523fa7a6SAndroid Build Coastguard Worker        yield
22*523fa7a6SAndroid Build Coastguard Worker    finally:
23*523fa7a6SAndroid Build Coastguard Worker        del guard
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerdef get_schema_for_operators(ops: List[str]) -> Dict[str, str]:
27*523fa7a6SAndroid Build Coastguard Worker    r"""
28*523fa7a6SAndroid Build Coastguard Worker    Accept a list of operator names fetched from the Graph Module (these are of
29*523fa7a6SAndroid Build Coastguard Worker    the form torch.ops.aten.cat.default, and return a dict of operator name (in
30*523fa7a6SAndroid Build Coastguard Worker    the form namespace::op_name.overload_name) to operator schema string.
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker    Note: This method should only be used for debugging errors in export, and
33*523fa7a6SAndroid Build Coastguard Worker    not in a production context.
34*523fa7a6SAndroid Build Coastguard Worker    """
35*523fa7a6SAndroid Build Coastguard Worker    d = {}
36*523fa7a6SAndroid Build Coastguard Worker    pat = re.compile(r"^torch.ops.([^\.]+)\.(.*)")
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Worker    aten_ops = []
39*523fa7a6SAndroid Build Coastguard Worker    for op in ops:
40*523fa7a6SAndroid Build Coastguard Worker        aten_ops.append(re.sub(pat, r"\1::\2", op))
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker    all_schemas = torch._C._jit_get_all_schemas()
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker    schema_dict = {}
45*523fa7a6SAndroid Build Coastguard Worker    for s in all_schemas:
46*523fa7a6SAndroid Build Coastguard Worker        n = s.name
47*523fa7a6SAndroid Build Coastguard Worker        if s.overload_name != "":
48*523fa7a6SAndroid Build Coastguard Worker            n = n + "." + s.overload_name
49*523fa7a6SAndroid Build Coastguard Worker        else:
50*523fa7a6SAndroid Build Coastguard Worker            n = n + ".default"
51*523fa7a6SAndroid Build Coastguard Worker        schema_dict[n] = str(s)
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    for op in aten_ops:
54*523fa7a6SAndroid Build Coastguard Worker        d[op] = "<No Schema Found>"
55*523fa7a6SAndroid Build Coastguard Worker        if op in schema_dict:
56*523fa7a6SAndroid Build Coastguard Worker            d[op] = schema_dict[op]
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker    return d
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard WorkerT = TypeVar("T")  # Declare type variable
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Workerdef extract_out_arguments(
65*523fa7a6SAndroid Build Coastguard Worker    schema: torch._C.FunctionSchema, keyword_args: Dict[str, T]
66*523fa7a6SAndroid Build Coastguard Worker) -> Union[Tuple[str, T], List[Tuple[str, T]]]:
67*523fa7a6SAndroid Build Coastguard Worker    # Given a possible out schema, find all out arguments and return them as tuple of
68*523fa7a6SAndroid Build Coastguard Worker    # the arg name and the actual value.
69*523fa7a6SAndroid Build Coastguard Worker    out_args = []
70*523fa7a6SAndroid Build Coastguard Worker    for arg in schema.arguments:
71*523fa7a6SAndroid Build Coastguard Worker        name = arg.name
72*523fa7a6SAndroid Build Coastguard Worker        if arg.is_out and name in keyword_args:
73*523fa7a6SAndroid Build Coastguard Worker            out_args.append((name, keyword_args[name]))
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker    # TODO (tmanlaibaatar) There are 3 ops with TensorList as the storage for aliased tensor
76*523fa7a6SAndroid Build Coastguard Worker    # which was added after is_out logic. Until we fix that implementation,
77*523fa7a6SAndroid Build Coastguard Worker    # hack to manually add out args
78*523fa7a6SAndroid Build Coastguard Worker    if len(out_args) == 0:
79*523fa7a6SAndroid Build Coastguard Worker        if "out" in keyword_args:
80*523fa7a6SAndroid Build Coastguard Worker            out_args.append(("out", keyword_args["out"]))
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    if len(out_args) == 1:
83*523fa7a6SAndroid Build Coastguard Worker        return out_args[0]
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    return out_args
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Workerdef format_schema_name(schema: torch._C.FunctionSchema) -> str:
89*523fa7a6SAndroid Build Coastguard Worker    if schema.overload_name != "":
90*523fa7a6SAndroid Build Coastguard Worker        return schema.name + "." + schema.overload_name
91*523fa7a6SAndroid Build Coastguard Worker    return schema.name
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker@contextmanager
95*523fa7a6SAndroid Build Coastguard Workerdef override_logger(
96*523fa7a6SAndroid Build Coastguard Worker    newLevel: int = logging.DEBUG,
97*523fa7a6SAndroid Build Coastguard Worker    fmtstr: str = "%(message)s",
98*523fa7a6SAndroid Build Coastguard Worker    filename: Optional[str] = None,
99*523fa7a6SAndroid Build Coastguard Worker) -> Generator[None, None, None]:
100*523fa7a6SAndroid Build Coastguard Worker    """
101*523fa7a6SAndroid Build Coastguard Worker    If an nonempty filename string is provided, the log wil also be written to
102*523fa7a6SAndroid Build Coastguard Worker    that file besides stderr.
103*523fa7a6SAndroid Build Coastguard Worker    """
104*523fa7a6SAndroid Build Coastguard Worker    try:
105*523fa7a6SAndroid Build Coastguard Worker        oldLevel = logging.root.level
106*523fa7a6SAndroid Build Coastguard Worker        logging.root.setLevel(newLevel)
107*523fa7a6SAndroid Build Coastguard Worker        if fmtstr:
108*523fa7a6SAndroid Build Coastguard Worker            newformatter = logging.Formatter(fmtstr, None, "%")
109*523fa7a6SAndroid Build Coastguard Worker            oldFormatters = []
110*523fa7a6SAndroid Build Coastguard Worker            for handler in logging.root.handlers:
111*523fa7a6SAndroid Build Coastguard Worker                oldFormatters.append(handler.formatter)
112*523fa7a6SAndroid Build Coastguard Worker                handler.formatter = newformatter
113*523fa7a6SAndroid Build Coastguard Worker        filehandler = None
114*523fa7a6SAndroid Build Coastguard Worker        if filename:
115*523fa7a6SAndroid Build Coastguard Worker            filehandler = logging.FileHandler(filename, mode="w")
116*523fa7a6SAndroid Build Coastguard Worker            logging.root.addHandler(filehandler)
117*523fa7a6SAndroid Build Coastguard Worker        yield
118*523fa7a6SAndroid Build Coastguard Worker    finally:
119*523fa7a6SAndroid Build Coastguard Worker        logging.root.setLevel(oldLevel)
120*523fa7a6SAndroid Build Coastguard Worker        if fmtstr:
121*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[61]: `oldFormatters` is undefined, or not always defined.
122*523fa7a6SAndroid Build Coastguard Worker            for handler, formatter in zip(logging.root.handlers, oldFormatters):
123*523fa7a6SAndroid Build Coastguard Worker                handler.formatter = formatter
124*523fa7a6SAndroid Build Coastguard Worker        if filehandler:
125*523fa7a6SAndroid Build Coastguard Worker            logging.root.removeHandler(filehandler)
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker@contextmanager
129*523fa7a6SAndroid Build Coastguard Workerdef setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]:
130*523fa7a6SAndroid Build Coastguard Worker    """
131*523fa7a6SAndroid Build Coastguard Worker    Temporarily increase the python interpreter stack recursion limit.
132*523fa7a6SAndroid Build Coastguard Worker    This is mostly used for pickling large scale modules.
133*523fa7a6SAndroid Build Coastguard Worker    """
134*523fa7a6SAndroid Build Coastguard Worker    default = sys.getrecursionlimit()
135*523fa7a6SAndroid Build Coastguard Worker    if limit > default:
136*523fa7a6SAndroid Build Coastguard Worker        sys.setrecursionlimit(limit)
137*523fa7a6SAndroid Build Coastguard Worker    try:
138*523fa7a6SAndroid Build Coastguard Worker        yield
139*523fa7a6SAndroid Build Coastguard Worker    finally:
140*523fa7a6SAndroid Build Coastguard Worker        sys.setrecursionlimit(default)
141