1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from enum import Enum 8 9 10class ArgType(str, Enum): 11 Tensor = "Tensor" 12 TensorOpt = "Tensor?" 13 TensorList = "Tensor[]" 14 TensorOptList = "Tensor?[]" 15 Scalar = "Scalar" 16 ScalarOpt = "Scalar?" 17 ScalarType = "ScalarType" 18 ScalarTypeOpt = "ScalarType?" 19 Param = "Param" 20 21 def is_tensor(self): 22 return self in [ArgType.Tensor, ArgType.TensorOpt] 23 24 def is_tensor_list(self): 25 return self in [ArgType.TensorList, ArgType.TensorOptList] 26 27 def is_scalar(self): 28 return self in [ArgType.Scalar, ArgType.ScalarOpt] 29 30 def is_scalar_type(self): 31 return self in [ArgType.ScalarType, ArgType.ScalarTypeOpt] 32 33 def is_optional(self): 34 return self in [ 35 ArgType.TensorOpt, 36 ArgType.ScalarOpt, 37 ArgType.ScalarTypeOpt, 38 ] 39 40 def has_dtype(self): 41 return ( 42 self.is_tensor() 43 or self.is_tensor_list() 44 or self.is_scalar() 45 or self.is_scalar_type() 46 ) 47