1from torch.fx.graph_module import GraphModule 2from typing import Any, Callable, Dict, List, Tuple, Type 3import torch 4import torch.nn as nn 5 6from torch.fx._compatibility import compatibility 7 8__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] 9 10# Matching method matches the attribute name of current version to the attribute name of `target_version` 11@compatibility(is_backward_compatible=False) 12def default_matching(name: str, target_version: int) -> str: 13 """Default matching method 14 """ 15 return name 16 17# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. 18# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. 19# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. 20module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { 21 torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), 22 torch.nn.modules.conv.Conv2d: ( 23 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching 24 ), 25 torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), 26 torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), 27 torch.nn.modules.pooling.MaxPool2d: ( 28 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching 29 ), 30 torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), 31} 32 33@compatibility(is_backward_compatible=False) 34def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: 35 """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` 36 after checking module's version is compatible with the `module_fetch_book`. 37 """ 38 attrs_for_lowering: Dict[str, Any] = {} 39 attrs_for_lowering["name"] = torch.typename(mod) 40 41 if type(mod) in module_fetch_book: 42 version, param_to_fetch, matching_method = module_fetch_book[type(mod)] 43 if version < mod._version: 44 raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " 45 "please upgrade the module_fetch_book, open an issue and @842974287 " 46 "or report a bug to AIACC team directly.") 47 for attr in param_to_fetch: 48 attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) 49 else: 50 raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " 51 "please add it to the module_fetch_book, open an issue and @842974287 " 52 "or report a bug to AIACC team directly.") 53 return attrs_for_lowering 54 55@compatibility(is_backward_compatible=False) 56def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: 57 """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. 58 """ 59 submodules = dict(fx_module.named_modules()) 60 61 for node in fx_module.graph.nodes: 62 if node.op == "call_module": 63 if isinstance(submodules[node.target], GraphModule): 64 lift_lowering_attrs_to_nodes(submodules[node.target]) 65 else: 66 node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) 67