1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8# 9# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction 10# of key information. These are used by the initial compile stage which captures 11# the standardised TOSA representation. 12# 13 14import serializer.tosa_serializer as ts 15import torch 16 17 18UNSUPPORTED_DTYPES = ( 19 torch.float64, 20 torch.double, 21 torch.complex64, 22 torch.cfloat, 23 torch.complex128, 24 torch.cdouble, 25 torch.uint8, 26 torch.int64, 27 torch.long, 28) 29 30DTYPE_MAP = { 31 torch.float32: ts.DType.FP32, 32 torch.float: ts.DType.FP32, 33 torch.float16: ts.DType.FP16, 34 torch.half: ts.DType.FP16, 35 torch.bfloat16: ts.DType.BF16, 36 torch.int8: ts.DType.INT8, 37 torch.int16: ts.DType.INT16, 38 torch.short: ts.DType.INT16, 39 torch.int32: ts.DType.INT32, 40 torch.int: ts.DType.INT32, 41 torch.bool: ts.DType.BOOL, 42} 43 44 45def map_dtype(data_type): 46 assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}" 47 assert data_type in DTYPE_MAP, f"Unknown type: {data_type}" 48 return DTYPE_MAP[data_type] 49 50 51# Returns the shape and type of a node 52# TODO: other types, can be 53# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None 54def extract_tensor_meta(meta): 55 assert meta.get("val") is not None 56 val = meta["val"] 57 if type(val) is tuple: 58 # TODO: should use first concrete representation 59 val = val[0] 60 61 assert torch._subclasses.fake_tensor.FakeTensor == type(val) 62 dtype = map_dtype(val.dtype) 63 shape = tuple(val.size()) 64 65 if meta.get("tosa_dim_order") is not None: 66 dim_order = meta["tosa_dim_order"] 67 else: 68 dim_order = tuple(range(len(shape))) 69 return (dtype, shape, dim_order) 70 71 72# Class to capture arguments and turn into tensor references for TOSA OPs 73class TosaArg: 74 def __process_node(self, argument): 75 assert isinstance(argument, torch.fx.node.Node) 76 self.name = argument.name 77 self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta) 78 79 def __process_list(self, argument): 80 self.special = list(argument) 81 82 def __process_number(self, argument): 83 self.number = argument 84 85 def __init__(self, argument) -> None: 86 self.name = None 87 self.dtype = None 88 self.shape = None 89 self.dim_order = None 90 self.special = None 91 92 if argument is None: 93 return 94 95 if isinstance(argument, torch.fx.node.Node): 96 self.__process_node(argument) 97 return 98 if isinstance(argument, list): 99 self.__process_list(argument) 100 return 101 if isinstance(argument, int): 102 self.__process_number(argument) 103 return 104 if isinstance(argument, float): 105 self.__process_number(argument) 106 return 107 108 RuntimeError( 109 f"Unhandled node input argument: {argument}, of type {type(argument)}" 110 ) 111