1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import cast, Optional 10 11import torch.fx 12from executorch.backends.arm.tosa_quant_utils import is_node_quantized 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, PassResult 15from torch._ops import OpOverload 16 17 18def conv_remainder(input_length, pad, dilation, weight, stride): 19 """ 20 Returns the size 21 """ 22 return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride 23 24 25def insert_q_dq_pair( 26 graph: torch.fx.Graph, 27 anchor: torch.fx.Node, 28 q_params: tuple, 29): 30 with graph.inserting_after(anchor): 31 q = create_node( 32 graph=graph, 33 op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 34 args=(), # We add the argument last 35 ) 36 q.meta = anchor.meta 37 38 with graph.inserting_after(q): 39 dq = create_node( 40 graph=graph, 41 op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 42 args=(q,) + q_params, 43 ) 44 dq.meta = q.meta 45 46 anchor.replace_all_uses_with(dq) 47 # We add this last so the replace all uses above does not replace the quantized 48 # node's first use 49 q.args = (anchor,) + q_params 50 return dq 51 52 53def create_node( 54 graph: torch.fx.Graph, 55 op_target: OpOverload, 56 args: tuple = (), 57 kwargs: Optional[dict] = None, 58): 59 return graph.create_node( 60 "call_function", 61 op_target, 62 args=args, 63 kwargs=kwargs or {}, 64 ) 65 66 67class SizeAdjustConv2DPass(ExportPass): 68 """ 69 Adjust the convolution input size to match perfectly with the 70 weight size, padding, stride and dilation parameters. 71 This is done by inserting a slice op to remove the uneven end of the input. 72 """ 73 74 conv2d_op = exir_ops.edge.aten.convolution.default 75 slice_op = exir_ops.edge.aten.slice_copy.Tensor 76 77 def call(self, graph_module: torch.fx.GraphModule): 78 graph = graph_module.graph 79 modified_graph = False 80 for node in graph.nodes: 81 if node.op != "call_function": 82 continue 83 if node.target != self.conv2d_op: 84 continue 85 86 conv_node = cast(torch.fx.Node, node) 87 input_node, weight, _, stride_hw, pad_hw, dilation_hw, _, _, _ = ( 88 conv_node.args 89 ) 90 weight_shape = cast(torch.fx.Node, weight).meta["val"].shape 91 input_shape = cast(torch.fx.Node, input_node).meta["val"].shape 92 93 slice_args = [] 94 for stride, pad, dilation, dim in zip( 95 cast(list, stride_hw), 96 cast(list, pad_hw), 97 cast(list, dilation_hw), 98 (2, 3), 99 ): 100 remainder = conv_remainder( 101 input_shape[dim], pad, dilation, weight_shape[dim], stride 102 ) 103 if remainder > pad: 104 adjustment = remainder - pad 105 args = (dim, 0, input_shape[dim] - adjustment) 106 slice_args.append(args) 107 if len(slice_args) == 0: 108 continue 109 110 with graph_module.graph.inserting_before(node): 111 last_node = cast(torch.fx.Node, input_node) 112 for args in slice_args: 113 slice_node = graph.create_node( 114 "call_function", self.slice_op, (last_node,) + args 115 ) 116 if is_node_quantized(last_node): 117 q_params = last_node.args[1:] 118 dq_node = insert_q_dq_pair( 119 graph_module.graph, slice_node, q_params 120 ) 121 last_node = dq_node 122 else: 123 last_node = slice_node 124 conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) 125 modified_graph = True 126 127 if modified_graph: 128 graph_module = super().call(graph_module).graph_module 129 graph.eliminate_dead_code() 130 graph_module.recompile() 131 return PassResult(graph_module, True) 132