xref: /aosp_15_r20/external/pytorch/test/fx/quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerr"""
2*da0073e9SAndroid Build Coastguard Worker**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
3*da0073e9SAndroid Build Coastguard Workerrely on it for anything!**
4*da0073e9SAndroid Build Coastguard Worker"""
5*da0073e9SAndroid Build Coastguard Workerimport operator
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerfrom torch.fx import Graph, GraphModule, Node
11*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.graph import map_arg
12*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.proxy import Proxy
13*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils import fuse_conv_bn_weights
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# can be a
17*da0073e9SAndroid Build Coastguard Worker#  module type, a builtin function, or a string to match target
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdef _minmax_scale_zeropoint(
21*da0073e9SAndroid Build Coastguard Worker    min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps
22*da0073e9SAndroid Build Coastguard Worker):
23*da0073e9SAndroid Build Coastguard Worker    min_val = min(0.0, min_val)
24*da0073e9SAndroid Build Coastguard Worker    max_val = max(0.0, max_val)
25*da0073e9SAndroid Build Coastguard Worker    if max_val == min_val:
26*da0073e9SAndroid Build Coastguard Worker        return 1.0, 0
27*da0073e9SAndroid Build Coastguard Worker    else:
28*da0073e9SAndroid Build Coastguard Worker        scale = (max_val - min_val) / float(qmax - qmin)
29*da0073e9SAndroid Build Coastguard Worker        scale = max(scale, eps)
30*da0073e9SAndroid Build Coastguard Worker        zero_point = qmin - round(min_val / scale)
31*da0073e9SAndroid Build Coastguard Worker        zero_point = max(qmin, zero_point)
32*da0073e9SAndroid Build Coastguard Worker        zero_point = min(qmax, zero_point)
33*da0073e9SAndroid Build Coastguard Worker        zero_point = int(zero_point)
34*da0073e9SAndroid Build Coastguard Worker        return scale, zero_point
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerclass MinMaxObserver:
38*da0073e9SAndroid Build Coastguard Worker    def __init__(self, quantizer, node):
39*da0073e9SAndroid Build Coastguard Worker        self.min, self.max = float("inf"), float("-inf")
40*da0073e9SAndroid Build Coastguard Worker        self.all_tensors = True
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def observe(self, node, env):
43*da0073e9SAndroid Build Coastguard Worker        v = env[node.name]
44*da0073e9SAndroid Build Coastguard Worker        if not isinstance(v, torch.Tensor):
45*da0073e9SAndroid Build Coastguard Worker            self.all_tensors = False
46*da0073e9SAndroid Build Coastguard Worker            return
47*da0073e9SAndroid Build Coastguard Worker        self.max = max(self.max, float(v.max()))
48*da0073e9SAndroid Build Coastguard Worker        self.min = min(self.min, float(v.min()))
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    def scale_zeropoint(self):
51*da0073e9SAndroid Build Coastguard Worker        return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerclass NoObserver:
55*da0073e9SAndroid Build Coastguard Worker    def __init__(self, quantizer, node):
56*da0073e9SAndroid Build Coastguard Worker        pass
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def observe(self, node, env):
59*da0073e9SAndroid Build Coastguard Worker        pass
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker_DEFAULT_QUANTIZATION_PATTERNS = {}
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Workerdef register_pattern(pattern):
66*da0073e9SAndroid Build Coastguard Worker    def insert(fn):
67*da0073e9SAndroid Build Coastguard Worker        _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
68*da0073e9SAndroid Build Coastguard Worker        return fn
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    return insert
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker@register_pattern(operator.add)
74*da0073e9SAndroid Build Coastguard Workerclass Add(MinMaxObserver):
75*da0073e9SAndroid Build Coastguard Worker    def quantize(self, quantizer, node, load_arg):
76*da0073e9SAndroid Build Coastguard Worker        if not self.all_tensors:
77*da0073e9SAndroid Build Coastguard Worker            return NotImplemented
78*da0073e9SAndroid Build Coastguard Worker        scale, zeropoint = self.scale_zeropoint()
79*da0073e9SAndroid Build Coastguard Worker        return quantizer.quantized_graph.create_node(
80*da0073e9SAndroid Build Coastguard Worker            "call_function",
81*da0073e9SAndroid Build Coastguard Worker            torch.ops.quantized.add,
82*da0073e9SAndroid Build Coastguard Worker            load_arg(node.args),
83*da0073e9SAndroid Build Coastguard Worker            {"scale": scale, "zero_point": zeropoint},
84*da0073e9SAndroid Build Coastguard Worker        )
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Workerclass Relu(NoObserver):
88*da0073e9SAndroid Build Coastguard Worker    def quantize(self, quantizer, node, load_arg):
89*da0073e9SAndroid Build Coastguard Worker        return torch.relu(
90*da0073e9SAndroid Build Coastguard Worker            load_arg(node.args[0])
91*da0073e9SAndroid Build Coastguard Worker        )  # torch.relu works directly on quantized tensors?
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker# these ops have quantized equivalents that do not need any extra information
95*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.ReLU)
96*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.AvgPool2d)
97*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.MaxPool2d)
98*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.AdaptiveAvgPool2d)
99*da0073e9SAndroid Build Coastguard Workerclass CopyNode(NoObserver):
100*da0073e9SAndroid Build Coastguard Worker    def quantize(self, quantizer, node, load_arg):
101*da0073e9SAndroid Build Coastguard Worker        return quantizer.quantized_graph.node_copy(node, load_arg)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerclass IdentityModule(torch.nn.Module):
105*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
106*da0073e9SAndroid Build Coastguard Worker        return x
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker# handle conv, maybe followed by bn, maybe followed by relu
110*da0073e9SAndroid Build Coastguard Worker@register_pattern(torch.nn.modules.conv.Conv2d)
111*da0073e9SAndroid Build Coastguard Worker@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d))
112*da0073e9SAndroid Build Coastguard Worker@register_pattern(
113*da0073e9SAndroid Build Coastguard Worker    (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)
114*da0073e9SAndroid Build Coastguard Worker)
115*da0073e9SAndroid Build Coastguard Worker@register_pattern(
116*da0073e9SAndroid Build Coastguard Worker    (
117*da0073e9SAndroid Build Coastguard Worker        torch.nn.ReLU,
118*da0073e9SAndroid Build Coastguard Worker        (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d),
119*da0073e9SAndroid Build Coastguard Worker    )
120*da0073e9SAndroid Build Coastguard Worker)
121*da0073e9SAndroid Build Coastguard Workerclass ConvNormRelu(MinMaxObserver):
122*da0073e9SAndroid Build Coastguard Worker    def __init__(self, quantizer, node):
123*da0073e9SAndroid Build Coastguard Worker        super().__init__(quantizer, node)
124*da0073e9SAndroid Build Coastguard Worker        self.relu_node, self.bn_node = None, None
125*da0073e9SAndroid Build Coastguard Worker        if isinstance(quantizer.modules[node.target], torch.nn.ReLU):
126*da0073e9SAndroid Build Coastguard Worker            self.relu_node = node
127*da0073e9SAndroid Build Coastguard Worker            node = node.args[0]
128*da0073e9SAndroid Build Coastguard Worker        if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
129*da0073e9SAndroid Build Coastguard Worker            self.bn_node = node
130*da0073e9SAndroid Build Coastguard Worker            self.bn = quantizer.modules[self.bn_node.target]
131*da0073e9SAndroid Build Coastguard Worker            node = node.args[0]
132*da0073e9SAndroid Build Coastguard Worker        assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d)
133*da0073e9SAndroid Build Coastguard Worker        self.conv_node = node
134*da0073e9SAndroid Build Coastguard Worker        self.conv = quantizer.modules[self.conv_node.target]
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    def quantize(self, quantizer, node, load_arg):
137*da0073e9SAndroid Build Coastguard Worker        mod = self.conv
138*da0073e9SAndroid Build Coastguard Worker        weight, bias = mod.weight, mod.bias
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        if self.bn_node is not None:
141*da0073e9SAndroid Build Coastguard Worker            weight, bias = fuse_conv_bn_weights(
142*da0073e9SAndroid Build Coastguard Worker                weight,
143*da0073e9SAndroid Build Coastguard Worker                bias,
144*da0073e9SAndroid Build Coastguard Worker                self.bn.running_mean,
145*da0073e9SAndroid Build Coastguard Worker                self.bn.running_var,
146*da0073e9SAndroid Build Coastguard Worker                self.bn.eps,
147*da0073e9SAndroid Build Coastguard Worker                self.bn.weight,
148*da0073e9SAndroid Build Coastguard Worker                self.bn.bias,
149*da0073e9SAndroid Build Coastguard Worker            )
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        min_val, max_val = float(weight.min()), float(weight.max())
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker        act_scale, act_zp = self.scale_zeropoint()
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
156*da0073e9SAndroid Build Coastguard Worker        qweight = torch.quantize_per_tensor(
157*da0073e9SAndroid Build Coastguard Worker            weight, weight_scale, weight_zp, torch.qint8
158*da0073e9SAndroid Build Coastguard Worker        )
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        ctor = (
161*da0073e9SAndroid Build Coastguard Worker            torch.ao.nn.intrinsic.quantized.ConvReLU2d
162*da0073e9SAndroid Build Coastguard Worker            if self.relu_node is not None
163*da0073e9SAndroid Build Coastguard Worker            else torch.ao.nn.quantized.Conv2d
164*da0073e9SAndroid Build Coastguard Worker        )
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        qconv = ctor(
167*da0073e9SAndroid Build Coastguard Worker            mod.in_channels,
168*da0073e9SAndroid Build Coastguard Worker            mod.out_channels,
169*da0073e9SAndroid Build Coastguard Worker            mod.kernel_size,
170*da0073e9SAndroid Build Coastguard Worker            mod.stride,
171*da0073e9SAndroid Build Coastguard Worker            mod.padding,
172*da0073e9SAndroid Build Coastguard Worker            mod.dilation,
173*da0073e9SAndroid Build Coastguard Worker            mod.groups,
174*da0073e9SAndroid Build Coastguard Worker            mod.bias is not None,
175*da0073e9SAndroid Build Coastguard Worker            mod.padding_mode,
176*da0073e9SAndroid Build Coastguard Worker        )
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        qconv.set_weight_bias(qweight, bias)
179*da0073e9SAndroid Build Coastguard Worker        qconv.scale = float(act_scale)
180*da0073e9SAndroid Build Coastguard Worker        qconv.zero_point = int(act_zp)
181*da0073e9SAndroid Build Coastguard Worker        parent_name, name = _parent_name(self.conv_node.target)
182*da0073e9SAndroid Build Coastguard Worker        setattr(quantizer.modules[parent_name], name, qconv)
183*da0073e9SAndroid Build Coastguard Worker        if self.bn_node is not None:
184*da0073e9SAndroid Build Coastguard Worker            parent_bn, bn_name = _parent_name(self.bn_node.target)
185*da0073e9SAndroid Build Coastguard Worker            # we can't just delete this because submodules's forwards (which are not longer use)
186*da0073e9SAndroid Build Coastguard Worker            # try to call it, so replace with something that does nothing.
187*da0073e9SAndroid Build Coastguard Worker            setattr(quantizer.modules[parent_name], bn_name, IdentityModule())
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        return quantizer.quantized_graph.create_node(
190*da0073e9SAndroid Build Coastguard Worker            "call_module",
191*da0073e9SAndroid Build Coastguard Worker            self.conv_node.target,
192*da0073e9SAndroid Build Coastguard Worker            (load_arg(self.conv_node.args[0]),),
193*da0073e9SAndroid Build Coastguard Worker            {},
194*da0073e9SAndroid Build Coastguard Worker        )
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker# turn foo.bar -> ['foo', 'bar']
198*da0073e9SAndroid Build Coastguard Workerdef _parent_name(target):
199*da0073e9SAndroid Build Coastguard Worker    r = target.rsplit(".", 1)
200*da0073e9SAndroid Build Coastguard Worker    if len(r) == 1:
201*da0073e9SAndroid Build Coastguard Worker        return "", r[0]
202*da0073e9SAndroid Build Coastguard Worker    else:
203*da0073e9SAndroid Build Coastguard Worker        return r[0], r[1]
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerclass DefaultQuant(MinMaxObserver):
207*da0073e9SAndroid Build Coastguard Worker    def quantize(self, input):
208*da0073e9SAndroid Build Coastguard Worker        assert self.all_tensors
209*da0073e9SAndroid Build Coastguard Worker        scale, zeropoint = self.scale_zeropoint()
210*da0073e9SAndroid Build Coastguard Worker        return torch.quantize_per_tensor(
211*da0073e9SAndroid Build Coastguard Worker            Proxy(input), scale, zeropoint, torch.quint8
212*da0073e9SAndroid Build Coastguard Worker        ).node
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Workerdef matches(modules, node, pattern, max_uses=sys.maxsize):
216*da0073e9SAndroid Build Coastguard Worker    if isinstance(pattern, tuple):
217*da0073e9SAndroid Build Coastguard Worker        self_match, *arg_matches = pattern
218*da0073e9SAndroid Build Coastguard Worker    else:
219*da0073e9SAndroid Build Coastguard Worker        self_match = pattern
220*da0073e9SAndroid Build Coastguard Worker        arg_matches = None
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    if len(node.users) > max_uses:
223*da0073e9SAndroid Build Coastguard Worker        return False
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
226*da0073e9SAndroid Build Coastguard Worker        if node.op != "call_module":
227*da0073e9SAndroid Build Coastguard Worker            return False
228*da0073e9SAndroid Build Coastguard Worker        if not isinstance(modules[node.target], self_match):
229*da0073e9SAndroid Build Coastguard Worker            return False
230*da0073e9SAndroid Build Coastguard Worker    elif callable(self_match):
231*da0073e9SAndroid Build Coastguard Worker        if node.op != "call_function" or node.target is not self_match:
232*da0073e9SAndroid Build Coastguard Worker            return False
233*da0073e9SAndroid Build Coastguard Worker    elif node.target != self_match:
234*da0073e9SAndroid Build Coastguard Worker        return False
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker    if not arg_matches:
237*da0073e9SAndroid Build Coastguard Worker        return True
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    if len(arg_matches) != len(node.args):
240*da0073e9SAndroid Build Coastguard Worker        return False
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    return all(
243*da0073e9SAndroid Build Coastguard Worker        matches(modules, node, arg_match, max_uses=1)
244*da0073e9SAndroid Build Coastguard Worker        for node, arg_match in zip(node.args, arg_matches)
245*da0073e9SAndroid Build Coastguard Worker    )
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Workerclass Quantizer:
249*da0073e9SAndroid Build Coastguard Worker    def __init__(
250*da0073e9SAndroid Build Coastguard Worker        self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant
251*da0073e9SAndroid Build Coastguard Worker    ):
252*da0073e9SAndroid Build Coastguard Worker        self.root = mod
253*da0073e9SAndroid Build Coastguard Worker        self.graph = mod.graph
254*da0073e9SAndroid Build Coastguard Worker        self.quant_ctor = quant_ctor
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        # cached information for observe
257*da0073e9SAndroid Build Coastguard Worker        self.state_dict = self.root.state_dict()
258*da0073e9SAndroid Build Coastguard Worker        self.modules = dict(self.root.named_modules())
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        # match the patterns that will get quantized
261*da0073e9SAndroid Build Coastguard Worker        self.matches = self._find_matches(patterns)
262*da0073e9SAndroid Build Coastguard Worker        # find _inputs_ to matched nodes that are not quantized, these
263*da0073e9SAndroid Build Coastguard Worker        # have to be quantized, which requires measuring stats,
264*da0073e9SAndroid Build Coastguard Worker        # initialize an quant_ctor object for each
265*da0073e9SAndroid Build Coastguard Worker        self.quants = self._find_quants(quant_ctor)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    def observe(self, args):
268*da0073e9SAndroid Build Coastguard Worker        # most of this function is just an interpreter for the graph
269*da0073e9SAndroid Build Coastguard Worker        # it would be possible to put this in some abstraction, but
270*da0073e9SAndroid Build Coastguard Worker        # it is pretty nice to just be able to see exactly what is happening here
271*da0073e9SAndroid Build Coastguard Worker        # and hack on it.
272*da0073e9SAndroid Build Coastguard Worker        # maybe we should just provide an example interpreter that people copy/paste
273*da0073e9SAndroid Build Coastguard Worker        # then edit.
274*da0073e9SAndroid Build Coastguard Worker        args_iter = iter(args)
275*da0073e9SAndroid Build Coastguard Worker        env = {}
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        def load_arg(a):
278*da0073e9SAndroid Build Coastguard Worker            return map_arg(a, lambda node: env[node.name])
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        output_node: Optional[Node] = None
281*da0073e9SAndroid Build Coastguard Worker        for node in self.graph.nodes:
282*da0073e9SAndroid Build Coastguard Worker            if node.op == "placeholder":
283*da0073e9SAndroid Build Coastguard Worker                result = next(args_iter)
284*da0073e9SAndroid Build Coastguard Worker            elif node.op == "get_attr":
285*da0073e9SAndroid Build Coastguard Worker                result = self.state_dict[node.target]
286*da0073e9SAndroid Build Coastguard Worker            elif node.op == "call_function":
287*da0073e9SAndroid Build Coastguard Worker                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
288*da0073e9SAndroid Build Coastguard Worker            elif node.op == "call_method":
289*da0073e9SAndroid Build Coastguard Worker                self_obj, *args = load_arg(node.args)
290*da0073e9SAndroid Build Coastguard Worker                kwargs = load_arg(node.kwargs)
291*da0073e9SAndroid Build Coastguard Worker                result = getattr(self_obj, node.target)(*args, **kwargs)
292*da0073e9SAndroid Build Coastguard Worker            elif node.op == "call_module":
293*da0073e9SAndroid Build Coastguard Worker                result = self.modules[node.target](
294*da0073e9SAndroid Build Coastguard Worker                    *load_arg(node.args), **load_arg(node.kwargs)
295*da0073e9SAndroid Build Coastguard Worker                )
296*da0073e9SAndroid Build Coastguard Worker            elif node.op == "output":
297*da0073e9SAndroid Build Coastguard Worker                return load_arg(node.args[0])
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker            env[node.name] = result
300*da0073e9SAndroid Build Coastguard Worker            root_node, obj = self.matches.get(node.name, (None, None))
301*da0073e9SAndroid Build Coastguard Worker            if root_node is node:
302*da0073e9SAndroid Build Coastguard Worker                obj.observe(node, env)
303*da0073e9SAndroid Build Coastguard Worker            if node.name in self.quants:
304*da0073e9SAndroid Build Coastguard Worker                self.quants[node.name].observe(node, env)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("Graph had no output node!")
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    def quantize(self):
309*da0073e9SAndroid Build Coastguard Worker        self.quantized_graph = Graph()
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        env = {}
312*da0073e9SAndroid Build Coastguard Worker        quant_env = {}
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker        def load_arg(n, quantized):
315*da0073e9SAndroid Build Coastguard Worker            if not quantized:
316*da0073e9SAndroid Build Coastguard Worker                if n.name not in env and n.name in quant_env:
317*da0073e9SAndroid Build Coastguard Worker                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
318*da0073e9SAndroid Build Coastguard Worker                return env[n.name]
319*da0073e9SAndroid Build Coastguard Worker            else:
320*da0073e9SAndroid Build Coastguard Worker                if n.name not in quant_env and n.name in env:
321*da0073e9SAndroid Build Coastguard Worker                    quant_env[n.name] = self.quants[n.name].quantize(env[n.name])
322*da0073e9SAndroid Build Coastguard Worker                return quant_env[n.name]
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        def copy_recursive(node):
325*da0073e9SAndroid Build Coastguard Worker            def load_or_emit(n):
326*da0073e9SAndroid Build Coastguard Worker                if n.name in env or e.name in quant_env:  # noqa: F821
327*da0073e9SAndroid Build Coastguard Worker                    return load_arg(n, quantized=False)
328*da0073e9SAndroid Build Coastguard Worker                else:
329*da0073e9SAndroid Build Coastguard Worker                    return copy_recursive(n)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker            r = env[node.name] = self.quantized_graph.node_copy(
332*da0073e9SAndroid Build Coastguard Worker                node, lambda n: load_arg(n, quantized=False)
333*da0073e9SAndroid Build Coastguard Worker            )
334*da0073e9SAndroid Build Coastguard Worker            return r
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        for node in self.graph.nodes:
337*da0073e9SAndroid Build Coastguard Worker            root_node, obj = self.matches.get(node.name, (None, None))
338*da0073e9SAndroid Build Coastguard Worker            if root_node is None:
339*da0073e9SAndroid Build Coastguard Worker                # not quantized just copy it
340*da0073e9SAndroid Build Coastguard Worker                env[node.name] = self.quantized_graph.node_copy(
341*da0073e9SAndroid Build Coastguard Worker                    node, lambda n: load_arg(n, quantized=False)
342*da0073e9SAndroid Build Coastguard Worker                )
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker            elif root_node is node:
345*da0073e9SAndroid Build Coastguard Worker                r = obj.quantize(
346*da0073e9SAndroid Build Coastguard Worker                    self,
347*da0073e9SAndroid Build Coastguard Worker                    node,
348*da0073e9SAndroid Build Coastguard Worker                    lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)),
349*da0073e9SAndroid Build Coastguard Worker                )
350*da0073e9SAndroid Build Coastguard Worker                if r is NotImplemented:
351*da0073e9SAndroid Build Coastguard Worker                    # quantizer choose to to quantize the node take the entire match, and just copy it over
352*da0073e9SAndroid Build Coastguard Worker                    env[node.name] = copy_recursive(node)
353*da0073e9SAndroid Build Coastguard Worker                else:
354*da0073e9SAndroid Build Coastguard Worker                    quant_env[node.name] = r
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        return GraphModule(self.root, self.quantized_graph)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    def _find_matches(self, patterns):
359*da0073e9SAndroid Build Coastguard Worker        modules = dict(self.root.named_modules())
360*da0073e9SAndroid Build Coastguard Worker        match_map = {}  # node name -> (root_node, match_value?)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        def apply_match(pattern, node, match):
363*da0073e9SAndroid Build Coastguard Worker            if isinstance(pattern, tuple):
364*da0073e9SAndroid Build Coastguard Worker                s, *args = pattern
365*da0073e9SAndroid Build Coastguard Worker                apply_match(s, node, match)
366*da0073e9SAndroid Build Coastguard Worker                for subpattern, arg in zip(args, node.args):
367*da0073e9SAndroid Build Coastguard Worker                    apply_match(subpattern, arg, match)
368*da0073e9SAndroid Build Coastguard Worker            else:
369*da0073e9SAndroid Build Coastguard Worker                match_map[node.name] = match
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        for node in reversed(self.graph.nodes):
372*da0073e9SAndroid Build Coastguard Worker            if node.name not in match_map:
373*da0073e9SAndroid Build Coastguard Worker                for pattern, value in patterns.items():
374*da0073e9SAndroid Build Coastguard Worker                    if matches(modules, node, pattern):
375*da0073e9SAndroid Build Coastguard Worker                        apply_match(pattern, node, (node, value(self, node)))
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker        return match_map
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    def _find_quants(self, quant_ctor):
380*da0073e9SAndroid Build Coastguard Worker        quants = {}
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        def visit_arg(n):
383*da0073e9SAndroid Build Coastguard Worker            # note: we have to measure quantization information
384*da0073e9SAndroid Build Coastguard Worker            # even for nodes where we might not use it because it is already
385*da0073e9SAndroid Build Coastguard Worker            # quantized. This is because each match has the option to
386*da0073e9SAndroid Build Coastguard Worker            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
387*da0073e9SAndroid Build Coastguard Worker            if n.name not in quants:
388*da0073e9SAndroid Build Coastguard Worker                quants[n.name] = quant_ctor(self, n)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        for node in self.graph.nodes:
391*da0073e9SAndroid Build Coastguard Worker            if node.name in self.matches:
392*da0073e9SAndroid Build Coastguard Worker                map_arg(node.args, visit_arg)
393*da0073e9SAndroid Build Coastguard Worker                map_arg(node.kwargs, visit_arg)
394*da0073e9SAndroid Build Coastguard Worker        return quants
395