1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import dispatcher 4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 5*da0073e9SAndroid Build Coastguard Worker BaseCppType, 6*da0073e9SAndroid Build Coastguard Worker BaseCType, 7*da0073e9SAndroid Build Coastguard Worker Binding, 8*da0073e9SAndroid Build Coastguard Worker boolT, 9*da0073e9SAndroid Build Coastguard Worker ConstRefCType, 10*da0073e9SAndroid Build Coastguard Worker CType, 11*da0073e9SAndroid Build Coastguard Worker longT, 12*da0073e9SAndroid Build Coastguard Worker NamedCType, 13*da0073e9SAndroid Build Coastguard Worker tensorT, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 16*da0073e9SAndroid Build Coastguard Worker Argument, 17*da0073e9SAndroid Build Coastguard Worker BaseTy, 18*da0073e9SAndroid Build Coastguard Worker BaseType, 19*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 20*da0073e9SAndroid Build Coastguard Worker NativeFunction, 21*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to API's used 26*da0073e9SAndroid Build Coastguard Worker# when creating view lambdas that are used by the functionalization pass. 27*da0073e9SAndroid Build Coastguard Worker# There are two types of lambdas: forward lambdas and reverse lambdas. 28*da0073e9SAndroid Build Coastguard Worker# These API's mostly follow the dispatcher API, with a few quirks: 29*da0073e9SAndroid Build Coastguard Worker# - The lambda capture has to convert reference types to value types 30*da0073e9SAndroid Build Coastguard Worker# - While the forward lambda just directly calls into the at::_ops API 31*da0073e9SAndroid Build Coastguard Worker# (following the dispatcher convention), the logic here for the reverse lambda 32*da0073e9SAndroid Build Coastguard Worker# is responsible for generating both the call-site, and the declarations 33*da0073e9SAndroid Build Coastguard Worker# (which are implemented manually in the at::functionalization::impl namespace). 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker# The lambdas generated for each view op in the functionalization pass are of the form 36*da0073e9SAndroid Build Coastguard Worker# [capture_arguments](outer_arguments) -> returns_type { 37*da0073e9SAndroid Build Coastguard Worker# return name(inner_arguments); 38*da0073e9SAndroid Build Coastguard Worker# } 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker# Define some specific lambda input arguments. 41*da0073e9SAndroid Build Coastguard Workerbase_binding = Binding( 42*da0073e9SAndroid Build Coastguard Worker name="base", 43*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))), 44*da0073e9SAndroid Build Coastguard Worker argument=Argument( 45*da0073e9SAndroid Build Coastguard Worker name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 46*da0073e9SAndroid Build Coastguard Worker ), 47*da0073e9SAndroid Build Coastguard Worker default=None, 48*da0073e9SAndroid Build Coastguard Worker) 49*da0073e9SAndroid Build Coastguard Workermutated_view_binding = Binding( 50*da0073e9SAndroid Build Coastguard Worker name="mutated_view", 51*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))), 52*da0073e9SAndroid Build Coastguard Worker argument=Argument( 53*da0073e9SAndroid Build Coastguard Worker name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 54*da0073e9SAndroid Build Coastguard Worker ), 55*da0073e9SAndroid Build Coastguard Worker default=None, 56*da0073e9SAndroid Build Coastguard Worker) 57*da0073e9SAndroid Build Coastguard Workermutated_view_idx_binding = Binding( 58*da0073e9SAndroid Build Coastguard Worker name="mutated_view_idx", 59*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)), 60*da0073e9SAndroid Build Coastguard Worker argument=Argument( 61*da0073e9SAndroid Build Coastguard Worker name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None 62*da0073e9SAndroid Build Coastguard Worker ), 63*da0073e9SAndroid Build Coastguard Worker default=None, 64*da0073e9SAndroid Build Coastguard Worker) 65*da0073e9SAndroid Build Coastguard Workerreapply_views_binding = Binding( 66*da0073e9SAndroid Build Coastguard Worker name="reapply_views", 67*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)), 68*da0073e9SAndroid Build Coastguard Worker argument=Argument( 69*da0073e9SAndroid Build Coastguard Worker name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None 70*da0073e9SAndroid Build Coastguard Worker ), 71*da0073e9SAndroid Build Coastguard Worker default=None, 72*da0073e9SAndroid Build Coastguard Worker) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard WorkerInverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") 75*da0073e9SAndroid Build Coastguard Workerinverse_return_mode_binding = Binding( 76*da0073e9SAndroid Build Coastguard Worker name="inverse_return_mode", 77*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), 78*da0073e9SAndroid Build Coastguard Worker argument=Argument( 79*da0073e9SAndroid Build Coastguard Worker name="inverse_return_mode", 80*da0073e9SAndroid Build Coastguard Worker # NB: not actually a bool but it doesn't matter because this isn't used 81*da0073e9SAndroid Build Coastguard Worker type=BaseType(BaseTy.bool), 82*da0073e9SAndroid Build Coastguard Worker default=None, 83*da0073e9SAndroid Build Coastguard Worker annotation=None, 84*da0073e9SAndroid Build Coastguard Worker ), 85*da0073e9SAndroid Build Coastguard Worker default=None, 86*da0073e9SAndroid Build Coastguard Worker) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker# The lambda capture itself doesn't have a name. 90*da0073e9SAndroid Build Coastguard Worker# The name returned here corresponds to the name of the inner function called by the lambda. 91*da0073e9SAndroid Build Coastguard Workerdef name( 92*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsViewGroup, 93*da0073e9SAndroid Build Coastguard Worker *, 94*da0073e9SAndroid Build Coastguard Worker is_reverse: bool, 95*da0073e9SAndroid Build Coastguard Worker include_namespace: bool, 96*da0073e9SAndroid Build Coastguard Worker reapply_views: bool | None = None, 97*da0073e9SAndroid Build Coastguard Worker) -> str: 98*da0073e9SAndroid Build Coastguard Worker if reapply_views is None: 99*da0073e9SAndroid Build Coastguard Worker # reapply_views is only important for the fwd lambda, 100*da0073e9SAndroid Build Coastguard Worker # since we always plumb the runtime "reapply_views" argument into the reverse function. 101*da0073e9SAndroid Build Coastguard Worker assert is_reverse 102*da0073e9SAndroid Build Coastguard Worker if is_reverse: 103*da0073e9SAndroid Build Coastguard Worker return reverse_name(g.view, include_namespace) 104*da0073e9SAndroid Build Coastguard Worker # in the forward case, we just directly call into the at::_ops API (so we always need the namespace) 105*da0073e9SAndroid Build Coastguard Worker assert include_namespace 106*da0073e9SAndroid Build Coastguard Worker assert g.view_copy is not None 107*da0073e9SAndroid Build Coastguard Worker api_name = ( 108*da0073e9SAndroid Build Coastguard Worker g.view.func.name.unambiguous_name() 109*da0073e9SAndroid Build Coastguard Worker if reapply_views 110*da0073e9SAndroid Build Coastguard Worker else g.view_copy.func.name.unambiguous_name() 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker return f"at::_ops::{api_name}::call" 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Workerdef reverse_name(f: NativeFunction, include_namespace: bool) -> str: 116*da0073e9SAndroid Build Coastguard Worker # for the reverse: we plumb the "reapply_views" flag into that function and support 117*da0073e9SAndroid Build Coastguard Worker # both copy and non-copy variants. (We could avoid doing that, but that would require 118*da0073e9SAndroid Build Coastguard Worker # writing out twice as many view inverse functions). 119*da0073e9SAndroid Build Coastguard Worker api_name = f.func.name.unambiguous_name() 120*da0073e9SAndroid Build Coastguard Worker # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) 121*da0073e9SAndroid Build Coastguard Worker if include_namespace: 122*da0073e9SAndroid Build Coastguard Worker return f"at::functionalization::FunctionalInverses::{api_name}_inverse" 123*da0073e9SAndroid Build Coastguard Worker else: 124*da0073e9SAndroid Build Coastguard Worker return f"{api_name}_inverse" 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Workerdef capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]: 128*da0073e9SAndroid Build Coastguard Worker # capture arguments include all arguments except `self`. 129*da0073e9SAndroid Build Coastguard Worker # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), 130*da0073e9SAndroid Build Coastguard Worker # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) 131*da0073e9SAndroid Build Coastguard Worker args = func.arguments.flat_all 132*da0073e9SAndroid Build Coastguard Worker assert args[0].type == BaseType(BaseTy.Tensor) 133*da0073e9SAndroid Build Coastguard Worker non_self_args = args[1:] 134*da0073e9SAndroid Build Coastguard Worker non_self_value_bindings = [ 135*da0073e9SAndroid Build Coastguard Worker dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args 136*da0073e9SAndroid Build Coastguard Worker ] 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker all_bindings = [ 139*da0073e9SAndroid Build Coastguard Worker inverse_return_mode_binding if is_reverse else reapply_views_binding 140*da0073e9SAndroid Build Coastguard Worker ] 141*da0073e9SAndroid Build Coastguard Worker all_bindings.extend(non_self_value_bindings) 142*da0073e9SAndroid Build Coastguard Worker return all_bindings 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Workerdef returns_type(func: FunctionSchema) -> CType: 146*da0073e9SAndroid Build Coastguard Worker # Assertion: all view ops return tensor-like outputs 147*da0073e9SAndroid Build Coastguard Worker assert len(func.returns) >= 1 148*da0073e9SAndroid Build Coastguard Worker for ret in func.returns: 149*da0073e9SAndroid Build Coastguard Worker assert ret.type.is_tensor_like() 150*da0073e9SAndroid Build Coastguard Worker # However, the return type of the lambda is always an individual tensor. 151*da0073e9SAndroid Build Coastguard Worker # For multi-tensor outputs, each tensor needs to be tracked individually. 152*da0073e9SAndroid Build Coastguard Worker return BaseCType(tensorT) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef outer_arguments(*, is_reverse: bool) -> list[Binding]: 156*da0073e9SAndroid Build Coastguard Worker if is_reverse: 157*da0073e9SAndroid Build Coastguard Worker return [base_binding, mutated_view_binding, mutated_view_idx_binding] 158*da0073e9SAndroid Build Coastguard Worker else: 159*da0073e9SAndroid Build Coastguard Worker return [base_binding, mutated_view_idx_binding] 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Workerdef inner_call_index(func: FunctionSchema) -> Binding | None: 163*da0073e9SAndroid Build Coastguard Worker # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. 164*da0073e9SAndroid Build Coastguard Worker # When we replay a view op that returns multiple tensors, we need to index into the output appropriately 165*da0073e9SAndroid Build Coastguard Worker if len(func.returns) > 1 or ( 166*da0073e9SAndroid Build Coastguard Worker len(func.returns) == 1 and func.returns[0].type.is_list_like() 167*da0073e9SAndroid Build Coastguard Worker ): 168*da0073e9SAndroid Build Coastguard Worker return mutated_view_idx_binding 169*da0073e9SAndroid Build Coastguard Worker return None 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Workerdef inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]: 173*da0073e9SAndroid Build Coastguard Worker args = func.arguments.flat_all 174*da0073e9SAndroid Build Coastguard Worker assert args[0].type == BaseType(BaseTy.Tensor) 175*da0073e9SAndroid Build Coastguard Worker non_self_args = args[1:] 176*da0073e9SAndroid Build Coastguard Worker # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. 177*da0073e9SAndroid Build Coastguard Worker # Both of these follow the dispatcher API. 178*da0073e9SAndroid Build Coastguard Worker non_self_bindings = [dispatcher.argument(a) for a in non_self_args] 179*da0073e9SAndroid Build Coastguard Worker if not is_reverse: 180*da0073e9SAndroid Build Coastguard Worker # the forward lambda swaps out the original tensor argument with the lambd arg "base" 181*da0073e9SAndroid Build Coastguard Worker return [base_binding] + non_self_bindings 182*da0073e9SAndroid Build Coastguard Worker else: 183*da0073e9SAndroid Build Coastguard Worker # the reverse lambda does the same, but with an additional "mutated_view" arg 184*da0073e9SAndroid Build Coastguard Worker # additionally, we have a calling convention: for view ops that return multiple tensor outputs 185*da0073e9SAndroid Build Coastguard Worker # their corresponding view_inverse function takes in an additional index argument. 186*da0073e9SAndroid Build Coastguard Worker index_binding = inner_call_index(func) 187*da0073e9SAndroid Build Coastguard Worker if index_binding is not None: 188*da0073e9SAndroid Build Coastguard Worker return [ 189*da0073e9SAndroid Build Coastguard Worker base_binding, 190*da0073e9SAndroid Build Coastguard Worker mutated_view_binding, 191*da0073e9SAndroid Build Coastguard Worker inverse_return_mode_binding, 192*da0073e9SAndroid Build Coastguard Worker index_binding, 193*da0073e9SAndroid Build Coastguard Worker ] + non_self_bindings 194*da0073e9SAndroid Build Coastguard Worker else: 195*da0073e9SAndroid Build Coastguard Worker return [ 196*da0073e9SAndroid Build Coastguard Worker base_binding, 197*da0073e9SAndroid Build Coastguard Worker mutated_view_binding, 198*da0073e9SAndroid Build Coastguard Worker inverse_return_mode_binding, 199*da0073e9SAndroid Build Coastguard Worker ] + non_self_bindings 200