1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10 11import torch 12from executorch import exir 13from executorch.exir import EdgeCompileConfig, to_edge 14from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 15from executorch.exir.tests.common import register_additional_test_aten_ops 16from torch.ao.quantization import ( # @manual 17 float_qparams_weight_only_qconfig, 18 get_default_qconfig_mapping, 19) 20from torch.ao.quantization.backend_config.executorch import ( 21 get_executorch_backend_config, 22) 23 24from torch.ao.quantization.qconfig_mapping import QConfigMapping 25from torch.ao.quantization.quantize_fx import ( 26 _convert_to_reference_decomposed_fx, 27 prepare_fx, 28) 29from torch.export import export 30from torch.nn import functional as F 31 32from torch.testing import FileCheck 33 34 35class TestQuantFusionPass(unittest.TestCase): 36 @classmethod 37 def setUpClass(cls) -> None: 38 register_additional_test_aten_ops() 39 40 def test_add(self) -> None: 41 class M(torch.nn.Module): 42 def forward(self, x, y): 43 # edge case, doesn't work yet, but we can add a fusion 44 # pattern to enable it if needed 45 # return x + x 46 return x + y 47 48 example_inputs = (torch.randn(1, 5), torch.randn(1, 5)) 49 m = M().eval() 50 # TODO: define qconfig_mapping specifically for executorch 51 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 52 m = prepare_fx( 53 m, 54 qconfig_mapping, 55 example_inputs, 56 backend_config=get_executorch_backend_config(), 57 ) 58 m = _convert_to_reference_decomposed_fx(m) 59 config = EdgeCompileConfig(_check_ir_validity=False) 60 m = to_edge(export(m, example_inputs), compile_config=config) 61 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 62 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 63 # check that we are using functional variant of q/dq/add 64 FileCheck().check( 65 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" 66 ).check( 67 "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default" 68 ).check( 69 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" 70 ).run( 71 m.exported_program().graph_module.code 72 ) 73 m = m.to_executorch() 74 # check that we are using out variant of q/dq/add 75 FileCheck().check("torch.ops.quantized_decomposed.add.out").run( 76 m.exported_program().graph_module.code 77 ) 78 79 def test_reshape(self) -> None: 80 class M(torch.nn.Module): 81 def forward(self, x, y): 82 x = x + y 83 x = x.reshape(1, x.numel()) 84 return x 85 86 example_inputs = (torch.randn(3, 5), torch.randn(3, 5)) 87 m = M().eval() 88 # TODO: define qconfig_mapping specifically for executorch 89 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 90 m = prepare_fx( 91 m, 92 qconfig_mapping, 93 example_inputs, 94 backend_config=get_executorch_backend_config(), 95 ) 96 m(*example_inputs) 97 m = _convert_to_reference_decomposed_fx(m) 98 config = EdgeCompileConfig(_check_ir_validity=False) 99 m = to_edge(export(m, example_inputs), compile_config=config) 100 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 101 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 102 # check that we are using functional variant of q/dq/add/reshape 103 # make sure we only have two quant and one dequant since the q/dq around reshape 104 # should be fused 105 FileCheck().check_count( 106 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", 107 2, 108 exactly=True, 109 ).check( 110 "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default" 111 ).check( 112 "executorch_exir_dialects_edge__ops_aten_view_copy_default" 113 ).check_count( 114 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", 115 1, 116 exactly=True, 117 ).run( 118 m.exported_program().graph_module.code 119 ) 120 121 m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False)) 122 # check that we are using out variant of q/dq/add 123 FileCheck().check("torch.ops.quantized_decomposed.add.out").check( 124 "torch.ops.aten.view_copy.out" 125 ).run(m.exported_program().graph_module.code) 126 127 def test_slice(self) -> None: 128 """We don't proactively quantize slice today, but we'll fuse the dq-slice-q 129 130 pattern into a int8 slice operator, we can revist this later to 131 see if proactively quantize slice is needed or not 132 """ 133 134 class M(torch.nn.Module): 135 def forward(self, x, y): 136 x = x + y 137 x = x[1:] 138 y = y[1:] 139 x = x + y 140 return x 141 142 example_inputs = (torch.randn(3, 5), torch.randn(3, 5)) 143 m = M().eval() 144 # TODO: define qconfig_mapping specifically for executorch 145 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 146 m = prepare_fx( 147 m, 148 qconfig_mapping, 149 example_inputs, 150 backend_config=get_executorch_backend_config(), 151 ) 152 m = _convert_to_reference_decomposed_fx(m) 153 config = EdgeCompileConfig(_check_ir_validity=False) 154 m = to_edge(export(m, example_inputs), compile_config=config) 155 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 156 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 157 # check that we are using functional variant of q/dq/add/slice 158 # make sure we only have one quant and one dequant since the q/dq around slice 159 # should be fused 160 FileCheck().check_count( 161 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", 162 2, 163 exactly=True, 164 ).check("executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor").check( 165 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" 166 ).check( 167 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" 168 ).check( 169 "executorch_exir_dialects_edge__ops_quantized_decomposed_add_default" 170 ).check( 171 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" 172 ).run( 173 m.exported_program().graph_module.code 174 ) 175 176 m = m.to_executorch() 177 # check that we are using out variant of add and slice_copy 178 FileCheck().check("torch.ops.quantized_decomposed.add.out").check( 179 "torch.ops.aten.slice_copy.Tensor_out" 180 ).run(m.exported_program().graph_module.code) 181 182 def test_cat(self) -> None: 183 class M(torch.nn.Module): 184 def forward(self, x, y): 185 x = torch.cat([x, x], dim=0) 186 return x 187 188 example_inputs = (torch.randn(3, 5), torch.randn(3, 5)) 189 m = M().eval() 190 # TODO: define qconfig_mapping specifically for executorch 191 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 192 m = prepare_fx( 193 m, 194 qconfig_mapping, 195 example_inputs, 196 backend_config=get_executorch_backend_config(), 197 ) 198 m(*example_inputs) 199 m = _convert_to_reference_decomposed_fx(m) 200 config = EdgeCompileConfig(_check_ir_validity=False) 201 m = to_edge(export(m, example_inputs), compile_config=config) 202 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 203 m = m.transform([QuantFusionPass()]) 204 # check that we are using functional variant of q/dq/cat 205 FileCheck().check_count( 206 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", 207 1, 208 exactly=True, 209 ).check("executorch_exir_dialects_edge__ops_aten_cat_default").check_count( 210 "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default", 211 1, 212 exactly=True, 213 ).run( 214 m.exported_program().graph_module.code 215 ) 216 217 m = m.to_executorch() 218 # Note: quantized add is not fused since the qparams are the same and current subgraph_rewriter 219 # doesn't work for the case when single graph node map to two different pattern node 220 # one work around would be to add new patterns for the case when qparams are the same 221 # for quantized add pattern, but this may not be needed in real use case, we can 222 # add this workaround if needed in another diff 223 FileCheck().check_count( 224 "torch.ops.quantized_decomposed.quantize_per_tensor.out", 1, exactly=True 225 ).check("torch.ops.aten.cat.out").check_count( 226 "torch.ops.quantized_decomposed.dequantize_per_tensor.out", 1, exactly=True 227 ).run( 228 m.exported_program().graph_module.code 229 ) 230 231 def test_embedding_byte(self) -> None: 232 class M(torch.nn.Module): 233 def __init__(self): 234 super().__init__() 235 self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12) 236 237 def forward(self, indices): 238 return self.emb(indices) 239 240 for qconfig in [float_qparams_weight_only_qconfig]: 241 m = M().eval() 242 indices = torch.tensor( 243 [ 244 9, 245 6, 246 5, 247 7, 248 8, 249 8, 250 9, 251 2, 252 8, 253 6, 254 6, 255 9, 256 1, 257 6, 258 8, 259 8, 260 3, 261 2, 262 3, 263 6, 264 3, 265 6, 266 5, 267 7, 268 0, 269 8, 270 4, 271 6, 272 5, 273 8, 274 2, 275 3, 276 ] 277 ) 278 example_inputs = (indices,) 279 # TODO: define qconfig_mapping specifically for executorch 280 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 281 qconfig_mapping = qconfig_mapping.set_object_type( 282 torch.nn.Embedding, qconfig 283 ) 284 m = prepare_fx( 285 m, 286 qconfig_mapping, 287 example_inputs, 288 backend_config=get_executorch_backend_config(), 289 ) 290 m(*example_inputs) 291 m = _convert_to_reference_decomposed_fx(m) 292 compile_config = EdgeCompileConfig( 293 _check_ir_validity=False, 294 _use_edge_ops=True, 295 ) 296 m = to_edge(export(m, example_inputs), compile_config=compile_config) 297 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 298 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 299 # check that we are using functional variant of q/dq/cat 300 FileCheck().check( 301 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default", 302 ).check( 303 "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default" 304 ).run( 305 m.exported_program().graph_module.code 306 ) 307 308 # TODO: enable after the out variants of quantize_per_channel is supported 309 # m = m.to_executorch() 310 # FileCheck().check( 311 # "executorch_exir_dialects_edge__ops_quantized_decomposed.quantize_per_channel.out", 312 # ).check("executorch_exir_dialects_edge__ops_quantized_decomposed.embedding_byte.out" 313 # ).run( 314 # m.dump_graph_module().code 315 # ) 316 317 def test_embedding_byte_functional(self) -> None: 318 class M(torch.nn.Module): 319 def __init__(self): 320 super().__init__() 321 self.weight = torch.rand((3, 2)) 322 323 def forward(self, indices): 324 return F.embedding(indices, self.weight) 325 326 for qconfig in [float_qparams_weight_only_qconfig]: 327 m = M().eval() 328 indices = torch.tensor( 329 [ 330 0, 331 ] 332 ) 333 example_inputs = (indices,) 334 335 qconfig_mapping = QConfigMapping().set_object_type( 336 F.embedding, 337 qconfig, 338 ) 339 340 m = prepare_fx( 341 m, 342 qconfig_mapping, 343 example_inputs, 344 backend_config=get_executorch_backend_config(), 345 ) 346 m(*example_inputs) 347 m = _convert_to_reference_decomposed_fx(m) 348 compile_config = EdgeCompileConfig( 349 _check_ir_validity=False, 350 _use_edge_ops=True, 351 ) 352 m = to_edge(export(m, example_inputs), compile_config=compile_config) 353 # QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph. 354 m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) 355 # check that we are using functional variant of q/dq/cat 356 FileCheck().check( 357 "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default", 358 ).check( 359 "executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default" 360 ).run( 361 m.exported_program().graph_module.code 362 ) 363 364 # TODO: enable after the out variants of quantize_per_channel is supported 365 # m = m.to_executorch() 366 # FileCheck().check( 367 # "executorch_exir_dialects_edge__ops_quantized_decomposed.quantize_per_channel.out", 368 # ).check("executorch_exir_dialects_edge__ops_quantized_decomposed.embedding_byte.out" 369 # ).run( 370 # m.dump_graph_module().code 371 # ) 372