1*da0073e9SAndroid Build Coastguard Workerr""" 2*da0073e9SAndroid Build Coastguard Worker**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not 3*da0073e9SAndroid Build Coastguard Workerrely on it for anything!** 4*da0073e9SAndroid Build Coastguard Worker""" 5*da0073e9SAndroid Build Coastguard Workerimport operator 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import Graph, GraphModule, Node 11*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.graph import map_arg 12*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.proxy import Proxy 13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import fuse_conv_bn_weights 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# can be a 17*da0073e9SAndroid Build Coastguard Worker# module type, a builtin function, or a string to match target 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdef _minmax_scale_zeropoint( 21*da0073e9SAndroid Build Coastguard Worker min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps 22*da0073e9SAndroid Build Coastguard Worker): 23*da0073e9SAndroid Build Coastguard Worker min_val = min(0.0, min_val) 24*da0073e9SAndroid Build Coastguard Worker max_val = max(0.0, max_val) 25*da0073e9SAndroid Build Coastguard Worker if max_val == min_val: 26*da0073e9SAndroid Build Coastguard Worker return 1.0, 0 27*da0073e9SAndroid Build Coastguard Worker else: 28*da0073e9SAndroid Build Coastguard Worker scale = (max_val - min_val) / float(qmax - qmin) 29*da0073e9SAndroid Build Coastguard Worker scale = max(scale, eps) 30*da0073e9SAndroid Build Coastguard Worker zero_point = qmin - round(min_val / scale) 31*da0073e9SAndroid Build Coastguard Worker zero_point = max(qmin, zero_point) 32*da0073e9SAndroid Build Coastguard Worker zero_point = min(qmax, zero_point) 33*da0073e9SAndroid Build Coastguard Worker zero_point = int(zero_point) 34*da0073e9SAndroid Build Coastguard Worker return scale, zero_point 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerclass MinMaxObserver: 38*da0073e9SAndroid Build Coastguard Worker def __init__(self, quantizer, node): 39*da0073e9SAndroid Build Coastguard Worker self.min, self.max = float("inf"), float("-inf") 40*da0073e9SAndroid Build Coastguard Worker self.all_tensors = True 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def observe(self, node, env): 43*da0073e9SAndroid Build Coastguard Worker v = env[node.name] 44*da0073e9SAndroid Build Coastguard Worker if not isinstance(v, torch.Tensor): 45*da0073e9SAndroid Build Coastguard Worker self.all_tensors = False 46*da0073e9SAndroid Build Coastguard Worker return 47*da0073e9SAndroid Build Coastguard Worker self.max = max(self.max, float(v.max())) 48*da0073e9SAndroid Build Coastguard Worker self.min = min(self.min, float(v.min())) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def scale_zeropoint(self): 51*da0073e9SAndroid Build Coastguard Worker return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Workerclass NoObserver: 55*da0073e9SAndroid Build Coastguard Worker def __init__(self, quantizer, node): 56*da0073e9SAndroid Build Coastguard Worker pass 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def observe(self, node, env): 59*da0073e9SAndroid Build Coastguard Worker pass 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker_DEFAULT_QUANTIZATION_PATTERNS = {} 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Workerdef register_pattern(pattern): 66*da0073e9SAndroid Build Coastguard Worker def insert(fn): 67*da0073e9SAndroid Build Coastguard Worker _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn 68*da0073e9SAndroid Build Coastguard Worker return fn 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker return insert 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker@register_pattern(operator.add) 74*da0073e9SAndroid Build Coastguard Workerclass Add(MinMaxObserver): 75*da0073e9SAndroid Build Coastguard Worker def quantize(self, quantizer, node, load_arg): 76*da0073e9SAndroid Build Coastguard Worker if not self.all_tensors: 77*da0073e9SAndroid Build Coastguard Worker return NotImplemented 78*da0073e9SAndroid Build Coastguard Worker scale, zeropoint = self.scale_zeropoint() 79*da0073e9SAndroid Build Coastguard Worker return quantizer.quantized_graph.create_node( 80*da0073e9SAndroid Build Coastguard Worker "call_function", 81*da0073e9SAndroid Build Coastguard Worker torch.ops.quantized.add, 82*da0073e9SAndroid Build Coastguard Worker load_arg(node.args), 83*da0073e9SAndroid Build Coastguard Worker {"scale": scale, "zero_point": zeropoint}, 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Workerclass Relu(NoObserver): 88*da0073e9SAndroid Build Coastguard Worker def quantize(self, quantizer, node, load_arg): 89*da0073e9SAndroid Build Coastguard Worker return torch.relu( 90*da0073e9SAndroid Build Coastguard Worker load_arg(node.args[0]) 91*da0073e9SAndroid Build Coastguard Worker ) # torch.relu works directly on quantized tensors? 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker# these ops have quantized equivalents that do not need any extra information 95*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.ReLU) 96*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.AvgPool2d) 97*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.MaxPool2d) 98*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.AdaptiveAvgPool2d) 99*da0073e9SAndroid Build Coastguard Workerclass CopyNode(NoObserver): 100*da0073e9SAndroid Build Coastguard Worker def quantize(self, quantizer, node, load_arg): 101*da0073e9SAndroid Build Coastguard Worker return quantizer.quantized_graph.node_copy(node, load_arg) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Workerclass IdentityModule(torch.nn.Module): 105*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 106*da0073e9SAndroid Build Coastguard Worker return x 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker# handle conv, maybe followed by bn, maybe followed by relu 110*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.modules.conv.Conv2d) 111*da0073e9SAndroid Build Coastguard Worker@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) 112*da0073e9SAndroid Build Coastguard Worker@register_pattern( 113*da0073e9SAndroid Build Coastguard Worker (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d) 114*da0073e9SAndroid Build Coastguard Worker) 115*da0073e9SAndroid Build Coastguard Worker@register_pattern( 116*da0073e9SAndroid Build Coastguard Worker ( 117*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU, 118*da0073e9SAndroid Build Coastguard Worker (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d), 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker) 121*da0073e9SAndroid Build Coastguard Workerclass ConvNormRelu(MinMaxObserver): 122*da0073e9SAndroid Build Coastguard Worker def __init__(self, quantizer, node): 123*da0073e9SAndroid Build Coastguard Worker super().__init__(quantizer, node) 124*da0073e9SAndroid Build Coastguard Worker self.relu_node, self.bn_node = None, None 125*da0073e9SAndroid Build Coastguard Worker if isinstance(quantizer.modules[node.target], torch.nn.ReLU): 126*da0073e9SAndroid Build Coastguard Worker self.relu_node = node 127*da0073e9SAndroid Build Coastguard Worker node = node.args[0] 128*da0073e9SAndroid Build Coastguard Worker if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): 129*da0073e9SAndroid Build Coastguard Worker self.bn_node = node 130*da0073e9SAndroid Build Coastguard Worker self.bn = quantizer.modules[self.bn_node.target] 131*da0073e9SAndroid Build Coastguard Worker node = node.args[0] 132*da0073e9SAndroid Build Coastguard Worker assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) 133*da0073e9SAndroid Build Coastguard Worker self.conv_node = node 134*da0073e9SAndroid Build Coastguard Worker self.conv = quantizer.modules[self.conv_node.target] 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def quantize(self, quantizer, node, load_arg): 137*da0073e9SAndroid Build Coastguard Worker mod = self.conv 138*da0073e9SAndroid Build Coastguard Worker weight, bias = mod.weight, mod.bias 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker if self.bn_node is not None: 141*da0073e9SAndroid Build Coastguard Worker weight, bias = fuse_conv_bn_weights( 142*da0073e9SAndroid Build Coastguard Worker weight, 143*da0073e9SAndroid Build Coastguard Worker bias, 144*da0073e9SAndroid Build Coastguard Worker self.bn.running_mean, 145*da0073e9SAndroid Build Coastguard Worker self.bn.running_var, 146*da0073e9SAndroid Build Coastguard Worker self.bn.eps, 147*da0073e9SAndroid Build Coastguard Worker self.bn.weight, 148*da0073e9SAndroid Build Coastguard Worker self.bn.bias, 149*da0073e9SAndroid Build Coastguard Worker ) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker min_val, max_val = float(weight.min()), float(weight.max()) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker act_scale, act_zp = self.scale_zeropoint() 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) 156*da0073e9SAndroid Build Coastguard Worker qweight = torch.quantize_per_tensor( 157*da0073e9SAndroid Build Coastguard Worker weight, weight_scale, weight_zp, torch.qint8 158*da0073e9SAndroid Build Coastguard Worker ) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker ctor = ( 161*da0073e9SAndroid Build Coastguard Worker torch.ao.nn.intrinsic.quantized.ConvReLU2d 162*da0073e9SAndroid Build Coastguard Worker if self.relu_node is not None 163*da0073e9SAndroid Build Coastguard Worker else torch.ao.nn.quantized.Conv2d 164*da0073e9SAndroid Build Coastguard Worker ) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker qconv = ctor( 167*da0073e9SAndroid Build Coastguard Worker mod.in_channels, 168*da0073e9SAndroid Build Coastguard Worker mod.out_channels, 169*da0073e9SAndroid Build Coastguard Worker mod.kernel_size, 170*da0073e9SAndroid Build Coastguard Worker mod.stride, 171*da0073e9SAndroid Build Coastguard Worker mod.padding, 172*da0073e9SAndroid Build Coastguard Worker mod.dilation, 173*da0073e9SAndroid Build Coastguard Worker mod.groups, 174*da0073e9SAndroid Build Coastguard Worker mod.bias is not None, 175*da0073e9SAndroid Build Coastguard Worker mod.padding_mode, 176*da0073e9SAndroid Build Coastguard Worker ) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker qconv.set_weight_bias(qweight, bias) 179*da0073e9SAndroid Build Coastguard Worker qconv.scale = float(act_scale) 180*da0073e9SAndroid Build Coastguard Worker qconv.zero_point = int(act_zp) 181*da0073e9SAndroid Build Coastguard Worker parent_name, name = _parent_name(self.conv_node.target) 182*da0073e9SAndroid Build Coastguard Worker setattr(quantizer.modules[parent_name], name, qconv) 183*da0073e9SAndroid Build Coastguard Worker if self.bn_node is not None: 184*da0073e9SAndroid Build Coastguard Worker parent_bn, bn_name = _parent_name(self.bn_node.target) 185*da0073e9SAndroid Build Coastguard Worker # we can't just delete this because submodules's forwards (which are not longer use) 186*da0073e9SAndroid Build Coastguard Worker # try to call it, so replace with something that does nothing. 187*da0073e9SAndroid Build Coastguard Worker setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker return quantizer.quantized_graph.create_node( 190*da0073e9SAndroid Build Coastguard Worker "call_module", 191*da0073e9SAndroid Build Coastguard Worker self.conv_node.target, 192*da0073e9SAndroid Build Coastguard Worker (load_arg(self.conv_node.args[0]),), 193*da0073e9SAndroid Build Coastguard Worker {}, 194*da0073e9SAndroid Build Coastguard Worker ) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker# turn foo.bar -> ['foo', 'bar'] 198*da0073e9SAndroid Build Coastguard Workerdef _parent_name(target): 199*da0073e9SAndroid Build Coastguard Worker r = target.rsplit(".", 1) 200*da0073e9SAndroid Build Coastguard Worker if len(r) == 1: 201*da0073e9SAndroid Build Coastguard Worker return "", r[0] 202*da0073e9SAndroid Build Coastguard Worker else: 203*da0073e9SAndroid Build Coastguard Worker return r[0], r[1] 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerclass DefaultQuant(MinMaxObserver): 207*da0073e9SAndroid Build Coastguard Worker def quantize(self, input): 208*da0073e9SAndroid Build Coastguard Worker assert self.all_tensors 209*da0073e9SAndroid Build Coastguard Worker scale, zeropoint = self.scale_zeropoint() 210*da0073e9SAndroid Build Coastguard Worker return torch.quantize_per_tensor( 211*da0073e9SAndroid Build Coastguard Worker Proxy(input), scale, zeropoint, torch.quint8 212*da0073e9SAndroid Build Coastguard Worker ).node 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Workerdef matches(modules, node, pattern, max_uses=sys.maxsize): 216*da0073e9SAndroid Build Coastguard Worker if isinstance(pattern, tuple): 217*da0073e9SAndroid Build Coastguard Worker self_match, *arg_matches = pattern 218*da0073e9SAndroid Build Coastguard Worker else: 219*da0073e9SAndroid Build Coastguard Worker self_match = pattern 220*da0073e9SAndroid Build Coastguard Worker arg_matches = None 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker if len(node.users) > max_uses: 223*da0073e9SAndroid Build Coastguard Worker return False 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): 226*da0073e9SAndroid Build Coastguard Worker if node.op != "call_module": 227*da0073e9SAndroid Build Coastguard Worker return False 228*da0073e9SAndroid Build Coastguard Worker if not isinstance(modules[node.target], self_match): 229*da0073e9SAndroid Build Coastguard Worker return False 230*da0073e9SAndroid Build Coastguard Worker elif callable(self_match): 231*da0073e9SAndroid Build Coastguard Worker if node.op != "call_function" or node.target is not self_match: 232*da0073e9SAndroid Build Coastguard Worker return False 233*da0073e9SAndroid Build Coastguard Worker elif node.target != self_match: 234*da0073e9SAndroid Build Coastguard Worker return False 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker if not arg_matches: 237*da0073e9SAndroid Build Coastguard Worker return True 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker if len(arg_matches) != len(node.args): 240*da0073e9SAndroid Build Coastguard Worker return False 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker return all( 243*da0073e9SAndroid Build Coastguard Worker matches(modules, node, arg_match, max_uses=1) 244*da0073e9SAndroid Build Coastguard Worker for node, arg_match in zip(node.args, arg_matches) 245*da0073e9SAndroid Build Coastguard Worker ) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Workerclass Quantizer: 249*da0073e9SAndroid Build Coastguard Worker def __init__( 250*da0073e9SAndroid Build Coastguard Worker self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant 251*da0073e9SAndroid Build Coastguard Worker ): 252*da0073e9SAndroid Build Coastguard Worker self.root = mod 253*da0073e9SAndroid Build Coastguard Worker self.graph = mod.graph 254*da0073e9SAndroid Build Coastguard Worker self.quant_ctor = quant_ctor 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker # cached information for observe 257*da0073e9SAndroid Build Coastguard Worker self.state_dict = self.root.state_dict() 258*da0073e9SAndroid Build Coastguard Worker self.modules = dict(self.root.named_modules()) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker # match the patterns that will get quantized 261*da0073e9SAndroid Build Coastguard Worker self.matches = self._find_matches(patterns) 262*da0073e9SAndroid Build Coastguard Worker # find _inputs_ to matched nodes that are not quantized, these 263*da0073e9SAndroid Build Coastguard Worker # have to be quantized, which requires measuring stats, 264*da0073e9SAndroid Build Coastguard Worker # initialize an quant_ctor object for each 265*da0073e9SAndroid Build Coastguard Worker self.quants = self._find_quants(quant_ctor) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def observe(self, args): 268*da0073e9SAndroid Build Coastguard Worker # most of this function is just an interpreter for the graph 269*da0073e9SAndroid Build Coastguard Worker # it would be possible to put this in some abstraction, but 270*da0073e9SAndroid Build Coastguard Worker # it is pretty nice to just be able to see exactly what is happening here 271*da0073e9SAndroid Build Coastguard Worker # and hack on it. 272*da0073e9SAndroid Build Coastguard Worker # maybe we should just provide an example interpreter that people copy/paste 273*da0073e9SAndroid Build Coastguard Worker # then edit. 274*da0073e9SAndroid Build Coastguard Worker args_iter = iter(args) 275*da0073e9SAndroid Build Coastguard Worker env = {} 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def load_arg(a): 278*da0073e9SAndroid Build Coastguard Worker return map_arg(a, lambda node: env[node.name]) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker output_node: Optional[Node] = None 281*da0073e9SAndroid Build Coastguard Worker for node in self.graph.nodes: 282*da0073e9SAndroid Build Coastguard Worker if node.op == "placeholder": 283*da0073e9SAndroid Build Coastguard Worker result = next(args_iter) 284*da0073e9SAndroid Build Coastguard Worker elif node.op == "get_attr": 285*da0073e9SAndroid Build Coastguard Worker result = self.state_dict[node.target] 286*da0073e9SAndroid Build Coastguard Worker elif node.op == "call_function": 287*da0073e9SAndroid Build Coastguard Worker result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 288*da0073e9SAndroid Build Coastguard Worker elif node.op == "call_method": 289*da0073e9SAndroid Build Coastguard Worker self_obj, *args = load_arg(node.args) 290*da0073e9SAndroid Build Coastguard Worker kwargs = load_arg(node.kwargs) 291*da0073e9SAndroid Build Coastguard Worker result = getattr(self_obj, node.target)(*args, **kwargs) 292*da0073e9SAndroid Build Coastguard Worker elif node.op == "call_module": 293*da0073e9SAndroid Build Coastguard Worker result = self.modules[node.target]( 294*da0073e9SAndroid Build Coastguard Worker *load_arg(node.args), **load_arg(node.kwargs) 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker elif node.op == "output": 297*da0073e9SAndroid Build Coastguard Worker return load_arg(node.args[0]) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker env[node.name] = result 300*da0073e9SAndroid Build Coastguard Worker root_node, obj = self.matches.get(node.name, (None, None)) 301*da0073e9SAndroid Build Coastguard Worker if root_node is node: 302*da0073e9SAndroid Build Coastguard Worker obj.observe(node, env) 303*da0073e9SAndroid Build Coastguard Worker if node.name in self.quants: 304*da0073e9SAndroid Build Coastguard Worker self.quants[node.name].observe(node, env) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Graph had no output node!") 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker def quantize(self): 309*da0073e9SAndroid Build Coastguard Worker self.quantized_graph = Graph() 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker env = {} 312*da0073e9SAndroid Build Coastguard Worker quant_env = {} 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def load_arg(n, quantized): 315*da0073e9SAndroid Build Coastguard Worker if not quantized: 316*da0073e9SAndroid Build Coastguard Worker if n.name not in env and n.name in quant_env: 317*da0073e9SAndroid Build Coastguard Worker env[n.name] = Proxy(quant_env[n.name]).dequantize().node 318*da0073e9SAndroid Build Coastguard Worker return env[n.name] 319*da0073e9SAndroid Build Coastguard Worker else: 320*da0073e9SAndroid Build Coastguard Worker if n.name not in quant_env and n.name in env: 321*da0073e9SAndroid Build Coastguard Worker quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) 322*da0073e9SAndroid Build Coastguard Worker return quant_env[n.name] 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker def copy_recursive(node): 325*da0073e9SAndroid Build Coastguard Worker def load_or_emit(n): 326*da0073e9SAndroid Build Coastguard Worker if n.name in env or e.name in quant_env: # noqa: F821 327*da0073e9SAndroid Build Coastguard Worker return load_arg(n, quantized=False) 328*da0073e9SAndroid Build Coastguard Worker else: 329*da0073e9SAndroid Build Coastguard Worker return copy_recursive(n) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker r = env[node.name] = self.quantized_graph.node_copy( 332*da0073e9SAndroid Build Coastguard Worker node, lambda n: load_arg(n, quantized=False) 333*da0073e9SAndroid Build Coastguard Worker ) 334*da0073e9SAndroid Build Coastguard Worker return r 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker for node in self.graph.nodes: 337*da0073e9SAndroid Build Coastguard Worker root_node, obj = self.matches.get(node.name, (None, None)) 338*da0073e9SAndroid Build Coastguard Worker if root_node is None: 339*da0073e9SAndroid Build Coastguard Worker # not quantized just copy it 340*da0073e9SAndroid Build Coastguard Worker env[node.name] = self.quantized_graph.node_copy( 341*da0073e9SAndroid Build Coastguard Worker node, lambda n: load_arg(n, quantized=False) 342*da0073e9SAndroid Build Coastguard Worker ) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker elif root_node is node: 345*da0073e9SAndroid Build Coastguard Worker r = obj.quantize( 346*da0073e9SAndroid Build Coastguard Worker self, 347*da0073e9SAndroid Build Coastguard Worker node, 348*da0073e9SAndroid Build Coastguard Worker lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)), 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker if r is NotImplemented: 351*da0073e9SAndroid Build Coastguard Worker # quantizer choose to to quantize the node take the entire match, and just copy it over 352*da0073e9SAndroid Build Coastguard Worker env[node.name] = copy_recursive(node) 353*da0073e9SAndroid Build Coastguard Worker else: 354*da0073e9SAndroid Build Coastguard Worker quant_env[node.name] = r 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker return GraphModule(self.root, self.quantized_graph) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def _find_matches(self, patterns): 359*da0073e9SAndroid Build Coastguard Worker modules = dict(self.root.named_modules()) 360*da0073e9SAndroid Build Coastguard Worker match_map = {} # node name -> (root_node, match_value?) 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker def apply_match(pattern, node, match): 363*da0073e9SAndroid Build Coastguard Worker if isinstance(pattern, tuple): 364*da0073e9SAndroid Build Coastguard Worker s, *args = pattern 365*da0073e9SAndroid Build Coastguard Worker apply_match(s, node, match) 366*da0073e9SAndroid Build Coastguard Worker for subpattern, arg in zip(args, node.args): 367*da0073e9SAndroid Build Coastguard Worker apply_match(subpattern, arg, match) 368*da0073e9SAndroid Build Coastguard Worker else: 369*da0073e9SAndroid Build Coastguard Worker match_map[node.name] = match 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker for node in reversed(self.graph.nodes): 372*da0073e9SAndroid Build Coastguard Worker if node.name not in match_map: 373*da0073e9SAndroid Build Coastguard Worker for pattern, value in patterns.items(): 374*da0073e9SAndroid Build Coastguard Worker if matches(modules, node, pattern): 375*da0073e9SAndroid Build Coastguard Worker apply_match(pattern, node, (node, value(self, node))) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker return match_map 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker def _find_quants(self, quant_ctor): 380*da0073e9SAndroid Build Coastguard Worker quants = {} 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker def visit_arg(n): 383*da0073e9SAndroid Build Coastguard Worker # note: we have to measure quantization information 384*da0073e9SAndroid Build Coastguard Worker # even for nodes where we might not use it because it is already 385*da0073e9SAndroid Build Coastguard Worker # quantized. This is because each match has the option to 386*da0073e9SAndroid Build Coastguard Worker # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) 387*da0073e9SAndroid Build Coastguard Worker if n.name not in quants: 388*da0073e9SAndroid Build Coastguard Worker quants[n.name] = quant_ctor(self, n) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker for node in self.graph.nodes: 391*da0073e9SAndroid Build Coastguard Worker if node.name in self.matches: 392*da0073e9SAndroid Build Coastguard Worker map_arg(node.args, visit_arg) 393*da0073e9SAndroid Build Coastguard Worker map_arg(node.kwargs, visit_arg) 394*da0073e9SAndroid Build Coastguard Worker return quants 395