1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 get_tensor_value, 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 16 XNNArgMaxPooling2d, 17 XNNGraph, 18 XNNMaxPooling2d, 19 XNode, 20) 21from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node 22 23 24@register_node_visitor 25class MaxDim(NodeVisitor): 26 target = "aten.amax.default" 27 28 def __init__(self, *args) -> None: 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 xnn_graph: XNNGraph, 35 vals_to_ids: Dict[torch.fx.Node, int], 36 debug_handle: int, 37 ) -> None: 38 39 check_or_raise( 40 len(node.args) == 3, 41 "amax.default only supports keep_dim == True", 42 ) 43 44 dim_val = cast(int, node.args[1]) 45 check_or_raise( 46 dim_val == 2 or dim_val == 3, 47 "amax.default only supports dim == 2 or dim == 3", 48 ) 49 50 input_id = vals_to_ids[get_input_node(node, 0)] 51 52 self.define_nodes_tensor_inputs_outputs( 53 node, xnn_graph, vals_to_ids, convert_to_nhwc=True 54 ) 55 56 output_id = vals_to_ids[node] 57 58 input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims 59 check_or_raise( 60 len(input_shape) == 4, "Require input to max.dim be 4 dimensional" 61 ) 62 63 # This is in NHWC 64 pooling_height = 1 65 pooling_width = 1 66 stride_height = 1 67 stride_width = 1 68 if dim_val == 2: 69 pooling_height = input_shape[1] 70 pooling_width = 1 71 stride_height = input_shape[1] 72 elif dim_val == 3: 73 pooling_height = 1 74 pooling_width = input_shape[2] 75 stride_width = input_shape[2] 76 77 ser_node = XNode( 78 xnode_union=XNNMaxPooling2d( 79 padding_top=0, 80 padding_right=0, 81 padding_bottom=0, 82 padding_left=0, 83 pooling_height=pooling_height, 84 pooling_width=pooling_width, 85 stride_height=stride_height, 86 stride_width=stride_width, 87 dilation_height=1, 88 dilation_width=1, 89 input_id=input_id, 90 output_id=output_id, 91 flags=0, 92 ), 93 debug_handle=debug_handle, 94 ) 95 96 xnn_graph.xnodes.append(ser_node) 97 98 99@register_node_visitor 100class ArgMaxDim(NodeVisitor): 101 target = "aten.max.dim" 102 103 def __init__(self, *args) -> None: 104 super().__init__(*args) 105 106 def define_node( 107 self, 108 node: torch.fx.Node, 109 xnn_graph: XNNGraph, 110 vals_to_ids: Dict[torch.fx.Node, int], 111 debug_handle: int, 112 ) -> None: 113 114 check_or_raise( 115 len(node.args) == 3, 116 "max.dim only supports keep_dim == True", 117 ) 118 119 dim_val = cast(int, node.args[1]) 120 check_or_raise( 121 dim_val == 2 or dim_val == 3, 122 "max.dim only supports dim == 2 or dim == 3", 123 ) 124 125 # node.meta["val"] is a tuple (values_tensor, indices_tensor) 126 # We don't care about how it is defined, so we can adjust val to be a 127 # single tensor rather than a tuple arbitrarily just to make 128 # define_nodes_tensor_inputs_outputs work 129 original_val = node.meta["val"] 130 node.meta["val"] = original_val[0] 131 132 self.define_nodes_tensor_inputs_outputs( 133 node, xnn_graph, vals_to_ids, convert_to_nhwc=True 134 ) 135 for user in node.users: 136 self.define_nodes_tensor_inputs_outputs( 137 user, xnn_graph, vals_to_ids, convert_to_nhwc=True 138 ) 139 140 # Restore node.meta["val"] 141 node.meta["val"] = original_val 142 143 input_id = vals_to_ids[get_input_node(node, 0)] 144 145 input_shape = get_tensor_value(xnn_graph.xvalues[input_id]).dims 146 check_or_raise( 147 len(input_shape) == 4, "Require input to max.dim be 4 dimensional" 148 ) 149 150 users = list(node.users.keys()) 151 152 if len(users) != 2: 153 raise AssertionError( 154 f"Invalid number of users for max.dim (Expected 2, Got: {len(users)})" 155 ) 156 157 values_node = None 158 indices_node = None 159 160 for getitem_node in users: 161 taget_name = cast(torch._ops.OpOverload, getitem_node.target).__name__ 162 if taget_name != "getitem": 163 raise AssertionError( 164 f"Expected max node's user to be getitem, got: {taget_name}" 165 ) 166 167 if getitem_node.args[1] == 0: 168 values_node = getitem_node 169 elif getitem_node.args[1] == 1: 170 indices_node = getitem_node 171 172 if values_node is None or indices_node is None: 173 raise AssertionError( 174 f"Expected max node's getitem args to be 1 and 2, got: {[user.args[1] for user in users]}" 175 ) 176 177 output_index_id = vals_to_ids[indices_node] 178 output_value_id = vals_to_ids[values_node] 179 180 # This is in NHWC 181 pooling_height = 1 182 pooling_width = 1 183 if dim_val == 2: 184 pooling_height = input_shape[1] 185 pooling_width = 1 186 elif dim_val == 3: 187 pooling_height = 1 188 pooling_width = input_shape[2] 189 190 ser_node = XNode( 191 xnode_union=XNNArgMaxPooling2d( 192 padding_top=0, 193 padding_right=0, 194 padding_bottom=0, 195 padding_left=0, 196 pooling_height=pooling_height, 197 pooling_width=pooling_width, 198 input_id=input_id, 199 output_value_id=output_value_id, 200 output_index_id=output_index_id, 201 flags=0, 202 ), 203 debug_handle=debug_handle, 204 ) 205 206 xnn_graph.xnodes.append(ser_node) 207