xref: /aosp_15_r20/external/executorch/backends/apple/mps/utils/mps_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast, Optional, Union
7
8import torch
9from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType
10from executorch.exir import ExportedProgram
11from torch._export.utils import get_buffer, get_param, is_buffer, is_param
12
13
14def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]:
15    return None if node is None else cast(torch.fx.Node, node.args[input_index])
16
17
18def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]:
19    return node.args[input_index]
20
21
22def edge_dtype_to_mps_dtype(dtype: torch.dtype):
23    if not hasattr(edge_dtype_to_mps_dtype, "map"):
24        edge_dtype_to_mps_dtype.map = {
25            torch.float16: MPSDataType.mps_data_type_float16,
26            torch.float32: MPSDataType.mps_data_type_float32,
27            torch.float64: MPSDataType.mps_data_type_float32,
28            torch.bfloat16: MPSDataType.mps_data_type_bfloat16,
29            torch.int8: MPSDataType.mps_data_type_int8,
30            torch.int16: MPSDataType.mps_data_type_int16,
31            torch.int32: MPSDataType.mps_data_type_int32,
32            torch.int64: MPSDataType.mps_data_type_int64,
33            torch.uint8: MPSDataType.mps_data_type_uint8,
34            torch.bool: MPSDataType.mps_data_type_bool,
35            torch.cfloat: MPSDataType.mps_data_type_complex_float32,
36            torch.chalf: MPSDataType.mps_data_type_complex_float16,
37        }
38    try:
39        return edge_dtype_to_mps_dtype.map[dtype]
40    except KeyError:
41        raise RuntimeError(f"Invalid data type: {dtype}")
42
43
44def get_param_tensor(
45    exp_prog: ExportedProgram, node: torch.fx.Node
46) -> Optional[torch.Tensor]:
47    if node is None:
48        return None
49    elif is_param(exp_prog, node):
50        return get_param(exp_prog, node)
51    elif is_buffer(exp_prog, node):
52        return get_buffer(exp_prog, node)
53    elif is_get_attr(node):
54        # Support both lifted and unlifted graph
55        try:
56            # Unlifted graph (coming from old exir.capture API)
57            return getattr(node.graph.owning_module, node.target)
58        except AttributeError:
59            return getattr(exp_prog.graph_module, node.target)
60    raise RuntimeError(f"unsupported param type, {node.op}.")
61
62
63def is_get_attr(node: torch.fx.Node):
64    """
65    Returns true if the given node is a get attr node for a tensor of the model
66    """
67    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
68
69
70def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool:
71    """
72    Check if a node is a lifted parameter (static data like weights and bias are
73    are supplied as inputs to the graph.
74
75    Args:
76        exp_prog (torch.export.ExportedProgram): _description_
77        node (torch.fx.Node): _description_
78
79    Returns:
80        bool: _description_
81    """
82    return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node)
83