xref: /aosp_15_r20/external/pytorch/torch/_export/tools.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import warnings
4from typing import Any, Dict, Iterable, Optional, Tuple
5
6import torch
7import torch.export
8import torch.export._trace
9from torch._utils_internal import log_export_usage
10
11
12log = logging.getLogger(__name__)
13
14__all__ = ["report_exportability"]
15
16
17def _generate_inputs_for_submodules(
18    model: torch.nn.Module,
19    target_submodules: Iterable[str],
20    args: Tuple[Any, ...],
21    kwargs: Optional[Dict[str, Any]] = None,
22) -> Dict[str, Tuple[Any, Any]]:
23    """
24    Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
25    function doesn't work.
26
27    Args:
28        model: root model.
29        inputs: inputs to the root model.
30        target_submodules: submodules that we want to generate inputs for.
31
32    Returns:
33        A dict that maps from submodule name to its inputs.
34    """
35    kwargs = kwargs or {}
36
37    handles = []
38    results = {}
39    submodule_to_names = {mod: name for name, mod in model.named_modules()}
40
41    def pre_forward(module, module_args, module_kwargs):
42        results[submodule_to_names[module]] = (module_args, module_kwargs)
43
44    try:
45        for name, mod in model.named_modules():
46            if name in target_submodules:
47                handles.append(
48                    mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
49                )
50        model(*args, **kwargs)
51    except Exception as e:
52        warnings.warn(
53            f"Failed to generate submodule inputs because of the following error:\n{e}"
54        )
55    finally:
56        for h in handles:
57            h.remove()
58    return results
59
60
61def report_exportability(
62    mod: torch.nn.Module,
63    args: Tuple[Any, ...],
64    kwargs: Optional[Dict[str, Any]] = None,
65    *,
66    strict: bool = True,
67    pre_dispatch: bool = False,
68) -> Dict[str, Optional[Exception]]:
69    """
70    Report exportability issues for a module in one-shot.
71
72    Args:
73        mod: root module.
74        args: args to the root module.
75        kwargs: kwargs to the root module.
76    Returns:
77        A dict that maps from submodule name to the exception that was raised when trying to export it.
78        `None` means the module is exportable without issue.
79    Sample output:
80        {
81            '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
82            'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
83            'submod_2': None
84        }
85    """
86
87    log_export_usage(event="export.report_exportability")
88
89    kwargs = kwargs or {}
90
91    all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
92    submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
93
94    tried_module_types = set()
95    report: Dict[str, Optional[Exception]] = {}
96
97    def try_export(module, module_name, args, kwargs):
98        nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types
99
100        if type(module) in tried_module_types:
101            return
102        tried_module_types.add(type(module))
103
104        if args is not None or kwargs is not None:
105            try:
106                torch.export._trace._export(
107                    module,
108                    args,
109                    kwargs,
110                    strict=strict,
111                    pre_dispatch=pre_dispatch,
112                )
113                report[module_name] = None
114                log.info("Successfully exported `%s`", module_name)
115                return
116            except Exception as e:
117                short_msg = repr(e).split("\n")[0]
118                log.warning(
119                    "Failed exporting `%s` with exception: %s", module_name, short_msg
120                )
121                report[module_name] = e
122
123        for name, submod in module.named_children():
124            sub_module_name = name if module_name == "" else f"{module_name}.{name}"
125
126            submod_args, submod_kwargs = submod_inputs.get(
127                sub_module_name, (None, None)
128            )
129
130            try_export(submod, sub_module_name, submod_args, submod_kwargs)
131
132        return
133
134    try_export(mod, "", args, kwargs)
135
136    unique_issues = set()
137    for exception in report.values():
138        if exception is not None:
139            key = repr(exception).split("\\n")[0]
140            unique_issues.add(key)
141
142    log.warning("Found %d export issues:", len(unique_issues))
143    for issue in unique_issues:
144        log.warning(issue)
145
146    return report
147