xref: /aosp_15_r20/external/pytorch/torchgen/api/functionalization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import dispatcher
4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
5*da0073e9SAndroid Build Coastguard Worker    BaseCppType,
6*da0073e9SAndroid Build Coastguard Worker    BaseCType,
7*da0073e9SAndroid Build Coastguard Worker    Binding,
8*da0073e9SAndroid Build Coastguard Worker    boolT,
9*da0073e9SAndroid Build Coastguard Worker    ConstRefCType,
10*da0073e9SAndroid Build Coastguard Worker    CType,
11*da0073e9SAndroid Build Coastguard Worker    longT,
12*da0073e9SAndroid Build Coastguard Worker    NamedCType,
13*da0073e9SAndroid Build Coastguard Worker    tensorT,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
16*da0073e9SAndroid Build Coastguard Worker    Argument,
17*da0073e9SAndroid Build Coastguard Worker    BaseTy,
18*da0073e9SAndroid Build Coastguard Worker    BaseType,
19*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
20*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
21*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to API's used
26*da0073e9SAndroid Build Coastguard Worker# when creating view lambdas that are used by the functionalization pass.
27*da0073e9SAndroid Build Coastguard Worker# There are two types of lambdas: forward lambdas and reverse lambdas.
28*da0073e9SAndroid Build Coastguard Worker# These API's mostly follow the dispatcher API, with a few quirks:
29*da0073e9SAndroid Build Coastguard Worker# - The lambda capture has to convert reference types to value types
30*da0073e9SAndroid Build Coastguard Worker# - While the forward lambda just directly calls into the at::_ops API
31*da0073e9SAndroid Build Coastguard Worker#   (following the dispatcher convention), the logic here for the reverse lambda
32*da0073e9SAndroid Build Coastguard Worker#   is responsible for generating both the call-site, and the declarations
33*da0073e9SAndroid Build Coastguard Worker#   (which are implemented manually in the at::functionalization::impl namespace).
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker# The lambdas generated for each view op in the functionalization pass are of the form
36*da0073e9SAndroid Build Coastguard Worker# [capture_arguments](outer_arguments) -> returns_type {
37*da0073e9SAndroid Build Coastguard Worker#     return name(inner_arguments);
38*da0073e9SAndroid Build Coastguard Worker# }
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker# Define some specific lambda input arguments.
41*da0073e9SAndroid Build Coastguard Workerbase_binding = Binding(
42*da0073e9SAndroid Build Coastguard Worker    name="base",
43*da0073e9SAndroid Build Coastguard Worker    nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
44*da0073e9SAndroid Build Coastguard Worker    argument=Argument(
45*da0073e9SAndroid Build Coastguard Worker        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
46*da0073e9SAndroid Build Coastguard Worker    ),
47*da0073e9SAndroid Build Coastguard Worker    default=None,
48*da0073e9SAndroid Build Coastguard Worker)
49*da0073e9SAndroid Build Coastguard Workermutated_view_binding = Binding(
50*da0073e9SAndroid Build Coastguard Worker    name="mutated_view",
51*da0073e9SAndroid Build Coastguard Worker    nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
52*da0073e9SAndroid Build Coastguard Worker    argument=Argument(
53*da0073e9SAndroid Build Coastguard Worker        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
54*da0073e9SAndroid Build Coastguard Worker    ),
55*da0073e9SAndroid Build Coastguard Worker    default=None,
56*da0073e9SAndroid Build Coastguard Worker)
57*da0073e9SAndroid Build Coastguard Workermutated_view_idx_binding = Binding(
58*da0073e9SAndroid Build Coastguard Worker    name="mutated_view_idx",
59*da0073e9SAndroid Build Coastguard Worker    nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
60*da0073e9SAndroid Build Coastguard Worker    argument=Argument(
61*da0073e9SAndroid Build Coastguard Worker        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
62*da0073e9SAndroid Build Coastguard Worker    ),
63*da0073e9SAndroid Build Coastguard Worker    default=None,
64*da0073e9SAndroid Build Coastguard Worker)
65*da0073e9SAndroid Build Coastguard Workerreapply_views_binding = Binding(
66*da0073e9SAndroid Build Coastguard Worker    name="reapply_views",
67*da0073e9SAndroid Build Coastguard Worker    nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
68*da0073e9SAndroid Build Coastguard Worker    argument=Argument(
69*da0073e9SAndroid Build Coastguard Worker        name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
70*da0073e9SAndroid Build Coastguard Worker    ),
71*da0073e9SAndroid Build Coastguard Worker    default=None,
72*da0073e9SAndroid Build Coastguard Worker)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard WorkerInverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
75*da0073e9SAndroid Build Coastguard Workerinverse_return_mode_binding = Binding(
76*da0073e9SAndroid Build Coastguard Worker    name="inverse_return_mode",
77*da0073e9SAndroid Build Coastguard Worker    nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
78*da0073e9SAndroid Build Coastguard Worker    argument=Argument(
79*da0073e9SAndroid Build Coastguard Worker        name="inverse_return_mode",
80*da0073e9SAndroid Build Coastguard Worker        # NB: not actually a bool but it doesn't matter because this isn't used
81*da0073e9SAndroid Build Coastguard Worker        type=BaseType(BaseTy.bool),
82*da0073e9SAndroid Build Coastguard Worker        default=None,
83*da0073e9SAndroid Build Coastguard Worker        annotation=None,
84*da0073e9SAndroid Build Coastguard Worker    ),
85*da0073e9SAndroid Build Coastguard Worker    default=None,
86*da0073e9SAndroid Build Coastguard Worker)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker# The lambda capture itself doesn't have a name.
90*da0073e9SAndroid Build Coastguard Worker# The name returned here corresponds to the name of the inner function called by the lambda.
91*da0073e9SAndroid Build Coastguard Workerdef name(
92*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsViewGroup,
93*da0073e9SAndroid Build Coastguard Worker    *,
94*da0073e9SAndroid Build Coastguard Worker    is_reverse: bool,
95*da0073e9SAndroid Build Coastguard Worker    include_namespace: bool,
96*da0073e9SAndroid Build Coastguard Worker    reapply_views: bool | None = None,
97*da0073e9SAndroid Build Coastguard Worker) -> str:
98*da0073e9SAndroid Build Coastguard Worker    if reapply_views is None:
99*da0073e9SAndroid Build Coastguard Worker        # reapply_views is only important for the fwd lambda,
100*da0073e9SAndroid Build Coastguard Worker        # since we always plumb the runtime "reapply_views" argument into the reverse function.
101*da0073e9SAndroid Build Coastguard Worker        assert is_reverse
102*da0073e9SAndroid Build Coastguard Worker    if is_reverse:
103*da0073e9SAndroid Build Coastguard Worker        return reverse_name(g.view, include_namespace)
104*da0073e9SAndroid Build Coastguard Worker    # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
105*da0073e9SAndroid Build Coastguard Worker    assert include_namespace
106*da0073e9SAndroid Build Coastguard Worker    assert g.view_copy is not None
107*da0073e9SAndroid Build Coastguard Worker    api_name = (
108*da0073e9SAndroid Build Coastguard Worker        g.view.func.name.unambiguous_name()
109*da0073e9SAndroid Build Coastguard Worker        if reapply_views
110*da0073e9SAndroid Build Coastguard Worker        else g.view_copy.func.name.unambiguous_name()
111*da0073e9SAndroid Build Coastguard Worker    )
112*da0073e9SAndroid Build Coastguard Worker    return f"at::_ops::{api_name}::call"
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Workerdef reverse_name(f: NativeFunction, include_namespace: bool) -> str:
116*da0073e9SAndroid Build Coastguard Worker    # for the reverse: we plumb the "reapply_views" flag into that function and support
117*da0073e9SAndroid Build Coastguard Worker    # both copy and non-copy variants. (We could avoid doing that, but that would require
118*da0073e9SAndroid Build Coastguard Worker    # writing out twice as many view inverse functions).
119*da0073e9SAndroid Build Coastguard Worker    api_name = f.func.name.unambiguous_name()
120*da0073e9SAndroid Build Coastguard Worker    # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
121*da0073e9SAndroid Build Coastguard Worker    if include_namespace:
122*da0073e9SAndroid Build Coastguard Worker        return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
123*da0073e9SAndroid Build Coastguard Worker    else:
124*da0073e9SAndroid Build Coastguard Worker        return f"{api_name}_inverse"
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Workerdef capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
128*da0073e9SAndroid Build Coastguard Worker    # capture arguments include all arguments except `self`.
129*da0073e9SAndroid Build Coastguard Worker    # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
130*da0073e9SAndroid Build Coastguard Worker    # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
131*da0073e9SAndroid Build Coastguard Worker    args = func.arguments.flat_all
132*da0073e9SAndroid Build Coastguard Worker    assert args[0].type == BaseType(BaseTy.Tensor)
133*da0073e9SAndroid Build Coastguard Worker    non_self_args = args[1:]
134*da0073e9SAndroid Build Coastguard Worker    non_self_value_bindings = [
135*da0073e9SAndroid Build Coastguard Worker        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
136*da0073e9SAndroid Build Coastguard Worker    ]
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    all_bindings = [
139*da0073e9SAndroid Build Coastguard Worker        inverse_return_mode_binding if is_reverse else reapply_views_binding
140*da0073e9SAndroid Build Coastguard Worker    ]
141*da0073e9SAndroid Build Coastguard Worker    all_bindings.extend(non_self_value_bindings)
142*da0073e9SAndroid Build Coastguard Worker    return all_bindings
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Workerdef returns_type(func: FunctionSchema) -> CType:
146*da0073e9SAndroid Build Coastguard Worker    # Assertion: all view ops return tensor-like outputs
147*da0073e9SAndroid Build Coastguard Worker    assert len(func.returns) >= 1
148*da0073e9SAndroid Build Coastguard Worker    for ret in func.returns:
149*da0073e9SAndroid Build Coastguard Worker        assert ret.type.is_tensor_like()
150*da0073e9SAndroid Build Coastguard Worker    # However, the return type of the lambda is always an individual tensor.
151*da0073e9SAndroid Build Coastguard Worker    # For multi-tensor outputs, each tensor needs to be tracked individually.
152*da0073e9SAndroid Build Coastguard Worker    return BaseCType(tensorT)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Workerdef outer_arguments(*, is_reverse: bool) -> list[Binding]:
156*da0073e9SAndroid Build Coastguard Worker    if is_reverse:
157*da0073e9SAndroid Build Coastguard Worker        return [base_binding, mutated_view_binding, mutated_view_idx_binding]
158*da0073e9SAndroid Build Coastguard Worker    else:
159*da0073e9SAndroid Build Coastguard Worker        return [base_binding, mutated_view_idx_binding]
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Workerdef inner_call_index(func: FunctionSchema) -> Binding | None:
163*da0073e9SAndroid Build Coastguard Worker    # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
164*da0073e9SAndroid Build Coastguard Worker    # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
165*da0073e9SAndroid Build Coastguard Worker    if len(func.returns) > 1 or (
166*da0073e9SAndroid Build Coastguard Worker        len(func.returns) == 1 and func.returns[0].type.is_list_like()
167*da0073e9SAndroid Build Coastguard Worker    ):
168*da0073e9SAndroid Build Coastguard Worker        return mutated_view_idx_binding
169*da0073e9SAndroid Build Coastguard Worker    return None
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Workerdef inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
173*da0073e9SAndroid Build Coastguard Worker    args = func.arguments.flat_all
174*da0073e9SAndroid Build Coastguard Worker    assert args[0].type == BaseType(BaseTy.Tensor)
175*da0073e9SAndroid Build Coastguard Worker    non_self_args = args[1:]
176*da0073e9SAndroid Build Coastguard Worker    # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
177*da0073e9SAndroid Build Coastguard Worker    # Both of these follow the dispatcher API.
178*da0073e9SAndroid Build Coastguard Worker    non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
179*da0073e9SAndroid Build Coastguard Worker    if not is_reverse:
180*da0073e9SAndroid Build Coastguard Worker        # the forward lambda swaps out the original tensor argument with the lambd arg "base"
181*da0073e9SAndroid Build Coastguard Worker        return [base_binding] + non_self_bindings
182*da0073e9SAndroid Build Coastguard Worker    else:
183*da0073e9SAndroid Build Coastguard Worker        # the reverse lambda does the same, but with an additional "mutated_view" arg
184*da0073e9SAndroid Build Coastguard Worker        # additionally, we have a calling convention: for view ops that return multiple tensor outputs
185*da0073e9SAndroid Build Coastguard Worker        # their corresponding view_inverse function takes in an additional index argument.
186*da0073e9SAndroid Build Coastguard Worker        index_binding = inner_call_index(func)
187*da0073e9SAndroid Build Coastguard Worker        if index_binding is not None:
188*da0073e9SAndroid Build Coastguard Worker            return [
189*da0073e9SAndroid Build Coastguard Worker                base_binding,
190*da0073e9SAndroid Build Coastguard Worker                mutated_view_binding,
191*da0073e9SAndroid Build Coastguard Worker                inverse_return_mode_binding,
192*da0073e9SAndroid Build Coastguard Worker                index_binding,
193*da0073e9SAndroid Build Coastguard Worker            ] + non_self_bindings
194*da0073e9SAndroid Build Coastguard Worker        else:
195*da0073e9SAndroid Build Coastguard Worker            return [
196*da0073e9SAndroid Build Coastguard Worker                base_binding,
197*da0073e9SAndroid Build Coastguard Worker                mutated_view_binding,
198*da0073e9SAndroid Build Coastguard Worker                inverse_return_mode_binding,
199*da0073e9SAndroid Build Coastguard Worker            ] + non_self_bindings
200