1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast, Optional, Union 7 8import torch 9from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType 10from executorch.exir import ExportedProgram 11from torch._export.utils import get_buffer, get_param, is_buffer, is_param 12 13 14def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]: 15 return None if node is None else cast(torch.fx.Node, node.args[input_index]) 16 17 18def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]: 19 return node.args[input_index] 20 21 22def edge_dtype_to_mps_dtype(dtype: torch.dtype): 23 if not hasattr(edge_dtype_to_mps_dtype, "map"): 24 edge_dtype_to_mps_dtype.map = { 25 torch.float16: MPSDataType.mps_data_type_float16, 26 torch.float32: MPSDataType.mps_data_type_float32, 27 torch.float64: MPSDataType.mps_data_type_float32, 28 torch.bfloat16: MPSDataType.mps_data_type_bfloat16, 29 torch.int8: MPSDataType.mps_data_type_int8, 30 torch.int16: MPSDataType.mps_data_type_int16, 31 torch.int32: MPSDataType.mps_data_type_int32, 32 torch.int64: MPSDataType.mps_data_type_int64, 33 torch.uint8: MPSDataType.mps_data_type_uint8, 34 torch.bool: MPSDataType.mps_data_type_bool, 35 torch.cfloat: MPSDataType.mps_data_type_complex_float32, 36 torch.chalf: MPSDataType.mps_data_type_complex_float16, 37 } 38 try: 39 return edge_dtype_to_mps_dtype.map[dtype] 40 except KeyError: 41 raise RuntimeError(f"Invalid data type: {dtype}") 42 43 44def get_param_tensor( 45 exp_prog: ExportedProgram, node: torch.fx.Node 46) -> Optional[torch.Tensor]: 47 if node is None: 48 return None 49 elif is_param(exp_prog, node): 50 return get_param(exp_prog, node) 51 elif is_buffer(exp_prog, node): 52 return get_buffer(exp_prog, node) 53 elif is_get_attr(node): 54 # Support both lifted and unlifted graph 55 try: 56 # Unlifted graph (coming from old exir.capture API) 57 return getattr(node.graph.owning_module, node.target) 58 except AttributeError: 59 return getattr(exp_prog.graph_module, node.target) 60 raise RuntimeError(f"unsupported param type, {node.op}.") 61 62 63def is_get_attr(node: torch.fx.Node): 64 """ 65 Returns true if the given node is a get attr node for a tensor of the model 66 """ 67 return isinstance(node, torch.fx.Node) and node.op == "get_attr" 68 69 70def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool: 71 """ 72 Check if a node is a lifted parameter (static data like weights and bias are 73 are supplied as inputs to the graph. 74 75 Args: 76 exp_prog (torch.export.ExportedProgram): _description_ 77 node (torch.fx.Node): _description_ 78 79 Returns: 80 bool: _description_ 81 """ 82 return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node) 83