xref: /aosp_15_r20/external/pytorch/torchgen/api/lazy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Any
4
5from torchgen.api.types import (
6    BaseCppType,
7    BaseCType,
8    boolT,
9    CType,
10    deviceT,
11    doubleT,
12    generatorT,
13    layoutT,
14    ListCType,
15    longT,
16    memoryFormatT,
17    NamedCType,
18    OptionalCType,
19    scalarT,
20    scalarTypeT,
21    stringT,
22    SymIntT,
23    VectorCType,
24)
25from torchgen.model import (
26    Argument,
27    BaseTy,
28    BaseType,
29    FunctionSchema,
30    ListType,
31    OperatorName,
32    OptionalType,
33    Return,
34    TensorOptionsArguments,
35    Type,
36)
37
38
39_valueT: BaseCppType | None = None
40
41
42# A ValueT is an IR type which represents the computation of a Tensor.  In other
43# words, a PyTorch user will do operations on lazy tensors, and each output lazy
44# tensor internally tracks a ValueT representing the IR node that would have
45# actually produced the value of this tensor for real.
46#
47# This is configurable because different lazy tensor backends (LTC vs XLA) will
48# have different IR representations.  (Though, arguably, after unification they
49# shouldn't!)
50def getValueT() -> BaseCppType:
51    global _valueT
52    if not _valueT:
53        raise NotImplementedError(
54            "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
55        )
56
57    return _valueT
58
59
60def setValueT(val: BaseCppType) -> None:
61    global _valueT
62    _valueT = val
63
64
65# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
66# making it easier to represent special properties of an arg.
67tensorListValueT = BaseCppType("torch::lazy", "Value")
68
69
70def process_ir_type(
71    typ: Type, properties: LazyIrProperties, *, symint: bool
72) -> BaseCType | VectorCType | OptionalCType | ListCType:
73    """
74    This function takes a type from NativeFunctions and converts it for use with
75    lazy tensor codegen.
76
77    Type conversion for lazy currently consists of
78     (1) changing at::Tensors into lazy::Values
79     (2) wrapping everything in a BaseCType
80     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
81
82    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
83    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
84
85    This is incomplete- there are assertions in places that it's expected to need to add
86    more types as the codegen is used with more operators.
87    """
88    if isinstance(typ, BaseType):
89        if typ.name == BaseTy.Tensor:
90            return BaseCType(getValueT())
91        elif typ.name == BaseTy.Scalar:
92            if properties.TreatScalarsAsConstants:
93                return BaseCType(scalarT)
94            # at::scalar has special handling,
95            # and is wrapped in an lazy::Value just like at::tensor
96            return BaseCType(getValueT())
97        elif typ.name == BaseTy.ScalarType:
98            return BaseCType(scalarTypeT)
99        elif typ.name == BaseTy.int:
100            return BaseCType(longT)
101        elif typ.name == BaseTy.SymInt:
102            if symint:
103                return BaseCType(getValueT())
104            else:
105                return BaseCType(longT)
106        elif typ.name == BaseTy.bool:
107            return BaseCType(boolT)
108        elif typ.name == BaseTy.float:
109            return BaseCType(doubleT)
110        elif typ.name == BaseTy.str:
111            return BaseCType(stringT)
112        elif typ.name == BaseTy.Device:
113            return BaseCType(deviceT)
114        elif typ.name == BaseTy.Generator:
115            return BaseCType(generatorT)
116        elif typ.name == BaseTy.Layout:
117            return BaseCType(layoutT)
118        elif typ.name == BaseTy.MemoryFormat:
119            return BaseCType(memoryFormatT)
120        else:
121            raise AssertionError(f"TODO add support for type {repr(typ)}")
122    elif isinstance(typ, OptionalType):
123        return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
124    elif isinstance(typ, ListType):
125        if str(typ.elem) == "Tensor?":
126            # TODO(whc) is this actually correct? or should it use a Vector like above
127            return ListCType(OptionalCType(BaseCType(getValueT())))
128        elif str(typ.elem) == "Tensor":
129            # this is a TensorList which comes in from GetTensorList as a Value
130            return BaseCType(tensorListValueT)
131        elif typ.elem == BaseType(BaseTy.SymInt):
132            # TODO: return a value type.  The problem here is analogous to
133            # the problem with tensorListValueT: if you have SymInt[] you
134            # cannot conveniently save the list of Value directly, as nodes
135            # expect to save values as a vector for ALL arguments.  So you
136            # need a separate IR node that represents all of the size nodes
137            # assembled into a list.  I'm not an LTC dev so I don't want to
138            # figure it out right now.  Y'all figure it out...
139            return VectorCType(BaseCType(longT))
140
141        else:
142            return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
143    else:
144        raise AssertionError(f"unrecognized type {repr(typ)}")
145
146
147# TODO: Determining this based off of CType is bad; this should be computed
148# from Type directly; then the same logic as process_ir_type can be used
149#
150# Invariant: passed typ should be an *owning* CType (e.g., we will report
151# that ArrayRef<Value> is NOT a value type)
152def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
153    """
154    Given a type, determine if it is a Value-like type.  This is equivalent to
155    being Tensor-like, but assumes the type has already been transformed.
156    """
157    if isinstance(typ, BaseCType):
158        # I am regretting my naming conventions, but now we are wrapping at::scalar in
159        # lazy value, while preserving other 'scalar' types as scalars in the IR
160        treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
161        return (
162            typ.type == getValueT()
163            or (typ.type == scalarT and not treat_scalars_as_constants)
164            or typ.type == SymIntT
165        )
166    elif typ == VectorCType(BaseCType(SymIntT)):
167        # TODO: report True for this
168        return False
169    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
170        return isValueType(typ.elem, properties)
171    return False
172
173
174def isSymIntType(typ: Type) -> bool:
175    return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
176
177
178def isWrappedScalarType(typ: Type) -> bool:
179    """
180    Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
181    Since we literally change the type from scalarT to valueT, information is lost.
182    This function helps build a list of wrapped scalars to save that information
183    """
184    if isinstance(typ, BaseType):
185        # I am regretting my naming conventions, but now we are wrapping at::scalar in
186        # lazy value, while preserving other 'scalar' types as scalars in the IR
187        return typ.name == BaseTy.Scalar
188    elif isinstance(typ, (OptionalType, ListType)):
189        return isWrappedScalarType(typ.elem)
190    return False
191
192
193# TODO: dedupe with Type.is_generator_like
194def isGeneratorType(typ: Type) -> bool:
195    if isinstance(typ, BaseType):
196        return typ.name == BaseTy.Generator
197    elif isinstance(typ, (OptionalType)):
198        return isGeneratorType(typ.elem)
199    return False
200
201
202# This class caches a few derived properties computed from an Argument
203# and LazyIrProperties
204class LazyArgument:
205    name: str
206    orig_type: Type
207    lazy_type_: CType | None
208    is_wrapped_scalar: bool
209    is_generator: bool
210    # TODO: this is lies, it is false for symint list
211    is_symint_or_list: bool
212
213    # Whether or not we are treating this as symint or not
214    symint: bool
215
216    # true if this argument is or contains a lazy IR value
217    is_lazy_value: bool
218
219    def __init__(
220        self, arg: Argument, properties: LazyIrProperties, *, symint: bool
221    ) -> None:
222        self.name = arg.name
223        self.orig_type = arg.type
224        self.symint = symint
225        self.is_optional = isinstance(arg.type, OptionalType)
226        self.is_generator = isGeneratorType(arg.type)
227        self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
228        self.is_wrapped_scalar = isWrappedScalarType(arg.type)
229        self.is_symint_or_list = symint and (
230            isSymIntType(arg.type)
231            or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
232            # TODO: lists of symints are not currently treated as value types
233            # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
234        )
235
236        self.is_lazy_value = isValueType(self.lazy_type, properties)
237
238    @property
239    def lazy_type(self) -> CType:
240        assert (
241            self.lazy_type_ is not None
242        ), f"Attempted to access lazy_type for invalid argument {self.name}"
243        return self.lazy_type_
244
245
246class LazyIrProperties:
247    """Collection of properties for an IR node
248
249    The property groups are listed below. Each group is mutually
250    exclusive, meaning that only one property from each group can be True
251    at any one time. The properties can be accessed as if they were normal
252    attributes. The mutual exclusivity is automatically handled.
253    """
254
255    Properties: tuple[tuple[str, ...], ...] = (
256        (
257            "ShapePrecompute",  # Assume shape has been precomputed
258            "ShapeCompute",  # Need to compute the shape on construction
259            "ShapeCache",  # Utilize the shape cache to defer computation
260        ),
261        (
262            "Lower",  # Codegen full lower function
263            "LowerDeclOnly",  # Codegen only lower function declaration
264        ),
265        (
266            "CanBeReused",  # Codegen full reuse function
267            "CanBeReusedDeclOnly",  # Codegen only reuse function declaration
268        ),
269        (
270            "CreateFn",  # Codegen full create function
271            "CreateFnDeclOnly",  # Codegen only create function declaration
272        ),
273        (
274            "TreatScalarsAsConstants",  # Treat Scalars as constants instead of handling like values
275        ),
276    )
277
278    def __init__(self, *default_properties: str) -> None:
279        properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
280            LazyIrProperties.Properties
281        )
282        self.__dict__["properties"] = properties
283        for p in default_properties:
284            setattr(self, p, True)
285
286    def __getattr__(self, key: str) -> Any:
287        properties = self.__dict__["properties"]
288        for values in LazyIrProperties.Properties:
289            if key in values:
290                return properties[values] == key
291
292        return self.__getattribute__(key)
293
294    def __setattr__(self, key: str, value: Any) -> Any:
295        properties = self.__dict__["properties"]
296        for values in LazyIrProperties.Properties:
297            if key in values:
298                properties[values] = key if value else None
299                return value
300
301        raise KeyError(f"Invalid property: {key}")
302
303
304# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
305# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
306# but carries type information from a native FunctionSchema modified for use with IR nodes,
307# and preserving original argument names.
308#
309# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
310class LazyIrSchema:
311    # The name of the operator this function schema describes.
312    name: OperatorName
313
314    positional_args: tuple[LazyArgument, ...]
315    keyword_args: tuple[LazyArgument, ...]
316
317    # TODO: Need to handle collisions with argument names at some point
318    returns: tuple[Return, ...]
319
320    # if this schema has a Generator arg, list its orig ctype/name but don't
321    # build a LazyArgument since lazy IR doesn't support it
322    generator_arg: NamedCType | None = None
323
324    # original function schema
325    func: FunctionSchema
326
327    # Whether or not we are code-genning for SymInt or not
328    symint: bool
329
330    properties: LazyIrProperties = LazyIrProperties(
331        # default properties
332        "ShapePrecompute",
333        "Lower",
334        "CanBeReused",
335    )
336    opkind: str | None = None
337
338    def __init__(
339        self,
340        func: FunctionSchema,
341        properties: LazyIrProperties | None = None,
342        *,
343        symint: bool,
344    ) -> None:
345        if properties:
346            self.properties = properties
347
348        self.func = func
349        self.symint = symint
350        positional_args: list[LazyArgument] = []
351        for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
352            if arg_field == "self_arg" and func.arguments.self_arg is not None:
353                arg = func.arguments.self_arg.argument
354                positional_args.append(
355                    LazyArgument(arg, self.properties, symint=symint)
356                )
357            elif getattr(func.arguments, arg_field) is not None:
358                positional_args.extend(
359                    LazyArgument(arg, self.properties, symint=symint)
360                    for arg in getattr(func.arguments, arg_field)
361                )
362        self.positional_args = tuple(positional_args)
363
364        keyword_args: list[LazyArgument] = []
365        for arg_field in [
366            "pre_tensor_options_kwarg_only",
367            "tensor_options",
368            "post_tensor_options_kwarg_only",
369            "out",
370        ]:
371            curr_args = getattr(func.arguments, arg_field)
372            if curr_args is not None:
373                if isinstance(curr_args, TensorOptionsArguments):
374                    curr_args = curr_args.all()
375                for arg in curr_args:
376                    if isGeneratorType(arg.type):
377                        assert (
378                            self.generator_arg is None
379                        ), "We expect there is only one generator arg"
380                        self.generator_arg = NamedCType(
381                            arg.name, arg.type  # type:ignore[arg-type]
382                        )
383                keyword_args.extend(
384                    LazyArgument(arg, self.properties, symint=symint)
385                    for arg in curr_args
386                )
387        self.keyword_args = tuple(keyword_args)
388        self.name = func.name
389        self.returns = func.returns
390
391    @property
392    def node_name(self) -> str:
393        """
394        Return camel-case version of op in node.
395
396        Note: This function also appends any `overload_name` in the operation.
397        For example, if the op is `bitwise_and.Tensor`, the returned name
398        will be `BitwiseAndTensor`.
399        """
400        op_name = f"{self.name.name}_{self.name.overload_name}".lower()
401        return "".join(word.capitalize() or "" for word in op_name.split("_"))
402
403    @property
404    def aten_name(self) -> str:
405        return str(self.name.name)
406
407    @property
408    def base_name(self) -> str:
409        return f"{self.name.name.base}"
410
411    def filtered_args(
412        self,
413        positional: bool = True,
414        keyword: bool = True,
415        values: bool = True,
416        scalars: bool = True,
417        generator: bool = True,
418    ) -> list[LazyArgument]:
419        # This function maintains the sorted order of arguments but provides different filtered views.
420        # Some parts of the code care about kwargs vs args (TS lowerings),
421        # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
422        # Generators are special cased, as they are needed for fallback/shape-inference but not supported
423        # in TS lowerings and therefore also omitted from lazy IR.
424        args: list[LazyArgument] = []
425        if positional:
426            args.extend(self.positional_args)
427        if keyword:
428            args.extend(self.keyword_args)
429
430        if values and scalars and generator:
431            return args
432        elif values and scalars:
433            return [a for a in args if not a.is_generator]
434        elif values:
435            return [a for a in args if a.is_lazy_value]
436        elif scalars:
437            return [
438                a
439                for a in args
440                if not a.is_lazy_value and (generator or not a.is_generator)
441            ]
442
443        return []
444
445    @property
446    def positional_values(self) -> list[LazyArgument]:
447        return self.filtered_args(
448            positional=True, keyword=False, values=True, scalars=False
449        )
450
451    @property
452    def positional_scalars(self) -> list[LazyArgument]:
453        return self.filtered_args(
454            positional=True, keyword=False, values=False, scalars=True
455        )
456
457    @property
458    def keyword_values(self) -> list[LazyArgument]:
459        return self.filtered_args(
460            positional=False, keyword=True, values=True, scalars=False
461        )
462
463    @property
464    def keyword_scalars(self) -> list[LazyArgument]:
465        return self.filtered_args(
466            positional=False, keyword=True, values=False, scalars=True
467        )
468