xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/weight_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Callable, Dict, List, Optional
2
3import torch
4import torch.ao.nn.intrinsic as nni
5import torch.ao.nn.intrinsic.qat as nniqat
6import torch.ao.nn.intrinsic.quantized as nniq
7import torch.ao.nn.qat as nnqat
8import torch.ao.nn.quantized as nnq
9import torch.ao.nn.quantized.dynamic as nnqd
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.fx import GraphModule
13from torch.fx.graph import Node
14
15from .ns_types import NSSingleResultType, NSSingleResultValuesType
16from .utils import get_target_type_str, getattr_from_fqn, return_first_non_observer_node
17
18
19toq = torch.ops.quantized
20
21
22def mod_weight_detach(mod: nn.Module) -> torch.Tensor:
23    return mod.weight.detach()  # type: ignore[operator]
24
25
26def mod_0_weight_detach(mod: nn.Module) -> torch.Tensor:
27    return mod[0].weight.detach()  # type: ignore[index]
28
29
30def mod_weight_bias_0(mod: nn.Module) -> torch.Tensor:
31    return mod._weight_bias()[0]  # type: ignore[operator]
32
33
34def get_lstm_weight(mod: nn.Module) -> List[torch.Tensor]:
35    res = []
36    for idx, param_name in enumerate(mod._flat_weights_names):  # type: ignore[arg-type]
37        if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
38            param_value = mod._flat_weights[idx].detach()  # type: ignore[index]
39            res.append(param_value)
40    return res
41
42
43def get_qlstm_weight(mod: nn.Module) -> List[torch.Tensor]:
44    res = []
45    for weight_value in mod._all_weight_values:  # type: ignore[union-attr]
46        res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
47        res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
48    return res
49
50
51def get_conv_mod_weight(mod: nn.Module) -> torch.Tensor:
52    if isinstance(mod, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
53        return mod.weight.detach()
54    elif isinstance(mod, (nni.ConvReLU1d, nni.ConvReLU2d, nni.ConvReLU3d)):
55        return mod[0].weight.detach()
56    else:
57        return mod._weight_bias()[0]  # type: ignore[operator]
58
59
60def get_linear_mod_weight(mod: nn.Module) -> torch.Tensor:
61    if isinstance(mod, nn.Linear):
62        return mod.weight.detach()
63    elif isinstance(mod, nni.LinearReLU):
64        return mod[0].weight.detach()
65    else:
66        return mod._weight_bias()[0]  # type: ignore[operator]
67
68
69def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]:
70    # TODO(future PR): make more generic, handle everything
71    if isinstance(mod, nn.LSTM):
72        res = []
73        for idx, param_name in enumerate(mod._flat_weights_names):
74            if "weight_ih_l" in param_name or "weight_hh_l" in param_name:
75                param_value = mod._flat_weights[idx].detach()
76                res.append(param_value)
77        return res
78    else:
79        assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet"
80        res = []
81        for weight_value in mod._all_weight_values:
82            res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0])
83            res.append(weight_value.param.__getstate__()[0][4][1].__getstate__()[0][0])
84        return res
85
86
87def get_conv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
88    # traverse backwards from the weight arg, accounting for any observers
89    weight_arg_node = node.args[1]
90    assert isinstance(weight_arg_node, Node)
91    weight_node = return_first_non_observer_node(weight_arg_node, gm)
92    assert isinstance(weight_node, Node)
93    assert weight_node.op == "get_attr"
94    weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
95    return weight.detach()
96
97
98def get_qconv_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
99    # qconv state is arg 1
100    qconv_state_node = node.args[1]
101    assert isinstance(qconv_state_node, Node)
102    assert qconv_state_node.op == "get_attr"
103    qconv_state_obj = getattr_from_fqn(gm, qconv_state_node.target)  # type: ignore[arg-type]
104    return qconv_state_obj.weight()
105
106
107def get_linear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
108    # traverse backwards from the weight arg, accounting for any observers
109    # supported patterns:
110    # weight -> obs -> linear
111    # weight -> to(torch.float16) -> dequantize -> linear
112    linear_second_arg = node.args[1]
113    assert isinstance(linear_second_arg, Node)
114
115    if linear_second_arg.op == "call_module":
116        # weight -> obs -> linear
117        weight_arg_node = node.args[1]
118        assert isinstance(weight_arg_node, Node)
119        weight_node = weight_arg_node.args[0]
120        assert isinstance(weight_node, Node)
121        assert weight_node.op == "get_attr"
122        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
123        return weight.detach()
124    elif linear_second_arg.op == "call_method":
125        # weight -> to(torch.float16) -> dequantize -> linear
126        assert linear_second_arg.op == "call_method"
127        dequant_node = node.args[1]
128        assert isinstance(dequant_node, Node)
129        to_fp16_node = dequant_node.args[0]
130        assert isinstance(to_fp16_node, Node)
131        # extract the dtype, so we can cast to it before returning
132        target_dtype = to_fp16_node.args[1]
133        weight_node = to_fp16_node.args[0]
134        assert isinstance(weight_node, Node)
135        assert weight_node.op == "get_attr"
136        weight = getattr_from_fqn(gm, weight_node.target)  # type: ignore[arg-type]
137        # return the weight with fp16 cast
138        return weight.detach().to(target_dtype)
139    else:
140        assert linear_second_arg.op == "get_attr"
141        weight = getattr_from_fqn(gm, linear_second_arg.target)  # type: ignore[arg-type]
142        return weight.detach()
143
144
145def get_qlinear_fun_weight(node: Node, gm: GraphModule) -> torch.Tensor:
146    # packed weight is arg 1
147    packed_weight_node = node.args[1]
148    assert isinstance(packed_weight_node, Node)
149    assert packed_weight_node.op == "get_attr"
150    packed_weight = getattr_from_fqn(gm, packed_weight_node.target)  # type: ignore[arg-type]
151    # TODO(future PR): why does packed_weight.unpack() not work?
152    (weight, _bias), _name = packed_weight.__getstate__()
153    return weight
154
155
156def get_op_to_type_to_weight_extraction_fn() -> Dict[str, Dict[Callable, Callable]]:
157    op_to_type_to_weight_extraction_fn: Dict[str, Dict[Callable, Callable]] = {
158        "call_module": {
159            # Conv1d
160            nn.Conv1d: mod_weight_detach,
161            nni.ConvReLU1d: mod_0_weight_detach,
162            nnq.Conv1d: mod_weight_bias_0,
163            nnqat.Conv1d: mod_weight_detach,
164            nniqat.ConvBn1d: mod_weight_detach,
165            nniqat.ConvBnReLU1d: mod_weight_detach,
166            nniqat.ConvReLU1d: mod_weight_detach,
167            nniq.ConvReLU1d: mod_weight_bias_0,
168            # Conv2d
169            nn.Conv2d: mod_weight_detach,
170            nni.ConvReLU2d: mod_0_weight_detach,
171            nnq.Conv2d: mod_weight_bias_0,
172            nnqat.Conv2d: mod_weight_detach,
173            nniqat.ConvBn2d: mod_weight_detach,
174            nniqat.ConvBnReLU2d: mod_weight_detach,
175            nniqat.ConvReLU2d: mod_weight_detach,
176            nniq.ConvReLU2d: mod_weight_bias_0,
177            # Conv3d
178            nn.Conv3d: mod_weight_detach,
179            nni.ConvReLU3d: mod_0_weight_detach,
180            nnq.Conv3d: mod_weight_bias_0,
181            nnqat.Conv3d: mod_weight_detach,
182            nniqat.ConvBn3d: mod_weight_detach,
183            nniqat.ConvBnReLU3d: mod_weight_detach,
184            nniqat.ConvReLU3d: mod_weight_detach,
185            nniq.ConvReLU3d: mod_weight_bias_0,
186            # Linear
187            nn.Linear: mod_weight_detach,
188            nnq.Linear: mod_weight_bias_0,
189            nni.LinearReLU: mod_0_weight_detach,
190            nniq.LinearReLU: mod_weight_bias_0,
191            nnqat.Linear: mod_weight_detach,
192            nnqd.Linear: mod_weight_bias_0,
193            nniqat.LinearReLU: mod_weight_detach,
194            nniqat.LinearBn1d: mod_weight_detach,
195            nn.modules.linear.NonDynamicallyQuantizableLinear: mod_weight_detach,
196            # LSTM
197            nn.LSTM: get_lstm_weight,
198            nnqd.LSTM: get_qlstm_weight,
199        },
200        "call_function": {
201            # Conv
202            F.conv1d: get_conv_fun_weight,
203            F.conv2d: get_conv_fun_weight,
204            F.conv3d: get_conv_fun_weight,
205            toq.conv1d: get_qconv_fun_weight,
206            toq.conv2d: get_qconv_fun_weight,
207            toq.conv3d: get_qconv_fun_weight,
208            toq.conv1d_relu: get_qconv_fun_weight,
209            toq.conv2d_relu: get_qconv_fun_weight,
210            toq.conv3d_relu: get_qconv_fun_weight,
211            # Linear
212            F.linear: get_linear_fun_weight,
213            toq.linear: get_qlinear_fun_weight,
214            toq.linear_relu: get_qlinear_fun_weight,
215        },
216    }
217
218    return op_to_type_to_weight_extraction_fn
219
220
221def extract_weight_from_node(
222    node: Node,
223    gm: GraphModule,
224    op_to_type_to_weight_extraction_fn: Optional[
225        Dict[str, Dict[Callable, Callable]]
226    ] = None,
227) -> Optional[NSSingleResultType]:
228    res_type = NSSingleResultValuesType.WEIGHT.value
229
230    # Not all graphmodules have _node_name_to_scope, so only fill it
231    # out if it exists.
232    fqn = None
233    if hasattr(gm, "_node_name_to_scope"):
234        fqn = gm._node_name_to_scope[node.name][0]  # type: ignore[index]
235
236    if op_to_type_to_weight_extraction_fn is None:
237        op_to_type_to_weight_extraction_fn = get_op_to_type_to_weight_extraction_fn()
238
239    ref_node_type = get_target_type_str(node, gm)
240    # for extracting weights, these are always the same
241    prev_node_type = ref_node_type
242
243    if node.op == "call_function":
244        function_mapping = op_to_type_to_weight_extraction_fn["call_function"]
245        for target_fn_type, weight_extraction_fn in function_mapping.items():
246            if node.target == target_fn_type:
247                weight = weight_extraction_fn(node, gm)
248                return {
249                    "type": res_type,
250                    "values": [weight],
251                    "prev_node_name": node.name,
252                    "prev_node_target_type": prev_node_type,
253                    "ref_node_name": node.name,
254                    "ref_node_target_type": ref_node_type,
255                    "index_within_arg": 0,
256                    "index_of_arg": 0,
257                    "fqn": fqn,
258                }
259
260    elif node.op == "call_module":
261        # for call_module, we need to look up the modules to do the type check
262        assert isinstance(node.target, str)
263        mod = getattr_from_fqn(gm, node.target)
264        module_mapping = op_to_type_to_weight_extraction_fn["call_module"]
265        for target_mod_type, weight_extraction_fn in module_mapping.items():
266            if type(mod) == target_mod_type:
267                weight = weight_extraction_fn(mod)
268                return {
269                    "type": res_type,
270                    "values": [weight],
271                    "prev_node_name": node.name,
272                    "prev_node_target_type": prev_node_type,
273                    "ref_node_name": node.name,
274                    "ref_node_target_type": ref_node_type,
275                    "index_within_arg": 0,
276                    "index_of_arg": 0,
277                    "fqn": fqn,
278                }
279
280    return None
281