1# mypy: allow-untyped-defs 2import logging 3import warnings 4from typing import Any, Dict, Iterable, Optional, Tuple 5 6import torch 7import torch.export 8import torch.export._trace 9from torch._utils_internal import log_export_usage 10 11 12log = logging.getLogger(__name__) 13 14__all__ = ["report_exportability"] 15 16 17def _generate_inputs_for_submodules( 18 model: torch.nn.Module, 19 target_submodules: Iterable[str], 20 args: Tuple[Any, ...], 21 kwargs: Optional[Dict[str, Any]] = None, 22) -> Dict[str, Tuple[Any, Any]]: 23 """ 24 Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this 25 function doesn't work. 26 27 Args: 28 model: root model. 29 inputs: inputs to the root model. 30 target_submodules: submodules that we want to generate inputs for. 31 32 Returns: 33 A dict that maps from submodule name to its inputs. 34 """ 35 kwargs = kwargs or {} 36 37 handles = [] 38 results = {} 39 submodule_to_names = {mod: name for name, mod in model.named_modules()} 40 41 def pre_forward(module, module_args, module_kwargs): 42 results[submodule_to_names[module]] = (module_args, module_kwargs) 43 44 try: 45 for name, mod in model.named_modules(): 46 if name in target_submodules: 47 handles.append( 48 mod.register_forward_pre_hook(pre_forward, with_kwargs=True) 49 ) 50 model(*args, **kwargs) 51 except Exception as e: 52 warnings.warn( 53 f"Failed to generate submodule inputs because of the following error:\n{e}" 54 ) 55 finally: 56 for h in handles: 57 h.remove() 58 return results 59 60 61def report_exportability( 62 mod: torch.nn.Module, 63 args: Tuple[Any, ...], 64 kwargs: Optional[Dict[str, Any]] = None, 65 *, 66 strict: bool = True, 67 pre_dispatch: bool = False, 68) -> Dict[str, Optional[Exception]]: 69 """ 70 Report exportability issues for a module in one-shot. 71 72 Args: 73 mod: root module. 74 args: args to the root module. 75 kwargs: kwargs to the root module. 76 Returns: 77 A dict that maps from submodule name to the exception that was raised when trying to export it. 78 `None` means the module is exportable without issue. 79 Sample output: 80 { 81 '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>), 82 'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>), 83 'submod_2': None 84 } 85 """ 86 87 log_export_usage(event="export.report_exportability") 88 89 kwargs = kwargs or {} 90 91 all_submod_names = [name for name, _ in mod.named_modules() if name != ""] 92 submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs) 93 94 tried_module_types = set() 95 report: Dict[str, Optional[Exception]] = {} 96 97 def try_export(module, module_name, args, kwargs): 98 nonlocal submod_inputs, report, strict, pre_dispatch, tried_module_types 99 100 if type(module) in tried_module_types: 101 return 102 tried_module_types.add(type(module)) 103 104 if args is not None or kwargs is not None: 105 try: 106 torch.export._trace._export( 107 module, 108 args, 109 kwargs, 110 strict=strict, 111 pre_dispatch=pre_dispatch, 112 ) 113 report[module_name] = None 114 log.info("Successfully exported `%s`", module_name) 115 return 116 except Exception as e: 117 short_msg = repr(e).split("\n")[0] 118 log.warning( 119 "Failed exporting `%s` with exception: %s", module_name, short_msg 120 ) 121 report[module_name] = e 122 123 for name, submod in module.named_children(): 124 sub_module_name = name if module_name == "" else f"{module_name}.{name}" 125 126 submod_args, submod_kwargs = submod_inputs.get( 127 sub_module_name, (None, None) 128 ) 129 130 try_export(submod, sub_module_name, submod_args, submod_kwargs) 131 132 return 133 134 try_export(mod, "", args, kwargs) 135 136 unique_issues = set() 137 for exception in report.values(): 138 if exception is not None: 139 key = repr(exception).split("\\n")[0] 140 unique_issues.add(key) 141 142 log.warning("Found %d export issues:", len(unique_issues)) 143 for issue in unique_issues: 144 log.warning(issue) 145 146 return report 147