1""" 2Where should I add a new type? `types_base.py` vs `types.py` 3 4This file defines data model classes for torchgen typing system, as well as some base types such as int32_t. 5 6`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types. 7 8The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't 9contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused 10if we want to generate code for another C++ library. 11 12Add new types to `types.py` if these types are ATen/c10 related. 13Add new types to `types_base.py` if they are basic and not attached to ATen/c10. 14""" 15 16from __future__ import annotations 17 18from abc import ABC, abstractmethod 19from dataclasses import dataclass 20from enum import auto, Enum 21from typing import TYPE_CHECKING, Union 22 23 24if TYPE_CHECKING: 25 from torchgen.model import Argument, SelfArgument, TensorOptionsArguments 26 27 28# An ArgName is just the str name of the argument in schema; 29# but in some special circumstances, we may add a little extra 30# context. The Enum SpecialArgName covers all of these cases; 31# grep for their construction sites to see when they can occur. 32 33 34class SpecialArgName(Enum): 35 possibly_redundant_memory_format = auto() 36 37 38ArgName = Union[str, SpecialArgName] 39 40 41# This class shouldn't be created directly; instead, use/create one of the singletons below. 42@dataclass(frozen=True) 43class BaseCppType: 44 ns: str | None 45 name: str 46 47 def __str__(self) -> str: 48 if self.ns is None or self.ns == "": 49 return self.name 50 return f"{self.ns}::{self.name}" 51 52 53# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen. 54# Templated types get their own dataclass, mainly to make namespace parsing easier. 55byteT = BaseCppType("", "uint8_t") 56charT = BaseCppType("", "int8_t") 57shortT = BaseCppType("", "int16_t") 58# It would be more symmetric for this to be called intT, but it easy to mix 59# this up with JIT int (which is int64_t in C++), so we intentionally don't 60# define intT to make it obvious when you've stuffed it up 61int32T = BaseCppType("", "int32_t") 62longT = BaseCppType("", "int64_t") 63doubleT = BaseCppType("", "double") 64floatT = BaseCppType("", "float") 65boolT = BaseCppType("", "bool") 66voidT = BaseCppType("", "void") 67 68 69class CType(ABC): 70 @abstractmethod 71 def cpp_type(self, *, strip_ref: bool = False) -> str: 72 raise NotImplementedError 73 74 @abstractmethod 75 def cpp_type_registration_declarations(self) -> str: 76 raise NotImplementedError 77 78 @abstractmethod 79 def remove_const_ref(self) -> CType: 80 return self 81 82 83@dataclass(frozen=True) 84class BaseCType(CType): 85 type: BaseCppType 86 87 def cpp_type(self, *, strip_ref: bool = False) -> str: 88 return str(self.type) 89 90 # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml 91 # TODO: Kill this when we eventually remove it! 92 def cpp_type_registration_declarations(self) -> str: 93 return str(self.type).replace("at::", "") 94 95 def remove_const_ref(self) -> CType: 96 return self 97 98 99@dataclass(frozen=True) 100class ConstRefCType(CType): 101 elem: CType 102 103 def cpp_type(self, *, strip_ref: bool = False) -> str: 104 if strip_ref: 105 return self.elem.cpp_type(strip_ref=strip_ref) 106 return f"const {self.elem.cpp_type()} &" 107 108 def cpp_type_registration_declarations(self) -> str: 109 return f"const {self.elem.cpp_type_registration_declarations()} &" 110 111 def remove_const_ref(self) -> CType: 112 return self.elem.remove_const_ref() 113 114 115@dataclass(frozen=True) 116class VectorCType(CType): 117 elem: CType 118 119 def cpp_type(self, *, strip_ref: bool = False) -> str: 120 # Do not pass `strip_ref` recursively. 121 return f"::std::vector<{self.elem.cpp_type()}>" 122 123 def cpp_type_registration_declarations(self) -> str: 124 return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>" 125 126 def remove_const_ref(self) -> CType: 127 return VectorCType(self.elem.remove_const_ref()) 128 129 130@dataclass(frozen=True) 131class ArrayCType(CType): 132 elem: CType 133 size: int 134 135 def cpp_type(self, *, strip_ref: bool = False) -> str: 136 # Do not pass `strip_ref` recursively. 137 return f"::std::array<{self.elem.cpp_type()},{self.size}>" 138 139 def cpp_type_registration_declarations(self) -> str: 140 return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>" 141 142 def remove_const_ref(self) -> CType: 143 return ArrayCType(self.elem.remove_const_ref(), self.size) 144 145 146@dataclass(frozen=True) 147class TupleCType(CType): 148 elems: list[CType] 149 150 def cpp_type(self, *, strip_ref: bool = False) -> str: 151 # Do not pass `strip_ref` recursively. 152 return f'::std::tuple<{",".join([e.cpp_type() for e in self.elems])}>' 153 154 def cpp_type_registration_declarations(self) -> str: 155 return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>' 156 157 def remove_const_ref(self) -> CType: 158 return TupleCType([e.remove_const_ref() for e in self.elems]) 159 160 161@dataclass(frozen=True) 162class MutRefCType(CType): 163 elem: CType 164 165 def cpp_type(self, *, strip_ref: bool = False) -> str: 166 if strip_ref: 167 return self.elem.cpp_type(strip_ref=strip_ref) 168 return f"{self.elem.cpp_type()} &" 169 170 def cpp_type_registration_declarations(self) -> str: 171 return f"{self.elem.cpp_type_registration_declarations()} &" 172 173 def remove_const_ref(self) -> CType: 174 return self.elem.remove_const_ref() 175 176 177# A NamedCType is short for Named C++ semantic type. A NamedCType represents a C++ type, plus 178# semantic information about what it represents. For example, consider the 179# argument "bool pin_memory"; its normal C++ type is "bool", but its C++ 180# semantic type also keeps track that this represents a "pin_memory"; you can't 181# just use a random other boolean in a context where you need a "pin_memory"! 182# 183 184 185@dataclass(frozen=True) 186class NamedCType: 187 name: ArgName 188 type: CType 189 190 def cpp_type(self, *, strip_ref: bool = False) -> str: 191 return self.type.cpp_type(strip_ref=strip_ref) 192 193 # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml 194 # TODO: Kill this when we eventually remove it! 195 def cpp_type_registration_declarations(self) -> str: 196 return self.type.cpp_type_registration_declarations() 197 198 def remove_const_ref(self) -> NamedCType: 199 return NamedCType(self.name, self.type.remove_const_ref()) 200 201 def with_name(self, name: str) -> NamedCType: 202 return NamedCType(name, self.type) 203 204 205# A binding represents any C++ binding site for a formal parameter. 206# We don't distinguish between binding sites for different APIs; 207# instead, all of the important distinctions are encoded in CType, 208# which you can use to figure out if a given Binding is appropriate 209# for use in another context. (See torchgen.api.translate) 210 211 212@dataclass(frozen=True) 213class Binding: 214 name: str 215 nctype: NamedCType 216 argument: Argument | TensorOptionsArguments | SelfArgument 217 # TODO: maybe don't represent default here 218 default: str | None = None 219 220 def rename(self, name: str) -> Binding: 221 return Binding( 222 name=name, 223 nctype=self.nctype, 224 argument=self.argument, 225 default=self.default, 226 ) 227 228 @property 229 def type(self) -> str: 230 return self.nctype.cpp_type() 231 232 def no_default(self) -> Binding: 233 return Binding( 234 name=self.name, 235 nctype=self.nctype, 236 default=None, 237 argument=self.argument, 238 ) 239 240 def decl(self, *, func_ptr_cast: bool = False) -> str: 241 mb_default = "" 242 if self.default is not None: 243 mb_default = f"={self.default}" 244 245 # casting only needs to know the type 246 if func_ptr_cast: 247 return f"{self.type}" 248 else: 249 return f"{self.type} {self.name}{mb_default}" 250 251 # For BC reasons, we don't want to introduce at:: namespaces to RegistrationDeclarations.yaml 252 # TODO: Kill this when we eventually remove it! 253 def decl_registration_declarations(self) -> str: 254 type_s = self.nctype.cpp_type_registration_declarations() 255 mb_default = "" 256 if self.default is not None: 257 mb_default = f"={self.default}" 258 return f"{type_s} {self.name}{mb_default}" 259 260 def defn(self) -> str: 261 return f"{self.type} {self.name}" 262 263 def with_name(self, name: str) -> Binding: 264 return Binding( 265 name=name, nctype=self.nctype, argument=self.argument, default=self.default 266 ) 267 268 269# An Expr is a C++ expression. It has a C++ string representing its syntax, 270# as well as a CType saying what it provides. 271 272 273@dataclass(frozen=True) 274class Expr: 275 expr: str 276 type: NamedCType 277