1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict, Optional 8 9import torch 10from torch._export.utils import ( 11 get_buffer, 12 get_lifted_tensor_constant, 13 get_param, 14 is_buffer, 15 is_lifted_tensor_constant, 16 is_param, 17) 18 19 20def is_parameter( 21 node: torch.fx.Node, edge_program: torch.export.ExportedProgram 22) -> bool: 23 return ( 24 is_param(edge_program, node) 25 or is_buffer(edge_program, node) 26 or is_lifted_tensor_constant(edge_program, node) 27 ) 28 29 30def get_parameter( 31 node: torch.fx.Node, edge_program: torch.export.ExportedProgram 32) -> torch.Tensor: 33 param = None 34 if is_param(edge_program, node): 35 param = get_param(edge_program, node) 36 if is_buffer(edge_program, node): 37 param = get_buffer(edge_program, node) 38 if is_lifted_tensor_constant(edge_program, node): 39 param = get_lifted_tensor_constant(edge_program, node) 40 if param is not None: 41 # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) 42 assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" 43 param = param.type(node.meta["val"].dtype) 44 return param 45 46 47def set_parameter( 48 param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram 49): 50 status = False 51 if is_param(edge_program, node): 52 edge_program.state_dict[ 53 edge_program.graph_signature.inputs_to_parameters[node.name] 54 ] = param 55 status = True 56 if is_buffer(edge_program, node): 57 buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] 58 if buffer_name in edge_program.graph_signature.non_persistent_buffers: 59 edge_program.constants[buffer_name] = param 60 else: 61 edge_program.state_dict[buffer_name] = param 62 status = True 63 assert status, "Failed to set parameter" 64 65 66def is_graph_input( 67 tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram 68) -> bool: 69 """ 70 Check if the given tensor is a graph input 71 72 Args: 73 tensor: EdgeIR Tensor that is being checked for graph input 74 """ 75 return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) 76 77 78def is_graph_output(tensor: torch.fx.Node) -> bool: 79 """ 80 Check if the given tensor is used as a graph output 81 82 Args: 83 tensor: EdgeIR Tensor that is being checked for graph input 84 """ 85 for user in tensor.users.keys(): 86 # getitem node is skiped, check the op_skip_ops.py 87 if user.op == "output" or ( 88 user.target.__name__ == "getitem" and is_graph_output(user) 89 ): 90 return True 91 return False 92 93 94def is_constant( 95 tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram 96) -> bool: 97 """ 98 Check if the given tensor is a constant 99 100 Args: 101 tensor: EdgeIR Tensor that is being checked for graph input 102 """ 103 # constants should not be treated as input placeholder 104 # pay attention to the pytorch design, change this if 105 # breakage happened: 106 # pytorch/torch/_export/passes/lift_constant_tensor_pass.py 107 if is_parameter(tensor, edge_program): 108 return tensor.meta["val"].constant is not None 109 110 return False 111 112 113def deduce_dtype( 114 tensor: torch.Tensor, quant_infos: Optional[Dict] = None 115) -> torch.dtype: 116 if quant_infos: 117 quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] 118 unsigned = quant_infos["quant_min"] >= 0 119 if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: 120 return torch.uint8 if unsigned else torch.int8 121 122 elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: 123 return torch.uint16 if unsigned else torch.int16 124 125 return quant_infos["dtype"] 126 127 return tensor.dtype 128