xref: /aosp_15_r20/external/executorch/backends/arm/_passes/size_adjust_conv2d_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import cast, Optional
10
11import torch.fx
12from executorch.backends.arm.tosa_quant_utils import is_node_quantized
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, PassResult
15from torch._ops import OpOverload
16
17
18def conv_remainder(input_length, pad, dilation, weight, stride):
19    """
20    Returns the size
21    """
22    return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
23
24
25def insert_q_dq_pair(
26    graph: torch.fx.Graph,
27    anchor: torch.fx.Node,
28    q_params: tuple,
29):
30    with graph.inserting_after(anchor):
31        q = create_node(
32            graph=graph,
33            op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
34            args=(),  # We add the argument last
35        )
36        q.meta = anchor.meta
37
38    with graph.inserting_after(q):
39        dq = create_node(
40            graph=graph,
41            op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
42            args=(q,) + q_params,
43        )
44        dq.meta = q.meta
45
46    anchor.replace_all_uses_with(dq)
47    # We add this last so the replace all uses above does not replace the quantized
48    # node's first use
49    q.args = (anchor,) + q_params
50    return dq
51
52
53def create_node(
54    graph: torch.fx.Graph,
55    op_target: OpOverload,
56    args: tuple = (),
57    kwargs: Optional[dict] = None,
58):
59    return graph.create_node(
60        "call_function",
61        op_target,
62        args=args,
63        kwargs=kwargs or {},
64    )
65
66
67class SizeAdjustConv2DPass(ExportPass):
68    """
69    Adjust the convolution input size to match perfectly with the
70    weight size, padding, stride and dilation parameters.
71    This is done by inserting a slice op to remove the uneven end of the input.
72    """
73
74    conv2d_op = exir_ops.edge.aten.convolution.default
75    slice_op = exir_ops.edge.aten.slice_copy.Tensor
76
77    def call(self, graph_module: torch.fx.GraphModule):
78        graph = graph_module.graph
79        modified_graph = False
80        for node in graph.nodes:
81            if node.op != "call_function":
82                continue
83            if node.target != self.conv2d_op:
84                continue
85
86            conv_node = cast(torch.fx.Node, node)
87            input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = (
88                conv_node.args
89            )
90            weight_shape = cast(torch.fx.Node, weight).meta["val"].shape
91            input_shape = cast(torch.fx.Node, input_node).meta["val"].shape
92
93            slice_args = []
94            for stride, pad, dilation, dim in zip(
95                cast(list, stride_hw),
96                cast(list, pad_hw),
97                cast(list, dilation_hw),
98                (2, 3),
99            ):
100                remainder = conv_remainder(
101                    input_shape[dim], pad, dilation, weight_shape[dim], stride
102                )
103                if remainder > pad:
104                    adjustment = remainder - pad
105                    args = (dim, 0, input_shape[dim] - adjustment)
106                    slice_args.append(args)
107            if len(slice_args) == 0:
108                continue
109
110            with graph_module.graph.inserting_before(node):
111                last_node = cast(torch.fx.Node, input_node)
112                for args in slice_args:
113                    slice_node = graph.create_node(
114                        "call_function", self.slice_op, (last_node,) + args
115                    )
116                    if is_node_quantized(last_node):
117                        q_params = last_node.args[1:]
118                        dq_node = insert_q_dq_pair(
119                            graph_module.graph, slice_node, q_params
120                        )
121                        last_node = dq_node
122                    else:
123                        last_node = slice_node
124                conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
125                modified_graph = True
126
127        if modified_graph:
128            graph_module = super().call(graph_module).graph_module
129            graph.eliminate_dead_code()
130            graph_module.recompile()
131        return PassResult(graph_module, True)
132