xref: /aosp_15_r20/external/executorch/backends/arm/tosa_mapping.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8#
9# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction
10# of key information. These are used by the initial compile stage which captures
11# the standardised TOSA representation.
12#
13
14import serializer.tosa_serializer as ts
15import torch
16
17
18UNSUPPORTED_DTYPES = (
19    torch.float64,
20    torch.double,
21    torch.complex64,
22    torch.cfloat,
23    torch.complex128,
24    torch.cdouble,
25    torch.uint8,
26    torch.int64,
27    torch.long,
28)
29
30DTYPE_MAP = {
31    torch.float32: ts.DType.FP32,
32    torch.float: ts.DType.FP32,
33    torch.float16: ts.DType.FP16,
34    torch.half: ts.DType.FP16,
35    torch.bfloat16: ts.DType.BF16,
36    torch.int8: ts.DType.INT8,
37    torch.int16: ts.DType.INT16,
38    torch.short: ts.DType.INT16,
39    torch.int32: ts.DType.INT32,
40    torch.int: ts.DType.INT32,
41    torch.bool: ts.DType.BOOL,
42}
43
44
45def map_dtype(data_type):
46    assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}"
47    assert data_type in DTYPE_MAP, f"Unknown type: {data_type}"
48    return DTYPE_MAP[data_type]
49
50
51# Returns the shape and type of a node
52# TODO: other types, can be
53# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
54def extract_tensor_meta(meta):
55    assert meta.get("val") is not None
56    val = meta["val"]
57    if type(val) is tuple:
58        # TODO: should use first concrete representation
59        val = val[0]
60
61    assert torch._subclasses.fake_tensor.FakeTensor == type(val)
62    dtype = map_dtype(val.dtype)
63    shape = tuple(val.size())
64
65    if meta.get("tosa_dim_order") is not None:
66        dim_order = meta["tosa_dim_order"]
67    else:
68        dim_order = tuple(range(len(shape)))
69    return (dtype, shape, dim_order)
70
71
72# Class to capture arguments and turn into tensor references for TOSA OPs
73class TosaArg:
74    def __process_node(self, argument):
75        assert isinstance(argument, torch.fx.node.Node)
76        self.name = argument.name
77        self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)
78
79    def __process_list(self, argument):
80        self.special = list(argument)
81
82    def __process_number(self, argument):
83        self.number = argument
84
85    def __init__(self, argument) -> None:
86        self.name = None
87        self.dtype = None
88        self.shape = None
89        self.dim_order = None
90        self.special = None
91
92        if argument is None:
93            return
94
95        if isinstance(argument, torch.fx.node.Node):
96            self.__process_node(argument)
97            return
98        if isinstance(argument, list):
99            self.__process_list(argument)
100            return
101        if isinstance(argument, int):
102            self.__process_number(argument)
103            return
104        if isinstance(argument, float):
105            self.__process_number(argument)
106            return
107
108        RuntimeError(
109            f"Unhandled node input argument: {argument}, of type {type(argument)}"
110        )
111