1from typing import Callable, List 2 3import torch 4from torch.ao.nn.intrinsic import _FusedModule 5from torch.fx._symbolic_trace import Tracer 6from torch.fx.proxy import Scope 7 8 9__all__ = [ 10 "QuantizationTracer", 11] 12 13 14class ScopeContextManager(torch.fx.proxy.ScopeContextManager): 15 def __init__( 16 self, scope: Scope, current_module: torch.nn.Module, current_module_path: str 17 ): 18 super().__init__(scope, Scope(current_module_path, type(current_module))) 19 20 21class QuantizationTracer(Tracer): 22 def __init__( 23 self, skipped_module_names: List[str], skipped_module_classes: List[Callable] 24 ): 25 super().__init__() 26 self.skipped_module_names = skipped_module_names 27 self.skipped_module_classes = skipped_module_classes 28 # NB: initialized the module_type of top level module to None 29 # we are assuming people won't configure the model with the type of top level 30 # module here, since people can use "" for global config 31 # We can change this if there is a use case that configures 32 # qconfig using top level module type 33 self.scope = Scope("", None) 34 self.record_stack_traces = True 35 36 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: 37 return ( 38 ( 39 ( 40 m.__module__.startswith("torch.nn") 41 or m.__module__.startswith("torch.ao.nn") 42 ) 43 and not isinstance(m, torch.nn.Sequential) 44 ) 45 or module_qualified_name in self.skipped_module_names 46 or type(m) in self.skipped_module_classes 47 or isinstance(m, _FusedModule) 48 ) 49