1from __future__ import annotations 2 3from typing import Sequence 4 5from torchgen import local 6from torchgen.api.types import ( 7 ArgName, 8 ArrayCType, 9 ArrayRefCType, 10 BaseCType, 11 BaseTypeToCppMapping, 12 Binding, 13 boolT, 14 ConstRefCType, 15 CType, 16 dimnameListT, 17 intArrayRefT, 18 iTensorListRefT, 19 ListCType, 20 longT, 21 MutRefCType, 22 NamedCType, 23 OptionalCType, 24 optionalIntArrayRefT, 25 optionalSymIntArrayRefT, 26 scalarT, 27 SpecialArgName, 28 symIntArrayRefT, 29 SymIntT, 30 tensorListT, 31 tensorOptionsT, 32 tensorT, 33 TupleCType, 34 VectorCType, 35 voidT, 36) 37from torchgen.model import ( 38 Argument, 39 Arguments, 40 BaseTy, 41 BaseType, 42 FunctionSchema, 43 ListType, 44 NativeFunction, 45 OptionalType, 46 Return, 47 SelfArgument, 48 TensorOptionsArguments, 49 Type, 50) 51from torchgen.utils import assert_never 52 53 54# This file describes the translation of JIT schema to the public C++ 55# API, which is what people use when they call functions like at::add. 56# 57# Prominent characteristics of the C++ API: 58# 59# - dtype, layout, device and pin_memory are collected into 60# a single C++ type TensorOptions (the native functions API 61# also has this, but tensor options is really most relevant 62# for the C++ API; it makes calling kwarg factory functions 63# pleasant) 64# 65# - defaulting lives here (in fact, the dispatcher is completely 66# oblivious of defaults!) 67# 68# BTW: policy on name collisions: we try not to have types with 69# collisions, but functions are fair game to collide 70 71 72def name( 73 func: FunctionSchema, 74 *, 75 faithful_name_for_out_overloads: bool = False, 76 symint_overload: bool = False, 77) -> str: 78 name = str(func.name.name) 79 if symint_overload: 80 name += "_symint" 81 if func.is_out_fn(): 82 if faithful_name_for_out_overloads: 83 name += "_outf" 84 else: 85 name += "_out" 86 87 return name 88 89 90# Translation of "value types" in JIT schema to C++ API type. Value 91# types look the same no matter if they are argument types or return 92# types. Returns None if the type in question is not a value type. 93def valuetype_type( 94 t: Type, 95 *, 96 binds: ArgName, 97 mutable: bool = True, 98 remove_non_owning_ref_types: bool = False, 99 symint: bool = False, 100) -> NamedCType | None: 101 if isinstance(t, BaseType): 102 if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: 103 return None 104 elif str(t) == "SymInt": 105 if symint: 106 return NamedCType(binds, BaseCType(SymIntT)) 107 else: 108 return NamedCType(binds, BaseCType(longT)) 109 if remove_non_owning_ref_types: 110 if t.name == BaseTy.str: 111 raise AssertionError( 112 "string ref->value conversion: not implemented yet" 113 ) 114 # All other BaseType currently map directly to BaseCppTypes. 115 return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) 116 elif isinstance(t, OptionalType): 117 elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) 118 if elem is None: 119 return None 120 return NamedCType(binds, OptionalCType(elem.type)) 121 elif isinstance(t, ListType): 122 if str(t.elem) == "bool": 123 assert t.size is not None 124 return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) 125 else: 126 return None 127 else: 128 raise AssertionError(f"unrecognized type {repr(t)}") 129 130 131# Translation of types occurring in JIT arguments to a C++ argument type. 132# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. 133# For example, we'll return std::vector<int> instead of IntArrayRef. 134# See Note [translation from C++ reference to value types] 135def argumenttype_type( 136 t: Type, 137 *, 138 mutable: bool, 139 binds: ArgName, 140 remove_non_owning_ref_types: bool = False, 141 symint: bool = False, 142) -> NamedCType: 143 # If it's a value type, do the value type translation 144 r = valuetype_type( 145 t, 146 binds=binds, 147 mutable=mutable, 148 symint=symint, 149 remove_non_owning_ref_types=remove_non_owning_ref_types, 150 ) 151 if r is not None: 152 return r 153 154 if isinstance(t, BaseType): 155 if t.name == BaseTy.Tensor: 156 if mutable and not local.use_const_ref_for_mutable_tensors(): 157 return NamedCType(binds, MutRefCType(BaseCType(tensorT))) 158 else: 159 return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) 160 elif t.name == BaseTy.Scalar: 161 return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 162 else: 163 raise AssertionError(f"base type should have been value type {t}") 164 elif isinstance(t, OptionalType): 165 if str(t.elem) == "Tensor": 166 if mutable and not local.use_const_ref_for_mutable_tensors(): 167 return NamedCType( 168 binds, MutRefCType(BaseCType(tensorT)) 169 ) # TODO: fix this discrepancy 170 else: 171 return NamedCType( 172 binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) 173 ) 174 elif str(t.elem) == "Scalar": 175 return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) 176 elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": 177 return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 178 elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt": 179 if symint: 180 return NamedCType(binds, BaseCType(optionalSymIntArrayRefT)) 181 else: 182 return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 183 elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) 184 return NamedCType(binds, OptionalCType(elem.type)) 185 elif isinstance(t, ListType): 186 # TODO: remove these special cases, ArrayRef fallthrough works fine 187 if str(t.elem) == "int": 188 if remove_non_owning_ref_types: 189 return NamedCType(binds, VectorCType(BaseCType(longT))) 190 else: 191 return NamedCType(binds, BaseCType(intArrayRefT)) 192 if str(t.elem) == "SymInt": 193 if remove_non_owning_ref_types: 194 if symint: 195 return NamedCType(binds, VectorCType(BaseCType(SymIntT))) 196 else: 197 return NamedCType(binds, VectorCType(BaseCType(longT))) 198 else: 199 if symint: 200 return NamedCType(binds, BaseCType(symIntArrayRefT)) 201 else: 202 return NamedCType(binds, BaseCType(intArrayRefT)) 203 if str(t.elem) == "Tensor": 204 if local.use_ilistref_for_tensor_lists(): 205 return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) 206 else: 207 return NamedCType(binds, BaseCType(tensorListT)) 208 elif str(t.elem) == "Scalar": 209 return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) 210 elif str(t.elem) == "Dimname": 211 return NamedCType(binds, BaseCType(dimnameListT)) 212 elif str(t.elem) == "Tensor?": 213 return NamedCType( 214 binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) 215 ) 216 elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) 217 return NamedCType(binds, ArrayRefCType(elem.type)) 218 else: 219 raise AssertionError(f"unrecognized type {repr(t)}") 220 221 222# Translate a JIT argument into its C++ type 223def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: 224 return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) 225 226 227# Translation of a (non-multi) return type from JIT to C++ 228# N.B: returntype_type returns a CType, not a NamedCType. 229# This is mostly because of the mismatch between return types and return names. 230# e.g. a function with a return type of 'void' has 0 return names, 231# and a function with a return type of 'std::tuple' has >1 return name. 232def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: 233 # placeholder is ignored 234 # NB: symint is ALWAYS respected for return types. So symint argument 235 # here is IGNORED 236 r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) 237 if r is not None: 238 return r.type 239 240 if isinstance(t, BaseType): 241 if t.name == BaseTy.Tensor: 242 if mutable: 243 if local.use_const_ref_for_mutable_tensors(): 244 return ConstRefCType(BaseCType(tensorT)) 245 else: 246 return MutRefCType(BaseCType(tensorT)) 247 else: 248 # Note [Tensor Copy Returns] 249 # Currently, we use "Argument.is_write" to determine 250 # whether or not Tensor return types should be copies or references. 251 # If that ever changes, take a look at other locations of this note! 252 return BaseCType(tensorT) 253 elif t.name == BaseTy.Scalar: 254 return BaseCType(scalarT) 255 elif isinstance(t, ListType): 256 assert ( 257 not mutable 258 ), "Native functions should never return a mutable tensor list. They should return void." 259 elem = returntype_type(t.elem, mutable=False) 260 assert t.size is None, f"fixed size list returns not supported: {t}" 261 return VectorCType(elem) 262 elif isinstance(t, OptionalType): 263 elem = returntype_type(t.elem, mutable=mutable) 264 if str(t.elem) == "Tensor": 265 return OptionalCType(elem) 266 267 raise AssertionError(f"unrecognized return type {t}") 268 269 270# Translation of a single return to its C++ type 271def return_type(r: Return, *, symint: bool = False) -> CType: 272 return returntype_type(r.type, mutable=r.is_write, symint=symint) 273 274 275# Translation of a full (possibly multi) return from JIT to its C++ type 276def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: 277 if len(rs) == 0: 278 return BaseCType(voidT) 279 elif len(rs) == 1: 280 return return_type(rs[0], symint=symint) 281 else: 282 return TupleCType([return_type(r, symint=symint) for r in rs]) 283 284 285def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: 286 returns: list[str] = [] 287 for i, r in enumerate(f.func.returns): 288 # If we have an inplace function, the return argument is 289 # implicitly named self. 290 # TODO: Consider incorporating this into the data model 291 if f.func.name.name.inplace: 292 assert i == 0, "illegal inplace function with multiple returns" 293 name = "self" 294 # If we are out function, the name is the name of the 295 # corresponding output function (r.name will get recorded 296 # in field_name later.) 297 elif f.func.is_out_fn(): 298 name = f.func.arguments.out[i].name 299 # If the return argument is explicitly named... 300 elif r.name: 301 name_conflict = any( 302 r.name == a.name for a in f.func.schema_order_arguments() 303 ) 304 if name_conflict and not f.func.is_out_fn(): 305 name = f"{r.name}_return" 306 else: 307 name = r.name 308 # If there is no explicit name and no fallback name was passed in, we just name the output result, 309 # unless it's a multi-return, in which case it's result0, 310 # result1, etc (zero-indexed) 311 else: 312 name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" 313 returns.append(name) 314 return returns 315 316 317JIT_TO_CPP_DEFAULT = { 318 "False": "false", 319 "True": "true", 320 "None": "::std::nullopt", # UGH this one is type directed 321 "Mean": "at::Reduction::Mean", 322 "[]": "{}", 323 "contiguous_format": "c10::MemoryFormat::Contiguous", 324 "long": "at::kLong", 325} 326 327 328# Convert a JIT default into C++ expression representing the default 329def default_expr(d: str, t: Type, *, symint: bool) -> str: 330 if d == "None" and str(t) == "Tensor?": 331 return "{}" 332 if isinstance(t, BaseType) and t.name is BaseTy.str: 333 # Schema allows single quotes but C++ needs double 334 if len(d) >= 2 and d[0] == "'" and d[-1] == "'": 335 s = "" 336 i = 1 337 while i + 1 < len(d): 338 if d[i] != "\\": 339 if d[i] == '"': 340 s += '\\"' 341 else: 342 s += d[i] 343 i += 1 344 else: 345 if d[i + 1] == "'": 346 s += "'" 347 else: 348 s += d[i : i + 2] 349 i += 2 350 351 return f'"{s}"' 352 353 if isinstance(t, OptionalType): 354 if d == "None": 355 return "::std::nullopt" 356 357 return default_expr(d, t.elem, symint=symint) 358 359 if isinstance(t, ListType): 360 if d.startswith("[") and d.endswith("]"): 361 return "{" + d[1:-1] + "}" 362 elif symint and d.isdigit() and str(t.elem) == "SymInt": 363 return f"c10::SymInt({d})" 364 elif t.size is None: 365 # NOTE: Sized lists can have scalar defaults 366 raise ValueError(f"Expected a list default '[...]' but found: '{d}'") 367 368 return JIT_TO_CPP_DEFAULT.get(d, d) 369 370 371# Convert an argument into its C++ API form 372 373 374def argument( 375 a: Argument | TensorOptionsArguments | SelfArgument, 376 *, 377 cpp_no_default_args: set[str], 378 method: bool, 379 faithful: bool, 380 symint: bool = False, 381 has_tensor_options: bool, 382) -> list[Binding]: 383 def sub_argument( 384 a: Argument | TensorOptionsArguments | SelfArgument, 385 ) -> list[Binding]: 386 return argument( 387 a, 388 cpp_no_default_args=cpp_no_default_args, 389 method=method, 390 faithful=faithful, 391 symint=symint, 392 has_tensor_options=has_tensor_options, 393 ) 394 395 if isinstance(a, Argument): 396 binds: ArgName 397 if a.name == "memory_format" and has_tensor_options: 398 binds = SpecialArgName.possibly_redundant_memory_format 399 else: 400 binds = a.name 401 default: str | None = None 402 if a.name not in cpp_no_default_args and a.default is not None: 403 default = default_expr(a.default, a.type, symint=symint) 404 return [ 405 Binding( 406 nctype=argument_type(a, binds=binds, symint=symint), 407 name=a.name, 408 default=default, 409 argument=a, 410 ) 411 ] 412 elif isinstance(a, TensorOptionsArguments): 413 if faithful: 414 return ( 415 sub_argument(a.dtype) 416 + sub_argument(a.layout) 417 + sub_argument(a.device) 418 + sub_argument(a.pin_memory) 419 ) 420 else: 421 default = None 422 # Enforced by NativeFunction.__post_init__ 423 assert "options" not in cpp_no_default_args 424 if all(x.default == "None" for x in a.all()): 425 default = "{}" 426 elif a.dtype.default == "long": 427 default = "at::kLong" # TODO: this is wrong 428 return [ 429 Binding( 430 nctype=NamedCType("options", BaseCType(tensorOptionsT)), 431 name="options", 432 default=default, 433 argument=a, 434 ) 435 ] 436 elif isinstance(a, SelfArgument): 437 if method: 438 # Caller is responsible for installing implicit this in context! 439 return [] 440 else: 441 return sub_argument(a.argument) 442 else: 443 assert_never(a) 444 445 446def arguments( 447 arguments: Arguments, 448 *, 449 faithful: bool, 450 symint: bool = False, 451 method: bool, 452 cpp_no_default_args: set[str], 453) -> list[Binding]: 454 args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 455 if faithful: 456 args.extend(arguments.non_out) 457 args.extend(arguments.out) 458 else: 459 args.extend(arguments.out) 460 args.extend(arguments.non_out) 461 return [ 462 r.no_default() if faithful else r 463 for a in args 464 for r in argument( 465 a, 466 faithful=faithful, 467 symint=symint, 468 method=method, 469 has_tensor_options=arguments.tensor_options is not None, 470 cpp_no_default_args=cpp_no_default_args, 471 ) 472 ] 473