1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Optional 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import ( 11*523fa7a6SAndroid Build Coastguard Worker get_buffer, 12*523fa7a6SAndroid Build Coastguard Worker get_lifted_tensor_constant, 13*523fa7a6SAndroid Build Coastguard Worker get_param, 14*523fa7a6SAndroid Build Coastguard Worker is_buffer, 15*523fa7a6SAndroid Build Coastguard Worker is_lifted_tensor_constant, 16*523fa7a6SAndroid Build Coastguard Worker is_param, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerdef is_parameter( 21*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, edge_program: torch.export.ExportedProgram 22*523fa7a6SAndroid Build Coastguard Worker) -> bool: 23*523fa7a6SAndroid Build Coastguard Worker return ( 24*523fa7a6SAndroid Build Coastguard Worker is_param(edge_program, node) 25*523fa7a6SAndroid Build Coastguard Worker or is_buffer(edge_program, node) 26*523fa7a6SAndroid Build Coastguard Worker or is_lifted_tensor_constant(edge_program, node) 27*523fa7a6SAndroid Build Coastguard Worker ) 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerdef get_parameter( 31*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, edge_program: torch.export.ExportedProgram 32*523fa7a6SAndroid Build Coastguard Worker) -> torch.Tensor: 33*523fa7a6SAndroid Build Coastguard Worker param = None 34*523fa7a6SAndroid Build Coastguard Worker if is_param(edge_program, node): 35*523fa7a6SAndroid Build Coastguard Worker param = get_param(edge_program, node) 36*523fa7a6SAndroid Build Coastguard Worker if is_buffer(edge_program, node): 37*523fa7a6SAndroid Build Coastguard Worker param = get_buffer(edge_program, node) 38*523fa7a6SAndroid Build Coastguard Worker if is_lifted_tensor_constant(edge_program, node): 39*523fa7a6SAndroid Build Coastguard Worker param = get_lifted_tensor_constant(edge_program, node) 40*523fa7a6SAndroid Build Coastguard Worker if param is not None: 41*523fa7a6SAndroid Build Coastguard Worker # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) 42*523fa7a6SAndroid Build Coastguard Worker assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" 43*523fa7a6SAndroid Build Coastguard Worker param = param.type(node.meta["val"].dtype) 44*523fa7a6SAndroid Build Coastguard Worker return param 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Workerdef set_parameter( 48*523fa7a6SAndroid Build Coastguard Worker param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram 49*523fa7a6SAndroid Build Coastguard Worker): 50*523fa7a6SAndroid Build Coastguard Worker status = False 51*523fa7a6SAndroid Build Coastguard Worker if is_param(edge_program, node): 52*523fa7a6SAndroid Build Coastguard Worker edge_program.state_dict[ 53*523fa7a6SAndroid Build Coastguard Worker edge_program.graph_signature.inputs_to_parameters[node.name] 54*523fa7a6SAndroid Build Coastguard Worker ] = param 55*523fa7a6SAndroid Build Coastguard Worker status = True 56*523fa7a6SAndroid Build Coastguard Worker if is_buffer(edge_program, node): 57*523fa7a6SAndroid Build Coastguard Worker buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] 58*523fa7a6SAndroid Build Coastguard Worker if buffer_name in edge_program.graph_signature.non_persistent_buffers: 59*523fa7a6SAndroid Build Coastguard Worker edge_program.constants[buffer_name] = param 60*523fa7a6SAndroid Build Coastguard Worker else: 61*523fa7a6SAndroid Build Coastguard Worker edge_program.state_dict[buffer_name] = param 62*523fa7a6SAndroid Build Coastguard Worker status = True 63*523fa7a6SAndroid Build Coastguard Worker assert status, "Failed to set parameter" 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Workerdef is_graph_input( 67*523fa7a6SAndroid Build Coastguard Worker tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram 68*523fa7a6SAndroid Build Coastguard Worker) -> bool: 69*523fa7a6SAndroid Build Coastguard Worker """ 70*523fa7a6SAndroid Build Coastguard Worker Check if the given tensor is a graph input 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker Args: 73*523fa7a6SAndroid Build Coastguard Worker tensor: EdgeIR Tensor that is being checked for graph input 74*523fa7a6SAndroid Build Coastguard Worker """ 75*523fa7a6SAndroid Build Coastguard Worker return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Workerdef is_graph_output(tensor: torch.fx.Node) -> bool: 79*523fa7a6SAndroid Build Coastguard Worker """ 80*523fa7a6SAndroid Build Coastguard Worker Check if the given tensor is used as a graph output 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker Args: 83*523fa7a6SAndroid Build Coastguard Worker tensor: EdgeIR Tensor that is being checked for graph input 84*523fa7a6SAndroid Build Coastguard Worker """ 85*523fa7a6SAndroid Build Coastguard Worker for user in tensor.users.keys(): 86*523fa7a6SAndroid Build Coastguard Worker # getitem node is skiped, check the op_skip_ops.py 87*523fa7a6SAndroid Build Coastguard Worker if user.op == "output" or ( 88*523fa7a6SAndroid Build Coastguard Worker user.target.__name__ == "getitem" and is_graph_output(user) 89*523fa7a6SAndroid Build Coastguard Worker ): 90*523fa7a6SAndroid Build Coastguard Worker return True 91*523fa7a6SAndroid Build Coastguard Worker return False 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Workerdef is_constant( 95*523fa7a6SAndroid Build Coastguard Worker tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram 96*523fa7a6SAndroid Build Coastguard Worker) -> bool: 97*523fa7a6SAndroid Build Coastguard Worker """ 98*523fa7a6SAndroid Build Coastguard Worker Check if the given tensor is a constant 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Worker Args: 101*523fa7a6SAndroid Build Coastguard Worker tensor: EdgeIR Tensor that is being checked for graph input 102*523fa7a6SAndroid Build Coastguard Worker """ 103*523fa7a6SAndroid Build Coastguard Worker # constants should not be treated as input placeholder 104*523fa7a6SAndroid Build Coastguard Worker # pay attention to the pytorch design, change this if 105*523fa7a6SAndroid Build Coastguard Worker # breakage happened: 106*523fa7a6SAndroid Build Coastguard Worker # pytorch/torch/_export/passes/lift_constant_tensor_pass.py 107*523fa7a6SAndroid Build Coastguard Worker if is_parameter(tensor, edge_program): 108*523fa7a6SAndroid Build Coastguard Worker return tensor.meta["val"].constant is not None 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker return False 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Workerdef deduce_dtype( 114*523fa7a6SAndroid Build Coastguard Worker tensor: torch.Tensor, quant_infos: Optional[Dict] = None 115*523fa7a6SAndroid Build Coastguard Worker) -> torch.dtype: 116*523fa7a6SAndroid Build Coastguard Worker if quant_infos: 117*523fa7a6SAndroid Build Coastguard Worker quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] 118*523fa7a6SAndroid Build Coastguard Worker unsigned = quant_infos["quant_min"] >= 0 119*523fa7a6SAndroid Build Coastguard Worker if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: 120*523fa7a6SAndroid Build Coastguard Worker return torch.uint8 if unsigned else torch.int8 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: 123*523fa7a6SAndroid Build Coastguard Worker return torch.uint16 if unsigned else torch.int16 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker return quant_infos["dtype"] 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker return tensor.dtype 128