1from __future__ import annotations 2 3from typing import Any 4 5from torchgen.api.types import ( 6 BaseCppType, 7 BaseCType, 8 boolT, 9 CType, 10 deviceT, 11 doubleT, 12 generatorT, 13 layoutT, 14 ListCType, 15 longT, 16 memoryFormatT, 17 NamedCType, 18 OptionalCType, 19 scalarT, 20 scalarTypeT, 21 stringT, 22 SymIntT, 23 VectorCType, 24) 25from torchgen.model import ( 26 Argument, 27 BaseTy, 28 BaseType, 29 FunctionSchema, 30 ListType, 31 OperatorName, 32 OptionalType, 33 Return, 34 TensorOptionsArguments, 35 Type, 36) 37 38 39_valueT: BaseCppType | None = None 40 41 42# A ValueT is an IR type which represents the computation of a Tensor. In other 43# words, a PyTorch user will do operations on lazy tensors, and each output lazy 44# tensor internally tracks a ValueT representing the IR node that would have 45# actually produced the value of this tensor for real. 46# 47# This is configurable because different lazy tensor backends (LTC vs XLA) will 48# have different IR representations. (Though, arguably, after unification they 49# shouldn't!) 50def getValueT() -> BaseCppType: 51 global _valueT 52 if not _valueT: 53 raise NotImplementedError( 54 "The value type needs to be set with setValueT() in run_gen_lazy_tensor()" 55 ) 56 57 return _valueT 58 59 60def setValueT(val: BaseCppType) -> None: 61 global _valueT 62 _valueT = val 63 64 65# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object, 66# making it easier to represent special properties of an arg. 67tensorListValueT = BaseCppType("torch::lazy", "Value") 68 69 70def process_ir_type( 71 typ: Type, properties: LazyIrProperties, *, symint: bool 72) -> BaseCType | VectorCType | OptionalCType | ListCType: 73 """ 74 This function takes a type from NativeFunctions and converts it for use with 75 lazy tensor codegen. 76 77 Type conversion for lazy currently consists of 78 (1) changing at::Tensors into lazy::Values 79 (2) wrapping everything in a BaseCType 80 (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef) 81 82 (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.) 83 There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like' 84 85 This is incomplete- there are assertions in places that it's expected to need to add 86 more types as the codegen is used with more operators. 87 """ 88 if isinstance(typ, BaseType): 89 if typ.name == BaseTy.Tensor: 90 return BaseCType(getValueT()) 91 elif typ.name == BaseTy.Scalar: 92 if properties.TreatScalarsAsConstants: 93 return BaseCType(scalarT) 94 # at::scalar has special handling, 95 # and is wrapped in an lazy::Value just like at::tensor 96 return BaseCType(getValueT()) 97 elif typ.name == BaseTy.ScalarType: 98 return BaseCType(scalarTypeT) 99 elif typ.name == BaseTy.int: 100 return BaseCType(longT) 101 elif typ.name == BaseTy.SymInt: 102 if symint: 103 return BaseCType(getValueT()) 104 else: 105 return BaseCType(longT) 106 elif typ.name == BaseTy.bool: 107 return BaseCType(boolT) 108 elif typ.name == BaseTy.float: 109 return BaseCType(doubleT) 110 elif typ.name == BaseTy.str: 111 return BaseCType(stringT) 112 elif typ.name == BaseTy.Device: 113 return BaseCType(deviceT) 114 elif typ.name == BaseTy.Generator: 115 return BaseCType(generatorT) 116 elif typ.name == BaseTy.Layout: 117 return BaseCType(layoutT) 118 elif typ.name == BaseTy.MemoryFormat: 119 return BaseCType(memoryFormatT) 120 else: 121 raise AssertionError(f"TODO add support for type {repr(typ)}") 122 elif isinstance(typ, OptionalType): 123 return OptionalCType(process_ir_type(typ.elem, properties, symint=symint)) 124 elif isinstance(typ, ListType): 125 if str(typ.elem) == "Tensor?": 126 # TODO(whc) is this actually correct? or should it use a Vector like above 127 return ListCType(OptionalCType(BaseCType(getValueT()))) 128 elif str(typ.elem) == "Tensor": 129 # this is a TensorList which comes in from GetTensorList as a Value 130 return BaseCType(tensorListValueT) 131 elif typ.elem == BaseType(BaseTy.SymInt): 132 # TODO: return a value type. The problem here is analogous to 133 # the problem with tensorListValueT: if you have SymInt[] you 134 # cannot conveniently save the list of Value directly, as nodes 135 # expect to save values as a vector for ALL arguments. So you 136 # need a separate IR node that represents all of the size nodes 137 # assembled into a list. I'm not an LTC dev so I don't want to 138 # figure it out right now. Y'all figure it out... 139 return VectorCType(BaseCType(longT)) 140 141 else: 142 return VectorCType(process_ir_type(typ.elem, properties, symint=symint)) 143 else: 144 raise AssertionError(f"unrecognized type {repr(typ)}") 145 146 147# TODO: Determining this based off of CType is bad; this should be computed 148# from Type directly; then the same logic as process_ir_type can be used 149# 150# Invariant: passed typ should be an *owning* CType (e.g., we will report 151# that ArrayRef<Value> is NOT a value type) 152def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool: 153 """ 154 Given a type, determine if it is a Value-like type. This is equivalent to 155 being Tensor-like, but assumes the type has already been transformed. 156 """ 157 if isinstance(typ, BaseCType): 158 # I am regretting my naming conventions, but now we are wrapping at::scalar in 159 # lazy value, while preserving other 'scalar' types as scalars in the IR 160 treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants 161 return ( 162 typ.type == getValueT() 163 or (typ.type == scalarT and not treat_scalars_as_constants) 164 or typ.type == SymIntT 165 ) 166 elif typ == VectorCType(BaseCType(SymIntT)): 167 # TODO: report True for this 168 return False 169 elif isinstance(typ, (OptionalCType, ListCType, VectorCType)): 170 return isValueType(typ.elem, properties) 171 return False 172 173 174def isSymIntType(typ: Type) -> bool: 175 return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt 176 177 178def isWrappedScalarType(typ: Type) -> bool: 179 """ 180 Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value. 181 Since we literally change the type from scalarT to valueT, information is lost. 182 This function helps build a list of wrapped scalars to save that information 183 """ 184 if isinstance(typ, BaseType): 185 # I am regretting my naming conventions, but now we are wrapping at::scalar in 186 # lazy value, while preserving other 'scalar' types as scalars in the IR 187 return typ.name == BaseTy.Scalar 188 elif isinstance(typ, (OptionalType, ListType)): 189 return isWrappedScalarType(typ.elem) 190 return False 191 192 193# TODO: dedupe with Type.is_generator_like 194def isGeneratorType(typ: Type) -> bool: 195 if isinstance(typ, BaseType): 196 return typ.name == BaseTy.Generator 197 elif isinstance(typ, (OptionalType)): 198 return isGeneratorType(typ.elem) 199 return False 200 201 202# This class caches a few derived properties computed from an Argument 203# and LazyIrProperties 204class LazyArgument: 205 name: str 206 orig_type: Type 207 lazy_type_: CType | None 208 is_wrapped_scalar: bool 209 is_generator: bool 210 # TODO: this is lies, it is false for symint list 211 is_symint_or_list: bool 212 213 # Whether or not we are treating this as symint or not 214 symint: bool 215 216 # true if this argument is or contains a lazy IR value 217 is_lazy_value: bool 218 219 def __init__( 220 self, arg: Argument, properties: LazyIrProperties, *, symint: bool 221 ) -> None: 222 self.name = arg.name 223 self.orig_type = arg.type 224 self.symint = symint 225 self.is_optional = isinstance(arg.type, OptionalType) 226 self.is_generator = isGeneratorType(arg.type) 227 self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint) 228 self.is_wrapped_scalar = isWrappedScalarType(arg.type) 229 self.is_symint_or_list = symint and ( 230 isSymIntType(arg.type) 231 or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem)) 232 # TODO: lists of symints are not currently treated as value types 233 # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem)) 234 ) 235 236 self.is_lazy_value = isValueType(self.lazy_type, properties) 237 238 @property 239 def lazy_type(self) -> CType: 240 assert ( 241 self.lazy_type_ is not None 242 ), f"Attempted to access lazy_type for invalid argument {self.name}" 243 return self.lazy_type_ 244 245 246class LazyIrProperties: 247 """Collection of properties for an IR node 248 249 The property groups are listed below. Each group is mutually 250 exclusive, meaning that only one property from each group can be True 251 at any one time. The properties can be accessed as if they were normal 252 attributes. The mutual exclusivity is automatically handled. 253 """ 254 255 Properties: tuple[tuple[str, ...], ...] = ( 256 ( 257 "ShapePrecompute", # Assume shape has been precomputed 258 "ShapeCompute", # Need to compute the shape on construction 259 "ShapeCache", # Utilize the shape cache to defer computation 260 ), 261 ( 262 "Lower", # Codegen full lower function 263 "LowerDeclOnly", # Codegen only lower function declaration 264 ), 265 ( 266 "CanBeReused", # Codegen full reuse function 267 "CanBeReusedDeclOnly", # Codegen only reuse function declaration 268 ), 269 ( 270 "CreateFn", # Codegen full create function 271 "CreateFnDeclOnly", # Codegen only create function declaration 272 ), 273 ( 274 "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values 275 ), 276 ) 277 278 def __init__(self, *default_properties: str) -> None: 279 properties: dict[tuple[str, ...], str | None] = dict.fromkeys( 280 LazyIrProperties.Properties 281 ) 282 self.__dict__["properties"] = properties 283 for p in default_properties: 284 setattr(self, p, True) 285 286 def __getattr__(self, key: str) -> Any: 287 properties = self.__dict__["properties"] 288 for values in LazyIrProperties.Properties: 289 if key in values: 290 return properties[values] == key 291 292 return self.__getattribute__(key) 293 294 def __setattr__(self, key: str, value: Any) -> Any: 295 properties = self.__dict__["properties"] 296 for values in LazyIrProperties.Properties: 297 if key in values: 298 properties[values] = key if value else None 299 return value 300 301 raise KeyError(f"Invalid property: {key}") 302 303 304# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node. 305# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML), 306# but carries type information from a native FunctionSchema modified for use with IR nodes, 307# and preserving original argument names. 308# 309# TODO: This is not idiomatic with how other torchgen APIs transform on schema. 310class LazyIrSchema: 311 # The name of the operator this function schema describes. 312 name: OperatorName 313 314 positional_args: tuple[LazyArgument, ...] 315 keyword_args: tuple[LazyArgument, ...] 316 317 # TODO: Need to handle collisions with argument names at some point 318 returns: tuple[Return, ...] 319 320 # if this schema has a Generator arg, list its orig ctype/name but don't 321 # build a LazyArgument since lazy IR doesn't support it 322 generator_arg: NamedCType | None = None 323 324 # original function schema 325 func: FunctionSchema 326 327 # Whether or not we are code-genning for SymInt or not 328 symint: bool 329 330 properties: LazyIrProperties = LazyIrProperties( 331 # default properties 332 "ShapePrecompute", 333 "Lower", 334 "CanBeReused", 335 ) 336 opkind: str | None = None 337 338 def __init__( 339 self, 340 func: FunctionSchema, 341 properties: LazyIrProperties | None = None, 342 *, 343 symint: bool, 344 ) -> None: 345 if properties: 346 self.properties = properties 347 348 self.func = func 349 self.symint = symint 350 positional_args: list[LazyArgument] = [] 351 for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: 352 if arg_field == "self_arg" and func.arguments.self_arg is not None: 353 arg = func.arguments.self_arg.argument 354 positional_args.append( 355 LazyArgument(arg, self.properties, symint=symint) 356 ) 357 elif getattr(func.arguments, arg_field) is not None: 358 positional_args.extend( 359 LazyArgument(arg, self.properties, symint=symint) 360 for arg in getattr(func.arguments, arg_field) 361 ) 362 self.positional_args = tuple(positional_args) 363 364 keyword_args: list[LazyArgument] = [] 365 for arg_field in [ 366 "pre_tensor_options_kwarg_only", 367 "tensor_options", 368 "post_tensor_options_kwarg_only", 369 "out", 370 ]: 371 curr_args = getattr(func.arguments, arg_field) 372 if curr_args is not None: 373 if isinstance(curr_args, TensorOptionsArguments): 374 curr_args = curr_args.all() 375 for arg in curr_args: 376 if isGeneratorType(arg.type): 377 assert ( 378 self.generator_arg is None 379 ), "We expect there is only one generator arg" 380 self.generator_arg = NamedCType( 381 arg.name, arg.type # type:ignore[arg-type] 382 ) 383 keyword_args.extend( 384 LazyArgument(arg, self.properties, symint=symint) 385 for arg in curr_args 386 ) 387 self.keyword_args = tuple(keyword_args) 388 self.name = func.name 389 self.returns = func.returns 390 391 @property 392 def node_name(self) -> str: 393 """ 394 Return camel-case version of op in node. 395 396 Note: This function also appends any `overload_name` in the operation. 397 For example, if the op is `bitwise_and.Tensor`, the returned name 398 will be `BitwiseAndTensor`. 399 """ 400 op_name = f"{self.name.name}_{self.name.overload_name}".lower() 401 return "".join(word.capitalize() or "" for word in op_name.split("_")) 402 403 @property 404 def aten_name(self) -> str: 405 return str(self.name.name) 406 407 @property 408 def base_name(self) -> str: 409 return f"{self.name.name.base}" 410 411 def filtered_args( 412 self, 413 positional: bool = True, 414 keyword: bool = True, 415 values: bool = True, 416 scalars: bool = True, 417 generator: bool = True, 418 ) -> list[LazyArgument]: 419 # This function maintains the sorted order of arguments but provides different filtered views. 420 # Some parts of the code care about kwargs vs args (TS lowerings), 421 # other parts care about whether they need to wrap the arg in a lazy value or leave it alone. 422 # Generators are special cased, as they are needed for fallback/shape-inference but not supported 423 # in TS lowerings and therefore also omitted from lazy IR. 424 args: list[LazyArgument] = [] 425 if positional: 426 args.extend(self.positional_args) 427 if keyword: 428 args.extend(self.keyword_args) 429 430 if values and scalars and generator: 431 return args 432 elif values and scalars: 433 return [a for a in args if not a.is_generator] 434 elif values: 435 return [a for a in args if a.is_lazy_value] 436 elif scalars: 437 return [ 438 a 439 for a in args 440 if not a.is_lazy_value and (generator or not a.is_generator) 441 ] 442 443 return [] 444 445 @property 446 def positional_values(self) -> list[LazyArgument]: 447 return self.filtered_args( 448 positional=True, keyword=False, values=True, scalars=False 449 ) 450 451 @property 452 def positional_scalars(self) -> list[LazyArgument]: 453 return self.filtered_args( 454 positional=True, keyword=False, values=False, scalars=True 455 ) 456 457 @property 458 def keyword_values(self) -> list[LazyArgument]: 459 return self.filtered_args( 460 positional=False, keyword=True, values=True, scalars=False 461 ) 462 463 @property 464 def keyword_scalars(self) -> list[LazyArgument]: 465 return self.filtered_args( 466 positional=False, keyword=True, values=False, scalars=True 467 ) 468