xref: /aosp_15_r20/external/executorch/backends/cadence/aot/replace_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4# This file contains all the functions that replace one op with another in the
5# graph. The functions replacing ops for models deployed with Jarvis are grouped
6# together in class 'ReplaceOpsInGraph'. Some examples of functions in the class are
7# 1. functions that replace an ATen op with a custom op that accepts extra arguments
8# 2. functions that replace in-place variants of ATen ops with out-of-place version.
9# 3. functions that replace an ATen op with another semantically equivalent ATen op.
10# 4. functions that concretize optional args.
11
12import math
13from operator import neg
14from typing import cast, Dict, Iterable, Sequence, Set, Tuple
15
16import torch
17import torch.fx
18from executorch.backends.cadence.aot.compiler_utils import (
19    get_shape,
20    get_tensor_from_attr,
21    get_transposed_dims,
22    get_zero_point,
23    is_node_with_op,
24    is_quantized_tensor,
25    quantize_tensor_multiplier,
26)
27from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps
28from executorch.backends.cadence.aot.pass_utils import (
29    CadencePassAttribute,
30    register_cadence_pass,
31)
32from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
33from executorch.backends.cadence.aot.utils import get_edge_overload_packet
34from executorch.exir.dialects._ops import ops as exir_ops
35from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
36from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
37from torch._subclasses import FakeTensor
38from torch.fx.node import Argument
39
40# A map to represent ops that:
41# (a) are functionally equivalent wrt. Jarvis; and
42# (b) have identical arguments
43# An op whose target is 'key' in this dict can be replaced by the functionally euivalent
44# op whose target is 'value'. The replacement would just involve changing the op target.
45functionally_equivalent_op_targets: Dict[EdgeOpOverload, EdgeOpOverload] = {
46    exir_ops.edge.aten.relu_.default: exir_ops.edge.aten.relu.default,
47    exir_ops.edge.aten.unsafe_split.Tensor: exir_ops.edge.aten.split_copy.Tensor,
48}
49
50
51def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool:
52    """
53    Return true if any of the node in the incoming nodes list is a placeholder
54    or parameter
55    """
56    return any(
57        is_node_with_op(node, "placeholder") or is_node_with_op(node, "get_attr")
58        for node in nodes
59    )
60
61
62@register_cadence_pass(CadencePassAttribute(opt_level=0))
63class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
64    """
65    A where op with a logical_not and a boolean tensor can be replaced
66    by a where op with flipped inputs and the initial boolean tensor.
67    """
68
69    def replace_logical_nop_where_with_where(
70        self, graph_module: torch.fx.GraphModule
71    ) -> None:
72        graph = graph_module.graph
73        for node in graph.nodes:
74            # We are only interested in where nodes
75            if node.target != exir_ops.edge.aten.where.self:
76                continue
77
78            # If the third arg is not a logical_not, bail.
79            if node.args[0].target != exir_ops.edge.aten.logical_not.default:
80                continue
81
82            # Get the third arg node and its input
83            logical_not_node = node.args[0]
84            logical_not_input_tensor = (
85                logical_not_node.args[0].to_tensor()
86                if isinstance(logical_not_node.args[0], ProxyValue)
87                else logical_not_node.args[0]
88            )
89
90            # If the logical_not input is not a boolean tensor, bail.
91            if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
92                continue
93
94            # Replace the where op with another one, flipping the inputs and using the boolean
95            # tensor from logical_not.
96            with graph.inserting_before(node):
97                linear_node = graph.call_function(
98                    exir_ops.edge.aten.where.self,
99                    args=(logical_not_node.args[0], node.args[2], node.args[1]),
100                )
101            # Replace all the uses
102            node.replace_all_uses_with(linear_node)
103
104        graph_module.recompile()
105        graph_module.graph.eliminate_dead_code()
106
107    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
108        self.replace_logical_nop_where_with_where(graph_module)
109        result = super().call(graph_module)
110        return result
111
112
113@register_cadence_pass(CadencePassAttribute(opt_level=0))
114class ReplaceSafeSoftmaxWithSoftmax(ExportPass):  # keep
115    """
116    Replace _safe_softmax with _softmax
117    """
118
119    def call_operator(
120        self,
121        op,
122        args: tuple[Argument, ...],
123        kwargs: dict[str, Argument],
124        meta: NodeMetadata,
125    ) -> ProxyValue:
126        if op != torch.ops.aten._safe_softmax.default:
127            return super().call_operator(op, args, kwargs, meta)
128
129        # Add False for the half_to_float argument of softmax
130        softmax_args = list(args) + [False]
131
132        return super().call_operator(
133            torch.ops.aten._softmax.default,
134            tuple(softmax_args),
135            kwargs,
136            meta,
137        )
138
139
140@register_cadence_pass(CadencePassAttribute(opt_level=0))
141class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
142    """
143    Replace the pt2 quantization ops with cadence quantization ops.
144    We do not link kernels to the PT2 quantization ops, so we need to
145    replace them with cadence ops at all optimization levels.
146    """
147
148    def call_operator(
149        self,
150        op,
151        args: Tuple[Argument, ...],
152        kwargs: Dict[str, Argument],
153        meta: NodeMetadata,
154    ) -> ProxyValue:
155        if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}:
156            return super().call_operator(op, args, kwargs, meta)
157
158        return super().call_operator(
159            exir_ops.edge.cadence.quantize_per_tensor.default,
160            args,
161            kwargs,
162            meta,
163        )
164
165
166@register_cadence_pass(CadencePassAttribute(opt_level=0))
167class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
168    """
169    Replace the pt2 dequantization ops with cadence dequantization ops.
170    We do not link kernels to the PT2 quantization ops, so we need to
171    replace them with cadence ops at all optimization levels.
172    """
173
174    def call_operator(
175        self,
176        op,
177        args: Tuple[Argument, ...],
178        kwargs: Dict[str, Argument],
179        meta: NodeMetadata,
180    ) -> ProxyValue:
181        if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}:
182            return super().call_operator(op, args, kwargs, meta)
183
184        return super().call_operator(
185            exir_ops.edge.cadence.dequantize_per_tensor.default,
186            args,
187            kwargs,
188            meta,
189        )
190
191
192@register_cadence_pass(CadencePassAttribute(opt_level=0))
193class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
194    """
195    When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
196    view_copy op
197    """
198
199    def call_operator(
200        self,
201        op,
202        args: Tuple[Argument, ...],
203        kwargs: Dict[str, Argument],
204        meta: NodeMetadata,
205    ) -> ProxyValue:
206        # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
207        # which allows us to cover all overloads.
208        if get_edge_overload_packet(op) not in {
209            exir_ops.edge.aten.squeeze_copy,
210            exir_ops.edge.aten.unsqueeze_copy,
211        }:
212            return super().call_operator(op, args, kwargs, meta)
213        # Get the output tensor shape
214        out_shape = meta["val"].shape
215
216        # Bail out if any dim is not an int (dynamic shape)
217        for dim in list(out_shape):
218            if not isinstance(dim, int):
219                return super().call_operator(op, args, kwargs, meta)
220
221        # Return a view op with the new shape
222        view_args = (args[0], list(out_shape))
223        return super().call_operator(
224            exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
225        )
226
227
228@register_cadence_pass(CadencePassAttribute(opt_level=0))
229class ReplaceFunctionallyEquivalentOpTargets(ExportPass):
230    """
231    Replace an op with a functionally equivalent op by just switching the op
232    target, but without incurring any change to the op args.
233    """
234
235    def call_operator(self, op, args, kwargs, meta):
236        if op not in functionally_equivalent_op_targets:
237            return super().call_operator(op, args, kwargs, meta)
238        return super().call_operator(
239            functionally_equivalent_op_targets[op], args, kwargs, meta
240        )
241
242
243@register_cadence_pass(CadencePassAttribute(opt_level=1))
244class ReplaceSelectWithViewOpPass(ExportPass):
245    """
246    If the size along the select dim is 1, then the select op can be replaced
247    by view op.
248    """
249
250    def call_operator(self, op, args, kwargs, meta):
251        if op != exir_ops.edge.aten.select_copy.int:
252            return super().call_operator(op, args, kwargs, meta)
253
254        # Glean the shape of input and output tensor
255        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
256        in_shape = in_tensor.shape
257        out_shape = meta["val"].shape
258        # Get the select dimension
259        select_dim = args[1] if args[1] >= 0 else args[1] + len(in_shape)
260
261        if in_shape[select_dim] == 1:
262            # Return a view op with the new shape
263            view_args = (args[0], list(out_shape))
264            return super().call_operator(
265                exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
266            )
267        return super().call_operator(op, args, kwargs, meta)
268
269
270@register_cadence_pass(CadencePassAttribute(opt_level=0))
271class ReplaceTCopyWithTransposePass(ExportPass):
272    """
273    Replace t_copy with transpose_copy.int. If the input is 1D, the t_copy is
274    a nop. t_copy is not supported, so this is an opt_level=0 pass.
275    """
276
277    def call_operator(self, op, args, kwargs, meta):
278        if get_edge_overload_packet(op) != exir_ops.edge.aten.t_copy:
279            return super().call_operator(op, args, kwargs, meta)
280
281        # Get the input tensor shape
282        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
283
284        # If the input is a 1D tensor, this t_copy is a nop, so return the input
285        if in_tensor.dim() <= 1:
286            return args[0]
287
288        assert in_tensor.dim() == 2, "t_copy expects a tensor with <= 2 dimensions"
289        transpose_args = (args[0], 0, 1)
290        return super().call_operator(
291            exir_ops.edge.aten.transpose_copy.int, transpose_args, kwargs, meta
292        )
293
294
295@register_cadence_pass(CadencePassAttribute(opt_level=0))
296class ReplaceMMWithAddMMPass(ExportPass):
297    """
298    This pass replaces mm with addmm by introducing a zero bias.
299    mm is not supported, so this is an opt_level=0 pass.
300    """
301
302    def call_operator(self, op, args, kwargs, meta):
303        if op != exir_ops.edge.aten.mm.default:
304            return super().call_operator(op, args, kwargs, meta)
305
306        # The mm op has two args: input, mat2
307        assert len(args) == 2
308        X, mat2 = args
309
310        # Create a zero bias tensor, and insert it as a graph buffer before the
311        # current node
312        mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2
313        bias_size = mat2_tensor.size(1)
314        zero_bias = super().call_operator(
315            exir_ops.edge.aten.full.default,
316            ([bias_size], 0.0),
317            {"dtype": torch.float32},
318            meta,
319        )
320
321        # Replace mm with addmm
322        new_args = (zero_bias, X, mat2)
323        return super().call_operator(
324            exir_ops.edge.aten.addmm.default, new_args, kwargs, meta
325        )
326
327
328@register_cadence_pass(CadencePassAttribute(opt_level=1))
329class ReplaceAddMMWithLinearPass(ExportPass):
330    """
331    This pass replaces addmm with linear op.
332    """
333
334    def __init__(self):
335        super().__init__()
336        self.counter = 0
337
338    def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule):
339        graph = graph_module.graph
340        for node in graph.nodes:
341            # We are only interested in admm nodes
342            if node.target != exir_ops.edge.aten.addmm.default:
343                continue
344
345            # The addmm op has three concrete args: input, mat1, mat2
346            assert len(node.args) >= 3
347            (bias, mat1, mat2) = node.args[0:3]
348            # The other two args are optional scale args
349            beta = node.kwargs.get("beta", 1.0)
350            alpha = node.kwargs.get("alpha", 1.0)
351
352            # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert
353            # it to linear op by multiplying beta to bias, and alpha to mat2.t().
354            # However, the following two conditions must hold:
355            # a. If bias is not a param, then beta must be 1.0
356            # b. If mat2 is not a param, then mat2 must be a transpose op. Also,
357            # the input to the transpose must be a param, or alpha must be 1.0.
358            fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0
359            fit_mat2 = is_node_with_op(mat2, "get_attr")
360            transposed_mat2 = False
361            if (
362                not fit_mat2
363                and is_node_with_op(mat2, "call_function")
364                and mat2.target == exir_ops.edge.aten.transpose_copy.int
365            ):
366                mat2, transposed_mat2 = mat2.args[0], True
367                fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0
368
369            if not fit_bias or not fit_mat2:
370                continue
371
372            # Multiply bias by beta
373            if beta != 1.0:
374                assert is_node_with_op(bias, "get_attr")
375                bias_tensor = get_tensor_from_attr(graph_module, bias)
376                assert isinstance(bias_tensor, torch.Tensor)
377                bias_tensor = beta * bias_tensor
378                with graph.inserting_before(node):
379                    bias_name = f"_bias_addmm_to_linear_{self.counter}"
380                    graph_module.register_buffer(bias_name, bias_tensor)
381                    bias = graph.get_attr(bias_name)
382
383            # Use associativity of scalar multiplication, and multiply alpha to mat2
384            if is_node_with_op(mat2, "get_attr"):
385                mat2_tensor = get_tensor_from_attr(graph_module, mat2)
386                assert isinstance(mat2_tensor, torch.Tensor)
387                mat2_tensor = alpha * mat2_tensor
388                # transpose mat2
389                mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t()
390                with graph.inserting_before(node):
391                    mat2_name = f"_mat2_addmm_to_linear_{self.counter}"
392                    graph_module.register_buffer(mat2_name, mat2_tensor)
393                    mat2 = graph.get_attr(mat2_name)
394
395            # Construct the linear node
396            linear_args = (mat1, mat2, bias)
397            with graph.inserting_before(node):
398                linear_node = graph.call_function(
399                    exir_ops.edge.aten.linear.default, args=linear_args
400                )
401                linear_node.meta = node.meta
402            # Replace all the uses of the addmm op with linear op
403            node.replace_all_uses_with(linear_node)
404            self.counter += 1
405
406        graph_module.recompile()
407        graph_module.graph.eliminate_dead_code()
408
409    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
410        self.replace_addmm_with_linear(graph_module)
411        result = super().call(graph_module)
412        return result
413
414
415@register_cadence_pass(CadencePassAttribute(opt_level=1))
416class ReplacePermuteWithTransposePass(ExportPass):
417    """
418    Replace permute op with transpose if the permutation is only along
419    two dimensions.
420    """
421
422    def call_operator(self, op, args, kwargs, meta):
423        if op != exir_ops.edge.aten.permute_copy.default:
424            return super().call_operator(op, args, kwargs, meta)
425
426        # Get the old dim and new dim order
427        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
428        old_dims = tuple(range(in_tensor.dim()))
429        new_dims = args[1]
430
431        # Compute the number of positions in which the old and new order differ
432        diff = [od for od, nd in zip(old_dims, new_dims) if od != nd]
433
434        # If the difference is in two dimensions, we can replace this permute op
435        # with transpose op.
436        if len(diff) == 2:
437            new_args = (args[0], diff[0], diff[1])
438            return super().call_operator(
439                exir_ops.edge.aten.transpose_copy.int, new_args, kwargs, meta
440            )
441
442        return (
443            args[0] if len(diff) == 0 else super().call_operator(op, args, kwargs, meta)
444        )
445
446
447@register_cadence_pass(CadencePassAttribute(opt_level=0))
448class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass):
449    """
450    Replace optional tensors with concrete tensors. Currently, we
451    replace the optional bias tensor with a zero tensor.
452    """
453
454    def call_operator(self, op, args, kwargs, meta):
455        if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution:
456            return super().call_operator(op, args, kwargs, meta)
457
458        # Check if the bias is already concrete
459        assert len(args) == 9
460        if args[2] is not None:
461            return super().call_operator(op, args, kwargs, meta)
462
463        # The bias length is the number of out channels.
464        out_shape = meta["val"].shape
465        bias_size = out_shape[1]
466        # Create a zero bias tensor (bias is not a constant tensor,
467        # so it needs to be the result of a graph operation).
468        zero_bias = super().call_operator(
469            exir_ops.edge.aten.full.default,
470            ([bias_size], 0.0),
471            {"dtype": torch.float32},
472            meta,
473        )
474
475        # Replace bias with zero_bias
476        args = list(args)
477        args[2] = zero_bias
478        args = tuple(args)
479
480        return super().call_operator(op, args, kwargs, meta)
481
482
483@register_cadence_pass(CadencePassAttribute(opt_level=0))
484class ReplaceRepeatWithCatPass(ExportPass):
485    """
486    Replace repeat op as successive cat ops along different dimensions.
487    repeat is not supported, so this is an opt_level=0 pass.
488    """
489
490    def call_operator(self, op, args, kwargs, meta):
491        if op != exir_ops.edge.aten.repeat.default:
492            return super().call_operator(op, args, kwargs, meta)
493
494        # Extract the input tensor, and the repeats from the args
495        in_tensor = args[0]
496        repeats = args[1]
497
498        # Glean the shapes of input tensor
499        in_shape = list(
500            in_tensor.to_tensor().shape
501            if isinstance(in_tensor, ProxyValue)
502            else in_tensor.shape
503        )
504
505        # If the size of repeats is more than the dimensionality of the tensor,
506        # the output of repeat will be a higher-dimensional tensor. We reshape
507        # the input so that it has the same dimensionality as the output tensor.
508        diff = len(repeats) - len(in_shape)
509        assert (
510            diff >= 0
511        ), "Repeat arg malformed: expected a repeat along each dimension of input tensor"
512
513        if diff > 0:
514            # Extend the input shape with 1's along the higher dimensions
515            in_shape = ([1] * diff) + in_shape
516            # Insert a view op that reshapes the input tensor to have same
517            # dimensionality as the output tensor.
518            in_tensor = super().call_operator(
519                exir_ops.edge.aten.view_copy.default,
520                (in_tensor, in_shape),
521                kwargs,
522                meta,
523            )
524            assert len(repeats) == len(in_shape)
525
526        # Repeat op is nothing but successive cat ops along each dimension.
527        for dim, repeat in reversed(list(enumerate(repeats))):
528            # We do not need to do anything if repeat factor is 1
529            if repeat == 1:
530                continue
531            cat_arg = [in_tensor] * repeat
532            in_tensor = super().call_operator(
533                exir_ops.edge.aten.cat.default, (cat_arg, dim), kwargs, meta
534            )
535
536        return in_tensor
537
538
539@register_cadence_pass(CadencePassAttribute(opt_level=1))
540class ReplacePadWithCatPass(ExportPass):
541    """
542    Replace constant pad nd op that does padding on outer-most dimension
543    with Cat(left_padding_constant_tensor, X, right_padding_constant_tensor)
544    """
545
546    def call_operator(self, op, args, kwargs, meta):
547        if op != exir_ops.edge.aten.constant_pad_nd.default:
548            return super().call_operator(op, args, kwargs, meta)
549
550        assert len(args) >= 2
551        input_node, orig_padding = args[:2]
552
553        # if there is no padding, this op will be treated in removal pass.
554        if not orig_padding:
555            return super().call_operator(op, args, kwargs, meta)
556
557        value = 0 if len(args) == 2 else args[2]
558
559        arg_shape = input_node.to_tensor().shape
560
561        padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0))
562        assert len(padding) >= 2
563        (left_padding_size, right_padding_size) = padding[-2:]
564        # Replace only if constant_pad_nd is along the innermost padding dimension.
565        if (
566            any(x != 0 for x in padding[0:-2])
567            or left_padding_size < 0
568            or right_padding_size < 0
569        ):
570            return super().call_operator(op, args, kwargs, meta)
571
572        cat_tensors = []
573        dim = len(arg_shape) - len(padding) // 2
574        # add left_padding
575        if left_padding_size > 0:
576            left_padding_shape = (
577                arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :]
578            )
579            left_padding_node = super().call_operator(
580                torch.ops.aten.full.default,
581                (
582                    left_padding_shape,
583                    value,
584                ),
585                {"dtype": torch.float32},
586                meta,
587            )
588            cat_tensors.append(left_padding_node)
589        # input_node
590        cat_tensors.append(input_node)
591        # right_padding
592        if right_padding_size > 0:
593            right_padding_shape = (
594                arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :]
595            )
596            right_padding_node = super().call_operator(
597                torch.ops.aten.full.default,
598                (
599                    right_padding_shape,
600                    value,
601                ),
602                {"dtype": torch.float32},
603                meta,
604            )
605            cat_tensors.append(right_padding_node)
606
607        assert len(cat_tensors) == 1 + (left_padding_size > 0) + (
608            right_padding_size > 0
609        )
610
611        new_args = (cat_tensors, dim)
612        return super().call_operator(
613            exir_ops.edge.aten.cat.default,
614            new_args,
615            kwargs,
616            meta,
617        )
618
619
620@register_cadence_pass(CadencePassAttribute(opt_level=1))
621class ReplaceConstantPadNdWithSlicePass(ExportPass):
622    """
623    Replace constant pad nd op that does padding on outer-most dimension
624    with exir_ops slice(left_padding_constant_tensor, X, right_padding_constant_tensor)
625    """
626
627    def call_operator(self, op, args, kwargs, meta):
628        if op != exir_ops.edge.aten.constant_pad_nd.default:
629            return super().call_operator(op, args, kwargs, meta)
630
631        assert len(args) >= 2
632        input_node, orig_padding = args[:2]
633
634        # if there is no padding, this op will be treated in removal pass.
635        if not orig_padding:
636            return super().call_operator(op, args, kwargs, meta)
637
638        padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0))
639        assert len(padding) >= 2
640        (start, diff) = map(neg, padding[-2:])
641        # Replace only if constant_pad_nd is along the innermost padding dimension.
642        if any(x != 0 for x in padding[0:-2]) or start < 0 or diff < 0:
643            return super().call_operator(op, args, kwargs, meta)
644
645        arg_shape = input_node.to_tensor().shape
646        dim = len(arg_shape) - len(padding) // 2
647        stop = arg_shape[dim] - diff
648        assert start <= stop
649        new_args = (input_node, dim, start, stop)
650        return super().call_operator(
651            exir_ops.edge.aten.slice.Tensor,
652            new_args,
653            kwargs,
654            meta,
655        )
656
657
658# Make that pass runnable standalone at opt level 0.
659@register_cadence_pass(CadencePassAttribute(opt_level=0))
660class ReplaceAtenConvolutionWithJarvisConvolutionPass(ExportPass):
661    """
662    Replace aten convolution op with jarvis-specific convolution op, since the
663    aten version is not supported by jarvis.
664    Also remove convolution stride if the output size along the strided dimension
665    is 1. We can enable more transformations (e.g., conv -> linear replacement)
666    for unit-stride convolutions.
667    """
668
669    def call_operator(self, op, args, kwargs, meta):
670        if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution:
671            return super().call_operator(op, args, kwargs, meta)
672        # There must be 9 total args.
673        assert len(args) == 9
674
675        # Unpack the args
676        (
677            in_tensor,
678            weight,
679            bias,
680            stride,
681            padding,
682            dilation,
683            transposed,
684            output_padding,
685            groups,
686        ) = args
687        # Currently we only handle conversion to conv1d and conv2d, therefore
688        # verify that the stride, padding, dilation, and output_padding have
689        # len <=2.
690        assert (
691            len(stride) == len(padding) == len(dilation) == len(output_padding) == 1
692        ) or (
693            len(stride) == len(padding) == len(dilation) == len(output_padding) == 2
694        ), "Can only map convolution to conv1d and conv2d at present"
695
696        target = (
697            exir_ops.edge.cadence.transposed_convolution.default
698            if transposed
699            else exir_ops.edge.cadence.convolution.default
700        )
701
702        if transposed:
703            # Flip the height and width dimensions of weight, since we apply a
704            # gather stencil. Also, the first two dimensions of weight must be
705            # transposed/interchanged.
706            # If weight is a ProxyValue, new_weight needs to be the output of a
707            # graph operation (in this case a transpose_copy op) to be an explicit
708            # ProxyValue as well. If not, the view op can be done directly on the
709            # tensor.
710            transposed_weight = (
711                super().call_operator(
712                    exir_ops.edge.aten.transpose_copy.int,
713                    (
714                        weight,
715                        0,
716                        1,
717                    ),
718                    kwargs,
719                    meta,
720                )
721                if isinstance(weight, ProxyValue)
722                else weight.transpose(0, 1)
723            )
724
725            flipped_weight = (
726                super().call_operator(
727                    torch.ops.aten.flip.default,
728                    (
729                        transposed_weight,
730                        [-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
731                    ),
732                    kwargs,
733                    meta,
734                )
735                if isinstance(transposed_weight, ProxyValue)
736                else (
737                    transposed_weight.flip(-1)
738                    if transposed_weight.dim() == 3
739                    else transposed_weight.flip(-1, -2)
740                )
741            )
742
743            # From the previous checks, if flipped_weight is a FakeTensor, it has to be
744            # a constant (if not, it would be a ProxyValue). Mark it as such.
745            if isinstance(flipped_weight, FakeTensor):
746                flipped_weight.constant = flipped_weight
747            new_args = (
748                in_tensor,
749                flipped_weight,
750                bias,
751                stride,
752                padding,
753                dilation,
754                output_padding,
755                groups,
756                False,
757            )
758        else:
759            # Verify that output_padding is 0.
760            assert all(
761                x == 0 for x in output_padding
762            ), "Cannot handle padded output in convolution"
763
764            # If the innermost dim of output tensor is 1, then the stride
765            # should be 1. Note that the first dimension of output tensor is
766            # channel
767            new_stride = stride.copy()
768            out_shape = meta["val"].shape
769            assert out_shape is not None
770            for i, e in enumerate(out_shape[2:]):
771                new_stride[i] = 1 if e == 1 else stride[i]
772
773            new_args = (
774                in_tensor,
775                weight,
776                bias,
777                new_stride,
778                padding,
779                dilation,
780                groups,
781                False,
782            )
783
784        return super().call_operator(target, new_args, kwargs, meta)
785
786
787# TODO(matthiascremon): this is a fuse op, not a replace op
788class ReplaceConvWithChannelLastConv:
789    """
790    Convolution op in pytorch expects NCHW layout for input, weight, and output
791    tensors. However, if the input and output to the convolution op are originally
792    in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
793    the two permute ops with the convolution op, and call the NHWC layout
794    convolution op in Jarvis.
795    """
796
797    def __init__(self):
798        self.counter = 0
799        self.graph_module = None
800
801    def __call__(self, graph_module: torch.fx.GraphModule):
802        self.replace_conv_with_nhwc_conv(graph_module)
803
804    def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool:
805        """
806        Return true if the convolution input and output are connected to permute
807        ops, and the input/output to/from the permute ops is NHWC layout tensor.
808        """
809        # There must only be a single user of the output node (which must be a
810        # permute/tranpsose op). The input of the convolution must be connected
811        # to a permute op, and that permute op should have a single user.
812        conv_inp = node.args[0]
813        assert isinstance(conv_inp, torch.fx.Node)
814        if len(node.users) != 1 or len(conv_inp.users) != 1:
815            return False
816
817        # Get the input and output (permute/transpose) nodes of the convolution
818        conv_user = list(node.users.keys())[0]
819        assert isinstance(conv_user, torch.fx.Node)
820        pt_nodes: Set[torch.fx.Node] = {conv_inp, conv_user}
821
822        # Any node in pt_nodes must not be a placeholder.
823        if contains_placeholder_or_param(pt_nodes):
824            return False
825
826        # Determine if the convolution is 1d or 2d. The output tensor must be
827        # 3- or 4-dimensional
828        out_shape = get_shape(self.graph_module, node)
829        assert out_shape is not None
830        out_dims = len(out_shape)
831        assert out_dims in {3, 4}, "Jarvis only supports conv1d and conv2d"
832        conv1d = out_dims == 3
833
834        # Get the possible targets for the nodes in pt_nodes. Since conv1d has
835        # 3-dimensional input and output tensors, the nodes in pt_nodes could
836        # be either permute or transpose op. For conv2d, the nodes in pt_nodes
837        # must be permute ops.
838        p_target = exir_ops.edge.aten.permute_copy.default
839        t_target = exir_ops.edge.aten.transpose_copy.int
840        pt_targets = [p_target] + ([t_target] if conv1d else [])
841
842        # If any node in pt_nodes is not permute op (or tranpose op for conv1d),
843        # bail.
844        if any(x.target not in pt_targets for x in pt_nodes):
845            return False
846
847        # Now we need to determine the dimension permutations:
848        # If the input had NHWC layout, which was then permuted/transposed
849        # by a permute/transpose op to NCHW layout, the permutation must be
850        # [0, 3, 2, 1] (or [0, 2, 1] for conv1d).
851        # If the output had NCHW layout, and was then permuted to NHWC layout,
852        # the permutation must be [0, 2, 3, 1] (or [0, 2, 1] for conv1d).
853        nhwc_permute_order = {
854            node.args[0]: [0, 2, 1] if conv1d else [0, 3, 1, 2],
855            list(node.users.keys())[0]: [0, 2, 1] if conv1d else [0, 2, 3, 1],
856        }
857        for x in pt_nodes:
858            order = (
859                x.args[1]
860                if x.target == p_target
861                else get_transposed_dims(x, list(range(out_dims)))
862            )
863            if order != nhwc_permute_order[x]:
864                return False
865
866        return True
867
868    def replace_conv_with_nhwc_conv(self, graph_module: torch.fx.GraphModule):
869        self.graph_module = graph_module
870        graph = graph_module.graph
871        for node in graph.nodes:
872            # We are only interested in convolution nodes that have NHWC layout
873            if node.target not in {
874                exir_ops.edge.cadence.quantized_conv.default,
875                exir_ops.edge.cadence.convolution.default,
876                exir_ops.edge.cadence.quantized_transposed_conv.default,
877                exir_ops.edge.cadence.transposed_convolution.default,
878            } or not self.conv_layout_is_nhwc(node):
879                continue
880
881            # Get the args of convolution op
882            args = list(node.args)
883            # The input is connected to a permute/transpose op that converts the
884            # NHWC layout to NCHW layout. The input of the permute op will become
885            # this convolution op's input.
886            in_tp = args[0]
887            args[0] = in_tp.args[0]
888            # The weight is in NHWC layout. Permute it to NHWC layout.
889            weight_tensor = get_tensor_from_attr(graph_module, args[1])
890            assert isinstance(weight_tensor, torch.Tensor)
891            # We cannot directly permute a per-channel quantized tensor. We will
892            # dequantize it, permute the fp32 tensor, and then requantize the
893            # permuted tensor.
894            if (
895                is_quantized_tensor(weight_tensor)
896                and weight_tensor.qscheme() == torch.per_channel_affine
897            ):
898                # We have already asserted during quantizing conv op that the
899                # quantization axis is 0.
900                dequant_weight = weight_tensor.dequantize()
901                dequant_weight = (
902                    dequant_weight.permute([0, 2, 1])
903                    if dequant_weight.dim() == 3
904                    else dequant_weight.permute([0, 2, 3, 1])
905                )
906                weight_tensor = torch.quantize_per_channel(
907                    dequant_weight.contiguous(),
908                    weight_tensor.q_per_channel_scales(),
909                    weight_tensor.q_per_channel_zero_points(),
910                    0,
911                    weight_tensor.dtype,
912                )
913            else:
914                weight_tensor = (
915                    weight_tensor.permute([0, 2, 1])
916                    if weight_tensor.dim() == 3
917                    else weight_tensor.permute([0, 2, 3, 1])
918                )
919            # Make the weight tensor contiguous, since we have permuted it.
920            weight_tensor = weight_tensor.contiguous()
921            # Add the permuted weight into the graph, and update the weight in
922            # args.
923            with graph.inserting_before(node):
924                weight_name = f"_weight_nhwc_{self.counter}"
925                graph_module.register_buffer(weight_name, weight_tensor)
926                weight = graph.get_attr(weight_name)
927            args[1] = weight
928
929            # The 'channel_last' arg is True. It is the last arg.
930            args[-1] = True
931            # Now update the convolution node args to mark it as NHWC convolution
932            node.args = tuple(args)
933
934            # Replace all the uses of the permute op connected to the output op
935            # with this convolution.
936            out_tp = list(node.users.keys())[0]
937            out_tp.replace_all_uses_with(node)
938            node.meta = out_tp.meta
939
940            # Erase the permute ops connected to the input and output of the
941            # convolution op.
942            graph.erase_node(in_tp)
943            graph.erase_node(out_tp)
944            self.counter += 1
945
946        graph_module.recompile()
947
948
949# This pass needs to be reworked to be compatible with PT2. It is an optimization
950# pass anyway, so move it to opt level 2.
951# TODO(matthiascremon): update and improve this pass.
952@register_cadence_pass(CadencePassAttribute(opt_level=2))
953class ReplaceConvWithChannelLastConvPass(ExportPass):
954    """
955    Replace the ATen convolution op with custom conv op with NCHW or NHWC layout
956    input tensors, depending on the presence of permute/transpose ops connected
957    to the input tensor.
958    """
959
960    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
961        result = ReplaceAtenConvolutionWithJarvisConvolutionPass()(graph_module)
962        assert result is not None
963        ReplaceConvWithChannelLastConv()(result.graph_module)
964        return result
965
966
967@register_cadence_pass(CadencePassAttribute(opt_level=1))
968class ReplaceTrivialConvWithLinear(ExportPass):
969    """
970    In nn.Conv1d, the operand shapes are:
971        input - [batch, in_channels, in_length]
972        weight - [out_channels, in_channels, weight_length]
973        output - [batch, out_channels, out_length]
974    When in_length == weight_length, out_length = 1. In this scenario, we can
975    view the input as a tensor shaped [batch, K], and weight as a tensor
976    shaped [out_channels, K], and replace nn.Conv1d with nn.Linear. This
977    optimization can be extended to nn.Conv2d as well, where in_length is a 2d
978    image, and weight_length can be replaced with a 2d filter the same shape as
979    the image.
980    """
981
982    trivial_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
983        exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
984        exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default,
985    }
986
987    def call_operator(self, op, args, kwargs, meta):
988        if op not in self.trivial_conv_op_to_linear_op:
989            return super().call_operator(op, args, kwargs, meta)
990
991        # Parse the necessary args of the convolution node. Both convolution
992        # and quantized_conv have the same first 8 args. The quantized op has
993        # extra args holding at least the zero point and scale of input, weight, bias,
994        # and output tensor.
995        quantized_op = op == exir_ops.edge.cadence.quantized_conv.default
996        assert (len(args) == 8 and not quantized_op) or (
997            len(args) >= 12 and quantized_op
998        ), "Inconsistent args for convolution"
999        (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]
1000
1001        # Glean the shapes of input, weight, and output
1002        in_shape = (
1003            in_tensor.to_tensor().shape
1004            if isinstance(in_tensor, ProxyValue)
1005            else in_tensor.shape
1006        )
1007
1008        weight_shape = (
1009            weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1010        )
1011        out_shape = meta["val"].shape
1012        assert None not in {in_shape, weight_shape, out_shape}
1013
1014        # Check the condition under which conv can be replaced by linear: (1) this
1015        # should not be a depthwise convolution; (2) the padding, stride, and dilation
1016        # should be standard; (3) The [channels, height, width] of input must match the
1017        # [channel, kernel_height, kernel_width] of the weight. These conditions would
1018        # ensure that output height and width are 1, and the convolution can be replaced
1019        # by linear.
1020        if (
1021            groups != 1
1022            or any(x != 0 for x in padding)
1023            or any(x != 1 for x in stride)
1024            or any(x != 1 for x in dilation)
1025            or (list(in_shape[1:]) != list(weight_shape[1:]))
1026        ):
1027            return super().call_operator(op, args, kwargs, meta)
1028
1029        # Reshape the weight to [out_channels, in_channels * X]
1030        K = math.prod(weight_shape[1:])
1031
1032        # If weight is a ProxyValue, linear_weight needs to be the output of a
1033        # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1034        # as well. If not, the view op can be done directly on the tensor.
1035        linear_weight = (
1036            super().call_operator(
1037                exir_ops.edge.aten.view_copy.default,
1038                (
1039                    weight,
1040                    [weight_shape[0], K],
1041                ),
1042                kwargs,
1043                meta,
1044            )
1045            if isinstance(weight, ProxyValue)
1046            else weight.contiguous().view(weight_shape[0], K)
1047        )
1048        # From the previous check, if linear_weight is a FakeTensor, it has to be
1049        # a constant (if not, it would be a ProxyValue). Mark it as such.
1050        if isinstance(linear_weight, FakeTensor):
1051            linear_weight.constant = linear_weight
1052
1053        # Reshape the input from 3d to 2d tensor
1054        in_view = super().call_operator(
1055            exir_ops.edge.aten.view_copy.default,
1056            (
1057                in_tensor,
1058                [in_shape[0], K],
1059            ),
1060            kwargs,
1061            meta,
1062        )
1063        # Create the linear node, which multiplies the 2d input and weight
1064        # tensors, and adds the 1d bias to produce a 2d output.
1065        if quantized_op:
1066            (
1067                in_zero_point,
1068                weight_zero_point,
1069                bias_scale,
1070                out_scale,
1071                out_zero_point,
1072            ) = args[7:12]
1073            # If the multiplier and shift tensors are provided, use them.
1074            if (
1075                len(args) >= 14
1076                and isinstance(args[12], ProxyValue)
1077                and isinstance(args[13], ProxyValue)
1078            ):
1079                out_multiplier = args[12]
1080                out_shift = args[13]
1081            # If not, compute them.
1082            else:
1083                requantize_scale = bias_scale / out_scale
1084                (out_multiplier, out_shift) = quantize_tensor_multiplier(
1085                    requantize_scale
1086                )
1087            linear_args = (
1088                in_view,
1089                linear_weight,
1090                bias,
1091                in_zero_point,
1092                weight_zero_point,
1093                out_multiplier,
1094                out_shift,
1095                out_zero_point,
1096                None,
1097            )
1098        else:
1099            linear_args = (in_view, linear_weight, bias)
1100
1101        linear_res = super().call_operator(
1102            self.trivial_conv_op_to_linear_op[op],
1103            linear_args,
1104            kwargs,
1105            meta,
1106        )
1107        # Reshape the output of linear from 2d to 3d tensor
1108        out_res = super().call_operator(
1109            exir_ops.edge.aten.view_copy.default,
1110            (linear_res, list(out_shape)),
1111            kwargs,
1112            meta,
1113        )
1114        return out_res
1115
1116
1117def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int:
1118    """Canonicalize transpose ops so it gets easier to pattern-match and fuse transpose ops."""
1119    if dim < 0:
1120        # Keep transpose dimensions positive.
1121        dim += len(shape)
1122    return dim
1123
1124
1125class ExportPassWithTransposeHelper(ExportPass):
1126    def transpose_dims(
1127        self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int
1128    ) -> ProxyValue:
1129        """Helper function to transpose dims of a `proxy` with given `meta`."""
1130        shape = proxy.data.shape
1131        dim0, dim1 = (
1132            canonicalize_transposed_dim(dim0, shape),
1133            canonicalize_transposed_dim(dim1, shape),
1134        )
1135        dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
1136        return super().call_operator(
1137            exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
1138        )
1139
1140
1141@register_cadence_pass(CadencePassAttribute(opt_level=3))
1142class ForceChannelLastForConvPass(ExportPassWithTransposeHelper):
1143    def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1144        shape = proxy.to_tensor().shape
1145        if len(shape) == 3:
1146            return self.transpose_dims(proxy, meta, 1, -1)
1147        indices = list(range(len(shape)))
1148        permute_indices = [indices[0]] + indices[2:] + [indices[1]]
1149        return super().call_operator(
1150            exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1151        )
1152
1153    def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
1154        shape = proxy.to_tensor().shape
1155        if len(shape) == 3:
1156            return self.transpose_dims(proxy, meta, 1, -1)
1157        indices = list(range(len(shape)))
1158        permute_indices = [indices[0], indices[-1]] + indices[1:-1]
1159        return super().call_operator(
1160            exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta
1161        )
1162
1163    def call_operator(
1164        self,
1165        op,
1166        args: tuple[Argument, ...],
1167        kwargs: dict[str, Argument],
1168        meta: NodeMetadata,
1169    ) -> ProxyValue:
1170        if op not in {
1171            exir_ops.edge.cadence.convolution.default,
1172            exir_ops.edge.cadence.quantized_conv.default,
1173        }:
1174            return super().call_operator(op, args, kwargs, meta)
1175
1176        quantized_op = op == exir_ops.edge.cadence.quantized_conv.default
1177        channel_last_arg_index = 14 if quantized_op else 7
1178        channel_last = (
1179            args[channel_last_arg_index]
1180            if len(args) > channel_last_arg_index
1181            # Default is false (NCHW).
1182            else False
1183        )
1184        if channel_last:
1185            return super().call_operator(op, args, kwargs, meta)
1186
1187        input_proxy = cast(ProxyValue, args[0])
1188        weight_proxy = cast(ProxyValue, args[1])
1189        input_proxy = self.change_nchw_to_nhwc(input_proxy, meta)
1190        weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta)
1191
1192        new_args = (
1193            # Transposed input/weights.
1194            (input_proxy, weight_proxy)
1195            # All other args (bias, quant params, etc)
1196            + tuple(args[2:channel_last_arg_index])
1197            # Channel last.
1198            + (True,)
1199        )
1200        output_proxy = super().call_operator(op, new_args, kwargs, meta)
1201        nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta)
1202        return nchw_proxy
1203
1204
1205@register_cadence_pass(CadencePassAttribute(opt_level=3))
1206class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper):
1207    def call_operator(
1208        self,
1209        op,
1210        args: tuple[Argument, ...],
1211        kwargs: dict[str, Argument],
1212        meta: NodeMetadata,
1213    ) -> ProxyValue:
1214        if op not in {
1215            exir_ops.edge.aten.cat.default,
1216            exir_ops.edge.aten.slice_copy.Tensor,
1217        }:
1218            return super().call_operator(op, args, kwargs, meta)
1219        dim = cast(int, args[1]) if len(args) > 1 else 0
1220        output_shape = meta["val"].shape
1221        if dim < 0:
1222            # Keep dim positive.
1223            dim += len(output_shape)
1224
1225        if dim == 0 or math.prod(output_shape[:dim]) == 1:
1226            # Not needed if dim is already outermost or all dims before it are 1.
1227            return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta)
1228
1229        if op == exir_ops.edge.aten.slice_copy.Tensor:
1230            # Transpose -> slice.
1231            slice_args = (
1232                self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0),
1233                0,
1234            ) + args[2:]
1235            new_op = super().call_operator(op, slice_args, kwargs, meta)
1236        else:
1237            # (Transpose input0, Transpose input1, ...) -> cat.
1238            cat_in_tensors = [
1239                self.transpose_dims(t, meta, dim, 0)
1240                for t in cast(list[ProxyValue], args[0])
1241            ]
1242            new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta)
1243        # slice/cat -> transpose.
1244        return self.transpose_dims(new_op, meta, 0, dim)
1245
1246
1247@register_cadence_pass(CadencePassAttribute(opt_level=1))
1248class ReplaceConvWithIm2RowAndLinear(ExportPass):
1249    """
1250    Replace convolution where groups=1 with im2row followed by a linear op.
1251    """
1252
1253    # A map from the convolution op to the linear op that it should
1254    # decompose to.
1255    conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
1256        exir_ops.edge.cadence.convolution.default: exir_ops.edge.aten.linear.default,
1257        exir_ops.edge.cadence.quantized_conv.default: exir_ops.edge.cadence.quantized_linear.default,
1258    }
1259
1260    def call_operator(self, op, args, kwargs, meta):
1261        if op not in self.conv_op_to_linear_op:
1262            return super().call_operator(op, args, kwargs, meta)
1263
1264        # Get the relevant args from convolution node.
1265        quantized_op = op == exir_ops.edge.cadence.quantized_conv.default
1266        assert (len(args) == 8 and not quantized_op) or (
1267            len(args) >= 12 and quantized_op
1268        ), "Inconsistent args for convolution"
1269        (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7]
1270
1271        # We do not replace depthwise convolution with gemm yet.
1272        if groups != 1:
1273            return super().call_operator(op, args, kwargs, meta)
1274
1275        weight_shape = (
1276            weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1277        )
1278        # If this is a pointwise convolution, im2col will start dominating the
1279        # runtime. So we call convolution op for this case.
1280        if (
1281            all(x == 1 for x in weight_shape[2:])
1282            and all(x == 1 for x in stride)
1283            and all(x == 0 for x in padding)
1284            and all(x == 1 for x in dilation)
1285        ):
1286            return super().call_operator(op, args, kwargs, meta)
1287
1288        # Get the shapes
1289        out_shape = meta["val"].shape
1290        assert None not in {weight_shape, out_shape}
1291
1292        # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the
1293        # channel_last layout is specified by the channel_last arg of conv
1294        # op, which is either the last argument (15th) or implicitely False
1295        # if the op is quantized, or the last argument if not.
1296        channel_last = (
1297            (args[14] if len(args) == 15 else False) if quantized_op else args[-1]
1298        )
1299        # The weight tensor is [out_channels, in_channels, X] for NCHW layout,
1300        # and [out_channels, X, in_channels] for NHWC layout. Here, X is the
1301        # kernel_width for conv1d, and X = kernel_height * kernel_width for
1302        # conv2d. We extract X as the kernel_size for im2row.
1303        kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:])
1304        # If the convolution op was quantized, we need the input tensor's
1305        # zero_point for im2row. Otherwise in_zero_point defaults to a zero
1306        # tensor.
1307        in_zero_point = (
1308            (
1309                super().call_operator(
1310                    exir_ops.edge.aten.full.default,
1311                    (
1312                        [1],
1313                        args[7],
1314                    ),
1315                    {"dtype": torch.int32},
1316                    meta,
1317                )
1318                if isinstance(in_tensor.to_tensor(), FakeTensor)
1319                else get_zero_point(in_tensor.to_tensor())
1320            )
1321            if quantized_op
1322            else torch.tensor(0, dtype=torch.int32)
1323        )
1324        # im2row expects every kernel parameter to be 2d. So we extend the
1325        # parameters for conv1d by prepending their default values.
1326        stride = ([1] + stride) if len(stride) == 1 else stride
1327        padding = ([0] + padding) if len(padding) == 1 else padding
1328        dilation = ([1] + dilation) if len(dilation) == 1 else dilation
1329        kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size
1330        # Assert that kernel size does not have a 0
1331        assert 0 not in kernel_size
1332
1333        # Create an im2row node with the input. This will create a 2d matrix of
1334        # shape [out_height*out_weight, X*in_channels]. X is as defined in the
1335        # comment above.
1336        im2row_args = (
1337            in_tensor,
1338            kernel_size,
1339            dilation,
1340            padding,
1341            stride,
1342            in_zero_point,
1343            channel_last,
1344        )
1345        im2row = super().call_operator(
1346            exir_ops.edge.cadence.im2row.default,
1347            im2row_args,
1348            kwargs,
1349            meta,
1350        )
1351
1352        # Get the product of the >2 dims of the weight
1353        K = math.prod(weight_shape[1:])
1354
1355        # If weight is a ProxyValue, linear_weight needs to be the output of a
1356        # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1357        # as well. If not, the view op can be done directly on the tensor.
1358        linear_weight = (
1359            super().call_operator(
1360                exir_ops.edge.aten.view_copy.default,
1361                (
1362                    weight,
1363                    [weight_shape[0], K],
1364                ),
1365                kwargs,
1366                meta,
1367            )
1368            if isinstance(weight, ProxyValue)
1369            else weight.contiguous().view(weight_shape[0], K)
1370        )
1371        # From the previous check, if linear_weight is a FakeTensor, it has to be
1372        # a constant (if not, it would be a ProxyValue). Mark it as such.
1373        if isinstance(linear_weight, FakeTensor):
1374            linear_weight.constant = linear_weight
1375
1376        # Create the linear node, which multiplies the 3d input with 2d weight
1377        # tensors with bias addition. The outermost dimension of the input is
1378        # the batch size for linear op.
1379        if quantized_op:
1380            (
1381                in_zero_point,
1382                weight_zero_point,
1383                bias_scale,
1384                out_scale,
1385                out_zero_point,
1386            ) = args[7:12]
1387            # If the multiplier and shift tensors are provided, use them.
1388            if (
1389                len(args) >= 14
1390                and isinstance(args[12], ProxyValue)
1391                and isinstance(args[13], ProxyValue)
1392            ):
1393                out_multiplier = args[12]
1394                out_shift = args[13]
1395            # If not, compute them.
1396            else:
1397                requantize_scale = bias_scale / out_scale
1398                (out_multiplier, out_shift) = quantize_tensor_multiplier(
1399                    requantize_scale
1400                )
1401            linear_args = (
1402                im2row,
1403                linear_weight,
1404                bias,
1405                in_zero_point,
1406                weight_zero_point,
1407                out_multiplier,
1408                out_shift,
1409                out_zero_point,
1410                None,
1411            )
1412        else:
1413            linear_args = (im2row, linear_weight, bias)
1414        linear_res = super().call_operator(
1415            self.conv_op_to_linear_op[op],
1416            linear_args,
1417            kwargs,
1418            meta,
1419        )
1420        # The output of linear is a 3D tensor. However, the output is in NHWC
1421        # layout by default, because an input vector of size X is multiplied
1422        # with the weight matrix, i.e., column values are contiguous. If the
1423        # channel_last is False, we want to transpose this output.
1424        if not channel_last:
1425            linear_res = super().call_operator(
1426                exir_ops.edge.aten.transpose_copy.int,
1427                (linear_res, 1, 2),
1428                kwargs,
1429                meta,
1430            )
1431        # And finally, we want to view the 3D output of linear op as 4D tensor
1432        return super().call_operator(
1433            exir_ops.edge.aten.view_copy.default,
1434            (linear_res, list(out_shape)),
1435            kwargs,
1436            meta,
1437        )
1438
1439
1440@register_cadence_pass(CadencePassAttribute(opt_level=1))
1441class ReplaceTransposedConvWithLinearPass(ExportPass):
1442    """
1443    Replace transposed convolution where groups=1 with transposed_im2row
1444    followed by a linear op.
1445    """
1446
1447    # A map from the transposed_convolution op to the linear op that it should
1448    # decompose to.
1449    transposed_conv_op_to_linear_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
1450        exir_ops.edge.cadence.transposed_convolution.default: exir_ops.edge.aten.linear.default,
1451        exir_ops.edge.cadence.quantized_transposed_conv.default: exir_ops.edge.cadence.quantized_linear.default,
1452    }
1453
1454    def call_operator(self, op, args, kwargs, meta):
1455        if op not in self.transposed_conv_op_to_linear_op:
1456            return super().call_operator(op, args, kwargs, meta)
1457
1458        # Get the relevant args from transposed_convolution node.
1459        quantized_op = op == exir_ops.edge.cadence.quantized_transposed_conv.default
1460        assert len(args) == (
1461            16 if quantized_op else 9
1462        ), "Inconsistent args for transposed_convolution"
1463        (
1464            in_tensor,
1465            weight,
1466            bias,
1467            stride,
1468            padding,
1469            dilation,
1470            output_padding,
1471            groups,
1472        ) = args[0:8]
1473
1474        # We do not replace depthwise transposed_convolution with gemm yet.
1475        if groups != 1:
1476            return super().call_operator(op, args, kwargs, meta)
1477
1478        # Get the shapes
1479        out_shape = meta["val"].shape
1480        weight_shape = (
1481            weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape
1482        )
1483        assert None not in {weight_shape, out_shape}
1484
1485        # Determine if the transposed_convolution is NCHW or NHWC. The NHWC,
1486        # i.e., the channel_last layout is specified by the channel_last arg
1487        # of transposed_conv op, which is the last argument.
1488        channel_last = args[-1]
1489        # The weight tensor is [out_channels, in_channels, X] for NCHW layout,
1490        # and [out_channels, X, in_channels] for NHWC layout. Here, X is the
1491        # kernel_width for conv1d, and X = kernel_height * kernel_width for
1492        # conv2d. We extract X as the kernel_size for im2row.
1493        kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:])
1494        # If the transposed_convolution op was quantized, we need the input tensor's
1495        # zero_point for im2row. Otherwise in_zero_point defaults to a zero
1496        # tensor.
1497        in_zero_point = (
1498            get_zero_point(in_tensor.to_tensor())
1499            if quantized_op
1500            else torch.tensor(0, dtype=torch.int32)
1501        )
1502        # transposed_im2row expects every kernel parameter to be 2d. So we extend the
1503        # parameters for conv1d by prepending their default values.
1504        stride = ([1] + stride) if len(stride) == 1 else stride
1505        padding = ([0] + padding) if len(padding) == 1 else padding
1506        dilation = ([1] + dilation) if len(dilation) == 1 else dilation
1507        output_padding = (
1508            ([0] + output_padding) if len(output_padding) == 1 else output_padding
1509        )
1510        kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size
1511        # Assert that kernel size does not have a 0
1512        assert 0 not in kernel_size
1513
1514        # Create a transposed_im2row node with the input. This will create a 2d
1515        # matrix of shape [out_height*out_weight, X*in_channels]. X is as
1516        # defined in the comment above.
1517        transposed_im2row_args = (
1518            in_tensor,
1519            kernel_size,
1520            dilation,
1521            padding,
1522            stride,
1523            output_padding,
1524            in_zero_point,
1525            channel_last,
1526        )
1527        transposed_im2row = super().call_operator(
1528            exir_ops.edge.cadence.transposed_im2row.default,
1529            transposed_im2row_args,
1530            kwargs,
1531            meta,
1532        )
1533        # Reshape the weight to [out_channels, in_channels * X]
1534        K = math.prod(weight_shape[1:])
1535
1536        # If weight is a ProxyValue, linear_weight needs to be the output of a
1537        # graph operation (in this case a view_copy op) to be an explicit ProxyValue
1538        # as well. If not, the view op can be done directly on the tensor.
1539        linear_weight = (
1540            super().call_operator(
1541                exir_ops.edge.aten.view_copy.default,
1542                (
1543                    weight,
1544                    [weight_shape[0], K],
1545                ),
1546                kwargs,
1547                meta,
1548            )
1549            if isinstance(weight, ProxyValue)
1550            else weight.contiguous().view(weight_shape[0], K)
1551        )
1552        # From the previous check, if linear_weight is a FakeTensor, it has to be
1553        # a constant (if not, it would be a ProxyValue). Mark it as such.
1554        if isinstance(linear_weight, FakeTensor):
1555            linear_weight.constant = linear_weight
1556
1557        # Create the linear node, which multiplies the 3d input with 2d weight
1558        # tensors with bias addition. The outermost dimension of the input is
1559        # the batch size for linear op.
1560        if quantized_op:
1561            (
1562                in_zero_point,
1563                weight_zero_point,
1564                bias_scale,
1565                out_scale,
1566                out_zero_point,
1567            ) = args[8:13]
1568            requantize_scale = bias_scale / out_scale
1569            (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale)
1570            linear_args = (
1571                transposed_im2row,
1572                linear_weight,
1573                bias,
1574                in_zero_point,
1575                weight_zero_point,
1576                out_multiplier,
1577                out_shift,
1578                out_zero_point,
1579                None,
1580            )
1581        else:
1582            linear_args = (transposed_im2row, linear_weight, bias)
1583        linear_res = super().call_operator(
1584            self.transposed_conv_op_to_linear_op[op],
1585            linear_args,
1586            kwargs,
1587            meta,
1588        )
1589        # The output of linear is a 3D tensor. However, the output is in NHWC
1590        # layout by default, because an input vector of size X is multiplied
1591        # with the weight matrix, i.e., column values are contiguous. If the
1592        # channel_last is False, we want to transpose this output.
1593        if not channel_last:
1594            linear_res = super().call_operator(
1595                exir_ops.edge.aten.transpose_copy.int,
1596                (linear_res, 1, 2),
1597                kwargs,
1598                meta,
1599            )
1600        # And finally, we want to view the 3D output of linear op as 4D tensor
1601        return super().call_operator(
1602            exir_ops.edge.aten.view_copy.default,
1603            (linear_res, list(out_shape)),
1604            kwargs,
1605            meta,
1606        )
1607
1608
1609@register_cadence_pass(CadencePassAttribute(opt_level=1))
1610class ReplaceNopTransposeOrPermuteWithViewPass(ExportPass):
1611    """
1612    If the transpose/permute op does not change the byte order (e.g.,
1613    transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced
1614    by view op.
1615    """
1616
1617    def call_operator(self, op, args, kwargs, meta):
1618        # Only proceed for transpose or permute op.
1619        if op not in {
1620            exir_ops.edge.aten.transpose_copy.int,
1621            exir_ops.edge.aten.permute_copy.default,
1622        }:
1623            return super().call_operator(op, args, kwargs, meta)
1624
1625        # Get the input tensor and shape
1626        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1627        in_shape = in_tensor.shape
1628        # Get the output tensor shape
1629        out_shape = meta["val"].shape
1630
1631        if op == exir_ops.edge.aten.transpose_copy.int:
1632            # Get the two dims to be transposed
1633            dim0 = args[1] if args[1] >= 0 else in_tensor.dim() + args[1]
1634            dim1 = args[2] if args[2] >= 0 else in_tensor.dim() + args[2]
1635            # We can eliminate transpose if (a) the size at dim0 and dim1 is 1;
1636            # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive.
1637            both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1
1638            either_one_and_consecutive = abs(dim0 - dim1) == 1 and (
1639                in_shape[dim0] == 1 or in_shape[dim1] == 1
1640            )
1641            if both_one or either_one_and_consecutive:
1642                new_args = (args[0], list(out_shape))
1643                return super().call_operator(
1644                    exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta
1645                )
1646
1647        elif op == exir_ops.edge.aten.permute_copy.default:
1648            old_dims = list(range(in_tensor.dim()))
1649            new_dims = args[1]
1650            # If the permute does not change anything, return the input as output.
1651            if old_dims == new_dims:
1652                return args[0]
1653            # Get the old dim order, and the permuted dim order for all dims that
1654            # are not 1.
1655            old_order = [
1656                dim for dim, shape_dim in zip(old_dims, in_shape) if shape_dim != 1
1657            ]
1658            new_order = [
1659                dim for dim, shape_dim in zip(new_dims, out_shape) if shape_dim != 1
1660            ]
1661            # If the byte ordering for non-unit dims is unchanged, this is a nop.
1662            if old_order == new_order:
1663                new_args = (args[0], list(out_shape))
1664                return super().call_operator(
1665                    exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta
1666                )
1667
1668        return super().call_operator(op, args, kwargs, meta)
1669
1670    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1671        result = super().call(graph_module)
1672        result = FuseCascadedViewOps()(result.graph_module)
1673        assert result is not None
1674        return result
1675
1676
1677@register_cadence_pass(CadencePassAttribute(opt_level=1))
1678class ReplaceLinearWithFullyConnectedOpPass(ExportPass):
1679    """
1680    If the input of linear/quantized_linear op is a vector, replace it with
1681    fully_connected op.
1682    """
1683
1684    linear_to_fc_op: Dict[EdgeOpOverload, EdgeOpOverload] = {
1685        exir_ops.edge.aten.linear.default: exir_ops.edge.cadence.fully_connected.default,
1686        exir_ops.edge.cadence.quantized_linear.default: exir_ops.edge.cadence.quantized_fully_connected.default,
1687    }
1688
1689    def call_operator(self, op, args, kwargs, meta):
1690        # Only proceed for linear or quantized_linear ops.
1691        if op not in self.linear_to_fc_op:
1692            return super().call_operator(op, args, kwargs, meta)
1693
1694        # Extract the input tensor
1695        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1696        leading_dims = math.prod(in_tensor.shape[:-1])
1697        # If the tensor is not a vector, do nothing.
1698        if leading_dims != 1:
1699            return super().call_operator(op, args, kwargs, meta)
1700
1701        # If the op is quantized::linear, but per-channel quantized, bail.
1702        if op == exir_ops.edge.cadence.quantized_linear.default:
1703            weight = args[1].to_tensor() if isinstance(args[1], ProxyValue) else args[1]
1704            if weight.shape != [1]:
1705                return super().call_operator(op, args, kwargs, meta)
1706
1707        # Replace the linear with fully connected op
1708        return super().call_operator(
1709            self.linear_to_fc_op[op],
1710            args,
1711            kwargs,
1712            meta,
1713        )
1714
1715
1716@register_cadence_pass(CadencePassAttribute(opt_level=0))
1717class ReplaceScalarWithTensorArgPass(ExportPass):
1718    """
1719    For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
1720    replace the scalar arg with Tensor arg.
1721    """
1722
1723    scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
1724        exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
1725        exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
1726        exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
1727        exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
1728    }
1729
1730    def get_replacement(self, op, args, kwargs, meta):
1731        return super().call_operator(
1732            # Replace with .Tensor variant.
1733            op=self.scalar_to_tensor_ops[op],
1734            args=(
1735                # Tensor arg.
1736                args[0],
1737                # Scalar arg - replace with aten.full tensor.
1738                super().call_operator(
1739                    exir_ops.edge.aten.full.default,
1740                    args=(
1741                        (1,),
1742                        args[1],
1743                    ),
1744                    kwargs={"dtype": args[0].to_tensor().dtype},
1745                    meta=meta,
1746                ),
1747                # Other args.
1748                *args[2:],
1749            ),
1750            kwargs=kwargs,
1751            meta=meta,
1752        )
1753
1754    def call_operator(self, op, args, kwargs, meta):
1755        if op not in self.scalar_to_tensor_ops:
1756            return super().call_operator(op, args, kwargs, meta)
1757
1758        # There must be exactly 2 args (3 for add and sub containing alpha)
1759        assert len(args) == 2 or len(args) == 3
1760
1761        # If there are two args, just replace the op.
1762        if len(args) == 2:
1763            return self.get_replacement(op, args, kwargs, meta)
1764
1765        # In case the op has three args, it must be scalar add/sub op.
1766        if (
1767            op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
1768            or "alpha" in kwargs
1769        ):
1770            return super().call_operator(op, args, kwargs, meta)
1771
1772        return self.get_replacement(op, args, kwargs, meta)
1773
1774
1775@register_cadence_pass(CadencePassAttribute(opt_level=0))
1776class ReplaceScalarTensorWithFullPass(ExportPass):
1777    """
1778    aten.scalar_tensor can be replaced by aten.full with a shape of [1].
1779    scalar_tensor is not supported, so this is an opt_level=0 pass.
1780    """
1781
1782    def call_operator(
1783        self,
1784        op,
1785        args: Tuple[Argument, ...],
1786        kwargs: Dict[str, Argument],
1787        meta: NodeMetadata,
1788    ) -> ProxyValue:
1789        if op not in {
1790            exir_ops.edge.aten.scalar_tensor.default,
1791            torch.ops.aten.scalar_tensor.default,
1792        }:
1793            return super().call_operator(op, args, kwargs, meta)
1794
1795        return super().call_operator(
1796            exir_ops.edge.aten.full.default,
1797            (
1798                [1],
1799                args[0],
1800            ),
1801            {"dtype": torch.float32},
1802            meta,
1803        )
1804
1805
1806@register_cadence_pass(CadencePassAttribute(opt_level=0))
1807class ReplaceFullLikeWithFullPass(ExportPass):
1808    """
1809    aten.full_like can be replaced by aten.full with the shape of the arg tensor.
1810    full_like is not supported, so this is an opt_level=0 pass.
1811    """
1812
1813    def call_operator(self, op, args, kwargs, meta):
1814        if op not in {
1815            exir_ops.edge.aten.full_like.default,
1816        }:
1817            return super().call_operator(op, args, kwargs, meta)
1818
1819        # Get the shape of the "like" tensor, and pass that in to the full op.
1820        return super().call_operator(
1821            exir_ops.edge.aten.full.default,
1822            (
1823                (
1824                    args[0].to_tensor().shape
1825                    if isinstance(args[0], ProxyValue)
1826                    else args[0].shape
1827                ),
1828                args[1],
1829            ),
1830            {},
1831            meta,
1832        )
1833
1834
1835@register_cadence_pass(CadencePassAttribute(opt_level=0))
1836class ReplaceInfArgInFullWithValuePass(ExportPass):
1837    """
1838    aten.full allows "-inf" and "inf" as inputs. The profiler cannot
1839    handle that, so replace them with the maximum value of the type.
1840    """
1841
1842    def call_operator(self, op, args, kwargs, meta):
1843        if op not in {
1844            exir_ops.edge.aten.full.default,
1845        }:
1846            return super().call_operator(op, args, kwargs, meta)
1847
1848        new_args = list(args)
1849
1850        if args[1] == float("-inf"):
1851            new_args[1] = torch.finfo(torch.float32).min
1852        elif args[1] == float("inf"):
1853            new_args[1] = torch.finfo(torch.float32).max
1854
1855        return super().call_operator(op, tuple(new_args), kwargs, meta)
1856
1857
1858@register_cadence_pass(CadencePassAttribute(opt_level=0))
1859class ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass(ExportPass):
1860    """
1861    Replace the aten.linalg_vector_norm op with a custom op.
1862    aten.linalg_vector_norm is not supported by Jarvis, so we
1863    need to replace it with native_batch_norm at all optimization levels.
1864    """
1865
1866    def call_operator(self, op, args, kwargs, meta):
1867        if op != exir_ops.edge.aten.linalg_vector_norm.default:
1868            return super().call_operator(op, args, kwargs, meta)
1869
1870        assert (
1871            len(args) == 1
1872        ), "aten.linalg_vector_norm should have 1 argument (a tensor), we do not support any custom variants"
1873
1874        return super().call_operator(
1875            exir_ops.edge.cadence.linalg_vector_norm.default,
1876            args,
1877            kwargs,
1878            meta,
1879        )
1880
1881
1882@register_cadence_pass(CadencePassAttribute(opt_level=1))
1883class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
1884    """
1885    Replace ops with single element arguments (size = [1]) with overloads that accept scalar ints/floats.
1886    """
1887
1888    # Keep track of which operators and arguments are being replaced.
1889    replaced_scalar_args: dict[
1890        EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]]
1891    ] = {
1892        exir_ops.edge.cadence.quantized_conv: (
1893            exir_ops.edge.cadence.quantized_conv.per_tensor,
1894            [8, 9, 12, 13],
1895        ),
1896        exir_ops.edge.cadence.quantized_layer_norm: (
1897            exir_ops.edge.cadence.quantized_layer_norm.per_tensor,
1898            [1, 2],
1899        ),
1900        exir_ops.edge.cadence.quantized_linear: (
1901            exir_ops.edge.cadence.quantized_linear.per_tensor,
1902            [4, 5, 6],
1903        ),
1904        exir_ops.edge.cadence.quantized_relu: (
1905            exir_ops.edge.cadence.quantized_relu.per_tensor,
1906            [1, 3, 4],
1907        ),
1908    }
1909
1910    def call_operator(self, op, args, kwargs, meta):
1911        op_edge_overload_packet = get_edge_overload_packet(op)
1912
1913        if op_edge_overload_packet not in self.replaced_scalar_args:
1914            return super().call_operator(op, args, kwargs, meta)
1915
1916        # Get all the args that need to be replaced.
1917        new_op, args_to_be_replaced = self.replaced_scalar_args[op_edge_overload_packet]
1918
1919        updated_args = list(args)
1920        for op_arg_index in args_to_be_replaced:
1921            arg = args[op_arg_index]
1922            if not isinstance(arg, ProxyValue):
1923                return super().call_operator(op, args, kwargs, meta)
1924
1925            if not arg.is_tensor():
1926                return super().call_operator(op, args, kwargs, meta)
1927
1928            if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full:
1929                # Only replace if arg generated by a full op.
1930                return super().call_operator(op, args, kwargs, meta)
1931
1932            if tuple(arg.node.args[0]) != (1,):
1933                # Only replace if the size of the full op is [1].
1934                return super().call_operator(op, args, kwargs, meta)
1935
1936            updated_args[op_arg_index] = arg.node.args[1]
1937
1938        return super().call_operator(
1939            new_op,
1940            tuple(updated_args),
1941            kwargs,
1942            meta,
1943        )
1944
1945
1946@register_cadence_pass(CadencePassAttribute(opt_level=0))
1947class ReplaceAtenAvgPoolWithJarvisAvgPoolPass(ExportPass):
1948    """
1949    Replace the aten avg_pool op with the jarvis custom avg_pool2d op.
1950    """
1951
1952    def call_operator(self, op, args, kwargs, meta):
1953        # Only continue for avg_pool op
1954        if op not in {
1955            exir_ops.edge.aten.avg_pool1d.default,
1956            exir_ops.edge.aten.avg_pool2d.default,
1957        }:
1958            return super().call_operator(op, args, kwargs, meta)
1959
1960        # Determine if the op is avg_pool1d or avg_pool2d
1961        avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default
1962        # Get the input tensor
1963        in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
1964
1965        # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is
1966        # quantized, pass its zero_point tensor as arg to the custom avg_pool2d.
1967        # stride, padding, ceil_mode, count_include_pad, divisor_override, are
1968        # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout,
1969        # and is False by default.
1970        kernel_size = args[1]
1971        stride = args[2] if len(args) >= 3 else [1, 1]
1972        padding = args[3] if len(args) >= 4 else [0, 0]
1973        ceil_mode = args[4] if len(args) >= 5 else False
1974        count_include_pad = args[5] if len(args) >= 6 else True
1975        divisor_override = args[6] if len(args) >= 7 else None
1976        zero_point = torch.tensor(0, dtype=torch.int32)
1977
1978        # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d
1979        # tensor.
1980        if avg_pool1d:
1981            in_shape = list(in_tensor.shape)
1982            assert len(in_shape) == 3, "Expected 3d input for avg_pool1d"
1983            in_shape.insert(2, 1)
1984            out_shape = meta["val"].shape
1985            in_view_op = super().call_operator(
1986                exir_ops.edge.aten.view_copy.default,
1987                (in_tensor, in_shape),
1988                kwargs,
1989                meta,
1990            )
1991            # Extend the kernel_size, stride and padding to 2d
1992            kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size
1993            stride = [1] + stride if len(stride) == 1 else stride
1994            padding = [0] + padding if len(padding) == 1 else padding
1995
1996        # Create a new avg_pool node with the updated args
1997        new_args = (
1998            in_view_op if avg_pool1d else args[0],
1999            kernel_size,
2000            stride,
2001            padding,
2002            ceil_mode,
2003            count_include_pad,
2004            divisor_override,
2005            zero_point,
2006            False,
2007        )
2008        avg_pool2d_op = super().call_operator(
2009            exir_ops.edge.cadence.avg_pool2d.default,
2010            new_args,
2011            kwargs,
2012            meta,
2013        )
2014
2015        # If the node was avg_pool1d, we again reshape the 4d output to 3d output
2016        return (
2017            super().call_operator(
2018                exir_ops.edge.aten.view_copy.default,
2019                (avg_pool2d_op, list(out_shape)),
2020                kwargs,
2021                meta,
2022            )
2023            if avg_pool1d
2024            else avg_pool2d_op
2025        )
2026
2027
2028@register_cadence_pass(CadencePassAttribute(opt_level=1))
2029class ReplaceIm2RowWithViewPass(ExportPass):
2030    def can_replace(self, op, args, kwargs, meta) -> bool:
2031        if op != exir_ops.edge.cadence.im2row.default:
2032            return False
2033
2034        # Check if im2row applies padding. If yes, we cannot replace it with view.
2035        pad = cast(tuple[int, ...], args[3])
2036        if any(p != 0 for p in pad):
2037            return False
2038
2039        # Check if im2row has dilation. If yes, we cannot replace it with view.
2040        dilation = cast(tuple[int, ...], args[2])
2041        if any(d != 1 for d in dilation):
2042            return False
2043
2044        # im2row works on 3D or 4D tensors.
2045        # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions.
2046        output_shape = meta["val"].shape
2047        if math.prod(output_shape[1:-1]) == 1:
2048            return True
2049
2050        return False
2051
2052    def call_operator(
2053        self,
2054        op,
2055        args: tuple[Argument, ...],
2056        kwargs: dict[str, Argument],
2057        meta: NodeMetadata,
2058    ) -> ProxyValue:
2059        if op != exir_ops.edge.cadence.im2row.default:
2060            return super().call_operator(op, args, kwargs, meta)
2061
2062        if not self.can_replace(op, args, kwargs, meta):
2063            return super().call_operator(op, args, kwargs, meta)
2064
2065        output_shape = meta["val"].shape
2066        return super().call_operator(
2067            exir_ops.edge.aten.view_copy.default,
2068            (args[0], tuple(output_shape)),
2069            kwargs,
2070            meta,
2071        )
2072
2073
2074# This class encapsulates all the functions that replace/switch one op in the
2075# graph with another.
2076class CadenceReplaceOpsInGraph:
2077    passes = [
2078        ReplaceFunctionallyEquivalentOpTargets,
2079        ReplaceTCopyWithTransposePass,
2080        ReplacePermuteWithTransposePass,
2081        ReplaceScalarWithTensorArgPass,
2082        ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2083        ReplaceMMWithAddMMPass,
2084        ReplaceSqueezeAndUnsqueezeWithViewPass,
2085        ReplaceAddMMWithLinearPass,
2086        RemoveNopSelectOpPass,
2087        ReplaceSelectWithViewOpPass,
2088        ReplaceRepeatWithCatPass,
2089        ReplacePadWithCatPass,
2090        ReplaceConstantPadNdWithSlicePass,
2091        ReplaceConvWithChannelLastConvPass,
2092        ReplaceAtenConvolutionWithJarvisConvolutionPass,
2093        ForceChannelLastForConvPass,
2094        ReplaceTrivialConvWithLinear,
2095        ReplaceConvWithIm2RowAndLinear,
2096        ReplaceTransposedConvWithLinearPass,
2097        # This pass should be after passes that replace conv -> im2row + linear.
2098        ReplaceIm2RowWithViewPass,
2099        MakeSliceAndCatDimOutermostPass,
2100        ReplaceNopTransposeOrPermuteWithViewPass,
2101        ReplaceLinearWithFullyConnectedOpPass,
2102        ReplaceScalarTensorWithFullPass,
2103        ReplaceFullLikeWithFullPass,
2104        ReplaceInfArgInFullWithValuePass,
2105        ReplaceLogicalNotBooleanWhereWithWherePass,
2106        ReplacePT2QuantWithCadenceQuantPass,
2107        ReplacePT2DequantWithCadenceDequantPass,
2108        ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
2109        ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
2110        ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2111    ]
2112