1*523fa7a6SAndroid Build Coastguard Worker# Copyright 2023-2024 Arm Limited and/or its affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 4*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 5*523fa7a6SAndroid Build Coastguard Worker 6*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker# 9*523fa7a6SAndroid Build Coastguard Worker# PyTorch to Tosa mapping - simple mapping functions and multi-type extraction 10*523fa7a6SAndroid Build Coastguard Worker# of key information. These are used by the initial compile stage which captures 11*523fa7a6SAndroid Build Coastguard Worker# the standardised TOSA representation. 12*523fa7a6SAndroid Build Coastguard Worker# 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport serializer.tosa_serializer as ts 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard WorkerUNSUPPORTED_DTYPES = ( 19*523fa7a6SAndroid Build Coastguard Worker torch.float64, 20*523fa7a6SAndroid Build Coastguard Worker torch.double, 21*523fa7a6SAndroid Build Coastguard Worker torch.complex64, 22*523fa7a6SAndroid Build Coastguard Worker torch.cfloat, 23*523fa7a6SAndroid Build Coastguard Worker torch.complex128, 24*523fa7a6SAndroid Build Coastguard Worker torch.cdouble, 25*523fa7a6SAndroid Build Coastguard Worker torch.uint8, 26*523fa7a6SAndroid Build Coastguard Worker torch.int64, 27*523fa7a6SAndroid Build Coastguard Worker torch.long, 28*523fa7a6SAndroid Build Coastguard Worker) 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard WorkerDTYPE_MAP = { 31*523fa7a6SAndroid Build Coastguard Worker torch.float32: ts.DType.FP32, 32*523fa7a6SAndroid Build Coastguard Worker torch.float: ts.DType.FP32, 33*523fa7a6SAndroid Build Coastguard Worker torch.float16: ts.DType.FP16, 34*523fa7a6SAndroid Build Coastguard Worker torch.half: ts.DType.FP16, 35*523fa7a6SAndroid Build Coastguard Worker torch.bfloat16: ts.DType.BF16, 36*523fa7a6SAndroid Build Coastguard Worker torch.int8: ts.DType.INT8, 37*523fa7a6SAndroid Build Coastguard Worker torch.int16: ts.DType.INT16, 38*523fa7a6SAndroid Build Coastguard Worker torch.short: ts.DType.INT16, 39*523fa7a6SAndroid Build Coastguard Worker torch.int32: ts.DType.INT32, 40*523fa7a6SAndroid Build Coastguard Worker torch.int: ts.DType.INT32, 41*523fa7a6SAndroid Build Coastguard Worker torch.bool: ts.DType.BOOL, 42*523fa7a6SAndroid Build Coastguard Worker} 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerdef map_dtype(data_type): 46*523fa7a6SAndroid Build Coastguard Worker assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}" 47*523fa7a6SAndroid Build Coastguard Worker assert data_type in DTYPE_MAP, f"Unknown type: {data_type}" 48*523fa7a6SAndroid Build Coastguard Worker return DTYPE_MAP[data_type] 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker# Returns the shape and type of a node 52*523fa7a6SAndroid Build Coastguard Worker# TODO: other types, can be 53*523fa7a6SAndroid Build Coastguard Worker# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None 54*523fa7a6SAndroid Build Coastguard Workerdef extract_tensor_meta(meta): 55*523fa7a6SAndroid Build Coastguard Worker assert meta.get("val") is not None 56*523fa7a6SAndroid Build Coastguard Worker val = meta["val"] 57*523fa7a6SAndroid Build Coastguard Worker if type(val) is tuple: 58*523fa7a6SAndroid Build Coastguard Worker # TODO: should use first concrete representation 59*523fa7a6SAndroid Build Coastguard Worker val = val[0] 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker assert torch._subclasses.fake_tensor.FakeTensor == type(val) 62*523fa7a6SAndroid Build Coastguard Worker dtype = map_dtype(val.dtype) 63*523fa7a6SAndroid Build Coastguard Worker shape = tuple(val.size()) 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker if meta.get("tosa_dim_order") is not None: 66*523fa7a6SAndroid Build Coastguard Worker dim_order = meta["tosa_dim_order"] 67*523fa7a6SAndroid Build Coastguard Worker else: 68*523fa7a6SAndroid Build Coastguard Worker dim_order = tuple(range(len(shape))) 69*523fa7a6SAndroid Build Coastguard Worker return (dtype, shape, dim_order) 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker# Class to capture arguments and turn into tensor references for TOSA OPs 73*523fa7a6SAndroid Build Coastguard Workerclass TosaArg: 74*523fa7a6SAndroid Build Coastguard Worker def __process_node(self, argument): 75*523fa7a6SAndroid Build Coastguard Worker assert isinstance(argument, torch.fx.node.Node) 76*523fa7a6SAndroid Build Coastguard Worker self.name = argument.name 77*523fa7a6SAndroid Build Coastguard Worker self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta) 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker def __process_list(self, argument): 80*523fa7a6SAndroid Build Coastguard Worker self.special = list(argument) 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker def __process_number(self, argument): 83*523fa7a6SAndroid Build Coastguard Worker self.number = argument 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Worker def __init__(self, argument) -> None: 86*523fa7a6SAndroid Build Coastguard Worker self.name = None 87*523fa7a6SAndroid Build Coastguard Worker self.dtype = None 88*523fa7a6SAndroid Build Coastguard Worker self.shape = None 89*523fa7a6SAndroid Build Coastguard Worker self.dim_order = None 90*523fa7a6SAndroid Build Coastguard Worker self.special = None 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker if argument is None: 93*523fa7a6SAndroid Build Coastguard Worker return 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker if isinstance(argument, torch.fx.node.Node): 96*523fa7a6SAndroid Build Coastguard Worker self.__process_node(argument) 97*523fa7a6SAndroid Build Coastguard Worker return 98*523fa7a6SAndroid Build Coastguard Worker if isinstance(argument, list): 99*523fa7a6SAndroid Build Coastguard Worker self.__process_list(argument) 100*523fa7a6SAndroid Build Coastguard Worker return 101*523fa7a6SAndroid Build Coastguard Worker if isinstance(argument, int): 102*523fa7a6SAndroid Build Coastguard Worker self.__process_number(argument) 103*523fa7a6SAndroid Build Coastguard Worker return 104*523fa7a6SAndroid Build Coastguard Worker if isinstance(argument, float): 105*523fa7a6SAndroid Build Coastguard Worker self.__process_number(argument) 106*523fa7a6SAndroid Build Coastguard Worker return 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Worker RuntimeError( 109*523fa7a6SAndroid Build Coastguard Worker f"Unhandled node input argument: {argument}, of type {type(argument)}" 110*523fa7a6SAndroid Build Coastguard Worker ) 111