1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8 9import torch 10from executorch.backends.example.example_operators.op_base import OpBase 11from executorch.backends.example.example_operators.utils import ( 12 _annotate_nodes, 13 _nodes_are_annotated, 14) 15 16 17def _annotate_mean(partitions, quant_config): 18 """ 19 This is what the graph of a simple adaptive_avg_pool2d op looks like: 20 fn_weight = self.fn_weight 21 fn_bias = self.fn_bias 22 permute_copy = torch.ops.aten.permute_copy.default(fn_weight, [1, 0]); fn_weight = None 23 addmm = torch.ops.aten.addmm.default(fn_bias, arg2_1, permute_copy); fn_bias = arg2_1 = permute_copy = None 24 """ 25 print("parititioners: ", partitions) 26 adaptive_avg_pool2d_node = partitions[0].output_nodes[0] 27 adaptive_avg_pool2d_node_input = adaptive_avg_pool2d_node.args[0] 28 29 print("adaptive_avg_pool2d_node: ", adaptive_avg_pool2d_node) 30 if _nodes_are_annotated([adaptive_avg_pool2d_node]): 31 return 32 33 _annotate_nodes( 34 [(adaptive_avg_pool2d_node, adaptive_avg_pool2d_node_input)], 35 quant_config.input_quant_spec, 36 input_node=True, 37 ) 38 _annotate_nodes([(adaptive_avg_pool2d_node,)], quant_config.output_quant_spec) 39 40 41@dataclass 42class AdaptiveAvgPool2dNode(OpBase): 43 def __init__(self): 44 super().__init__( 45 pattern=(torch.nn.AdaptiveAvgPool2d,), 46 annotate_handle=_annotate_mean, 47 ) 48