xref: /aosp_15_r20/external/executorch/backends/qualcomm/builders/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Optional
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import (
11*523fa7a6SAndroid Build Coastguard Worker    get_buffer,
12*523fa7a6SAndroid Build Coastguard Worker    get_lifted_tensor_constant,
13*523fa7a6SAndroid Build Coastguard Worker    get_param,
14*523fa7a6SAndroid Build Coastguard Worker    is_buffer,
15*523fa7a6SAndroid Build Coastguard Worker    is_lifted_tensor_constant,
16*523fa7a6SAndroid Build Coastguard Worker    is_param,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerdef is_parameter(
21*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node, edge_program: torch.export.ExportedProgram
22*523fa7a6SAndroid Build Coastguard Worker) -> bool:
23*523fa7a6SAndroid Build Coastguard Worker    return (
24*523fa7a6SAndroid Build Coastguard Worker        is_param(edge_program, node)
25*523fa7a6SAndroid Build Coastguard Worker        or is_buffer(edge_program, node)
26*523fa7a6SAndroid Build Coastguard Worker        or is_lifted_tensor_constant(edge_program, node)
27*523fa7a6SAndroid Build Coastguard Worker    )
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerdef get_parameter(
31*523fa7a6SAndroid Build Coastguard Worker    node: torch.fx.Node, edge_program: torch.export.ExportedProgram
32*523fa7a6SAndroid Build Coastguard Worker) -> torch.Tensor:
33*523fa7a6SAndroid Build Coastguard Worker    param = None
34*523fa7a6SAndroid Build Coastguard Worker    if is_param(edge_program, node):
35*523fa7a6SAndroid Build Coastguard Worker        param = get_param(edge_program, node)
36*523fa7a6SAndroid Build Coastguard Worker    if is_buffer(edge_program, node):
37*523fa7a6SAndroid Build Coastguard Worker        param = get_buffer(edge_program, node)
38*523fa7a6SAndroid Build Coastguard Worker    if is_lifted_tensor_constant(edge_program, node):
39*523fa7a6SAndroid Build Coastguard Worker        param = get_lifted_tensor_constant(edge_program, node)
40*523fa7a6SAndroid Build Coastguard Worker    if param is not None:
41*523fa7a6SAndroid Build Coastguard Worker        # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
42*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
43*523fa7a6SAndroid Build Coastguard Worker        param = param.type(node.meta["val"].dtype)
44*523fa7a6SAndroid Build Coastguard Worker    return param
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Workerdef set_parameter(
48*523fa7a6SAndroid Build Coastguard Worker    param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram
49*523fa7a6SAndroid Build Coastguard Worker):
50*523fa7a6SAndroid Build Coastguard Worker    status = False
51*523fa7a6SAndroid Build Coastguard Worker    if is_param(edge_program, node):
52*523fa7a6SAndroid Build Coastguard Worker        edge_program.state_dict[
53*523fa7a6SAndroid Build Coastguard Worker            edge_program.graph_signature.inputs_to_parameters[node.name]
54*523fa7a6SAndroid Build Coastguard Worker        ] = param
55*523fa7a6SAndroid Build Coastguard Worker        status = True
56*523fa7a6SAndroid Build Coastguard Worker    if is_buffer(edge_program, node):
57*523fa7a6SAndroid Build Coastguard Worker        buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
58*523fa7a6SAndroid Build Coastguard Worker        if buffer_name in edge_program.graph_signature.non_persistent_buffers:
59*523fa7a6SAndroid Build Coastguard Worker            edge_program.constants[buffer_name] = param
60*523fa7a6SAndroid Build Coastguard Worker        else:
61*523fa7a6SAndroid Build Coastguard Worker            edge_program.state_dict[buffer_name] = param
62*523fa7a6SAndroid Build Coastguard Worker        status = True
63*523fa7a6SAndroid Build Coastguard Worker    assert status, "Failed to set parameter"
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker
66*523fa7a6SAndroid Build Coastguard Workerdef is_graph_input(
67*523fa7a6SAndroid Build Coastguard Worker    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
68*523fa7a6SAndroid Build Coastguard Worker) -> bool:
69*523fa7a6SAndroid Build Coastguard Worker    """
70*523fa7a6SAndroid Build Coastguard Worker    Check if the given tensor is a graph input
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    Args:
73*523fa7a6SAndroid Build Coastguard Worker        tensor: EdgeIR Tensor that is being checked for graph input
74*523fa7a6SAndroid Build Coastguard Worker    """
75*523fa7a6SAndroid Build Coastguard Worker    return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Workerdef is_graph_output(tensor: torch.fx.Node) -> bool:
79*523fa7a6SAndroid Build Coastguard Worker    """
80*523fa7a6SAndroid Build Coastguard Worker    Check if the given tensor is used as a graph output
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    Args:
83*523fa7a6SAndroid Build Coastguard Worker        tensor: EdgeIR Tensor that is being checked for graph input
84*523fa7a6SAndroid Build Coastguard Worker    """
85*523fa7a6SAndroid Build Coastguard Worker    for user in tensor.users.keys():
86*523fa7a6SAndroid Build Coastguard Worker        # getitem node is skiped, check the op_skip_ops.py
87*523fa7a6SAndroid Build Coastguard Worker        if user.op == "output" or (
88*523fa7a6SAndroid Build Coastguard Worker            user.target.__name__ == "getitem" and is_graph_output(user)
89*523fa7a6SAndroid Build Coastguard Worker        ):
90*523fa7a6SAndroid Build Coastguard Worker            return True
91*523fa7a6SAndroid Build Coastguard Worker    return False
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Workerdef is_constant(
95*523fa7a6SAndroid Build Coastguard Worker    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
96*523fa7a6SAndroid Build Coastguard Worker) -> bool:
97*523fa7a6SAndroid Build Coastguard Worker    """
98*523fa7a6SAndroid Build Coastguard Worker    Check if the given tensor is a constant
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker    Args:
101*523fa7a6SAndroid Build Coastguard Worker        tensor: EdgeIR Tensor that is being checked for graph input
102*523fa7a6SAndroid Build Coastguard Worker    """
103*523fa7a6SAndroid Build Coastguard Worker    # constants should not be treated as input placeholder
104*523fa7a6SAndroid Build Coastguard Worker    # pay attention to the pytorch design, change this if
105*523fa7a6SAndroid Build Coastguard Worker    # breakage happened:
106*523fa7a6SAndroid Build Coastguard Worker    # pytorch/torch/_export/passes/lift_constant_tensor_pass.py
107*523fa7a6SAndroid Build Coastguard Worker    if is_parameter(tensor, edge_program):
108*523fa7a6SAndroid Build Coastguard Worker        return tensor.meta["val"].constant is not None
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker    return False
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Workerdef deduce_dtype(
114*523fa7a6SAndroid Build Coastguard Worker    tensor: torch.Tensor, quant_infos: Optional[Dict] = None
115*523fa7a6SAndroid Build Coastguard Worker) -> torch.dtype:
116*523fa7a6SAndroid Build Coastguard Worker    if quant_infos:
117*523fa7a6SAndroid Build Coastguard Worker        quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
118*523fa7a6SAndroid Build Coastguard Worker        unsigned = quant_infos["quant_min"] >= 0
119*523fa7a6SAndroid Build Coastguard Worker        if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
120*523fa7a6SAndroid Build Coastguard Worker            return torch.uint8 if unsigned else torch.int8
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Worker        elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
123*523fa7a6SAndroid Build Coastguard Worker            return torch.uint16 if unsigned else torch.int16
124*523fa7a6SAndroid Build Coastguard Worker
125*523fa7a6SAndroid Build Coastguard Worker        return quant_infos["dtype"]
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker    return tensor.dtype
128