1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2024 Arm Limited and/or its affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 5*523fa7a6SAndroid Build Coastguard Worker# 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import cast, Dict 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 11*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.operators.node_visitor import NodeVisitor 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_mapping import map_dtype, TosaArg 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_quant_utils import ( 17*523fa7a6SAndroid Build Coastguard Worker get_quant_arg_upstream, 18*523fa7a6SAndroid Build Coastguard Worker get_quantized_node_output_dtype, 19*523fa7a6SAndroid Build Coastguard Worker is_node_quantized, 20*523fa7a6SAndroid Build Coastguard Worker) 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_specification import TosaSpecification 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_utils import ( 23*523fa7a6SAndroid Build Coastguard Worker getNodeArgs, 24*523fa7a6SAndroid Build Coastguard Worker is_bias_node_for_quantized_conv, 25*523fa7a6SAndroid Build Coastguard Worker tosa_shape, 26*523fa7a6SAndroid Build Coastguard Worker) 27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerdef process_call_function( 31*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 32*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 33*523fa7a6SAndroid Build Coastguard Worker node_visitors: Dict[str, NodeVisitor], 34*523fa7a6SAndroid Build Coastguard Worker tosa_spec: TosaSpecification, 35*523fa7a6SAndroid Build Coastguard Worker): 36*523fa7a6SAndroid Build Coastguard Worker # Unpack arguments and convert 37*523fa7a6SAndroid Build Coastguard Worker inputs = getNodeArgs(node) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker # Convert output (this node itself) 40*523fa7a6SAndroid Build Coastguard Worker output = TosaArg(node) 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker is_quant_node = is_node_quantized(node) 43*523fa7a6SAndroid Build Coastguard Worker if is_quant_node: 44*523fa7a6SAndroid Build Coastguard Worker output_dtype = map_dtype(get_quantized_node_output_dtype(node)) 45*523fa7a6SAndroid Build Coastguard Worker else: 46*523fa7a6SAndroid Build Coastguard Worker output_dtype = output.dtype 47*523fa7a6SAndroid Build Coastguard Worker tosa_graph.currRegion.currBasicBlock.addTensor( 48*523fa7a6SAndroid Build Coastguard Worker output.name, 49*523fa7a6SAndroid Build Coastguard Worker tosa_shape(output.shape, output.dim_order), 50*523fa7a6SAndroid Build Coastguard Worker output_dtype, 51*523fa7a6SAndroid Build Coastguard Worker ) 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker # Visiting each Node 54*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore[16]: Undefined attribute. 55*523fa7a6SAndroid Build Coastguard Worker if node.target.__name__ in node_visitors: 56*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore[16]: Undefined attribute. 57*523fa7a6SAndroid Build Coastguard Worker node_visitors[node.target.__name__].define_node( 58*523fa7a6SAndroid Build Coastguard Worker node, 59*523fa7a6SAndroid Build Coastguard Worker tosa_graph, 60*523fa7a6SAndroid Build Coastguard Worker inputs, 61*523fa7a6SAndroid Build Coastguard Worker output, 62*523fa7a6SAndroid Build Coastguard Worker is_quant_node, 63*523fa7a6SAndroid Build Coastguard Worker ) 64*523fa7a6SAndroid Build Coastguard Worker else: 65*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}") 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Workerdef process_inputs( 69*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 70*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 71*523fa7a6SAndroid Build Coastguard Worker tosa_spec: TosaSpecification, 72*523fa7a6SAndroid Build Coastguard Worker): 73*523fa7a6SAndroid Build Coastguard Worker """Serialize an input node""" 74*523fa7a6SAndroid Build Coastguard Worker # inputs need to be in default dim_order (contiguous memory format) 75*523fa7a6SAndroid Build Coastguard Worker meta = node.meta["val"] 76*523fa7a6SAndroid Build Coastguard Worker if meta.dim_order() != tuple(range(meta.dim())): 77*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 78*523fa7a6SAndroid Build Coastguard Worker f"Arm backend only supports contiguous memory format for inputs. " 79*523fa7a6SAndroid Build Coastguard Worker f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}" 80*523fa7a6SAndroid Build Coastguard Worker ) 81*523fa7a6SAndroid Build Coastguard Worker inputs = [TosaArg(node)] 82*523fa7a6SAndroid Build Coastguard Worker input_shape = inputs[0].shape 83*523fa7a6SAndroid Build Coastguard Worker input_dim_order = inputs[0].dim_order 84*523fa7a6SAndroid Build Coastguard Worker tensor = ts.TosaSerializerTensor( 85*523fa7a6SAndroid Build Coastguard Worker inputs[0].name, 86*523fa7a6SAndroid Build Coastguard Worker tosa_shape(input_shape, input_dim_order), 87*523fa7a6SAndroid Build Coastguard Worker ( 88*523fa7a6SAndroid Build Coastguard Worker map_dtype(get_quantized_node_output_dtype(node)) 89*523fa7a6SAndroid Build Coastguard Worker if is_node_quantized(node) 90*523fa7a6SAndroid Build Coastguard Worker else inputs[0].dtype 91*523fa7a6SAndroid Build Coastguard Worker ), 92*523fa7a6SAndroid Build Coastguard Worker data=None, 93*523fa7a6SAndroid Build Coastguard Worker placeholderFilename=inputs[0].name + ".npy", 94*523fa7a6SAndroid Build Coastguard Worker ) 95*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addInputTensor(tensor) 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Workerdef process_quantized_bias( 99*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 100*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 101*523fa7a6SAndroid Build Coastguard Worker parameter_values, 102*523fa7a6SAndroid Build Coastguard Worker): 103*523fa7a6SAndroid Build Coastguard Worker """ 104*523fa7a6SAndroid Build Coastguard Worker Serialize bias node that needs to be quantized. 105*523fa7a6SAndroid Build Coastguard Worker """ 106*523fa7a6SAndroid Build Coastguard Worker consumer_node = list(node.users)[0] 107*523fa7a6SAndroid Build Coastguard Worker ( 108*523fa7a6SAndroid Build Coastguard Worker input_node, 109*523fa7a6SAndroid Build Coastguard Worker weight_node, 110*523fa7a6SAndroid Build Coastguard Worker _, 111*523fa7a6SAndroid Build Coastguard Worker ) = consumer_node.all_input_nodes 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker input_node_scale = get_quant_arg_upstream(input_node).scale 114*523fa7a6SAndroid Build Coastguard Worker weight_node_scale = get_quant_arg_upstream(weight_node).scale 115*523fa7a6SAndroid Build Coastguard Worker bias_values_quantized = ( 116*523fa7a6SAndroid Build Coastguard Worker (parameter_values / (input_node_scale * weight_node_scale)) 117*523fa7a6SAndroid Build Coastguard Worker .round() 118*523fa7a6SAndroid Build Coastguard Worker .astype(np.int32) 119*523fa7a6SAndroid Build Coastguard Worker ) 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addConst( 122*523fa7a6SAndroid Build Coastguard Worker bias_values_quantized.shape, 123*523fa7a6SAndroid Build Coastguard Worker ts.DType.INT32, 124*523fa7a6SAndroid Build Coastguard Worker bias_values_quantized, 125*523fa7a6SAndroid Build Coastguard Worker name=node.name, 126*523fa7a6SAndroid Build Coastguard Worker ) 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_parameters( 130*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 131*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 132*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 133*523fa7a6SAndroid Build Coastguard Worker tosa_spec: TosaSpecification, 134*523fa7a6SAndroid Build Coastguard Worker): 135*523fa7a6SAndroid Build Coastguard Worker """Serialize bias and non-quantized weights""" 136*523fa7a6SAndroid Build Coastguard Worker inputs = [TosaArg(node)] 137*523fa7a6SAndroid Build Coastguard Worker parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name] 138*523fa7a6SAndroid Build Coastguard Worker parameter_data = edge_program.state_dict[parameter_name] 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor" 141*523fa7a6SAndroid Build Coastguard Worker parameter_values = parameter_data.detach().numpy() 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Worker if is_bias_node_for_quantized_conv(node): 144*523fa7a6SAndroid Build Coastguard Worker # BI bias 145*523fa7a6SAndroid Build Coastguard Worker assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer" 146*523fa7a6SAndroid Build Coastguard Worker process_quantized_bias(node, tosa_graph, parameter_values) 147*523fa7a6SAndroid Build Coastguard Worker else: 148*523fa7a6SAndroid Build Coastguard Worker # MI weights or bias 149*523fa7a6SAndroid Build Coastguard Worker if inputs[0].dtype == torch.float32: 150*523fa7a6SAndroid Build Coastguard Worker assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker parameter_values = np.transpose(parameter_values, inputs[0].dim_order) 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addConst( 155*523fa7a6SAndroid Build Coastguard Worker parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name 156*523fa7a6SAndroid Build Coastguard Worker ) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_buffers( 160*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 161*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 162*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 163*523fa7a6SAndroid Build Coastguard Worker): 164*523fa7a6SAndroid Build Coastguard Worker """Serialize quantized weights""" 165*523fa7a6SAndroid Build Coastguard Worker inputs = [TosaArg(node)] 166*523fa7a6SAndroid Build Coastguard Worker buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] 167*523fa7a6SAndroid Build Coastguard Worker buffer_data = edge_program.state_dict[buffer_name] 168*523fa7a6SAndroid Build Coastguard Worker 169*523fa7a6SAndroid Build Coastguard Worker assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor" 170*523fa7a6SAndroid Build Coastguard Worker buffer_values = buffer_data.detach().numpy() 171*523fa7a6SAndroid Build Coastguard Worker 172*523fa7a6SAndroid Build Coastguard Worker # TODO: fragile code for temporary fix 173*523fa7a6SAndroid Build Coastguard Worker # the mean and var tensors are also stored here but they have shape (1, ) 174*523fa7a6SAndroid Build Coastguard Worker # we only transpose weights here 175*523fa7a6SAndroid Build Coastguard Worker buffer_values = np.transpose(buffer_values, inputs[0].dim_order) 176*523fa7a6SAndroid Build Coastguard Worker 177*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addConst( 178*523fa7a6SAndroid Build Coastguard Worker buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name 179*523fa7a6SAndroid Build Coastguard Worker ) 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_lifted_tensor_constants( 183*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 184*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 185*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 186*523fa7a6SAndroid Build Coastguard Worker): 187*523fa7a6SAndroid Build Coastguard Worker arg = TosaArg(node) 188*523fa7a6SAndroid Build Coastguard Worker tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[ 189*523fa7a6SAndroid Build Coastguard Worker arg.name 190*523fa7a6SAndroid Build Coastguard Worker ] 191*523fa7a6SAndroid Build Coastguard Worker tensor = edge_program.tensor_constants[tensor_name] 192*523fa7a6SAndroid Build Coastguard Worker tensor_data = tensor.detach().numpy() 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name) 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard Workerdef process_placeholder( 198*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 199*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 200*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 201*523fa7a6SAndroid Build Coastguard Worker tosa_spec: TosaSpecification, 202*523fa7a6SAndroid Build Coastguard Worker): 203*523fa7a6SAndroid Build Coastguard Worker """Wrapper for processing and serializing all types of placeholders""" 204*523fa7a6SAndroid Build Coastguard Worker assert node.name == node.target, "Expect placeholder name and target to match" 205*523fa7a6SAndroid Build Coastguard Worker assert 0 == len(node.args), "Can't handle default input values" 206*523fa7a6SAndroid Build Coastguard Worker 207*523fa7a6SAndroid Build Coastguard Worker if node.name in edge_program.graph_signature.user_inputs: 208*523fa7a6SAndroid Build Coastguard Worker process_inputs(node, tosa_graph, tosa_spec) 209*523fa7a6SAndroid Build Coastguard Worker elif node.name in edge_program.graph_signature.inputs_to_parameters: 210*523fa7a6SAndroid Build Coastguard Worker process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec) 211*523fa7a6SAndroid Build Coastguard Worker elif node.name in edge_program.graph_signature.inputs_to_buffers: 212*523fa7a6SAndroid Build Coastguard Worker process_inputs_to_buffers(node, tosa_graph, edge_program) 213*523fa7a6SAndroid Build Coastguard Worker elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants: 214*523fa7a6SAndroid Build Coastguard Worker process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program) 215*523fa7a6SAndroid Build Coastguard Worker elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs: 216*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError( 217*523fa7a6SAndroid Build Coastguard Worker "Placeholder is of type 'lifted custom object' which is not supported." 218*523fa7a6SAndroid Build Coastguard Worker ) 219*523fa7a6SAndroid Build Coastguard Worker else: 220*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.") 221*523fa7a6SAndroid Build Coastguard Worker 222*523fa7a6SAndroid Build Coastguard Worker 223*523fa7a6SAndroid Build Coastguard Workerdef process_output( 224*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 225*523fa7a6SAndroid Build Coastguard Worker tosa_graph: ts.TosaSerializer, 226*523fa7a6SAndroid Build Coastguard Worker): 227*523fa7a6SAndroid Build Coastguard Worker for output in cast(tuple[torch.fx.Node, ...], node.args[0]): 228*523fa7a6SAndroid Build Coastguard Worker tosa_graph.addOutputTensor( 229*523fa7a6SAndroid Build Coastguard Worker tosa_graph.currRegion.currBasicBlock.tensors[output.name] 230*523fa7a6SAndroid Build Coastguard Worker ) 231