xref: /aosp_15_r20/external/executorch/backends/arm/tosa_mapping.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker#
9*523fa7a6SAndroid Build Coastguard Worker# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction
10*523fa7a6SAndroid Build Coastguard Worker# of key information. These are used by the initial compile stage which captures
11*523fa7a6SAndroid Build Coastguard Worker# the standardised TOSA representation.
12*523fa7a6SAndroid Build Coastguard Worker#
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard WorkerUNSUPPORTED_DTYPES = (
19*523fa7a6SAndroid Build Coastguard Worker    torch.float64,
20*523fa7a6SAndroid Build Coastguard Worker    torch.double,
21*523fa7a6SAndroid Build Coastguard Worker    torch.complex64,
22*523fa7a6SAndroid Build Coastguard Worker    torch.cfloat,
23*523fa7a6SAndroid Build Coastguard Worker    torch.complex128,
24*523fa7a6SAndroid Build Coastguard Worker    torch.cdouble,
25*523fa7a6SAndroid Build Coastguard Worker    torch.uint8,
26*523fa7a6SAndroid Build Coastguard Worker    torch.int64,
27*523fa7a6SAndroid Build Coastguard Worker    torch.long,
28*523fa7a6SAndroid Build Coastguard Worker)
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard WorkerDTYPE_MAP = {
31*523fa7a6SAndroid Build Coastguard Worker    torch.float32: ts.DType.FP32,
32*523fa7a6SAndroid Build Coastguard Worker    torch.float: ts.DType.FP32,
33*523fa7a6SAndroid Build Coastguard Worker    torch.float16: ts.DType.FP16,
34*523fa7a6SAndroid Build Coastguard Worker    torch.half: ts.DType.FP16,
35*523fa7a6SAndroid Build Coastguard Worker    torch.bfloat16: ts.DType.BF16,
36*523fa7a6SAndroid Build Coastguard Worker    torch.int8: ts.DType.INT8,
37*523fa7a6SAndroid Build Coastguard Worker    torch.int16: ts.DType.INT16,
38*523fa7a6SAndroid Build Coastguard Worker    torch.short: ts.DType.INT16,
39*523fa7a6SAndroid Build Coastguard Worker    torch.int32: ts.DType.INT32,
40*523fa7a6SAndroid Build Coastguard Worker    torch.int: ts.DType.INT32,
41*523fa7a6SAndroid Build Coastguard Worker    torch.bool: ts.DType.BOOL,
42*523fa7a6SAndroid Build Coastguard Worker}
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerdef map_dtype(data_type):
46*523fa7a6SAndroid Build Coastguard Worker    assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}"
47*523fa7a6SAndroid Build Coastguard Worker    assert data_type in DTYPE_MAP, f"Unknown type: {data_type}"
48*523fa7a6SAndroid Build Coastguard Worker    return DTYPE_MAP[data_type]
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker# Returns the shape and type of a node
52*523fa7a6SAndroid Build Coastguard Worker# TODO: other types, can be
53*523fa7a6SAndroid Build Coastguard Worker# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
54*523fa7a6SAndroid Build Coastguard Workerdef extract_tensor_meta(meta):
55*523fa7a6SAndroid Build Coastguard Worker    assert meta.get("val") is not None
56*523fa7a6SAndroid Build Coastguard Worker    val = meta["val"]
57*523fa7a6SAndroid Build Coastguard Worker    if type(val) is tuple:
58*523fa7a6SAndroid Build Coastguard Worker        # TODO: should use first concrete representation
59*523fa7a6SAndroid Build Coastguard Worker        val = val[0]
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker    assert torch._subclasses.fake_tensor.FakeTensor == type(val)
62*523fa7a6SAndroid Build Coastguard Worker    dtype = map_dtype(val.dtype)
63*523fa7a6SAndroid Build Coastguard Worker    shape = tuple(val.size())
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Worker    if meta.get("tosa_dim_order") is not None:
66*523fa7a6SAndroid Build Coastguard Worker        dim_order = meta["tosa_dim_order"]
67*523fa7a6SAndroid Build Coastguard Worker    else:
68*523fa7a6SAndroid Build Coastguard Worker        dim_order = tuple(range(len(shape)))
69*523fa7a6SAndroid Build Coastguard Worker    return (dtype, shape, dim_order)
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker# Class to capture arguments and turn into tensor references for TOSA OPs
73*523fa7a6SAndroid Build Coastguard Workerclass TosaArg:
74*523fa7a6SAndroid Build Coastguard Worker    def __process_node(self, argument):
75*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(argument, torch.fx.node.Node)
76*523fa7a6SAndroid Build Coastguard Worker        self.name = argument.name
77*523fa7a6SAndroid Build Coastguard Worker        self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    def __process_list(self, argument):
80*523fa7a6SAndroid Build Coastguard Worker        self.special = list(argument)
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    def __process_number(self, argument):
83*523fa7a6SAndroid Build Coastguard Worker        self.number = argument
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, argument) -> None:
86*523fa7a6SAndroid Build Coastguard Worker        self.name = None
87*523fa7a6SAndroid Build Coastguard Worker        self.dtype = None
88*523fa7a6SAndroid Build Coastguard Worker        self.shape = None
89*523fa7a6SAndroid Build Coastguard Worker        self.dim_order = None
90*523fa7a6SAndroid Build Coastguard Worker        self.special = None
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker        if argument is None:
93*523fa7a6SAndroid Build Coastguard Worker            return
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker        if isinstance(argument, torch.fx.node.Node):
96*523fa7a6SAndroid Build Coastguard Worker            self.__process_node(argument)
97*523fa7a6SAndroid Build Coastguard Worker            return
98*523fa7a6SAndroid Build Coastguard Worker        if isinstance(argument, list):
99*523fa7a6SAndroid Build Coastguard Worker            self.__process_list(argument)
100*523fa7a6SAndroid Build Coastguard Worker            return
101*523fa7a6SAndroid Build Coastguard Worker        if isinstance(argument, int):
102*523fa7a6SAndroid Build Coastguard Worker            self.__process_number(argument)
103*523fa7a6SAndroid Build Coastguard Worker            return
104*523fa7a6SAndroid Build Coastguard Worker        if isinstance(argument, float):
105*523fa7a6SAndroid Build Coastguard Worker            self.__process_number(argument)
106*523fa7a6SAndroid Build Coastguard Worker            return
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker        RuntimeError(
109*523fa7a6SAndroid Build Coastguard Worker            f"Unhandled node input argument: {argument}, of type {type(argument)}"
110*523fa7a6SAndroid Build Coastguard Worker        )
111