xref: /aosp_15_r20/external/executorch/backends/arm/process_node.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker#
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import cast, Dict
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
11*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Workerimport torch.fx
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.operators.node_visitor import NodeVisitor
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_quant_utils import (
17*523fa7a6SAndroid Build Coastguard Worker    get_quant_arg_upstream,
18*523fa7a6SAndroid Build Coastguard Worker    get_quantized_node_output_dtype,
19*523fa7a6SAndroid Build Coastguard Worker    is_node_quantized,
20*523fa7a6SAndroid Build Coastguard Worker)
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_specification import TosaSpecification
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.arm.tosa_utils import (
23*523fa7a6SAndroid Build Coastguard Worker    getNodeArgs,
24*523fa7a6SAndroid Build Coastguard Worker    is_bias_node_for_quantized_conv,
25*523fa7a6SAndroid Build Coastguard Worker    tosa_shape,
26*523fa7a6SAndroid Build Coastguard Worker)
27*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerdef process_call_function(
31*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
32*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
33*523fa7a6SAndroid Build Coastguard Worker    node_visitors: Dict[str, NodeVisitor],
34*523fa7a6SAndroid Build Coastguard Worker    tosa_spec: TosaSpecification,
35*523fa7a6SAndroid Build Coastguard Worker):
36*523fa7a6SAndroid Build Coastguard Worker    # Unpack arguments and convert
37*523fa7a6SAndroid Build Coastguard Worker    inputs = getNodeArgs(node)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    # Convert output (this node itself)
40*523fa7a6SAndroid Build Coastguard Worker    output = TosaArg(node)
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker    is_quant_node = is_node_quantized(node)
43*523fa7a6SAndroid Build Coastguard Worker    if is_quant_node:
44*523fa7a6SAndroid Build Coastguard Worker        output_dtype = map_dtype(get_quantized_node_output_dtype(node))
45*523fa7a6SAndroid Build Coastguard Worker    else:
46*523fa7a6SAndroid Build Coastguard Worker        output_dtype = output.dtype
47*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.currRegion.currBasicBlock.addTensor(
48*523fa7a6SAndroid Build Coastguard Worker        output.name,
49*523fa7a6SAndroid Build Coastguard Worker        tosa_shape(output.shape, output.dim_order),
50*523fa7a6SAndroid Build Coastguard Worker        output_dtype,
51*523fa7a6SAndroid Build Coastguard Worker    )
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    # Visiting each Node
54*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore[16]: Undefined attribute.
55*523fa7a6SAndroid Build Coastguard Worker    if node.target.__name__ in node_visitors:
56*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore[16]: Undefined attribute.
57*523fa7a6SAndroid Build Coastguard Worker        node_visitors[node.target.__name__].define_node(
58*523fa7a6SAndroid Build Coastguard Worker            node,
59*523fa7a6SAndroid Build Coastguard Worker            tosa_graph,
60*523fa7a6SAndroid Build Coastguard Worker            inputs,
61*523fa7a6SAndroid Build Coastguard Worker            output,
62*523fa7a6SAndroid Build Coastguard Worker            is_quant_node,
63*523fa7a6SAndroid Build Coastguard Worker        )
64*523fa7a6SAndroid Build Coastguard Worker    else:
65*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Workerdef process_inputs(
69*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
70*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
71*523fa7a6SAndroid Build Coastguard Worker    tosa_spec: TosaSpecification,
72*523fa7a6SAndroid Build Coastguard Worker):
73*523fa7a6SAndroid Build Coastguard Worker    """Serialize an input node"""
74*523fa7a6SAndroid Build Coastguard Worker    # inputs need to be in default dim_order (contiguous memory format)
75*523fa7a6SAndroid Build Coastguard Worker    meta = node.meta["val"]
76*523fa7a6SAndroid Build Coastguard Worker    if meta.dim_order() != tuple(range(meta.dim())):
77*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
78*523fa7a6SAndroid Build Coastguard Worker            f"Arm backend only supports contiguous memory format for inputs. "
79*523fa7a6SAndroid Build Coastguard Worker            f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
80*523fa7a6SAndroid Build Coastguard Worker        )
81*523fa7a6SAndroid Build Coastguard Worker    inputs = [TosaArg(node)]
82*523fa7a6SAndroid Build Coastguard Worker    input_shape = inputs[0].shape
83*523fa7a6SAndroid Build Coastguard Worker    input_dim_order = inputs[0].dim_order
84*523fa7a6SAndroid Build Coastguard Worker    tensor = ts.TosaSerializerTensor(
85*523fa7a6SAndroid Build Coastguard Worker        inputs[0].name,
86*523fa7a6SAndroid Build Coastguard Worker        tosa_shape(input_shape, input_dim_order),
87*523fa7a6SAndroid Build Coastguard Worker        (
88*523fa7a6SAndroid Build Coastguard Worker            map_dtype(get_quantized_node_output_dtype(node))
89*523fa7a6SAndroid Build Coastguard Worker            if is_node_quantized(node)
90*523fa7a6SAndroid Build Coastguard Worker            else inputs[0].dtype
91*523fa7a6SAndroid Build Coastguard Worker        ),
92*523fa7a6SAndroid Build Coastguard Worker        data=None,
93*523fa7a6SAndroid Build Coastguard Worker        placeholderFilename=inputs[0].name + ".npy",
94*523fa7a6SAndroid Build Coastguard Worker    )
95*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.addInputTensor(tensor)
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Workerdef process_quantized_bias(
99*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
100*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
101*523fa7a6SAndroid Build Coastguard Worker    parameter_values,
102*523fa7a6SAndroid Build Coastguard Worker):
103*523fa7a6SAndroid Build Coastguard Worker    """
104*523fa7a6SAndroid Build Coastguard Worker    Serialize bias node that needs to be quantized.
105*523fa7a6SAndroid Build Coastguard Worker    """
106*523fa7a6SAndroid Build Coastguard Worker    consumer_node = list(node.users)[0]
107*523fa7a6SAndroid Build Coastguard Worker    (
108*523fa7a6SAndroid Build Coastguard Worker        input_node,
109*523fa7a6SAndroid Build Coastguard Worker        weight_node,
110*523fa7a6SAndroid Build Coastguard Worker        _,
111*523fa7a6SAndroid Build Coastguard Worker    ) = consumer_node.all_input_nodes
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker    input_node_scale = get_quant_arg_upstream(input_node).scale
114*523fa7a6SAndroid Build Coastguard Worker    weight_node_scale = get_quant_arg_upstream(weight_node).scale
115*523fa7a6SAndroid Build Coastguard Worker    bias_values_quantized = (
116*523fa7a6SAndroid Build Coastguard Worker        (parameter_values / (input_node_scale * weight_node_scale))
117*523fa7a6SAndroid Build Coastguard Worker        .round()
118*523fa7a6SAndroid Build Coastguard Worker        .astype(np.int32)
119*523fa7a6SAndroid Build Coastguard Worker    )
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.addConst(
122*523fa7a6SAndroid Build Coastguard Worker        bias_values_quantized.shape,
123*523fa7a6SAndroid Build Coastguard Worker        ts.DType.INT32,
124*523fa7a6SAndroid Build Coastguard Worker        bias_values_quantized,
125*523fa7a6SAndroid Build Coastguard Worker        name=node.name,
126*523fa7a6SAndroid Build Coastguard Worker    )
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_parameters(
130*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
131*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
132*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
133*523fa7a6SAndroid Build Coastguard Worker    tosa_spec: TosaSpecification,
134*523fa7a6SAndroid Build Coastguard Worker):
135*523fa7a6SAndroid Build Coastguard Worker    """Serialize bias and non-quantized weights"""
136*523fa7a6SAndroid Build Coastguard Worker    inputs = [TosaArg(node)]
137*523fa7a6SAndroid Build Coastguard Worker    parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name]
138*523fa7a6SAndroid Build Coastguard Worker    parameter_data = edge_program.state_dict[parameter_name]
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
141*523fa7a6SAndroid Build Coastguard Worker    parameter_values = parameter_data.detach().numpy()
142*523fa7a6SAndroid Build Coastguard Worker
143*523fa7a6SAndroid Build Coastguard Worker    if is_bias_node_for_quantized_conv(node):
144*523fa7a6SAndroid Build Coastguard Worker        # BI bias
145*523fa7a6SAndroid Build Coastguard Worker        assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
146*523fa7a6SAndroid Build Coastguard Worker        process_quantized_bias(node, tosa_graph, parameter_values)
147*523fa7a6SAndroid Build Coastguard Worker    else:
148*523fa7a6SAndroid Build Coastguard Worker        # MI weights or bias
149*523fa7a6SAndroid Build Coastguard Worker        if inputs[0].dtype == torch.float32:
150*523fa7a6SAndroid Build Coastguard Worker            assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker        parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
153*523fa7a6SAndroid Build Coastguard Worker
154*523fa7a6SAndroid Build Coastguard Worker        tosa_graph.addConst(
155*523fa7a6SAndroid Build Coastguard Worker            parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
156*523fa7a6SAndroid Build Coastguard Worker        )
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_buffers(
160*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
161*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
162*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
163*523fa7a6SAndroid Build Coastguard Worker):
164*523fa7a6SAndroid Build Coastguard Worker    """Serialize quantized weights"""
165*523fa7a6SAndroid Build Coastguard Worker    inputs = [TosaArg(node)]
166*523fa7a6SAndroid Build Coastguard Worker    buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
167*523fa7a6SAndroid Build Coastguard Worker    buffer_data = edge_program.state_dict[buffer_name]
168*523fa7a6SAndroid Build Coastguard Worker
169*523fa7a6SAndroid Build Coastguard Worker    assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
170*523fa7a6SAndroid Build Coastguard Worker    buffer_values = buffer_data.detach().numpy()
171*523fa7a6SAndroid Build Coastguard Worker
172*523fa7a6SAndroid Build Coastguard Worker    # TODO: fragile code for temporary fix
173*523fa7a6SAndroid Build Coastguard Worker    # the mean and var tensors are also stored here but they have shape (1, )
174*523fa7a6SAndroid Build Coastguard Worker    # we only transpose weights here
175*523fa7a6SAndroid Build Coastguard Worker    buffer_values = np.transpose(buffer_values, inputs[0].dim_order)
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.addConst(
178*523fa7a6SAndroid Build Coastguard Worker        buffer_values.shape, inputs[0].dtype, buffer_values, name=node.name
179*523fa7a6SAndroid Build Coastguard Worker    )
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Workerdef process_inputs_to_lifted_tensor_constants(
183*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
184*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
185*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
186*523fa7a6SAndroid Build Coastguard Worker):
187*523fa7a6SAndroid Build Coastguard Worker    arg = TosaArg(node)
188*523fa7a6SAndroid Build Coastguard Worker    tensor_name = edge_program.graph_signature.inputs_to_lifted_tensor_constants[
189*523fa7a6SAndroid Build Coastguard Worker        arg.name
190*523fa7a6SAndroid Build Coastguard Worker    ]
191*523fa7a6SAndroid Build Coastguard Worker    tensor = edge_program.tensor_constants[tensor_name]
192*523fa7a6SAndroid Build Coastguard Worker    tensor_data = tensor.detach().numpy()
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker    tosa_graph.addConst(tensor_data.shape, arg.dtype, tensor_data, name=arg.name)
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Workerdef process_placeholder(
198*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
199*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
200*523fa7a6SAndroid Build Coastguard Worker    edge_program: ExportedProgram,
201*523fa7a6SAndroid Build Coastguard Worker    tosa_spec: TosaSpecification,
202*523fa7a6SAndroid Build Coastguard Worker):
203*523fa7a6SAndroid Build Coastguard Worker    """Wrapper for processing and serializing all types of placeholders"""
204*523fa7a6SAndroid Build Coastguard Worker    assert node.name == node.target, "Expect placeholder name and target to match"
205*523fa7a6SAndroid Build Coastguard Worker    assert 0 == len(node.args), "Can't handle default input values"
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard Worker    if node.name in edge_program.graph_signature.user_inputs:
208*523fa7a6SAndroid Build Coastguard Worker        process_inputs(node, tosa_graph, tosa_spec)
209*523fa7a6SAndroid Build Coastguard Worker    elif node.name in edge_program.graph_signature.inputs_to_parameters:
210*523fa7a6SAndroid Build Coastguard Worker        process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
211*523fa7a6SAndroid Build Coastguard Worker    elif node.name in edge_program.graph_signature.inputs_to_buffers:
212*523fa7a6SAndroid Build Coastguard Worker        process_inputs_to_buffers(node, tosa_graph, edge_program)
213*523fa7a6SAndroid Build Coastguard Worker    elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
214*523fa7a6SAndroid Build Coastguard Worker        process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
215*523fa7a6SAndroid Build Coastguard Worker    elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
216*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError(
217*523fa7a6SAndroid Build Coastguard Worker            "Placeholder is of type 'lifted custom object' which is not supported."
218*523fa7a6SAndroid Build Coastguard Worker        )
219*523fa7a6SAndroid Build Coastguard Worker    else:
220*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
221*523fa7a6SAndroid Build Coastguard Worker
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Workerdef process_output(
224*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node,
225*523fa7a6SAndroid Build Coastguard Worker    tosa_graph: ts.TosaSerializer,
226*523fa7a6SAndroid Build Coastguard Worker):
227*523fa7a6SAndroid Build Coastguard Worker    for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
228*523fa7a6SAndroid Build Coastguard Worker        tosa_graph.addOutputTensor(
229*523fa7a6SAndroid Build Coastguard Worker            tosa_graph.currRegion.currBasicBlock.tensors[output.name]
230*523fa7a6SAndroid Build Coastguard Worker        )
231