1# mypy: allow-untyped-defs 2import copy 3from collections import OrderedDict 4from typing import Any, Dict 5 6from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize 7from torch.ao.quantization.observer import ObserverBase 8from torch.ao.quantization.utils import Pattern 9 10 11__all__ = [ 12 "get_default_fusion_patterns", 13 "get_default_quant_patterns", 14 "get_default_output_activation_post_process_map", 15] 16 17# TODO(future PR): fix the typing on QuantizeHandler (currently a circular dependency) 18QuantizeHandler = Any 19 20# pattern for conv bn fusion 21_DEFAULT_FUSION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict() 22 23 24def _register_fusion_pattern(pattern): 25 def insert(fn): 26 _DEFAULT_FUSION_PATTERNS[pattern] = fn 27 return fn 28 29 return insert 30 31 32def get_default_fusion_patterns() -> Dict[Pattern, QuantizeHandler]: 33 return copy.copy(_DEFAULT_FUSION_PATTERNS) 34 35 36_DEFAULT_QUANTIZATION_PATTERNS: Dict[Pattern, QuantizeHandler] = OrderedDict() 37 38# Mapping from pattern to activation_post_process(observer/fake_quant) constructor for output activation 39# e.g. pattern: torch.sigmoid, 40# output_activation_post_process: default_fixed_qparams_range_0to1_fake_quant 41_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP: Dict[Pattern, QuantizeHandler] = {} 42_DEFAULT_OUTPUT_OBSERVER_MAP: Dict[Pattern, QuantizeHandler] = {} 43 44 45# Register pattern for both static quantization and qat 46def _register_quant_pattern(pattern, fixed_qparams_observer=None): 47 def insert(fn): 48 _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn 49 if fixed_qparams_observer is not None: 50 _DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP[ 51 pattern 52 ] = FixedQParamsFakeQuantize.with_args(observer=fixed_qparams_observer) 53 _DEFAULT_OUTPUT_OBSERVER_MAP[pattern] = fixed_qparams_observer 54 return fn 55 56 return insert 57 58 59# Get patterns for both static quantization and qat 60def get_default_quant_patterns() -> Dict[Pattern, QuantizeHandler]: 61 return copy.copy(_DEFAULT_QUANTIZATION_PATTERNS) 62 63 64# a map from pattern to output activation post process constructor 65# e.g. torch.sigmoid -> default_affine_fixed_qparam_fake_quant 66def get_default_output_activation_post_process_map( 67 is_training, 68) -> Dict[Pattern, ObserverBase]: 69 if is_training: 70 return copy.copy(_DEFAULT_OUTPUT_FAKE_QUANTIZE_MAP) 71 else: 72 return copy.copy(_DEFAULT_OUTPUT_OBSERVER_MAP) 73 74 75# Example use of register pattern function: 76# @_register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d))) 77# class ConvOrLinearBNReLUFusion(): 78# def __init__(...): 79# ... 80# 81 82 83def _sorted_patterns_dict( 84 patterns_dict: Dict[Pattern, QuantizeHandler] 85) -> Dict[Pattern, QuantizeHandler]: 86 """ 87 Return a sorted version of the patterns dictionary such that longer patterns are matched first, 88 e.g. match (F.relu, F.linear) before F.relu. 89 This works for current use cases, but we may need to have a more clever way to sort 90 things to address more complex patterns 91 """ 92 93 def get_len(pattern): 94 """this will calculate the length of the pattern by counting all the entries 95 in the pattern. 96 this will make sure (nn.ReLU, (nn.BatchNorm, nn.Conv2d)) comes before 97 (nn.BatchNorm, nn.Conv2d) so that we can match the former first 98 """ 99 len = 0 100 if isinstance(pattern, tuple): 101 for item in pattern: 102 len += get_len(item) 103 else: 104 len += 1 105 return len 106 107 return OrderedDict( 108 sorted( 109 patterns_dict.items(), 110 key=lambda kv: -get_len(kv[0]) if isinstance(kv[0], tuple) else 1, 111 ) 112 ) 113