1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4# This file contains all the functions that replace one op with another in the 5# graph. The functions replacing ops for models deployed with Jarvis are grouped 6# together in class 'ReplaceOpsInGraph'. Some examples of functions in the class are 7# 1. functions that replace an ATen op with a custom op that accepts extra arguments 8# 2. functions that replace in-place variants of ATen ops with out-of-place version. 9# 3. functions that replace an ATen op with another semantically equivalent ATen op. 10# 4. functions that concretize optional args. 11 12import math 13from operator import neg 14from typing import cast, Dict, Iterable, Sequence, Set, Tuple 15 16import torch 17import torch.fx 18from executorch.backends.cadence.aot.compiler_utils import ( 19 get_shape, 20 get_tensor_from_attr, 21 get_transposed_dims, 22 get_zero_point, 23 is_node_with_op, 24 is_quantized_tensor, 25 quantize_tensor_multiplier, 26) 27from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps 28from executorch.backends.cadence.aot.pass_utils import ( 29 CadencePassAttribute, 30 register_cadence_pass, 31) 32from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass 33from executorch.backends.cadence.aot.utils import get_edge_overload_packet 34from executorch.exir.dialects._ops import ops as exir_ops 35from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 36from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue 37from torch._subclasses import FakeTensor 38from torch.fx.node import Argument 39 40# A map to represent ops that: 41# (a) are functionally equivalent wrt. Jarvis; and 42# (b) have identical arguments 43# An op whose target is 'key' in this dict can be replaced by the functionally euivalent 44# op whose target is 'value'. The replacement would just involve changing the op target. 45functionally_equivalent_op_targets: Dict[EdgeOpOverload, EdgeOpOverload] = { 46 exir_ops.edge.aten.relu_.default: exir_ops.edge.aten.relu.default, 47 exir_ops.edge.aten.unsafe_split.Tensor: exir_ops.edge.aten.split_copy.Tensor, 48} 49 50 51def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: 52 """ 53 Return true if any of the node in the incoming nodes list is a placeholder 54 or parameter 55 """ 56 return any( 57 is_node_with_op(node, "placeholder") or is_node_with_op(node, "get_attr") 58 for node in nodes 59 ) 60 61 62@register_cadence_pass(CadencePassAttribute(opt_level=0)) 63class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): 64 """ 65 A where op with a logical_not and a boolean tensor can be replaced 66 by a where op with flipped inputs and the initial boolean tensor. 67 """ 68 69 def replace_logical_nop_where_with_where( 70 self, graph_module: torch.fx.GraphModule 71 ) -> None: 72 graph = graph_module.graph 73 for node in graph.nodes: 74 # We are only interested in where nodes 75 if node.target != exir_ops.edge.aten.where.self: 76 continue 77 78 # If the third arg is not a logical_not, bail. 79 if node.args[0].target != exir_ops.edge.aten.logical_not.default: 80 continue 81 82 # Get the third arg node and its input 83 logical_not_node = node.args[0] 84 logical_not_input_tensor = ( 85 logical_not_node.args[0].to_tensor() 86 if isinstance(logical_not_node.args[0], ProxyValue) 87 else logical_not_node.args[0] 88 ) 89 90 # If the logical_not input is not a boolean tensor, bail. 91 if logical_not_input_tensor.meta["spec"].dtype != torch.bool: 92 continue 93 94 # Replace the where op with another one, flipping the inputs and using the boolean 95 # tensor from logical_not. 96 with graph.inserting_before(node): 97 linear_node = graph.call_function( 98 exir_ops.edge.aten.where.self, 99 args=(logical_not_node.args[0], node.args[2], node.args[1]), 100 ) 101 # Replace all the uses 102 node.replace_all_uses_with(linear_node) 103 104 graph_module.recompile() 105 graph_module.graph.eliminate_dead_code() 106 107 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 108 self.replace_logical_nop_where_with_where(graph_module) 109 result = super().call(graph_module) 110 return result 111 112 113@register_cadence_pass(CadencePassAttribute(opt_level=0)) 114class ReplaceSafeSoftmaxWithSoftmax(ExportPass): # keep 115 """ 116 Replace _safe_softmax with _softmax 117 """ 118 119 def call_operator( 120 self, 121 op, 122 args: tuple[Argument, ...], 123 kwargs: dict[str, Argument], 124 meta: NodeMetadata, 125 ) -> ProxyValue: 126 if op != torch.ops.aten._safe_softmax.default: 127 return super().call_operator(op, args, kwargs, meta) 128 129 # Add False for the half_to_float argument of softmax 130 softmax_args = list(args) + [False] 131 132 return super().call_operator( 133 torch.ops.aten._softmax.default, 134 tuple(softmax_args), 135 kwargs, 136 meta, 137 ) 138 139 140@register_cadence_pass(CadencePassAttribute(opt_level=0)) 141class ReplacePT2QuantWithCadenceQuantPass(ExportPass): 142 """ 143 Replace the pt2 quantization ops with cadence quantization ops. 144 We do not link kernels to the PT2 quantization ops, so we need to 145 replace them with cadence ops at all optimization levels. 146 """ 147 148 def call_operator( 149 self, 150 op, 151 args: Tuple[Argument, ...], 152 kwargs: Dict[str, Argument], 153 meta: NodeMetadata, 154 ) -> ProxyValue: 155 if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: 156 return super().call_operator(op, args, kwargs, meta) 157 158 return super().call_operator( 159 exir_ops.edge.cadence.quantize_per_tensor.default, 160 args, 161 kwargs, 162 meta, 163 ) 164 165 166@register_cadence_pass(CadencePassAttribute(opt_level=0)) 167class ReplacePT2DequantWithCadenceDequantPass(ExportPass): 168 """ 169 Replace the pt2 dequantization ops with cadence dequantization ops. 170 We do not link kernels to the PT2 quantization ops, so we need to 171 replace them with cadence ops at all optimization levels. 172 """ 173 174 def call_operator( 175 self, 176 op, 177 args: Tuple[Argument, ...], 178 kwargs: Dict[str, Argument], 179 meta: NodeMetadata, 180 ) -> ProxyValue: 181 if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: 182 return super().call_operator(op, args, kwargs, meta) 183 184 return super().call_operator( 185 exir_ops.edge.cadence.dequantize_per_tensor.default, 186 args, 187 kwargs, 188 meta, 189 ) 190 191 192@register_cadence_pass(CadencePassAttribute(opt_level=0)) 193class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): 194 """ 195 When the shape is static, replace squeeze_copy and unsqueeze_copy ops with 196 view_copy op 197 """ 198 199 def call_operator( 200 self, 201 op, 202 args: Tuple[Argument, ...], 203 kwargs: Dict[str, Argument], 204 meta: NodeMetadata, 205 ) -> ProxyValue: 206 # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, 207 # which allows us to cover all overloads. 208 if get_edge_overload_packet(op) not in { 209 exir_ops.edge.aten.squeeze_copy, 210 exir_ops.edge.aten.unsqueeze_copy, 211 }: 212 return super().call_operator(op, args, kwargs, meta) 213 # Get the output tensor shape 214 out_shape = meta["val"].shape 215 216 # Bail out if any dim is not an int (dynamic shape) 217 for dim in list(out_shape): 218 if not isinstance(dim, int): 219 return super().call_operator(op, args, kwargs, meta) 220 221 # Return a view op with the new shape 222 view_args = (args[0], list(out_shape)) 223 return super().call_operator( 224 exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta 225 ) 226 227 228@register_cadence_pass(CadencePassAttribute(opt_level=0)) 229class ReplaceFunctionallyEquivalentOpTargets(ExportPass): 230 """ 231 Replace an op with a functionally equivalent op by just switching the op 232 target, but without incurring any change to the op args. 233 """ 234 235 def call_operator(self, op, args, kwargs, meta): 236 if op not in functionally_equivalent_op_targets: 237 return super().call_operator(op, args, kwargs, meta) 238 return super().call_operator( 239 functionally_equivalent_op_targets[op], args, kwargs, meta 240 ) 241 242 243@register_cadence_pass(CadencePassAttribute(opt_level=1)) 244class ReplaceSelectWithViewOpPass(ExportPass): 245 """ 246 If the size along the select dim is 1, then the select op can be replaced 247 by view op. 248 """ 249 250 def call_operator(self, op, args, kwargs, meta): 251 if op != exir_ops.edge.aten.select_copy.int: 252 return super().call_operator(op, args, kwargs, meta) 253 254 # Glean the shape of input and output tensor 255 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 256 in_shape = in_tensor.shape 257 out_shape = meta["val"].shape 258 # Get the select dimension 259 select_dim = args[1] if args[1] >= 0 else args[1] + len(in_shape) 260 261 if in_shape[select_dim] == 1: 262 # Return a view op with the new shape 263 view_args = (args[0], list(out_shape)) 264 return super().call_operator( 265 exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta 266 ) 267 return super().call_operator(op, args, kwargs, meta) 268 269 270@register_cadence_pass(CadencePassAttribute(opt_level=0)) 271class ReplaceTCopyWithTransposePass(ExportPass): 272 """ 273 Replace t_copy with transpose_copy.int. If the input is 1D, the t_copy is 274 a nop. t_copy is not supported, so this is an opt_level=0 pass. 275 """ 276 277 def call_operator(self, op, args, kwargs, meta): 278 if get_edge_overload_packet(op) != exir_ops.edge.aten.t_copy: 279 return super().call_operator(op, args, kwargs, meta) 280 281 # Get the input tensor shape 282 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 283 284 # If the input is a 1D tensor, this t_copy is a nop, so return the input 285 if in_tensor.dim() <= 1: 286 return args[0] 287 288 assert in_tensor.dim() == 2, "t_copy expects a tensor with <= 2 dimensions" 289 transpose_args = (args[0], 0, 1) 290 return super().call_operator( 291 exir_ops.edge.aten.transpose_copy.int, transpose_args, kwargs, meta 292 ) 293 294 295@register_cadence_pass(CadencePassAttribute(opt_level=0)) 296class ReplaceMMWithAddMMPass(ExportPass): 297 """ 298 This pass replaces mm with addmm by introducing a zero bias. 299 mm is not supported, so this is an opt_level=0 pass. 300 """ 301 302 def call_operator(self, op, args, kwargs, meta): 303 if op != exir_ops.edge.aten.mm.default: 304 return super().call_operator(op, args, kwargs, meta) 305 306 # The mm op has two args: input, mat2 307 assert len(args) == 2 308 X, mat2 = args 309 310 # Create a zero bias tensor, and insert it as a graph buffer before the 311 # current node 312 mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2 313 bias_size = mat2_tensor.size(1) 314 zero_bias = super().call_operator( 315 exir_ops.edge.aten.full.default, 316 ([bias_size], 0.0), 317 {"dtype": torch.float32}, 318 meta, 319 ) 320 321 # Replace mm with addmm 322 new_args = (zero_bias, X, mat2) 323 return super().call_operator( 324 exir_ops.edge.aten.addmm.default, new_args, kwargs, meta 325 ) 326 327 328@register_cadence_pass(CadencePassAttribute(opt_level=1)) 329class ReplaceAddMMWithLinearPass(ExportPass): 330 """ 331 This pass replaces addmm with linear op. 332 """ 333 334 def __init__(self): 335 super().__init__() 336 self.counter = 0 337 338 def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule): 339 graph = graph_module.graph 340 for node in graph.nodes: 341 # We are only interested in admm nodes 342 if node.target != exir_ops.edge.aten.addmm.default: 343 continue 344 345 # The addmm op has three concrete args: input, mat1, mat2 346 assert len(node.args) >= 3 347 (bias, mat1, mat2) = node.args[0:3] 348 # The other two args are optional scale args 349 beta = node.kwargs.get("beta", 1.0) 350 alpha = node.kwargs.get("alpha", 1.0) 351 352 # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert 353 # it to linear op by multiplying beta to bias, and alpha to mat2.t(). 354 # However, the following two conditions must hold: 355 # a. If bias is not a param, then beta must be 1.0 356 # b. If mat2 is not a param, then mat2 must be a transpose op. Also, 357 # the input to the transpose must be a param, or alpha must be 1.0. 358 fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 359 fit_mat2 = is_node_with_op(mat2, "get_attr") 360 transposed_mat2 = False 361 if ( 362 not fit_mat2 363 and is_node_with_op(mat2, "call_function") 364 and mat2.target == exir_ops.edge.aten.transpose_copy.int 365 ): 366 mat2, transposed_mat2 = mat2.args[0], True 367 fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 368 369 if not fit_bias or not fit_mat2: 370 continue 371 372 # Multiply bias by beta 373 if beta != 1.0: 374 assert is_node_with_op(bias, "get_attr") 375 bias_tensor = get_tensor_from_attr(graph_module, bias) 376 assert isinstance(bias_tensor, torch.Tensor) 377 bias_tensor = beta * bias_tensor 378 with graph.inserting_before(node): 379 bias_name = f"_bias_addmm_to_linear_{self.counter}" 380 graph_module.register_buffer(bias_name, bias_tensor) 381 bias = graph.get_attr(bias_name) 382 383 # Use associativity of scalar multiplication, and multiply alpha to mat2 384 if is_node_with_op(mat2, "get_attr"): 385 mat2_tensor = get_tensor_from_attr(graph_module, mat2) 386 assert isinstance(mat2_tensor, torch.Tensor) 387 mat2_tensor = alpha * mat2_tensor 388 # transpose mat2 389 mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t() 390 with graph.inserting_before(node): 391 mat2_name = f"_mat2_addmm_to_linear_{self.counter}" 392 graph_module.register_buffer(mat2_name, mat2_tensor) 393 mat2 = graph.get_attr(mat2_name) 394 395 # Construct the linear node 396 linear_args = (mat1, mat2, bias) 397 with graph.inserting_before(node): 398 linear_node = graph.call_function( 399 exir_ops.edge.aten.linear.default, args=linear_args 400 ) 401 linear_node.meta = node.meta 402 # Replace all the uses of the addmm op with linear op 403 node.replace_all_uses_with(linear_node) 404 self.counter += 1 405 406 graph_module.recompile() 407 graph_module.graph.eliminate_dead_code() 408 409 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 410 self.replace_addmm_with_linear(graph_module) 411 result = super().call(graph_module) 412 return result 413 414 415@register_cadence_pass(CadencePassAttribute(opt_level=1)) 416class ReplacePermuteWithTransposePass(ExportPass): 417 """ 418 Replace permute op with transpose if the permutation is only along 419 two dimensions. 420 """ 421 422 def call_operator(self, op, args, kwargs, meta): 423 if op != exir_ops.edge.aten.permute_copy.default: 424 return super().call_operator(op, args, kwargs, meta) 425 426 # Get the old dim and new dim order 427 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 428 old_dims = tuple(range(in_tensor.dim())) 429 new_dims = args[1] 430 431 # Compute the number of positions in which the old and new order differ 432 diff = [od for od, nd in zip(old_dims, new_dims) if od != nd] 433 434 # If the difference is in two dimensions, we can replace this permute op 435 # with transpose op. 436 if len(diff) == 2: 437 new_args = (args[0], diff[0], diff[1]) 438 return super().call_operator( 439 exir_ops.edge.aten.transpose_copy.int, new_args, kwargs, meta 440 ) 441 442 return ( 443 args[0] if len(diff) == 0 else super().call_operator(op, args, kwargs, meta) 444 ) 445 446 447@register_cadence_pass(CadencePassAttribute(opt_level=0)) 448class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass): 449 """ 450 Replace optional tensors with concrete tensors. Currently, we 451 replace the optional bias tensor with a zero tensor. 452 """ 453 454 def call_operator(self, op, args, kwargs, meta): 455 if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: 456 return super().call_operator(op, args, kwargs, meta) 457 458 # Check if the bias is already concrete 459 assert len(args) == 9 460 if args[2] is not None: 461 return super().call_operator(op, args, kwargs, meta) 462 463 # The bias length is the number of out channels. 464 out_shape = meta["val"].shape 465 bias_size = out_shape[1] 466 # Create a zero bias tensor (bias is not a constant tensor, 467 # so it needs to be the result of a graph operation). 468 zero_bias = super().call_operator( 469 exir_ops.edge.aten.full.default, 470 ([bias_size], 0.0), 471 {"dtype": torch.float32}, 472 meta, 473 ) 474 475 # Replace bias with zero_bias 476 args = list(args) 477 args[2] = zero_bias 478 args = tuple(args) 479 480 return super().call_operator(op, args, kwargs, meta) 481 482 483@register_cadence_pass(CadencePassAttribute(opt_level=0)) 484class ReplaceRepeatWithCatPass(ExportPass): 485 """ 486 Replace repeat op as successive cat ops along different dimensions. 487 repeat is not supported, so this is an opt_level=0 pass. 488 """ 489 490 def call_operator(self, op, args, kwargs, meta): 491 if op != exir_ops.edge.aten.repeat.default: 492 return super().call_operator(op, args, kwargs, meta) 493 494 # Extract the input tensor, and the repeats from the args 495 in_tensor = args[0] 496 repeats = args[1] 497 498 # Glean the shapes of input tensor 499 in_shape = list( 500 in_tensor.to_tensor().shape 501 if isinstance(in_tensor, ProxyValue) 502 else in_tensor.shape 503 ) 504 505 # If the size of repeats is more than the dimensionality of the tensor, 506 # the output of repeat will be a higher-dimensional tensor. We reshape 507 # the input so that it has the same dimensionality as the output tensor. 508 diff = len(repeats) - len(in_shape) 509 assert ( 510 diff >= 0 511 ), "Repeat arg malformed: expected a repeat along each dimension of input tensor" 512 513 if diff > 0: 514 # Extend the input shape with 1's along the higher dimensions 515 in_shape = ([1] * diff) + in_shape 516 # Insert a view op that reshapes the input tensor to have same 517 # dimensionality as the output tensor. 518 in_tensor = super().call_operator( 519 exir_ops.edge.aten.view_copy.default, 520 (in_tensor, in_shape), 521 kwargs, 522 meta, 523 ) 524 assert len(repeats) == len(in_shape) 525 526 # Repeat op is nothing but successive cat ops along each dimension. 527 for dim, repeat in reversed(list(enumerate(repeats))): 528 # We do not need to do anything if repeat factor is 1 529 if repeat == 1: 530 continue 531 cat_arg = [in_tensor] * repeat 532 in_tensor = super().call_operator( 533 exir_ops.edge.aten.cat.default, (cat_arg, dim), kwargs, meta 534 ) 535 536 return in_tensor 537 538 539@register_cadence_pass(CadencePassAttribute(opt_level=1)) 540class ReplacePadWithCatPass(ExportPass): 541 """ 542 Replace constant pad nd op that does padding on outer-most dimension 543 with Cat(left_padding_constant_tensor, X, right_padding_constant_tensor) 544 """ 545 546 def call_operator(self, op, args, kwargs, meta): 547 if op != exir_ops.edge.aten.constant_pad_nd.default: 548 return super().call_operator(op, args, kwargs, meta) 549 550 assert len(args) >= 2 551 input_node, orig_padding = args[:2] 552 553 # if there is no padding, this op will be treated in removal pass. 554 if not orig_padding: 555 return super().call_operator(op, args, kwargs, meta) 556 557 value = 0 if len(args) == 2 else args[2] 558 559 arg_shape = input_node.to_tensor().shape 560 561 padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) 562 assert len(padding) >= 2 563 (left_padding_size, right_padding_size) = padding[-2:] 564 # Replace only if constant_pad_nd is along the innermost padding dimension. 565 if ( 566 any(x != 0 for x in padding[0:-2]) 567 or left_padding_size < 0 568 or right_padding_size < 0 569 ): 570 return super().call_operator(op, args, kwargs, meta) 571 572 cat_tensors = [] 573 dim = len(arg_shape) - len(padding) // 2 574 # add left_padding 575 if left_padding_size > 0: 576 left_padding_shape = ( 577 arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :] 578 ) 579 left_padding_node = super().call_operator( 580 torch.ops.aten.full.default, 581 ( 582 left_padding_shape, 583 value, 584 ), 585 {"dtype": torch.float32}, 586 meta, 587 ) 588 cat_tensors.append(left_padding_node) 589 # input_node 590 cat_tensors.append(input_node) 591 # right_padding 592 if right_padding_size > 0: 593 right_padding_shape = ( 594 arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :] 595 ) 596 right_padding_node = super().call_operator( 597 torch.ops.aten.full.default, 598 ( 599 right_padding_shape, 600 value, 601 ), 602 {"dtype": torch.float32}, 603 meta, 604 ) 605 cat_tensors.append(right_padding_node) 606 607 assert len(cat_tensors) == 1 + (left_padding_size > 0) + ( 608 right_padding_size > 0 609 ) 610 611 new_args = (cat_tensors, dim) 612 return super().call_operator( 613 exir_ops.edge.aten.cat.default, 614 new_args, 615 kwargs, 616 meta, 617 ) 618 619 620@register_cadence_pass(CadencePassAttribute(opt_level=1)) 621class ReplaceConstantPadNdWithSlicePass(ExportPass): 622 """ 623 Replace constant pad nd op that does padding on outer-most dimension 624 with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor) 625 """ 626 627 def call_operator(self, op, args, kwargs, meta): 628 if op != exir_ops.edge.aten.constant_pad_nd.default: 629 return super().call_operator(op, args, kwargs, meta) 630 631 assert len(args) >= 2 632 input_node, orig_padding = args[:2] 633 634 # if there is no padding, this op will be treated in removal pass. 635 if not orig_padding: 636 return super().call_operator(op, args, kwargs, meta) 637 638 padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) 639 assert len(padding) >= 2 640 (start, diff) = map(neg, padding[-2:]) 641 # Replace only if constant_pad_nd is along the innermost padding dimension. 642 if any(x != 0 for x in padding[0:-2]) or start < 0 or diff < 0: 643 return super().call_operator(op, args, kwargs, meta) 644 645 arg_shape = input_node.to_tensor().shape 646 dim = len(arg_shape) - len(padding) // 2 647 stop = arg_shape[dim] - diff 648 assert start <= stop 649 new_args = (input_node, dim, start, stop) 650 return super().call_operator( 651 exir_ops.edge.aten.slice.Tensor, 652 new_args, 653 kwargs, 654 meta, 655 ) 656 657 658# Make that pass runnable standalone at opt level 0. 659@register_cadence_pass(CadencePassAttribute(opt_level=0)) 660class ReplaceAtenConvolutionWithJarvisConvolutionPass(ExportPass): 661 """ 662 Replace aten convolution op with jarvis-specific convolution op, since the 663 aten version is not supported by jarvis. 664 Also remove convolution stride if the output size along the strided dimension 665 is 1. We can enable more transformations (e.g., conv -> linear replacement) 666 for unit-stride convolutions. 667 """ 668 669 def call_operator(self, op, args, kwargs, meta): 670 if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: 671 return super().call_operator(op, args, kwargs, meta) 672 # There must be 9 total args. 673 assert len(args) == 9 674 675 # Unpack the args 676 ( 677 in_tensor, 678 weight, 679 bias, 680 stride, 681 padding, 682 dilation, 683 transposed, 684 output_padding, 685 groups, 686 ) = args 687 # Currently we only handle conversion to conv1d and conv2d, therefore 688 # verify that the stride, padding, dilation, and output_padding have 689 # len <=2. 690 assert ( 691 len(stride) == len(padding) == len(dilation) == len(output_padding) == 1 692 ) or ( 693 len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 694 ), "Can only map convolution to conv1d and conv2d at present" 695 696 target = ( 697 exir_ops.edge.cadence.transposed_convolution.default 698 if transposed 699 else exir_ops.edge.cadence.convolution.default 700 ) 701 702 if transposed: 703 # Flip the height and width dimensions of weight, since we apply a 704 # gather stencil. Also, the first two dimensions of weight must be 705 # transposed/interchanged. 706 # If weight is a ProxyValue, new_weight needs to be the output of a 707 # graph operation (in this case a transpose_copy op) to be an explicit 708 # ProxyValue as well. If not, the view op can be done directly on the 709 # tensor. 710 transposed_weight = ( 711 super().call_operator( 712 exir_ops.edge.aten.transpose_copy.int, 713 ( 714 weight, 715 0, 716 1, 717 ), 718 kwargs, 719 meta, 720 ) 721 if isinstance(weight, ProxyValue) 722 else weight.transpose(0, 1) 723 ) 724 725 flipped_weight = ( 726 super().call_operator( 727 torch.ops.aten.flip.default, 728 ( 729 transposed_weight, 730 [-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2], 731 ), 732 kwargs, 733 meta, 734 ) 735 if isinstance(transposed_weight, ProxyValue) 736 else ( 737 transposed_weight.flip(-1) 738 if transposed_weight.dim() == 3 739 else transposed_weight.flip(-1, -2) 740 ) 741 ) 742 743 # From the previous checks, if flipped_weight is a FakeTensor, it has to be 744 # a constant (if not, it would be a ProxyValue). Mark it as such. 745 if isinstance(flipped_weight, FakeTensor): 746 flipped_weight.constant = flipped_weight 747 new_args = ( 748 in_tensor, 749 flipped_weight, 750 bias, 751 stride, 752 padding, 753 dilation, 754 output_padding, 755 groups, 756 False, 757 ) 758 else: 759 # Verify that output_padding is 0. 760 assert all( 761 x == 0 for x in output_padding 762 ), "Cannot handle padded output in convolution" 763 764 # If the innermost dim of output tensor is 1, then the stride 765 # should be 1. Note that the first dimension of output tensor is 766 # channel 767 new_stride = stride.copy() 768 out_shape = meta["val"].shape 769 assert out_shape is not None 770 for i, e in enumerate(out_shape[2:]): 771 new_stride[i] = 1 if e == 1 else stride[i] 772 773 new_args = ( 774 in_tensor, 775 weight, 776 bias, 777 new_stride, 778 padding, 779 dilation, 780 groups, 781 False, 782 ) 783 784 return super().call_operator(target, new_args, kwargs, meta) 785 786 787# TODO(matthiascremon): this is a fuse op, not a replace op 788class ReplaceConvWithChannelLastConv: 789 """ 790 Convolution op in pytorch expects NCHW layout for input, weight, and output 791 tensors. However, if the input and output to the convolution op are originally 792 in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse 793 the two permute ops with the convolution op, and call the NHWC layout 794 convolution op in Jarvis. 795 """ 796 797 def __init__(self): 798 self.counter = 0 799 self.graph_module = None 800 801 def __call__(self, graph_module: torch.fx.GraphModule): 802 self.replace_conv_with_nhwc_conv(graph_module) 803 804 def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool: 805 """ 806 Return true if the convolution input and output are connected to permute 807 ops, and the input/output to/from the permute ops is NHWC layout tensor. 808 """ 809 # There must only be a single user of the output node (which must be a 810 # permute/tranpsose op). The input of the convolution must be connected 811 # to a permute op, and that permute op should have a single user. 812 conv_inp = node.args[0] 813 assert isinstance(conv_inp, torch.fx.Node) 814 if len(node.users) != 1 or len(conv_inp.users) != 1: 815 return False 816 817 # Get the input and output (permute/transpose) nodes of the convolution 818 conv_user = list(node.users.keys())[0] 819 assert isinstance(conv_user, torch.fx.Node) 820 pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user} 821 822 # Any node in pt_nodes must not be a placeholder. 823 if contains_placeholder_or_param(pt_nodes): 824 return False 825 826 # Determine if the convolution is 1d or 2d. The output tensor must be 827 # 3- or 4-dimensional 828 out_shape = get_shape(self.graph_module, node) 829 assert out_shape is not None 830 out_dims = len(out_shape) 831 assert out_dims in {3, 4}, "Jarvis only supports conv1d and conv2d" 832 conv1d = out_dims == 3 833 834 # Get the possible targets for the nodes in pt_nodes. Since conv1d has 835 # 3-dimensional input and output tensors, the nodes in pt_nodes could 836 # be either permute or transpose op. For conv2d, the nodes in pt_nodes 837 # must be permute ops. 838 p_target = exir_ops.edge.aten.permute_copy.default 839 t_target = exir_ops.edge.aten.transpose_copy.int 840 pt_targets = [p_target] + ([t_target] if conv1d else []) 841 842 # If any node in pt_nodes is not permute op (or tranpose op for conv1d), 843 # bail. 844 if any(x.target not in pt_targets for x in pt_nodes): 845 return False 846 847 # Now we need to determine the dimension permutations: 848 # If the input had NHWC layout, which was then permuted/transposed 849 # by a permute/transpose op to NCHW layout, the permutation must be 850 # [0, 3, 2, 1] (or [0, 2, 1] for conv1d). 851 # If the output had NCHW layout, and was then permuted to NHWC layout, 852 # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d). 853 nhwc_permute_order = { 854 node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2], 855 list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1], 856 } 857 for x in pt_nodes: 858 order = ( 859 x.args[1] 860 if x.target == p_target 861 else get_transposed_dims(x, list(range(out_dims))) 862 ) 863 if order != nhwc_permute_order[x]: 864 return False 865 866 return True 867 868 def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule): 869 self.graph_module = graph_module 870 graph = graph_module.graph 871 for node in graph.nodes: 872 # We are only interested in convolution nodes that have NHWC layout 873 if node.target not in { 874 exir_ops.edge.cadence.quantized_conv.default, 875 exir_ops.edge.cadence.convolution.default, 876 exir_ops.edge.cadence.quantized_transposed_conv.default, 877 exir_ops.edge.cadence.transposed_convolution.default, 878 } or not self.conv_layout_is_nhwc(node): 879 continue 880 881 # Get the args of convolution op 882 args = list(node.args) 883 # The input is connected to a permute/transpose op that converts the 884 # NHWC layout to NCHW layout. The input of the permute op will become 885 # this convolution op's input. 886 in_tp = args[0] 887 args[0] = in_tp.args[0] 888 # The weight is in NHWC layout. Permute it to NHWC layout. 889 weight_tensor = get_tensor_from_attr(graph_module, args[1]) 890 assert isinstance(weight_tensor, torch.Tensor) 891 # We cannot directly permute a per-channel quantized tensor. We will 892 # dequantize it, permute the fp32 tensor, and then requantize the 893 # permuted tensor. 894 if ( 895 is_quantized_tensor(weight_tensor) 896 and weight_tensor.qscheme() == torch.per_channel_affine 897 ): 898 # We have already asserted during quantizing conv op that the 899 # quantization axis is 0. 900 dequant_weight = weight_tensor.dequantize() 901 dequant_weight = ( 902 dequant_weight.permute([0, 2, 1]) 903 if dequant_weight.dim() == 3 904 else dequant_weight.permute([0, 2, 3, 1]) 905 ) 906 weight_tensor = torch.quantize_per_channel( 907 dequant_weight.contiguous(), 908 weight_tensor.q_per_channel_scales(), 909 weight_tensor.q_per_channel_zero_points(), 910 0, 911 weight_tensor.dtype, 912 ) 913 else: 914 weight_tensor = ( 915 weight_tensor.permute([0, 2, 1]) 916 if weight_tensor.dim() == 3 917 else weight_tensor.permute([0, 2, 3, 1]) 918 ) 919 # Make the weight tensor contiguous, since we have permuted it. 920 weight_tensor = weight_tensor.contiguous() 921 # Add the permuted weight into the graph, and update the weight in 922 # args. 923 with graph.inserting_before(node): 924 weight_name = f"_weight_nhwc_{self.counter}" 925 graph_module.register_buffer(weight_name, weight_tensor) 926 weight = graph.get_attr(weight_name) 927 args[1] = weight 928 929 # The 'channel_last' arg is True. It is the last arg. 930 args[-1] = True 931 # Now update the convolution node args to mark it as NHWC convolution 932 node.args = tuple(args) 933 934 # Replace all the uses of the permute op connected to the output op 935 # with this convolution. 936 out_tp = list(node.users.keys())[0] 937 out_tp.replace_all_uses_with(node) 938 node.meta = out_tp.meta 939 940 # Erase the permute ops connected to the input and output of the 941 # convolution op. 942 graph.erase_node(in_tp) 943 graph.erase_node(out_tp) 944 self.counter += 1 945 946 graph_module.recompile() 947 948 949# This pass needs to be reworked to be compatible with PT2. It is an optimization 950# pass anyway, so move it to opt level 2. 951# TODO(matthiascremon): update and improve this pass. 952@register_cadence_pass(CadencePassAttribute(opt_level=2)) 953class ReplaceConvWithChannelLastConvPass(ExportPass): 954 """ 955 Replace the ATen convolution op with custom conv op with NCHW or NHWC layout 956 input tensors, depending on the presence of permute/transpose ops connected 957 to the input tensor. 958 """ 959 960 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 961 result = ReplaceAtenConvolutionWithJarvisConvolutionPass()(graph_module) 962 assert result is not None 963 ReplaceConvWithChannelLastConv()(result.graph_module) 964 return result 965 966 967@register_cadence_pass(CadencePassAttribute(opt_level=1)) 968class ReplaceTrivialConvWithLinear(ExportPass): 969 """ 970 In nn.Conv1d, the operand shapes are: 971 input - [batch, in_channels, in_length] 972 weight - [out_channels, in_channels, weight_length] 973 output - [batch, out_channels, out_length] 974 When in_length == weight_length, out_length = 1. In this scenario, we can 975 view the input as a tensor shaped [batch, K], and weight as a tensor 976 shaped [out_channels, K], and replace nn.Conv1d with nn.Linear. This 977 optimization can be extended to nn.Conv2d as well, where in_length is a 2d 978 image, and weight_length can be replaced with a 2d filter the same shape as 979 the image. 980 """ 981 982 trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { 983 exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, 984 exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default, 985 } 986 987 def call_operator(self, op, args, kwargs, meta): 988 if op not in self.trivial_conv_op_to_linear_op: 989 return super().call_operator(op, args, kwargs, meta) 990 991 # Parse the necessary args of the convolution node. Both convolution 992 # and quantized_conv have the same first 8 args. The quantized op has 993 # extra args holding at least the zero point and scale of input, weight, bias, 994 # and output tensor. 995 quantized_op = op == exir_ops.edge.cadence.quantized_conv.default 996 assert (len(args) == 8 and not quantized_op) or ( 997 len(args) >= 12 and quantized_op 998 ), "Inconsistent args for convolution" 999 (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] 1000 1001 # Glean the shapes of input, weight, and output 1002 in_shape = ( 1003 in_tensor.to_tensor().shape 1004 if isinstance(in_tensor, ProxyValue) 1005 else in_tensor.shape 1006 ) 1007 1008 weight_shape = ( 1009 weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape 1010 ) 1011 out_shape = meta["val"].shape 1012 assert None not in {in_shape, weight_shape, out_shape} 1013 1014 # Check the condition under which conv can be replaced by linear: (1) this 1015 # should not be a depthwise convolution; (2) the padding, stride, and dilation 1016 # should be standard; (3) The [channels, height, width] of input must match the 1017 # [channel, kernel_height, kernel_width] of the weight. These conditions would 1018 # ensure that output height and width are 1, and the convolution can be replaced 1019 # by linear. 1020 if ( 1021 groups != 1 1022 or any(x != 0 for x in padding) 1023 or any(x != 1 for x in stride) 1024 or any(x != 1 for x in dilation) 1025 or (list(in_shape[1:]) != list(weight_shape[1:])) 1026 ): 1027 return super().call_operator(op, args, kwargs, meta) 1028 1029 # Reshape the weight to [out_channels, in_channels * X] 1030 K = math.prod(weight_shape[1:]) 1031 1032 # If weight is a ProxyValue, linear_weight needs to be the output of a 1033 # graph operation (in this case a view_copy op) to be an explicit ProxyValue 1034 # as well. If not, the view op can be done directly on the tensor. 1035 linear_weight = ( 1036 super().call_operator( 1037 exir_ops.edge.aten.view_copy.default, 1038 ( 1039 weight, 1040 [weight_shape[0], K], 1041 ), 1042 kwargs, 1043 meta, 1044 ) 1045 if isinstance(weight, ProxyValue) 1046 else weight.contiguous().view(weight_shape[0], K) 1047 ) 1048 # From the previous check, if linear_weight is a FakeTensor, it has to be 1049 # a constant (if not, it would be a ProxyValue). Mark it as such. 1050 if isinstance(linear_weight, FakeTensor): 1051 linear_weight.constant = linear_weight 1052 1053 # Reshape the input from 3d to 2d tensor 1054 in_view = super().call_operator( 1055 exir_ops.edge.aten.view_copy.default, 1056 ( 1057 in_tensor, 1058 [in_shape[0], K], 1059 ), 1060 kwargs, 1061 meta, 1062 ) 1063 # Create the linear node, which multiplies the 2d input and weight 1064 # tensors, and adds the 1d bias to produce a 2d output. 1065 if quantized_op: 1066 ( 1067 in_zero_point, 1068 weight_zero_point, 1069 bias_scale, 1070 out_scale, 1071 out_zero_point, 1072 ) = args[7:12] 1073 # If the multiplier and shift tensors are provided, use them. 1074 if ( 1075 len(args) >= 14 1076 and isinstance(args[12], ProxyValue) 1077 and isinstance(args[13], ProxyValue) 1078 ): 1079 out_multiplier = args[12] 1080 out_shift = args[13] 1081 # If not, compute them. 1082 else: 1083 requantize_scale = bias_scale / out_scale 1084 (out_multiplier, out_shift) = quantize_tensor_multiplier( 1085 requantize_scale 1086 ) 1087 linear_args = ( 1088 in_view, 1089 linear_weight, 1090 bias, 1091 in_zero_point, 1092 weight_zero_point, 1093 out_multiplier, 1094 out_shift, 1095 out_zero_point, 1096 None, 1097 ) 1098 else: 1099 linear_args = (in_view, linear_weight, bias) 1100 1101 linear_res = super().call_operator( 1102 self.trivial_conv_op_to_linear_op[op], 1103 linear_args, 1104 kwargs, 1105 meta, 1106 ) 1107 # Reshape the output of linear from 2d to 3d tensor 1108 out_res = super().call_operator( 1109 exir_ops.edge.aten.view_copy.default, 1110 (linear_res, list(out_shape)), 1111 kwargs, 1112 meta, 1113 ) 1114 return out_res 1115 1116 1117def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: 1118 """Canonicalize transpose ops so it gets easier to pattern-match and fuse transpose ops.""" 1119 if dim < 0: 1120 # Keep transpose dimensions positive. 1121 dim += len(shape) 1122 return dim 1123 1124 1125class ExportPassWithTransposeHelper(ExportPass): 1126 def transpose_dims( 1127 self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int 1128 ) -> ProxyValue: 1129 """Helper function to transpose dims of a `proxy` with given `meta`.""" 1130 shape = proxy.data.shape 1131 dim0, dim1 = ( 1132 canonicalize_transposed_dim(dim0, shape), 1133 canonicalize_transposed_dim(dim1, shape), 1134 ) 1135 dim0, dim1 = min(dim0, dim1), max(dim0, dim1) 1136 return super().call_operator( 1137 exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta 1138 ) 1139 1140 1141@register_cadence_pass(CadencePassAttribute(opt_level=3)) 1142class ForceChannelLastForConvPass(ExportPassWithTransposeHelper): 1143 def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: 1144 shape = proxy.to_tensor().shape 1145 if len(shape) == 3: 1146 return self.transpose_dims(proxy, meta, 1, -1) 1147 indices = list(range(len(shape))) 1148 permute_indices = [indices[0]] + indices[2:] + [indices[1]] 1149 return super().call_operator( 1150 exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta 1151 ) 1152 1153 def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: 1154 shape = proxy.to_tensor().shape 1155 if len(shape) == 3: 1156 return self.transpose_dims(proxy, meta, 1, -1) 1157 indices = list(range(len(shape))) 1158 permute_indices = [indices[0], indices[-1]] + indices[1:-1] 1159 return super().call_operator( 1160 exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta 1161 ) 1162 1163 def call_operator( 1164 self, 1165 op, 1166 args: tuple[Argument, ...], 1167 kwargs: dict[str, Argument], 1168 meta: NodeMetadata, 1169 ) -> ProxyValue: 1170 if op not in { 1171 exir_ops.edge.cadence.convolution.default, 1172 exir_ops.edge.cadence.quantized_conv.default, 1173 }: 1174 return super().call_operator(op, args, kwargs, meta) 1175 1176 quantized_op = op == exir_ops.edge.cadence.quantized_conv.default 1177 channel_last_arg_index = 14 if quantized_op else 7 1178 channel_last = ( 1179 args[channel_last_arg_index] 1180 if len(args) > channel_last_arg_index 1181 # Default is false (NCHW). 1182 else False 1183 ) 1184 if channel_last: 1185 return super().call_operator(op, args, kwargs, meta) 1186 1187 input_proxy = cast(ProxyValue, args[0]) 1188 weight_proxy = cast(ProxyValue, args[1]) 1189 input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) 1190 weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) 1191 1192 new_args = ( 1193 # Transposed input/weights. 1194 (input_proxy, weight_proxy) 1195 # All other args (bias, quant params, etc) 1196 + tuple(args[2:channel_last_arg_index]) 1197 # Channel last. 1198 + (True,) 1199 ) 1200 output_proxy = super().call_operator(op, new_args, kwargs, meta) 1201 nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) 1202 return nchw_proxy 1203 1204 1205@register_cadence_pass(CadencePassAttribute(opt_level=3)) 1206class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): 1207 def call_operator( 1208 self, 1209 op, 1210 args: tuple[Argument, ...], 1211 kwargs: dict[str, Argument], 1212 meta: NodeMetadata, 1213 ) -> ProxyValue: 1214 if op not in { 1215 exir_ops.edge.aten.cat.default, 1216 exir_ops.edge.aten.slice_copy.Tensor, 1217 }: 1218 return super().call_operator(op, args, kwargs, meta) 1219 dim = cast(int, args[1]) if len(args) > 1 else 0 1220 output_shape = meta["val"].shape 1221 if dim < 0: 1222 # Keep dim positive. 1223 dim += len(output_shape) 1224 1225 if dim == 0 or math.prod(output_shape[:dim]) == 1: 1226 # Not needed if dim is already outermost or all dims before it are 1. 1227 return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) 1228 1229 if op == exir_ops.edge.aten.slice_copy.Tensor: 1230 # Transpose -> slice. 1231 slice_args = ( 1232 self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), 1233 0, 1234 ) + args[2:] 1235 new_op = super().call_operator(op, slice_args, kwargs, meta) 1236 else: 1237 # (Transpose input0, Transpose input1, ...) -> cat. 1238 cat_in_tensors = [ 1239 self.transpose_dims(t, meta, dim, 0) 1240 for t in cast(list[ProxyValue], args[0]) 1241 ] 1242 new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) 1243 # slice/cat -> transpose. 1244 return self.transpose_dims(new_op, meta, 0, dim) 1245 1246 1247@register_cadence_pass(CadencePassAttribute(opt_level=1)) 1248class ReplaceConvWithIm2RowAndLinear(ExportPass): 1249 """ 1250 Replace convolution where groups=1 with im2row followed by a linear op. 1251 """ 1252 1253 # A map from the convolution op to the linear op that it should 1254 # decompose to. 1255 conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { 1256 exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default, 1257 exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default, 1258 } 1259 1260 def call_operator(self, op, args, kwargs, meta): 1261 if op not in self.conv_op_to_linear_op: 1262 return super().call_operator(op, args, kwargs, meta) 1263 1264 # Get the relevant args from convolution node. 1265 quantized_op = op == exir_ops.edge.cadence.quantized_conv.default 1266 assert (len(args) == 8 and not quantized_op) or ( 1267 len(args) >= 12 and quantized_op 1268 ), "Inconsistent args for convolution" 1269 (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] 1270 1271 # We do not replace depthwise convolution with gemm yet. 1272 if groups != 1: 1273 return super().call_operator(op, args, kwargs, meta) 1274 1275 weight_shape = ( 1276 weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape 1277 ) 1278 # If this is a pointwise convolution, im2col will start dominating the 1279 # runtime. So we call convolution op for this case. 1280 if ( 1281 all(x == 1 for x in weight_shape[2:]) 1282 and all(x == 1 for x in stride) 1283 and all(x == 0 for x in padding) 1284 and all(x == 1 for x in dilation) 1285 ): 1286 return super().call_operator(op, args, kwargs, meta) 1287 1288 # Get the shapes 1289 out_shape = meta["val"].shape 1290 assert None not in {weight_shape, out_shape} 1291 1292 # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the 1293 # channel_last layout is specified by the channel_last arg of conv 1294 # op, which is either the last argument (15th) or implicitely False 1295 # if the op is quantized, or the last argument if not. 1296 channel_last = ( 1297 (args[14] if len(args) == 15 else False) if quantized_op else args[-1] 1298 ) 1299 # The weight tensor is [out_channels, in_channels, X] for NCHW layout, 1300 # and [out_channels, X, in_channels] for NHWC layout. Here, X is the 1301 # kernel_width for conv1d, and X = kernel_height * kernel_width for 1302 # conv2d. We extract X as the kernel_size for im2row. 1303 kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) 1304 # If the convolution op was quantized, we need the input tensor's 1305 # zero_point for im2row. Otherwise in_zero_point defaults to a zero 1306 # tensor. 1307 in_zero_point = ( 1308 ( 1309 super().call_operator( 1310 exir_ops.edge.aten.full.default, 1311 ( 1312 [1], 1313 args[7], 1314 ), 1315 {"dtype": torch.int32}, 1316 meta, 1317 ) 1318 if isinstance(in_tensor.to_tensor(), FakeTensor) 1319 else get_zero_point(in_tensor.to_tensor()) 1320 ) 1321 if quantized_op 1322 else torch.tensor(0, dtype=torch.int32) 1323 ) 1324 # im2row expects every kernel parameter to be 2d. So we extend the 1325 # parameters for conv1d by prepending their default values. 1326 stride = ([1] + stride) if len(stride) == 1 else stride 1327 padding = ([0] + padding) if len(padding) == 1 else padding 1328 dilation = ([1] + dilation) if len(dilation) == 1 else dilation 1329 kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size 1330 # Assert that kernel size does not have a 0 1331 assert 0 not in kernel_size 1332 1333 # Create an im2row node with the input. This will create a 2d matrix of 1334 # shape [out_height*out_weight, X*in_channels]. X is as defined in the 1335 # comment above. 1336 im2row_args = ( 1337 in_tensor, 1338 kernel_size, 1339 dilation, 1340 padding, 1341 stride, 1342 in_zero_point, 1343 channel_last, 1344 ) 1345 im2row = super().call_operator( 1346 exir_ops.edge.cadence.im2row.default, 1347 im2row_args, 1348 kwargs, 1349 meta, 1350 ) 1351 1352 # Get the product of the >2 dims of the weight 1353 K = math.prod(weight_shape[1:]) 1354 1355 # If weight is a ProxyValue, linear_weight needs to be the output of a 1356 # graph operation (in this case a view_copy op) to be an explicit ProxyValue 1357 # as well. If not, the view op can be done directly on the tensor. 1358 linear_weight = ( 1359 super().call_operator( 1360 exir_ops.edge.aten.view_copy.default, 1361 ( 1362 weight, 1363 [weight_shape[0], K], 1364 ), 1365 kwargs, 1366 meta, 1367 ) 1368 if isinstance(weight, ProxyValue) 1369 else weight.contiguous().view(weight_shape[0], K) 1370 ) 1371 # From the previous check, if linear_weight is a FakeTensor, it has to be 1372 # a constant (if not, it would be a ProxyValue). Mark it as such. 1373 if isinstance(linear_weight, FakeTensor): 1374 linear_weight.constant = linear_weight 1375 1376 # Create the linear node, which multiplies the 3d input with 2d weight 1377 # tensors with bias addition. The outermost dimension of the input is 1378 # the batch size for linear op. 1379 if quantized_op: 1380 ( 1381 in_zero_point, 1382 weight_zero_point, 1383 bias_scale, 1384 out_scale, 1385 out_zero_point, 1386 ) = args[7:12] 1387 # If the multiplier and shift tensors are provided, use them. 1388 if ( 1389 len(args) >= 14 1390 and isinstance(args[12], ProxyValue) 1391 and isinstance(args[13], ProxyValue) 1392 ): 1393 out_multiplier = args[12] 1394 out_shift = args[13] 1395 # If not, compute them. 1396 else: 1397 requantize_scale = bias_scale / out_scale 1398 (out_multiplier, out_shift) = quantize_tensor_multiplier( 1399 requantize_scale 1400 ) 1401 linear_args = ( 1402 im2row, 1403 linear_weight, 1404 bias, 1405 in_zero_point, 1406 weight_zero_point, 1407 out_multiplier, 1408 out_shift, 1409 out_zero_point, 1410 None, 1411 ) 1412 else: 1413 linear_args = (im2row, linear_weight, bias) 1414 linear_res = super().call_operator( 1415 self.conv_op_to_linear_op[op], 1416 linear_args, 1417 kwargs, 1418 meta, 1419 ) 1420 # The output of linear is a 3D tensor. However, the output is in NHWC 1421 # layout by default, because an input vector of size X is multiplied 1422 # with the weight matrix, i.e., column values are contiguous. If the 1423 # channel_last is False, we want to transpose this output. 1424 if not channel_last: 1425 linear_res = super().call_operator( 1426 exir_ops.edge.aten.transpose_copy.int, 1427 (linear_res, 1, 2), 1428 kwargs, 1429 meta, 1430 ) 1431 # And finally, we want to view the 3D output of linear op as 4D tensor 1432 return super().call_operator( 1433 exir_ops.edge.aten.view_copy.default, 1434 (linear_res, list(out_shape)), 1435 kwargs, 1436 meta, 1437 ) 1438 1439 1440@register_cadence_pass(CadencePassAttribute(opt_level=1)) 1441class ReplaceTransposedConvWithLinearPass(ExportPass): 1442 """ 1443 Replace transposed convolution where groups=1 with transposed_im2row 1444 followed by a linear op. 1445 """ 1446 1447 # A map from the transposed_convolution op to the linear op that it should 1448 # decompose to. 1449 transposed_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = { 1450 exir_ops.edge.cadence.transposed_convolution.default: exir_ops.edge.aten.linear.default, 1451 exir_ops.edge.cadence.quantized_transposed_conv.default: exir_ops.edge.cadence.quantized_linear.default, 1452 } 1453 1454 def call_operator(self, op, args, kwargs, meta): 1455 if op not in self.transposed_conv_op_to_linear_op: 1456 return super().call_operator(op, args, kwargs, meta) 1457 1458 # Get the relevant args from transposed_convolution node. 1459 quantized_op = op == exir_ops.edge.cadence.quantized_transposed_conv.default 1460 assert len(args) == ( 1461 16 if quantized_op else 9 1462 ), "Inconsistent args for transposed_convolution" 1463 ( 1464 in_tensor, 1465 weight, 1466 bias, 1467 stride, 1468 padding, 1469 dilation, 1470 output_padding, 1471 groups, 1472 ) = args[0:8] 1473 1474 # We do not replace depthwise transposed_convolution with gemm yet. 1475 if groups != 1: 1476 return super().call_operator(op, args, kwargs, meta) 1477 1478 # Get the shapes 1479 out_shape = meta["val"].shape 1480 weight_shape = ( 1481 weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape 1482 ) 1483 assert None not in {weight_shape, out_shape} 1484 1485 # Determine if the transposed_convolution is NCHW or NHWC. The NHWC, 1486 # i.e., the channel_last layout is specified by the channel_last arg 1487 # of transposed_conv op, which is the last argument. 1488 channel_last = args[-1] 1489 # The weight tensor is [out_channels, in_channels, X] for NCHW layout, 1490 # and [out_channels, X, in_channels] for NHWC layout. Here, X is the 1491 # kernel_width for conv1d, and X = kernel_height * kernel_width for 1492 # conv2d. We extract X as the kernel_size for im2row. 1493 kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) 1494 # If the transposed_convolution op was quantized, we need the input tensor's 1495 # zero_point for im2row. Otherwise in_zero_point defaults to a zero 1496 # tensor. 1497 in_zero_point = ( 1498 get_zero_point(in_tensor.to_tensor()) 1499 if quantized_op 1500 else torch.tensor(0, dtype=torch.int32) 1501 ) 1502 # transposed_im2row expects every kernel parameter to be 2d. So we extend the 1503 # parameters for conv1d by prepending their default values. 1504 stride = ([1] + stride) if len(stride) == 1 else stride 1505 padding = ([0] + padding) if len(padding) == 1 else padding 1506 dilation = ([1] + dilation) if len(dilation) == 1 else dilation 1507 output_padding = ( 1508 ([0] + output_padding) if len(output_padding) == 1 else output_padding 1509 ) 1510 kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size 1511 # Assert that kernel size does not have a 0 1512 assert 0 not in kernel_size 1513 1514 # Create a transposed_im2row node with the input. This will create a 2d 1515 # matrix of shape [out_height*out_weight, X*in_channels]. X is as 1516 # defined in the comment above. 1517 transposed_im2row_args = ( 1518 in_tensor, 1519 kernel_size, 1520 dilation, 1521 padding, 1522 stride, 1523 output_padding, 1524 in_zero_point, 1525 channel_last, 1526 ) 1527 transposed_im2row = super().call_operator( 1528 exir_ops.edge.cadence.transposed_im2row.default, 1529 transposed_im2row_args, 1530 kwargs, 1531 meta, 1532 ) 1533 # Reshape the weight to [out_channels, in_channels * X] 1534 K = math.prod(weight_shape[1:]) 1535 1536 # If weight is a ProxyValue, linear_weight needs to be the output of a 1537 # graph operation (in this case a view_copy op) to be an explicit ProxyValue 1538 # as well. If not, the view op can be done directly on the tensor. 1539 linear_weight = ( 1540 super().call_operator( 1541 exir_ops.edge.aten.view_copy.default, 1542 ( 1543 weight, 1544 [weight_shape[0], K], 1545 ), 1546 kwargs, 1547 meta, 1548 ) 1549 if isinstance(weight, ProxyValue) 1550 else weight.contiguous().view(weight_shape[0], K) 1551 ) 1552 # From the previous check, if linear_weight is a FakeTensor, it has to be 1553 # a constant (if not, it would be a ProxyValue). Mark it as such. 1554 if isinstance(linear_weight, FakeTensor): 1555 linear_weight.constant = linear_weight 1556 1557 # Create the linear node, which multiplies the 3d input with 2d weight 1558 # tensors with bias addition. The outermost dimension of the input is 1559 # the batch size for linear op. 1560 if quantized_op: 1561 ( 1562 in_zero_point, 1563 weight_zero_point, 1564 bias_scale, 1565 out_scale, 1566 out_zero_point, 1567 ) = args[8:13] 1568 requantize_scale = bias_scale / out_scale 1569 (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale) 1570 linear_args = ( 1571 transposed_im2row, 1572 linear_weight, 1573 bias, 1574 in_zero_point, 1575 weight_zero_point, 1576 out_multiplier, 1577 out_shift, 1578 out_zero_point, 1579 None, 1580 ) 1581 else: 1582 linear_args = (transposed_im2row, linear_weight, bias) 1583 linear_res = super().call_operator( 1584 self.transposed_conv_op_to_linear_op[op], 1585 linear_args, 1586 kwargs, 1587 meta, 1588 ) 1589 # The output of linear is a 3D tensor. However, the output is in NHWC 1590 # layout by default, because an input vector of size X is multiplied 1591 # with the weight matrix, i.e., column values are contiguous. If the 1592 # channel_last is False, we want to transpose this output. 1593 if not channel_last: 1594 linear_res = super().call_operator( 1595 exir_ops.edge.aten.transpose_copy.int, 1596 (linear_res, 1, 2), 1597 kwargs, 1598 meta, 1599 ) 1600 # And finally, we want to view the 3D output of linear op as 4D tensor 1601 return super().call_operator( 1602 exir_ops.edge.aten.view_copy.default, 1603 (linear_res, list(out_shape)), 1604 kwargs, 1605 meta, 1606 ) 1607 1608 1609@register_cadence_pass(CadencePassAttribute(opt_level=1)) 1610class ReplaceNopTransposeOrPermuteWithViewPass(ExportPass): 1611 """ 1612 If the transpose/permute op does not change the byte order (e.g., 1613 transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced 1614 by view op. 1615 """ 1616 1617 def call_operator(self, op, args, kwargs, meta): 1618 # Only proceed for transpose or permute op. 1619 if op not in { 1620 exir_ops.edge.aten.transpose_copy.int, 1621 exir_ops.edge.aten.permute_copy.default, 1622 }: 1623 return super().call_operator(op, args, kwargs, meta) 1624 1625 # Get the input tensor and shape 1626 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 1627 in_shape = in_tensor.shape 1628 # Get the output tensor shape 1629 out_shape = meta["val"].shape 1630 1631 if op == exir_ops.edge.aten.transpose_copy.int: 1632 # Get the two dims to be transposed 1633 dim0 = args[1] if args[1] >= 0 else in_tensor.dim() + args[1] 1634 dim1 = args[2] if args[2] >= 0 else in_tensor.dim() + args[2] 1635 # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; 1636 # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. 1637 both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 1638 either_one_and_consecutive = abs(dim0 - dim1) == 1 and ( 1639 in_shape[dim0] == 1 or in_shape[dim1] == 1 1640 ) 1641 if both_one or either_one_and_consecutive: 1642 new_args = (args[0], list(out_shape)) 1643 return super().call_operator( 1644 exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta 1645 ) 1646 1647 elif op == exir_ops.edge.aten.permute_copy.default: 1648 old_dims = list(range(in_tensor.dim())) 1649 new_dims = args[1] 1650 # If the permute does not change anything, return the input as output. 1651 if old_dims == new_dims: 1652 return args[0] 1653 # Get the old dim order, and the permuted dim order for all dims that 1654 # are not 1. 1655 old_order = [ 1656 dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1 1657 ] 1658 new_order = [ 1659 dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1 1660 ] 1661 # If the byte ordering for non-unit dims is unchanged, this is a nop. 1662 if old_order == new_order: 1663 new_args = (args[0], list(out_shape)) 1664 return super().call_operator( 1665 exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta 1666 ) 1667 1668 return super().call_operator(op, args, kwargs, meta) 1669 1670 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 1671 result = super().call(graph_module) 1672 result = FuseCascadedViewOps()(result.graph_module) 1673 assert result is not None 1674 return result 1675 1676 1677@register_cadence_pass(CadencePassAttribute(opt_level=1)) 1678class ReplaceLinearWithFullyConnectedOpPass(ExportPass): 1679 """ 1680 If the input of linear/quantized_linear op is a vector, replace it with 1681 fully_connected op. 1682 """ 1683 1684 linear_to_fc_op: Dict[EdgeOpOverload, EdgeOpOverload] = { 1685 exir_ops.edge.aten.linear.default: exir_ops.edge.cadence.fully_connected.default, 1686 exir_ops.edge.cadence.quantized_linear.default: exir_ops.edge.cadence.quantized_fully_connected.default, 1687 } 1688 1689 def call_operator(self, op, args, kwargs, meta): 1690 # Only proceed for linear or quantized_linear ops. 1691 if op not in self.linear_to_fc_op: 1692 return super().call_operator(op, args, kwargs, meta) 1693 1694 # Extract the input tensor 1695 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 1696 leading_dims = math.prod(in_tensor.shape[:-1]) 1697 # If the tensor is not a vector, do nothing. 1698 if leading_dims != 1: 1699 return super().call_operator(op, args, kwargs, meta) 1700 1701 # If the op is quantized::linear, but per-channel quantized, bail. 1702 if op == exir_ops.edge.cadence.quantized_linear.default: 1703 weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1] 1704 if weight.shape != [1]: 1705 return super().call_operator(op, args, kwargs, meta) 1706 1707 # Replace the linear with fully connected op 1708 return super().call_operator( 1709 self.linear_to_fc_op[op], 1710 args, 1711 kwargs, 1712 meta, 1713 ) 1714 1715 1716@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1717class ReplaceScalarWithTensorArgPass(ExportPass): 1718 """ 1719 For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, 1720 replace the scalar arg with Tensor arg. 1721 """ 1722 1723 scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { 1724 exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, 1725 exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, 1726 exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, 1727 exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, 1728 } 1729 1730 def get_replacement(self, op, args, kwargs, meta): 1731 return super().call_operator( 1732 # Replace with .Tensor variant. 1733 op=self.scalar_to_tensor_ops[op], 1734 args=( 1735 # Tensor arg. 1736 args[0], 1737 # Scalar arg - replace with aten.full tensor. 1738 super().call_operator( 1739 exir_ops.edge.aten.full.default, 1740 args=( 1741 (1,), 1742 args[1], 1743 ), 1744 kwargs={"dtype": args[0].to_tensor().dtype}, 1745 meta=meta, 1746 ), 1747 # Other args. 1748 *args[2:], 1749 ), 1750 kwargs=kwargs, 1751 meta=meta, 1752 ) 1753 1754 def call_operator(self, op, args, kwargs, meta): 1755 if op not in self.scalar_to_tensor_ops: 1756 return super().call_operator(op, args, kwargs, meta) 1757 1758 # There must be exactly 2 args (3 for add and sub containing alpha) 1759 assert len(args) == 2 or len(args) == 3 1760 1761 # If there are two args, just replace the op. 1762 if len(args) == 2: 1763 return self.get_replacement(op, args, kwargs, meta) 1764 1765 # In case the op has three args, it must be scalar add/sub op. 1766 if ( 1767 op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} 1768 or "alpha" in kwargs 1769 ): 1770 return super().call_operator(op, args, kwargs, meta) 1771 1772 return self.get_replacement(op, args, kwargs, meta) 1773 1774 1775@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1776class ReplaceScalarTensorWithFullPass(ExportPass): 1777 """ 1778 aten.scalar_tensor can be replaced by aten.full with a shape of [1]. 1779 scalar_tensor is not supported, so this is an opt_level=0 pass. 1780 """ 1781 1782 def call_operator( 1783 self, 1784 op, 1785 args: Tuple[Argument, ...], 1786 kwargs: Dict[str, Argument], 1787 meta: NodeMetadata, 1788 ) -> ProxyValue: 1789 if op not in { 1790 exir_ops.edge.aten.scalar_tensor.default, 1791 torch.ops.aten.scalar_tensor.default, 1792 }: 1793 return super().call_operator(op, args, kwargs, meta) 1794 1795 return super().call_operator( 1796 exir_ops.edge.aten.full.default, 1797 ( 1798 [1], 1799 args[0], 1800 ), 1801 {"dtype": torch.float32}, 1802 meta, 1803 ) 1804 1805 1806@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1807class ReplaceFullLikeWithFullPass(ExportPass): 1808 """ 1809 aten.full_like can be replaced by aten.full with the shape of the arg tensor. 1810 full_like is not supported, so this is an opt_level=0 pass. 1811 """ 1812 1813 def call_operator(self, op, args, kwargs, meta): 1814 if op not in { 1815 exir_ops.edge.aten.full_like.default, 1816 }: 1817 return super().call_operator(op, args, kwargs, meta) 1818 1819 # Get the shape of the "like" tensor, and pass that in to the full op. 1820 return super().call_operator( 1821 exir_ops.edge.aten.full.default, 1822 ( 1823 ( 1824 args[0].to_tensor().shape 1825 if isinstance(args[0], ProxyValue) 1826 else args[0].shape 1827 ), 1828 args[1], 1829 ), 1830 {}, 1831 meta, 1832 ) 1833 1834 1835@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1836class ReplaceInfArgInFullWithValuePass(ExportPass): 1837 """ 1838 aten.full allows "-inf" and "inf" as inputs. The profiler cannot 1839 handle that, so replace them with the maximum value of the type. 1840 """ 1841 1842 def call_operator(self, op, args, kwargs, meta): 1843 if op not in { 1844 exir_ops.edge.aten.full.default, 1845 }: 1846 return super().call_operator(op, args, kwargs, meta) 1847 1848 new_args = list(args) 1849 1850 if args[1] == float("-inf"): 1851 new_args[1] = torch.finfo(torch.float32).min 1852 elif args[1] == float("inf"): 1853 new_args[1] = torch.finfo(torch.float32).max 1854 1855 return super().call_operator(op, tuple(new_args), kwargs, meta) 1856 1857 1858@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1859class ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass(ExportPass): 1860 """ 1861 Replace the aten.linalg_vector_norm op with a custom op. 1862 aten.linalg_vector_norm is not supported by Jarvis, so we 1863 need to replace it with native_batch_norm at all optimization levels. 1864 """ 1865 1866 def call_operator(self, op, args, kwargs, meta): 1867 if op != exir_ops.edge.aten.linalg_vector_norm.default: 1868 return super().call_operator(op, args, kwargs, meta) 1869 1870 assert ( 1871 len(args) == 1 1872 ), "aten.linalg_vector_norm should have 1 argument (a tensor), we do not support any custom variants" 1873 1874 return super().call_operator( 1875 exir_ops.edge.cadence.linalg_vector_norm.default, 1876 args, 1877 kwargs, 1878 meta, 1879 ) 1880 1881 1882@register_cadence_pass(CadencePassAttribute(opt_level=1)) 1883class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): 1884 """ 1885 Replace ops with single element arguments (size = [1]) with overloads that accept scalar ints/floats. 1886 """ 1887 1888 # Keep track of which operators and arguments are being replaced. 1889 replaced_scalar_args: dict[ 1890 EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]] 1891 ] = { 1892 exir_ops.edge.cadence.quantized_conv: ( 1893 exir_ops.edge.cadence.quantized_conv.per_tensor, 1894 [8, 9, 12, 13], 1895 ), 1896 exir_ops.edge.cadence.quantized_layer_norm: ( 1897 exir_ops.edge.cadence.quantized_layer_norm.per_tensor, 1898 [1, 2], 1899 ), 1900 exir_ops.edge.cadence.quantized_linear: ( 1901 exir_ops.edge.cadence.quantized_linear.per_tensor, 1902 [4, 5, 6], 1903 ), 1904 exir_ops.edge.cadence.quantized_relu: ( 1905 exir_ops.edge.cadence.quantized_relu.per_tensor, 1906 [1, 3, 4], 1907 ), 1908 } 1909 1910 def call_operator(self, op, args, kwargs, meta): 1911 op_edge_overload_packet = get_edge_overload_packet(op) 1912 1913 if op_edge_overload_packet not in self.replaced_scalar_args: 1914 return super().call_operator(op, args, kwargs, meta) 1915 1916 # Get all the args that need to be replaced. 1917 new_op, args_to_be_replaced = self.replaced_scalar_args[op_edge_overload_packet] 1918 1919 updated_args = list(args) 1920 for op_arg_index in args_to_be_replaced: 1921 arg = args[op_arg_index] 1922 if not isinstance(arg, ProxyValue): 1923 return super().call_operator(op, args, kwargs, meta) 1924 1925 if not arg.is_tensor(): 1926 return super().call_operator(op, args, kwargs, meta) 1927 1928 if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full: 1929 # Only replace if arg generated by a full op. 1930 return super().call_operator(op, args, kwargs, meta) 1931 1932 if tuple(arg.node.args[0]) != (1,): 1933 # Only replace if the size of the full op is [1]. 1934 return super().call_operator(op, args, kwargs, meta) 1935 1936 updated_args[op_arg_index] = arg.node.args[1] 1937 1938 return super().call_operator( 1939 new_op, 1940 tuple(updated_args), 1941 kwargs, 1942 meta, 1943 ) 1944 1945 1946@register_cadence_pass(CadencePassAttribute(opt_level=0)) 1947class ReplaceAtenAvgPoolWithJarvisAvgPoolPass(ExportPass): 1948 """ 1949 Replace the aten avg_pool op with the jarvis custom avg_pool2d op. 1950 """ 1951 1952 def call_operator(self, op, args, kwargs, meta): 1953 # Only continue for avg_pool op 1954 if op not in { 1955 exir_ops.edge.aten.avg_pool1d.default, 1956 exir_ops.edge.aten.avg_pool2d.default, 1957 }: 1958 return super().call_operator(op, args, kwargs, meta) 1959 1960 # Determine if the op is avg_pool1d or avg_pool2d 1961 avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default 1962 # Get the input tensor 1963 in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] 1964 1965 # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is 1966 # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. 1967 # stride, padding, ceil_mode, count_include_pad, divisor_override, are 1968 # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout, 1969 # and is False by default. 1970 kernel_size = args[1] 1971 stride = args[2] if len(args) >= 3 else [1, 1] 1972 padding = args[3] if len(args) >= 4 else [0, 0] 1973 ceil_mode = args[4] if len(args) >= 5 else False 1974 count_include_pad = args[5] if len(args) >= 6 else True 1975 divisor_override = args[6] if len(args) >= 7 else None 1976 zero_point = torch.tensor(0, dtype=torch.int32) 1977 1978 # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d 1979 # tensor. 1980 if avg_pool1d: 1981 in_shape = list(in_tensor.shape) 1982 assert len(in_shape) == 3, "Expected 3d input for avg_pool1d" 1983 in_shape.insert(2, 1) 1984 out_shape = meta["val"].shape 1985 in_view_op = super().call_operator( 1986 exir_ops.edge.aten.view_copy.default, 1987 (in_tensor, in_shape), 1988 kwargs, 1989 meta, 1990 ) 1991 # Extend the kernel_size, stride and padding to 2d 1992 kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size 1993 stride = [1] + stride if len(stride) == 1 else stride 1994 padding = [0] + padding if len(padding) == 1 else padding 1995 1996 # Create a new avg_pool node with the updated args 1997 new_args = ( 1998 in_view_op if avg_pool1d else args[0], 1999 kernel_size, 2000 stride, 2001 padding, 2002 ceil_mode, 2003 count_include_pad, 2004 divisor_override, 2005 zero_point, 2006 False, 2007 ) 2008 avg_pool2d_op = super().call_operator( 2009 exir_ops.edge.cadence.avg_pool2d.default, 2010 new_args, 2011 kwargs, 2012 meta, 2013 ) 2014 2015 # If the node was avg_pool1d, we again reshape the 4d output to 3d output 2016 return ( 2017 super().call_operator( 2018 exir_ops.edge.aten.view_copy.default, 2019 (avg_pool2d_op, list(out_shape)), 2020 kwargs, 2021 meta, 2022 ) 2023 if avg_pool1d 2024 else avg_pool2d_op 2025 ) 2026 2027 2028@register_cadence_pass(CadencePassAttribute(opt_level=1)) 2029class ReplaceIm2RowWithViewPass(ExportPass): 2030 def can_replace(self, op, args, kwargs, meta) -> bool: 2031 if op != exir_ops.edge.cadence.im2row.default: 2032 return False 2033 2034 # Check if im2row applies padding. If yes, we cannot replace it with view. 2035 pad = cast(tuple[int, ...], args[3]) 2036 if any(p != 0 for p in pad): 2037 return False 2038 2039 # Check if im2row has dilation. If yes, we cannot replace it with view. 2040 dilation = cast(tuple[int, ...], args[2]) 2041 if any(d != 1 for d in dilation): 2042 return False 2043 2044 # im2row works on 3D or 4D tensors. 2045 # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. 2046 output_shape = meta["val"].shape 2047 if math.prod(output_shape[1:-1]) == 1: 2048 return True 2049 2050 return False 2051 2052 def call_operator( 2053 self, 2054 op, 2055 args: tuple[Argument, ...], 2056 kwargs: dict[str, Argument], 2057 meta: NodeMetadata, 2058 ) -> ProxyValue: 2059 if op != exir_ops.edge.cadence.im2row.default: 2060 return super().call_operator(op, args, kwargs, meta) 2061 2062 if not self.can_replace(op, args, kwargs, meta): 2063 return super().call_operator(op, args, kwargs, meta) 2064 2065 output_shape = meta["val"].shape 2066 return super().call_operator( 2067 exir_ops.edge.aten.view_copy.default, 2068 (args[0], tuple(output_shape)), 2069 kwargs, 2070 meta, 2071 ) 2072 2073 2074# This class encapsulates all the functions that replace/switch one op in the 2075# graph with another. 2076class CadenceReplaceOpsInGraph: 2077 passes = [ 2078 ReplaceFunctionallyEquivalentOpTargets, 2079 ReplaceTCopyWithTransposePass, 2080 ReplacePermuteWithTransposePass, 2081 ReplaceScalarWithTensorArgPass, 2082 ReplaceConvolutionOptionalArgsWithConcreteArgsPass, 2083 ReplaceMMWithAddMMPass, 2084 ReplaceSqueezeAndUnsqueezeWithViewPass, 2085 ReplaceAddMMWithLinearPass, 2086 RemoveNopSelectOpPass, 2087 ReplaceSelectWithViewOpPass, 2088 ReplaceRepeatWithCatPass, 2089 ReplacePadWithCatPass, 2090 ReplaceConstantPadNdWithSlicePass, 2091 ReplaceConvWithChannelLastConvPass, 2092 ReplaceAtenConvolutionWithJarvisConvolutionPass, 2093 ForceChannelLastForConvPass, 2094 ReplaceTrivialConvWithLinear, 2095 ReplaceConvWithIm2RowAndLinear, 2096 ReplaceTransposedConvWithLinearPass, 2097 # This pass should be after passes that replace conv -> im2row + linear. 2098 ReplaceIm2RowWithViewPass, 2099 MakeSliceAndCatDimOutermostPass, 2100 ReplaceNopTransposeOrPermuteWithViewPass, 2101 ReplaceLinearWithFullyConnectedOpPass, 2102 ReplaceScalarTensorWithFullPass, 2103 ReplaceFullLikeWithFullPass, 2104 ReplaceInfArgInFullWithValuePass, 2105 ReplaceLogicalNotBooleanWhereWithWherePass, 2106 ReplacePT2QuantWithCadenceQuantPass, 2107 ReplacePT2DequantWithCadenceDequantPass, 2108 ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, 2109 ReplaceAtenAvgPoolWithJarvisAvgPoolPass, 2110 ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, 2111 ] 2112