# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Dict, Optional import torch from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, get_param, is_buffer, is_lifted_tensor_constant, is_param, ) def is_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: return ( is_param(edge_program, node) or is_buffer(edge_program, node) or is_lifted_tensor_constant(edge_program, node) ) def get_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> torch.Tensor: param = None if is_param(edge_program, node): param = get_param(edge_program, node) if is_buffer(edge_program, node): param = get_buffer(edge_program, node) if is_lifted_tensor_constant(edge_program, node): param = get_lifted_tensor_constant(edge_program, node) if param is not None: # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32) assert isinstance(param, torch.Tensor), "Expect parameter to be tensor" param = param.type(node.meta["val"].dtype) return param def set_parameter( param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram ): status = False if is_param(edge_program, node): edge_program.state_dict[ edge_program.graph_signature.inputs_to_parameters[node.name] ] = param status = True if is_buffer(edge_program, node): buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] if buffer_name in edge_program.graph_signature.non_persistent_buffers: edge_program.constants[buffer_name] = param else: edge_program.state_dict[buffer_name] = param status = True assert status, "Failed to set parameter" def is_graph_input( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: """ Check if the given tensor is a graph input Args: tensor: EdgeIR Tensor that is being checked for graph input """ return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) def is_graph_output(tensor: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output Args: tensor: EdgeIR Tensor that is being checked for graph input """ for user in tensor.users.keys(): # getitem node is skiped, check the op_skip_ops.py if user.op == "output" or ( user.target.__name__ == "getitem" and is_graph_output(user) ): return True return False def is_constant( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: """ Check if the given tensor is a constant Args: tensor: EdgeIR Tensor that is being checked for graph input """ # constants should not be treated as input placeholder # pay attention to the pytorch design, change this if # breakage happened: # pytorch/torch/_export/passes/lift_constant_tensor_pass.py if is_parameter(tensor, edge_program): return tensor.meta["val"].constant is not None return False def deduce_dtype( tensor: torch.Tensor, quant_infos: Optional[Dict] = None ) -> torch.dtype: if quant_infos: quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] unsigned = quant_infos["quant_min"] >= 0 if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: return torch.uint8 if unsigned else torch.int8 elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: return torch.uint16 if unsigned else torch.int16 return quant_infos["dtype"] return tensor.dtype