1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 check_or_raise, 12 get_tensor_value, 13 NodeVisitor, 14 register_node_visitor, 15) 16from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 17 XNNGraph, 18 XNNMaxPooling2d, 19 XNode, 20) 21from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS 22 23 24@register_node_visitor 25class MaxPooling2d(NodeVisitor): 26 target = "aten.max_pool2d.default" 27 28 def __init__(self, *args) -> None: 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 xnn_graph: XNNGraph, 35 vals_to_ids: Dict[torch.fx.Node, int], 36 debug_handle: int, 37 ) -> None: 38 39 self.define_nodes_tensor_inputs_outputs( 40 node, xnn_graph, vals_to_ids, convert_to_nhwc=True 41 ) 42 kwargs = {} 43 44 kwargs["input_id"] = vals_to_ids[node.all_input_nodes[0]] 45 46 input_shape = get_tensor_value(xnn_graph.xvalues[kwargs["input_id"]]).dims 47 check_or_raise(len(input_shape) == 4, "Require input to be 4 dimensional") 48 49 # output 50 kwargs["input_id"] = vals_to_ids[node.all_input_nodes[0]] 51 kwargs["output_id"] = vals_to_ids[node] 52 53 # kernel info 54 kernal_shape = cast(List[int], node.args[1]) 55 kwargs["pooling_height"] = kernal_shape[0] 56 kwargs["pooling_width"] = kernal_shape[1] 57 58 # stride info 59 stride = cast(List[int], node.args[2]) 60 kwargs["stride_height"] = stride[0] 61 kwargs["stride_width"] = stride[1] 62 63 # padding info 64 kwargs["padding_top"] = 0 65 kwargs["padding_right"] = 0 66 kwargs["padding_bottom"] = 0 67 kwargs["padding_left"] = 0 68 69 if len(node.args) > 3: 70 padding_shape = cast(List[int], node.args[3]) 71 kwargs["padding_top"] = padding_shape[0] 72 kwargs["padding_right"] = padding_shape[1] 73 kwargs["padding_bottom"] = padding_shape[0] 74 kwargs["padding_left"] = padding_shape[1] 75 76 # dilation info 77 kwargs["dilation_height"] = 1 78 kwargs["dilation_width"] = 1 79 if len(node._args) > 4: 80 dilation = cast(List[int], node.args[4]) 81 kwargs["dilation_height"] = dilation[0] 82 kwargs["dilation_width"] = dilation[1] 83 84 kwargs["flags"] = XNN_FLAG_KEEP_DIMS 85 86 ser_node = XNode( 87 xnode_union=XNNMaxPooling2d( 88 **kwargs, 89 ), 90 debug_handle=debug_handle, 91 ) 92 xnn_graph.xnodes.append(ser_node) 93