1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8 9import torch 10from executorch.backends.example.example_operators.op_base import OpBase 11from executorch.backends.example.example_operators.utils import ( 12 _annotate_nodes, 13 _nodes_are_annotated, 14) 15 16 17def _annotate_linear(partitions, quant_config): 18 """ 19 This is what the graph of a simple linear op looks like: 20 fn_weight = self.fn_weight 21 fn_bias = self.fn_bias 22 permute_copy = torch.ops.aten.permute_copy.default(fn_weight, [1, 0]); fn_weight = None 23 addmm = torch.ops.aten.addmm.default(fn_bias, arg2_1, permute_copy); fn_bias = arg2_1 = permute_copy = None 24 """ 25 linear_node = partitions[0].output_nodes[0] 26 if _nodes_are_annotated([linear_node]): 27 return 28 29 input_node = linear_node.args[0] 30 # permute_node = linear_node.args[1] 31 # print("permute_node: ", permute_node, " args: ", permute_node.args, " target: ", permute_node.target) 32 weight_node = linear_node.args[1] 33 print( 34 "weight_node: ", 35 weight_node, 36 " args: ", 37 weight_node.args, 38 " target: ", 39 weight_node.target, 40 ) 41 # Unused. 42 # bias_node = output_node.args[0] 43 44 # if _nodes_are_annotated([linear_node, permute_node]): 45 # return 46 47 _annotate_nodes( 48 [(linear_node, input_node)], quant_config.input_quant_spec, input_node=True 49 ) 50 _annotate_nodes( 51 [(linear_node, weight_node)], quant_config.weight_quant_spec, input_node=True 52 ) 53 _annotate_nodes([(linear_node,)], quant_config.output_quant_spec) 54 55 56@dataclass 57class LinearNode(OpBase): 58 def __init__(self): 59 super().__init__( 60 pattern=(torch.nn.Linear,), 61 annotate_handle=_annotate_linear, 62 ) 63