1from __future__ import annotations 2 3from dataclasses import dataclass 4from typing import Callable, TYPE_CHECKING 5 6from torchgen.api import cpp, dispatcher 7from torchgen.api.translate import translate 8from torchgen.api.types import ( 9 BaseCType, 10 Binding, 11 CType, 12 DispatcherSignature, 13 FunctionalizationLambda, 14 iTensorListRefT, 15 NativeSignature, 16 OptionalCType, 17 optionalSymIntArrayRefT, 18 symIntArrayRefT, 19 SymIntT, 20 tensorListT, 21 tensorT, 22 VectorCType, 23 ViewInverseSignature, 24) 25from torchgen.context import ( 26 method_with_native_function, 27 native_function_manager, 28 with_native_function, 29 with_native_function_and, 30) 31from torchgen.model import ( 32 Argument, 33 BackendIndex, 34 BaseTy, 35 BaseType, 36 FunctionSchema, 37 ListType, 38 NativeFunction, 39 NativeFunctionsGroup, 40 NativeFunctionsViewGroup, 41 Return, 42 SchemaKind, 43 SelfArgument, 44 TensorOptionsArguments, 45) 46from torchgen.native_function_generation import ( 47 INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY, 48 MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, 49 OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, 50) 51from torchgen.utils import dataclass_repr 52 53 54if TYPE_CHECKING: 55 from torchgen.selective_build.selector import SelectiveBuilder 56 57 58# Note: [Mutable Ops Not Using Functionalization] 59# Ops in this list currently do not work with functionalization and should be fixed. 60MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = ( 61 OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY 62 + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT 63 + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY 64 + [ 65 # It will be BC-breaking, but we should fix their schemas. 66 # should be inplace? 67 "record_stream", 68 # See Note [resize_ in Functionalization] 69 "resize_", 70 "resize_as_", 71 # This function is used as for testing purposes only. 72 "_fill_mem_eff_dropout_mask_", 73 ] 74) 75 76# This file contains codegen that relates to the functionalization pass. 77# It includes: 78# - gen_functionalization_definition 79# Generates dispatcher kernel definitions for the functionalization pass. 80# - gen_functionalization_registration 81# Generates dispatcher kernel registrations for the functionalization pass. 82# - gen_functionalization_view_inverse_declaration 83# Generates a declaration for an "inverse view", for every view op 84# that is needed in functionalization. We manually implement their definitions. 85# - gen_composite_view_copy_kernel 86# Generates view_copy() composite kernels for all view_copy operators. 87 88 89# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction 90# See Note [view_copy NativeFunctions] 91@dataclass(frozen=True) 92class GenCompositeViewCopyKernel: 93 backend_index: BackendIndex 94 95 @method_with_native_function 96 def __call__(self, g: NativeFunctionsViewGroup) -> str | None: 97 if g.view_copy is None: 98 return None 99 elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy": 100 # If the view_copy doesn't match the standard naming scheme of <op>_copy, 101 # assume it already exists and doesn't need to be generated. 102 # Example: slice_inverse() with the copy variant named slice_scatter() 103 # instead of slice_inverse_copy() 104 return None 105 106 metadata = self.backend_index.get_kernel(g.view_copy) 107 assert metadata is not None 108 109 # We can make view_copy work in more cases by using reshape() 110 # when a normal view call would ordinarily fail. 111 # This also makes LTC more efficient, because they don't need to include 112 # clone() calls in their graph (which is normally needed by reshape). 113 if str(g.view_copy.func.name) == "view_copy": 114 assert metadata.kernel == "view_copy_symint" 115 return """\ 116at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) { 117 c10::SymDimVector shape = infer_size_dv(size, self.sym_numel()); 118 if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) { 119 return self.reshape_symint(size); 120 } else { 121 auto output = at::_ops::view::call(self, size); 122 return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous); 123 } 124} 125""" 126 # view_copy is a native signature, since we're generating an at::native:: kernel 127 # Functionalization always operates on symints though 128 view_copy_sig = NativeSignature( 129 g.view_copy.func, symint=metadata.supports_symint() 130 ) 131 132 # view is a dispatcher signature, since we're calling into the at::_ops API 133 view_sig = DispatcherSignature(g.view.func) 134 135 view_api_name = g.view.func.name.unambiguous_name() 136 exprs = ", ".join( 137 [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())] 138 ) 139 140 # view ops today always return either a Tensor or a list of Tensors 141 assert len(g.view.func.returns) == 1 142 assert g.view.func.returns[0].type == BaseType( 143 BaseTy.Tensor 144 ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None) 145 146 if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): 147 return_cloned_output = """\ 148 return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);""" 149 else: 150 # If the return type is a list, we need to clone each tensor in the list. 151 return_cloned_output = f"""\ 152 {view_copy_sig.returns_type().cpp_type()} out_clone; 153 for (const auto i : c10::irange(output.size())) {{ 154 out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous)); 155 }} 156 return out_clone;""" 157 158 # The default generated composite kernel for {view}_copy() operators just clones 159 # the input tensor, and runs the underlying view on the clone. 160 return f""" 161{view_copy_sig.defn(name=metadata.kernel)} {{ 162 auto output = at::_ops::{view_api_name}::call({exprs}); 163 {return_cloned_output} 164}} 165""" 166 167 168def return_str(rets: tuple[Return, ...], names: list[str]) -> str: 169 assert len(rets) == len(names) 170 if len(rets) == 0: 171 return "" 172 elif len(rets) == 1: 173 return f"return {names[0]};" 174 else: 175 return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});" 176 177 178def modifies_arguments(f: NativeFunction) -> bool: 179 return any( 180 a.annotation is not None and a.annotation.is_write 181 for a in f.func.arguments.flat_all 182 ) 183 184 185def wrapper_name(func: FunctionSchema) -> str: 186 if func.name.overload_name: 187 return f"{cpp.name(func)}_{func.name.overload_name}" 188 else: 189 return cpp.name(func) 190 191 192def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool: 193 return isinstance(a, SelfArgument) or ( 194 isinstance(a, Argument) and a.type.is_tensor_like() 195 ) 196 197 198# We need to wrap / unwrap various arguments from the op in the functionalization kernels. 199# Some op schemas include non-owning types though (like TensorList), 200# and when we unwrap them we expect to get out an owning type!. 201# We also return a lambda that tells you how to conver the non-owning type argument into the owning type. 202def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]: 203 if t == BaseCType(tensorListT): 204 return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" 205 if t == BaseCType(iTensorListRefT): 206 return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}" 207 # There are technically other non-owning types out there (like IntArrayRef), 208 # but functionalization only actually cares about the ones involving tensors. 209 return t, lambda x: x 210 211 212# unwraps all tensor-like arguments, returning: 213# (1) a string containing all of the logic that does the unwrapping 214# (2) a context, to be used by translate(), with all of the relevant bindings. 215def unwrap_tensor_args( 216 sig: DispatcherSignature, *, is_view_op: bool 217) -> tuple[str, list[Binding]]: 218 context: list[Binding] = [] 219 unwrapped_tensor_args: list[str] = [] 220 for arg in sig.arguments(): 221 if is_tensor_like(arg.argument): 222 # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. 223 unwrapped_name = f"{arg.name}_" 224 # For most ops, the functionalization needs to sync any pending updates on the input tensors 225 # before calling the operator, since otherwise the operator will act on stale data. 226 # For view ops though, we can continue to defer syncing until the tensor is used by 227 # a non-view operator. 228 maybe_sync_input = ( 229 "" if is_view_op else f"at::functionalization::impl::sync({arg.name});" 230 ) 231 unwrapped_type, conversion_fn = get_owning_type( 232 arg.nctype.remove_const_ref().type 233 ) 234 unwrapped_tensor_args.append( 235 f""" 236 {unwrapped_type.cpp_type()} {unwrapped_name}; 237 if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{ 238 {maybe_sync_input} 239 {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name}); 240 }} else {{ 241 {unwrapped_name} = {conversion_fn(arg.name)}; 242 }}""" 243 ) 244 context.append(arg.with_name(unwrapped_name)) 245 else: 246 # for non-tensor inputs, we want to pass them directly into the redispatch calls. 247 context.append(arg) 248 unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) 249 return unwrap_tensor_args_str, context 250 251 252# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns: 253# (1) a string containing all of the logic that does the conversions. 254# (2) a context, to be used by translate(), with all of the relevant bindings. 255def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: 256 context: list[Binding] = [] 257 unwrapped_tensor_args: list[str] = [] 258 for arg in sig.arguments(): 259 if is_tensor_like(arg.argument): 260 # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. 261 a_ = arg.name 262 unwrapped_name = f"{arg.name}_meta" 263 unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});") 264 context.append(arg.with_name(unwrapped_name)) 265 else: 266 # for non-tensor inputs, we want to pass them directly into the redispatch calls. 267 context.append(arg) 268 unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) 269 return unwrap_tensor_args_str, context 270 271 272# The functionalization codegen currently expects view op schemas to have this form: 273# foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose) 274# foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_) 275def assert_view_op_properties(func: FunctionSchema) -> None: 276 def is_alias(a: Argument) -> bool: 277 return a.annotation is not None 278 279 args = func.arguments.flat_non_out 280 # The first argument is a tensor with an alias semantics (annotations) 281 assert len(args) > 0 and args[0].type == BaseType( 282 BaseTy.Tensor 283 ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor, 284but found an argument of type {str(args[0].type)} for operator: {str(func.name)}.""" 285 # No other arguments have aliasing semantics 286 assert is_alias(args[0]) and not any( 287 is_alias(a) for a in args[1:] 288 ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output. 289View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint""" 290 291 292# One-liner expression for checking if an expression expr of type type has any 293# symbolic values. 294def emit_expr_has_symbolic_values(expr: str, type: CType) -> str: 295 if type == BaseCType(SymIntT): 296 return f"{expr}.is_symbolic()" 297 298 if isinstance(type, OptionalCType): 299 innerexpr = f"(*{expr})" 300 return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false" 301 302 if type == BaseCType(optionalSymIntArrayRefT): 303 return emit_expr_has_symbolic_values( 304 expr, OptionalCType(BaseCType(symIntArrayRefT)) 305 ) 306 307 if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))): 308 argname = "arg" 309 lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT)) 310 return ( 311 "std::any_of(" 312 f"{expr}.begin(), {expr}.end(), " 313 f"[=](auto& {argname}) {{ return {lambda_check}; }})" 314 ) 315 316 raise ValueError( 317 "unsupported type for has_symbolic_values check. " 318 "It should be a SymInt or a collection of those. " 319 f"Got: {type.cpp_type()}" 320 ) 321 322 323# Detects whether any of the SymInt arguments are, in fact, symbolic values. 324# This is used in the constructor of ViewMeta. 325def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]: 326 name = "has_symbolic_inputs" 327 statements = [ 328 f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});" 329 for binding in sig.arguments() 330 if ( 331 isinstance(binding.argument, Argument) 332 and binding.argument.type.is_symint_like() 333 ) 334 ] 335 body = "\n ".join(statements) 336 return ( 337 name, 338 f""" 339 bool {name} = false; 340 {body}""", 341 ) 342 343 344# Generates the Functionalization kernel for: 345# - ops that create aliases (e.g. transpose()) 346# - ops that are views AND mutations (e.g. transpose_()) 347def emit_view_functionalization_body( 348 g: NativeFunctionsViewGroup, *, view_inplace: bool 349) -> str: 350 if view_inplace: 351 # This op is both an inplace op AND a view op. 352 # See Note [Functionalization Pass - Inplace View Ops] for details. 353 # I currently have the view meta call into the out-of-place variant of the view, to avoid 354 # having to define an extra ~20 inplace {view}_inverse_ functions. 355 # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops. 356 # I'm assuming that every inplace-view op has a corresponding out-of-place view op, 357 # with the same name but the trailing underscore removed. 358 # This is currently asserted at parse time in gen.py (see error_check_native_functions). 359 assert g.view_inplace is not None 360 f = g.view_inplace 361 else: 362 f = g.view 363 364 assert g.view_copy is not None 365 with native_function_manager(f): 366 call_sig = DispatcherSignature.from_schema(g.view_copy.func) 367 368 # the "view_copy" op name that the functionalization kernels need to call 369 api_name = g.view_copy.func.name.unambiguous_name() 370 # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors) 371 # "no-op"ing in this context is just redispatching to the original op. 372 noop_api_name = f.func.name.unambiguous_name() 373 374 dispatcher_sig = DispatcherSignature.from_schema(f.func) 375 assert_view_op_properties(f.func) 376 view_tensor_name = dispatcher_sig.arguments()[0].name 377 378 return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type() 379 380 unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( 381 dispatcher_sig, is_view_op=True 382 ) 383 view_redispatch_args = [ 384 e.expr 385 for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False) 386 ] 387 388 forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False) 389 reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True) 390 391 # The meta API call should use the same arguments, but convert all tensors to meta tensors first. 392 meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) 393 meta_call_args = [ 394 e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) 395 ] 396 397 ( 398 symbolic_inputs_varname, 399 symbolic_inputs_check, 400 ) = emit_has_symbolic_inputs(call_sig) 401 402 if "inplace_view" in f.tags: 403 # See Note [Functionalization Pass - Inplace View Ops] for more details 404 return f""" 405 {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ 406 if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ 407 // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. 408 {unwrap_tensor_args_str} 409 at::AutoDispatchSkipFunctionalize guard; 410 return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); 411 }} 412 auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); 413 auto inverse_return_mode = ( 414 reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse 415 : at::functionalization::InverseReturnMode::NeverView 416 ); 417 {symbolic_inputs_check} 418 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( 419 {forward_lambda.decl()} {{ 420 if (reapply_views) {{ 421 return {forward_lambda.inner_call(reapply_views=True)} 422 }} else {{ 423 return {forward_lambda.inner_call(reapply_views=False)} 424 }} 425 }}, 426 {reverse_lambda.decl()} {{ 427 return {reverse_lambda.inner_call()} 428 }}, 429 /*has_symbolic_inputs=*/{symbolic_inputs_varname} 430 ); 431 auto compute_reference_meta = 432 {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || 433 {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); 434 {return_type} reference_tensor_output; 435 if (compute_reference_meta) {{ 436 {meta_conversion_str} 437 at::AutoDispatchSkipFunctionalize func_guard; 438 c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); 439 reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); 440 }} 441 // This function adds the above view meta to the current tensor and replays them off the base, 442 // mutating the size/stride info of the current FunctionalTensorWrapper. 443 // Because of this, we need to make sure to run the reference shape function above, 444 // BEFORE doing this (otherwise we'll end up runnin the reference function using the wrong sizes/strides) 445 at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta); 446 // See Note [Propagating strides in the functionalization pass] 447 // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely 448 // on a reference implementation here (instead of relying on the output from the forward lambda 449 // having the correct stride info) 450 if (compute_reference_meta) {{ 451 at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output); 452 }} 453 return {view_tensor_name}; 454 }} 455""" 456 457 else: 458 is_multi_output_view = isinstance(f.func.returns[0].type, ListType) 459 return f""" 460 {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ 461 {unwrap_tensor_args_str} 462 if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{ 463 // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper. 464 at::AutoDispatchSkipFunctionalize guard; 465 return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); 466 }} 467 auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); 468 auto inverse_return_mode = ( 469 reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse 470 : at::functionalization::InverseReturnMode::NeverView 471 ); 472 auto compute_reference_meta = 473 {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || 474 {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit); 475 {return_type} reference_tensor_output; 476 if (compute_reference_meta) {{ 477 {meta_conversion_str} 478 at::AutoDispatchSkipFunctionalize func_guard; 479 c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); 480 reference_tensor_output = at::_ops::{noop_api_name}::call({', '.join(meta_call_args)}); 481 }} 482 {return_type} tmp_output; 483 {{ 484 at::AutoDispatchSkipFunctionalize guard; 485 if (reapply_views) {{ 486 tmp_output = at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); 487 }} else {{ 488 tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)}); 489 }} 490 }} 491 {symbolic_inputs_check} 492 at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( 493 {forward_lambda.decl()} {{ 494 if (reapply_views) {{ 495 return {forward_lambda.inner_call(reapply_views=True)} 496 }} else {{ 497 return {forward_lambda.inner_call(reapply_views=False)} 498 }} 499 }}, 500 {reverse_lambda.decl()} {{ 501 return {reverse_lambda.inner_call()} 502 }}, 503 /*has_symbolic_inputs=*/{symbolic_inputs_varname}, 504 /*is_multi_output=*/{str(is_multi_output_view).lower()}, 505 /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} 506 ); 507 auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta); 508 // See Note [Propagating strides in the functionalization pass] 509 if (compute_reference_meta) {{ 510 at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output); 511 }} 512 return out; 513 }} 514""" 515 516 517def maybe_create_output(f: NativeFunction, var_name: str) -> str: 518 if len(f.func.returns) == 0: 519 return "" 520 return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type() 521 return f"{return_type} {var_name} = " 522 523 524# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function, 525# this returns two lists of names, consisting of: 526# - the names of returns corresponding to the original (mutable) inputs of the outer function 527# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function 528def get_mutable_redispatch_return_names( 529 f: NativeFunction, inner_return_var: str 530) -> tuple[list[str], list[str]]: 531 aliased_returns = [] 532 non_aliased_returns = [] 533 for i, name in enumerate(f.func.aliased_return_names()): 534 if name is not None: 535 aliased_returns.append(name) 536 else: 537 non_aliased_returns.append( 538 inner_return_var 539 if len(f.func.returns) == 1 540 else f"std::get<{i}>({inner_return_var})" 541 ) 542 return aliased_returns, non_aliased_returns 543 544 545# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that: 546# - For fresh outputs, we return the result of the redispatch (without wrapping outputs) 547# - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped) 548def return_from_mutable_noop_redispatch( 549 f: NativeFunction, inner_return_var: str 550) -> str: 551 aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var) 552 # Just get all of the return names, and immediately return them 553 return return_str(f.func.returns, aliased + non_aliased) 554 555 556def wrap_propagate_mutations_and_return( 557 f: NativeFunction, functional_op: NativeFunction, inner_return_var: str 558) -> str: 559 mutable_arg_names = f.func.arguments.mutable_arg_names() 560 ( 561 aliased_outer_rets, 562 non_aliased_outer_rets, 563 ) = get_mutable_redispatch_return_names(f, inner_return_var) 564 _, non_aliased_inner_rets = get_mutable_redispatch_return_names( 565 functional_op, inner_return_var 566 ) 567 # The outer function may have a mix of aliased and non-aliased outputs, 568 # But the inner functional op that we're transforming to should only have non-aliased outputs 569 assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len( 570 non_aliased_inner_rets 571 ) 572 573 # First, take all of the newly created outputs from the inner call and wrap them into functional tensors 574 updates = [] 575 non_aliased_wrapped_ret_names = [] 576 for i, inner_ret in enumerate( 577 non_aliased_inner_rets[: len(non_aliased_outer_rets)] 578 ): 579 ret_name = f"output_{i}" 580 updates.append( 581 f"""\ 582 auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});""" 583 ) 584 non_aliased_wrapped_ret_names.append(ret_name) 585 586 # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs, 587 # and propagate the mutations 588 for outer_arg, inner_ret in zip( 589 mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :] 590 ): 591 updates.append( 592 f"""\ 593 auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg}); 594 at::functionalization::impl::replace_({outer_arg}, {inner_ret}); 595 at::functionalization::impl::commit_update({outer_arg}); 596 at::functionalization::impl::sync({outer_arg}); 597 auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg}); 598 at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);""" 599 ) 600 601 # Finally, we return: 602 # - Any mutable arguments that also returns 603 # - Any immutable returns that were created wrapping the output from the inner call 604 returns_str = return_str( 605 f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names 606 ) 607 updates_str = "\n".join(updates) 608 return f"""\ 609{updates_str} 610 {returns_str}""" 611 612 613# Generates the Functionalization kernel for: 614# - mutation ops (inplace and out= ops) 615@with_native_function_and 616def emit_inplace_functionalization_body( 617 f: NativeFunction, g: NativeFunctionsGroup 618) -> str: 619 # mutation case 620 assert modifies_arguments(f) 621 622 dispatcher_sig = DispatcherSignature.from_schema(f.func) 623 624 unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args( 625 dispatcher_sig, is_view_op=False 626 ) 627 628 mutated_names = [ 629 a.name 630 for a in f.func.arguments.flat_all 631 if a.type.is_tensor_like() and a.annotation is not None 632 ] 633 non_mutated_names = [ 634 a.name 635 for a in f.func.arguments.flat_all 636 if a.type.is_tensor_like() and a.annotation is None 637 ] 638 non_mutated_tensor_names = [ 639 a.name 640 for a in f.func.arguments.flat_all 641 if a.type == BaseType(BaseTy.Tensor) and a.annotation is None 642 ] 643 # all mutable inputs must be functional tensors in order to participate in functionalization 644 check_all_mutated_args_are_functional = " && ".join( 645 ["true"] 646 + [ 647 f"at::functionalization::impl::isFunctionalTensor({a})" 648 for a in mutated_names 649 ] 650 ) 651 check_any_non_mutated_args_are_functional = " || ".join( 652 ["false"] 653 + [ 654 f"at::functionalization::impl::isFunctionalTensor({a})" 655 for a in non_mutated_names 656 ] 657 ) 658 659 check_any_non_mutated_tensors_are_xla = " || ".join( 660 ["false"] 661 + [ 662 f"{a}.device().type() == c10::DeviceType::XLA" 663 for a in non_mutated_tensor_names 664 ] 665 ) 666 # These are used in the cases where we don't functionalize and redispatch to the inplace op 667 # case 1: we hit an inplace op that doesn't have an out-of-place equivalent 668 # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops) 669 inplace_exprs = [ 670 e.expr 671 for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False) 672 ] 673 674 # call the out-of-place variant of the op 675 return_type = ( 676 dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type() 677 ) 678 functional_sig = DispatcherSignature.from_schema(g.functional.func) 679 functional_exprs = [ 680 e.expr 681 for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) 682 ] 683 684 if f.func.is_out_fn(): 685 mutable_input_post_processing = "\n".join( 686 [ 687 f""" 688 at::functionalization::impl::replace_( 689 {a.name}, {'std::get<' + str(i) + '>(tmp_output)' if len(f.func.returns) > 1 else 'tmp_output'}); 690 at::functionalization::impl::commit_update({a.name});""" 691 for (i, a) in enumerate(f.func.arguments.out) 692 if a.annotation and a.annotation.is_write and a.type.is_tensor_like() 693 ] 694 ) 695 else: 696 mutable_input_post_processing = "\n".join( 697 [ 698 f""" 699 at::functionalization::impl::replace_({a.name}, tmp_output); 700 at::functionalization::impl::commit_update({a.name});""" 701 for a in f.func.arguments.flat_all 702 if a.annotation and a.annotation.is_write and a.type.is_tensor_like() 703 ] 704 ) 705 706 meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) 707 # We don't want to run the inplace meta func for ops like .set_(), because: 708 # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(), 709 # where broadcasting will work for the out-of-place case but should fail on the inplace call 710 # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument 711 # into a meta storage 712 any_storage_args = any( 713 a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all 714 ) 715 716 return f""" 717 {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{ 718 if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{ 719 // Before converting the mutable op to its functional variant, run meta tensors through the original op. 720 // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants. 721 // (We can only do this for inplace ops today though, because they technically all support meta tensors). 722 {meta_conversion_str} 723 at::AutoDispatchSkipFunctionalize func_guard; 724 c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch); 725 at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(a.name for a in meta_call_ctx)}); 726 }} 727 {unwrap_tensor_args_str} 728 if (!({check_all_mutated_args_are_functional})) {{ 729 // We want to disable this check if there are any XLA tensors. 730 // cpu_tensor.copy_(xla_tensor) is valid code. 731 if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{ 732 // case 1: trying to mutate a non functional tensor with a functional tensor is an error 733 TORCH_INTERNAL_ASSERT(false, 734 "mutating a non-functional tensor with a functional tensor is not allowed.", 735 " Please ensure that all of your inputs are wrapped inside of a functionalize() call."); 736 }} else {{ 737 // case 2: arguments are not functional tensors, so we no-op and redispatch. 738 at::AutoDispatchSkipFunctionalize guard; 739 {maybe_create_output(f, 'tmp_output')}at::_ops::{f.func.name.unambiguous_name()}::call({', '.join(inplace_exprs)}); 740 {return_from_mutable_noop_redispatch(f, 'tmp_output')} 741 }} 742 }} else {{ 743 {return_type} tmp_output; 744 {{ 745 at::AutoDispatchSkipFunctionalize guard; 746 tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({', '.join(functional_exprs)}); 747 }} 748 {wrap_propagate_mutations_and_return(f, g.functional, 'tmp_output')} 749 }} 750 }}""" 751 752 753# The below functions generate RegisterFunctionalization.cpp 754# These files provide the kernels that run the functionalization pass, which can be opted into 755# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch). 756 757 758# See Note [Functionalization Pass: View Inverses]. 759def gen_functionalization_view_inverse_declaration( 760 selector: SelectiveBuilder, g: NativeFunctionsViewGroup 761) -> str | None: 762 # For every (non-composite) view op, we need a corresponding "inverse view" function. 763 # This generates the declarations so we get a good compiler error when someone adds a new view. 764 @with_native_function 765 def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None: 766 if g.view.has_composite_implicit_autograd_kernel: 767 return None 768 view_inverse_sig = ViewInverseSignature(g) 769 return view_inverse_sig.decl() 770 771 return emit_decl_helper(g) 772 773 774def gen_functionalization_registration( 775 selector: SelectiveBuilder, 776 g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 777 composite_implicit_autograd_index: BackendIndex, 778) -> list[str]: 779 @with_native_function 780 def emit_registration_helper(f: NativeFunction) -> str: 781 assert not f.has_composite_implicit_autograd_kernel 782 registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" 783 return f'm.impl("{f.func.name}", {registration_str});' 784 785 # Don't generate kernels in mobile build 786 if not selector.include_all_operators: 787 return [] 788 789 if isinstance(g, NativeFunctionsViewGroup): 790 # functionalization needs to register kernels for view + view_inplace ops 791 # See Note [Functionalization <> torch.Tensor constructor] 792 if str(g.view.func.name) == "lift_fresh": 793 return [] 794 view_str = [] 795 if not g.view.has_composite_implicit_autograd_kernel: 796 view_str.append(emit_registration_helper(g.view)) 797 if ( 798 g.view_inplace is not None 799 and not g.view_inplace.has_composite_implicit_autograd_kernel 800 ): 801 assert g.view_inplace.is_view_op 802 view_str.append(emit_registration_helper(g.view_inplace)) 803 return view_str 804 805 elif isinstance(g, NativeFunctionsGroup): 806 # Gets a hand-written functionalization kernel 807 if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor": 808 fns = [] 809 else: 810 fns = list(g.functions()) 811 else: 812 if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION: 813 return [] 814 fns = [g] 815 816 registrations = [] 817 for f in fns: 818 if f.has_composite_implicit_autograd_kernel: 819 continue 820 if str(f.func.name) == "lift": 821 # See Note [Functionalization <> torch.Tensor constructor] 822 return [] 823 if str(f.func.name) == "resize_": 824 # See Note [resize_ in Functionalization] 825 return [] 826 if str(f.func.name.name) != "set_": 827 assert not f.is_view_op 828 # functionalization needs to generate and register kernels for inplace ops. 829 # We *also* need to directly register CompositeImplicitAUtograd kernels 830 # so that they decompose properly before functioanlization. 831 if modifies_arguments(f): 832 registrations.append(emit_registration_helper(f)) 833 return registrations 834 835 836def gen_functionalization_definition( 837 selector: SelectiveBuilder, 838 # Note: Ideally this code should never have to look at NativeFunction 839 # (and instead only need to operate on grouped NativeFunctions). 840 # The only reason currently is because we need to emit direct dispatch registrations 841 # For CompositeImplicitAutograd operators, which are potentially ungrouped. 842 g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup, 843) -> list[str]: 844 # Don't generate kernels in mobile build 845 if not selector.include_all_operators: 846 return [] 847 848 if isinstance(g, NativeFunctionsViewGroup): 849 # Case 1: emit view -> view_copy kernels for the functionalization pass 850 view_defs = [] 851 if not g.composite: 852 # invariant: NativeFunctionsViewGroup's always have a view_copy operator 853 # if the view is not composite (implicit autograd) 854 assert g.view_copy is not None, dataclass_repr(g, indent=1) 855 view_defs.append(emit_view_functionalization_body(g, view_inplace=False)) 856 if g.view_inplace is not None: 857 view_defs.append(emit_view_functionalization_body(g, view_inplace=True)) 858 return view_defs 859 elif isinstance(g, NativeFunction): 860 # Invariant: all mutable operators that we need to handle in functionalization 861 # should have been properly grouped up. 862 # TODO: The below ops all have "problematic" schemas that prevent them from 863 # getting functionalized. Instead of bending over backwards to get things to work, 864 # I think we should either: 865 # (1) fix their schemas (BC-breaking) 866 # (2) hand-write their functionalization kernels 867 if ( 868 str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION 869 and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION 870 ): 871 assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g) 872 return [] 873 else: 874 # Case 2: emit inplace -> out-of-place kernels for the functionalization pass 875 mutation_defs = [] 876 mutation_defs.append(emit_inplace_functionalization_body(g.out, g)) 877 if g.inplace is not None: 878 mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g)) 879 if g.mutable is not None: 880 mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g)) 881 return mutation_defs 882 return [] 883