1from __future__ import annotations 2 3import itertools 4import textwrap 5from dataclasses import dataclass 6from typing import Literal, TYPE_CHECKING 7 8import torchgen.api.cpp as cpp 9import torchgen.api.meta as meta 10import torchgen.api.structured as structured 11from torchgen.api.translate import translate 12from torchgen.api.types import ( 13 BaseCType, 14 Binding, 15 ConstRefCType, 16 CppSignature, 17 CppSignatureGroup, 18 DispatcherSignature, 19 Expr, 20 kernel_signature, 21 MutRefCType, 22 NamedCType, 23 NativeSignature, 24 tensorT, 25) 26from torchgen.context import method_with_native_function, native_function_manager 27from torchgen.model import ( 28 Argument, 29 BackendIndex, 30 DeviceCheckType, 31 DispatchKey, 32 gets_generated_out_inplace_wrapper, 33 is_cuda_dispatch_key, 34 NativeFunction, 35 NativeFunctionsGroup, 36 SchemaKind, 37 TensorOptionsArguments, 38) 39from torchgen.utils import assert_never, mapMaybe, Target 40 41 42if TYPE_CHECKING: 43 from torchgen.selective_build.selector import SelectiveBuilder 44 45 46def gen_registration_headers( 47 backend_index: BackendIndex, 48 per_operator_headers: bool, 49 rocm: bool, 50) -> list[str]: 51 if per_operator_headers: 52 headers = ["#include <ATen/ops/as_strided_native.h>"] 53 else: 54 headers = ["#include <ATen/NativeFunctions.h>"] 55 56 if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta): 57 headers.append("#include <ATen/EmptyTensor.h>") 58 elif backend_index.dispatch_key == DispatchKey.CUDA: 59 if rocm: 60 headers.append("#include <ATen/hip/EmptyTensor.h>") 61 else: 62 headers.append("#include <ATen/cuda/EmptyTensor.h>") 63 elif backend_index.dispatch_key == DispatchKey.MPS: 64 headers.append("#include <ATen/mps/EmptyTensor.h>") 65 elif backend_index.dispatch_key == DispatchKey.XPU: 66 # XPU specific, this header resides in third_party/torch-xpu-ops 67 headers.append("#include <ATen/xpu/EmptyTensor.h>") 68 elif per_operator_headers: 69 headers += [ 70 "#include <ATen/ops/empty.h>", 71 "#include <ATen/ops/empty_strided.h>", 72 "#include <ATen/ops/_copy_from_and_resize.h>", 73 "#include <ATen/ops/_copy_from.h>", 74 ] 75 else: 76 headers.append("#include <ATen/Functions.h>") 77 78 headers.append("#include <c10/macros/Macros.h>") 79 return headers 80 81 82def gen_empty_impl_names( 83 backend_index: BackendIndex, 84) -> tuple[str | None, str | None]: 85 empty_impl = None 86 empty_strided_impl = None 87 88 if backend_index.dispatch_key in ( 89 DispatchKey.Meta, 90 DispatchKey.CPU, 91 DispatchKey.CUDA, 92 DispatchKey.MPS, 93 DispatchKey.XPU, 94 ): 95 dispatch = str(backend_index.dispatch_key).lower() 96 empty_impl = f"at::detail::empty_{dispatch}" 97 empty_strided_impl = f"at::detail::empty_strided_{dispatch}" 98 elif backend_index.dispatch_key in ( 99 DispatchKey.CompositeExplicitAutogradNonFunctional, 100 DispatchKey.QuantizedCPU, 101 DispatchKey.QuantizedCUDA, 102 DispatchKey.XPU, 103 ): 104 empty_impl = "at::empty" 105 empty_strided_impl = "at::empty_strided" 106 107 return empty_impl, empty_strided_impl 108 109 110def gen_create_out_helper(backend_index: BackendIndex) -> list[str]: 111 if backend_index.dispatch_key == DispatchKey.Meta: 112 empty_options = "options.device(at::kMeta)" 113 else: 114 empty_options = "options" 115 116 empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index) 117 if empty_impl is None: 118 return [] 119 120 return [ 121 f""" 122Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ 123 if (strides.empty()) {{ 124 return {empty_impl}(sizes, {empty_options}); 125 }} else {{ 126 return {empty_strided_impl}(sizes, strides, {empty_options}); 127 }} 128}} 129""" 130 ] 131 132 133def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]: 134 _, empty_strided_impl = gen_empty_impl_names(backend_index) 135 return ( 136 [] 137 if empty_strided_impl is None 138 else [ 139 f""" 140std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{ 141 if (out.strides() != strides) {{ 142 return {empty_strided_impl}(sizes, strides, options); 143 }} 144 return std::nullopt; 145}} 146""" 147 ] 148 ) 149 150 151def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]: 152 if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: 153 # The function isn't used by this key (since only functional ops have a kernel for this key), 154 # so we need to not include it to avoid a defined-but-not-used error. 155 return [] 156 return [ 157 """ 158void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) { 159 TORCH_CHECK(options.dtype() == out.dtype(), 160 "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead"); 161 TORCH_CHECK(options.device() == out.device(), 162 "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead"); 163 const bool resized = at::native::resize_output(out, sizes); 164 // Only restride if a resize occurred; otherwise we ignore the (advisory) 165 // strides from the meta function and directly use the output tensor's 166 // preexisting strides 167 if (resized) { 168 if (!strides.empty()) { 169 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); 170 // TODO: avoid the redispatch here 171 out.as_strided_(sizes, strides); 172 } else if (options.memory_format_opt().has_value()) { 173 out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); 174 } 175 } 176} 177""" 178 ] 179 180 181def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]: 182 return [ 183 """ 184void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { 185 // These checks are needed on those operators that: 186 // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm') 187 // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod') 188 // For other operators (e.g. 'add'), 'TensorIterator' already checks 189 // these things separately. 190 TORCH_CHECK(options.dtype() == self.dtype(), 191 "Bad in-place call: ", 192 "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match"); 193 TORCH_CHECK(options.device() == self.device(), 194 "Bad in-place call: ", 195 "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match"); 196 TORCH_CHECK(sizes == self.sizes(), 197 "Bad in-place call: ", 198 "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match"); 199} 200""" 201 ] 202 203 204def gen_registration_helpers(backend_index: BackendIndex) -> list[str]: 205 return [ 206 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")', 207 *gen_create_out_helper(backend_index), 208 *gen_resize_out_helper(backend_index), 209 *gen_check_inplace_helper(backend_index), 210 *gen_maybe_create_proxy_helper(backend_index), 211 "C10_DIAGNOSTIC_POP()", 212 ] 213 214 215# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp). 216# 217# - The primary function of this file is to register all of the 218# implementations for the given dispatch key to the dispatcher, 219# so they are available for use in PyTorch. If dispatch is 220# None, we generate schema (def) registrations and catchall 221# registrations. 222# - The secondary function of this file is to generate a wrapper 223# around functions. In CPUType these wrappers do nothing 224# (and should be removed), but in other cases they handle 225# DeviceGuard. A small extra benefit of wrappers is they 226# are not overloaded, so they can be used in the registration 227# API without having to disambiguate which overload you want 228# (as would be the case if you directly registered native:: 229# functions). 230# - The tertiary function of this file is to generate *static* 231# cpp API bindings which can be used to bypass dispatcher 232# directly to kernels, but with user-friendly cpp-style API 233@dataclass(frozen=True) 234class RegisterDispatchKey: 235 backend_index: BackendIndex 236 237 target: Literal[ 238 Target.ANONYMOUS_DEFINITION, 239 Target.NAMESPACED_DEFINITION, 240 Target.NAMESPACED_DECLARATION, 241 Target.REGISTRATION, 242 ] 243 244 # Selector object to determine which operators to generate 245 # registration code for. 246 selector: SelectiveBuilder 247 248 # Whether or not we are actually code-genning for ROCm 249 rocm: bool 250 251 # Whether or not to generate symint registrations or not. External users 252 # of codegen who don't care about symints can set this to false to get 253 # non-SymInt codegen 254 symint: bool 255 256 # The class that all unstructured native functions live under. This is used to improve 257 # compiler error messages when a kernel writer adds a native function with the wrong signature. 258 # This is only used in unstructured kernels, since structured kernels already live in a class. 259 # Finally, this field is currently Optional because it is only used by external backends. 260 # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating 261 # all of the existing kernel signatures scattered across aten/src/ATen/native. 262 class_method_name: str | None 263 264 # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering 265 # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. 266 skip_dispatcher_op_registration: bool 267 268 @staticmethod 269 def gen_device_check( 270 type: DeviceCheckType, args: list[Argument], method_name: str 271 ) -> str: 272 if type == DeviceCheckType.NoCheck: 273 return " // No device check\n" 274 275 device_check = "std::optional<Device> common_device = std::nullopt;\n" 276 device_check += "(void)common_device; // Suppress unused variable warning\n" 277 for arg in args: 278 # Only tensor like arguments are eligible 279 if arg.type.is_tensor_like(): 280 device_check += f""" 281 c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");""" 282 return device_check 283 284 @method_with_native_function 285 def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: 286 if isinstance(f, NativeFunctionsGroup): 287 g: NativeFunctionsGroup = f 288 # Note: We call gen_structured() if the operator is marked structured, regardless of the backend. 289 # gen_structured() has special logic to handle auto-generated kernels. 290 if g.structured: 291 return self.gen_structured(g) 292 else: 293 return list( 294 mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()) 295 ) 296 elif isinstance(f, NativeFunction): 297 r = self.gen_unstructured(f) 298 return [] if r is None else [r] 299 else: 300 assert_never(f) 301 302 def wrapper_kernel_sig( 303 self, f: NativeFunction 304 ) -> NativeSignature | DispatcherSignature: 305 # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. 306 return DispatcherSignature.from_schema( 307 f.func, 308 prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_", 309 symint=self.symint, 310 ) 311 312 def gen_out_inplace_wrapper( 313 self, f: NativeFunction, g: NativeFunctionsGroup | None 314 ) -> str | None: 315 if g is None: 316 return None 317 k = f.func.kind() 318 if k is SchemaKind.inplace: 319 copy_op = "at::_copy_from" 320 elif k is SchemaKind.out: 321 copy_op = "at::_copy_from_and_resize" 322 else: 323 raise AssertionError("gen_out_inplace_wrapper called on a functional op") 324 325 sig = self.wrapper_kernel_sig(f) 326 name = sig.name() 327 328 func_res = f"{name}_tmp" 329 return_names = cpp.return_names(f) 330 if len(return_names) > 1: 331 updates = "\n ".join( 332 f"{copy_op}(std::get<{i}>({func_res}), {ret_name});" 333 for i, ret_name in enumerate(return_names) 334 ) 335 returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})' 336 elif len(return_names) == 1: 337 ret_name = return_names[0] 338 updates = f"{copy_op}({func_res}, {ret_name});" 339 returns = ret_name 340 else: 341 assert len(f.func.arguments.out) == 1 342 returns = "" 343 out_arg = f.func.arguments.out[0] 344 if out_arg.type.is_list_like(): 345 updates = f"""\ 346 for (int64_t i = 0; i < {func_res}.size(); ++i) {{ 347 {copy_op}({func_res}[i], {out_arg.name}[i]); 348 }}""" 349 else: 350 updates = f"{copy_op}({func_res}, {out_arg.name});" 351 352 functional_sig = self.wrapper_kernel_sig(g.functional) 353 wrapper_name = sig.name() 354 355 return f"""\ 356{sig.defn(name=wrapper_name)} {{ 357 auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))}); 358 {updates} 359 return {returns}; 360}} 361""" 362 363 def gen_structured(self, g: NativeFunctionsGroup) -> list[str]: 364 metadata = self.backend_index.get_kernel(g) 365 if self.backend_index.dispatch_key == DispatchKey.Meta: 366 assert not self.backend_index.has_kernel(g.out), ( 367 "Do not explicitly specify Meta dispatch key on structured " 368 "functions, they will be automatically generated for you" 369 ) 370 elif ( 371 self.backend_index.dispatch_key 372 == DispatchKey.CompositeExplicitAutogradNonFunctional 373 ): 374 assert not self.backend_index.has_kernel(g.out), ( 375 "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured " 376 "functions, they will be automatically generated for you" 377 ) 378 elif metadata is None or not metadata.structured: 379 return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())) 380 structured_gen = StructuredRegisterDispatchKey( 381 self.backend_index, 382 self.target, 383 self.selector, 384 self.rocm, 385 self.symint, 386 self.class_method_name, 387 self.skip_dispatcher_op_registration, 388 g, 389 ) 390 return list(mapMaybe(structured_gen.gen_one, g.functions())) 391 392 def gen_unstructured( 393 self, f: NativeFunction, g: NativeFunctionsGroup | None = None 394 ) -> str | None: 395 with native_function_manager(f): 396 inplace_meta = False 397 gets_out_inplace_wrapper = False 398 if not self.backend_index.has_kernel(f): 399 if ( 400 self.backend_index.dispatch_key == DispatchKey.Meta 401 and f.func.kind() is SchemaKind.inplace 402 and 403 # Defer to composites for meta implementation 404 not f.has_composite_kernel 405 and 406 # Inplace list operations are not supported 407 len(f.func.returns) == 1 408 ): 409 inplace_meta = True 410 elif ( 411 not self.backend_index.use_out_as_primary 412 and g is not None 413 and gets_generated_out_inplace_wrapper(f, g, self.backend_index) 414 ): 415 # We want to generate inplace/out wrappers, that don't have a kernel for the backend. 416 gets_out_inplace_wrapper = True 417 else: 418 return None 419 if f.manual_kernel_registration: 420 return None 421 422 if ( 423 self.target is Target.REGISTRATION 424 and not self.selector.is_native_function_selected(f) 425 ): 426 return None 427 428 sig = self.wrapper_kernel_sig(f) 429 430 name = sig.name() 431 returns_type = sig.returns_type().cpp_type() 432 args = sig.arguments() 433 args_str = ", ".join(a.defn() for a in args) 434 435 # See Note [Direct dispatch bindings] 436 cpp_sig_group = CppSignatureGroup.from_native_function( 437 f, method=False, fallback_binding=False 438 ) 439 440 # TODO: dedupe this with the structured codegen 441 if self.target is Target.NAMESPACED_DECLARATION: 442 result = "" 443 for cpp_sig in cpp_sig_group.signatures(symint=self.symint): 444 result += f"TORCH_API {cpp_sig.decl()};\n" 445 return result 446 elif self.target is Target.NAMESPACED_DEFINITION: 447 448 def generate_defn(cpp_sig: CppSignature) -> str: 449 return f""" 450{cpp_sig.defn()} {{ 451return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); 452}} 453""" 454 455 result = "" 456 for cpp_sig in cpp_sig_group.signatures(symint=self.symint): 457 result += generate_defn(cpp_sig) 458 return result 459 460 elif self.target is Target.ANONYMOUS_DEFINITION: 461 # short circuit for inplace_meta 462 if inplace_meta: 463 assert f.func.arguments.self_arg is not None 464 self_arg_name = f.func.arguments.self_arg.argument.name 465 # TODO: handle in place on tensor list 466 return f""" 467{returns_type} {name}({args_str}) {{ 468 TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(), 469 "Cannot inplace into non-meta tensor with meta tensor argument"); 470 return {self_arg_name}; 471}} 472""" 473 474 # short circuit for generated inplace/out wrappers 475 if gets_out_inplace_wrapper: 476 return self.gen_out_inplace_wrapper(f, g) 477 478 metadata = self.backend_index.get_kernel(f) 479 if metadata is None: 480 return None 481 if self.class_method_name is None: 482 impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}" 483 else: 484 impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}" 485 486 kernel_sig = kernel_signature(f, self.backend_index) 487 488 args_exprs_str = ", ".join( 489 e.expr 490 for e in translate( 491 sig.arguments(), kernel_sig.arguments(), method=False 492 ) 493 ) 494 495 device_check = " // No device check\n" 496 # Backends that require device guards presumably also require device checks. 497 if self.backend_index.device_guard: 498 device_check_args = itertools.chain( 499 f.func.arguments.out, f.func.arguments.flat_positional 500 ) 501 device_check = RegisterDispatchKey.gen_device_check( 502 f.device_check, list(device_check_args), name 503 ) 504 505 device_guard = "// DeviceGuard omitted" # default 506 if f.device_guard and self.backend_index.device_guard: 507 has_tensor_options = any( 508 isinstance(a, TensorOptionsArguments) 509 for a in f.func.arguments.non_out 510 ) 511 if has_tensor_options: 512 # kernel is creating a tensor 513 device_guard = """ 514 const DeviceGuard device_guard(device_or_default(device));""" 515 516 # CUDA requires special handling 517 if is_cuda_dispatch_key(self.backend_index.dispatch_key): 518 device_guard = ( 519 f"globalContext().lazyInitCUDA();\n{device_guard}" 520 ) 521 else: 522 # kernel is operating on existing tensors 523 524 # There is precedence for which argument we use to do 525 # device guard. This describes the precedence order. 526 self_arg = ( 527 [f.func.arguments.self_arg.argument] 528 if f.func.arguments.self_arg is not None 529 else [] 530 ) 531 candidate_args = itertools.chain( 532 self_arg, 533 f.func.arguments.out, 534 f.func.arguments.flat_positional, 535 ) 536 537 # Only tensor like arguments are eligible 538 device_of = next( 539 ( 540 f"{a.name}" 541 for a in candidate_args 542 if a.type.is_tensor_like() 543 ), 544 None, 545 ) 546 if device_of is not None: 547 device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));" 548 549 return f"""\ 550namespace {{ 551 552{returns_type} {name}({args_str}) {{ 553 {device_check} 554 555 {device_guard} 556 return {impl_name}({args_exprs_str}); 557}} 558 559}} // anonymous namespace 560""" 561 562 elif self.target is Target.REGISTRATION: 563 if f.manual_kernel_registration or self.skip_dispatcher_op_registration: 564 return None 565 else: 566 payload = f"TORCH_FN({name})" 567 return f'm.impl("{f.func.name}",\n{payload});\n' 568 else: 569 assert_never(self.target) 570 571 572# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 573# 574# STRUCTURED 575# 576# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 577 578 579@dataclass(frozen=True) 580class StructuredRegisterDispatchKey(RegisterDispatchKey): 581 g: NativeFunctionsGroup 582 583 def gen_class_set_output_functions( 584 self, k: SchemaKind, parent_class: str, generate_super: bool 585 ) -> str: 586 if generate_super: 587 set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);" 588 else: 589 set_output_super = "" 590 591 def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str: 592 return f""" 593void set_output_{name}( 594 int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, 595 TensorOptions options, DimnameList names 596) override {{ 597{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")} 598 if (!names.empty()) {{ 599 namedinference::propagate_names(outputs_[output_idx], names); 600 }} 601 // super must happen after, so that downstream can use maybe_get_output 602 // to retrieve the output 603{textwrap.indent(set_output_super, " ")} 604}} 605""" 606 607 return f""" 608{gen_set_output_function("strided", maybe_create_proxy=True)} 609{gen_set_output_function("raw_strided", maybe_create_proxy=False)} 610""" 611 612 def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str: 613 if self.backend_index.dispatch_key in [ 614 DispatchKey.CUDA, 615 DispatchKey.MPS, 616 DispatchKey.CompositeExplicitAutogradNonFunctional, 617 ]: 618 maybe_set_guard = """ 619auto current_device = guard_.current_device(); 620if (C10_UNLIKELY(current_device.has_value())) { 621 TORCH_INTERNAL_ASSERT(*current_device == options.device(), 622 "structured kernels don't support multi-device outputs"); 623} else { 624 guard_.reset_device(options.device()); 625} 626""" 627 maybe_set_guard_line = maybe_set_guard + "\n" 628 else: 629 maybe_set_guard_line = maybe_set_guard = "" 630 631 if maybe_create_proxy: 632 create_proxy = """ 633auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options); 634if (C10_UNLIKELY(maybe_proxy.has_value())) { 635 proxy_outputs_[output_idx] = std::move(maybe_proxy).value(); 636} 637""" 638 else: 639 create_proxy = "" 640 641 if k is SchemaKind.functional: 642 assert self.backend_index.dispatch_key in ( 643 DispatchKey.Meta, 644 DispatchKey.CPU, 645 DispatchKey.CUDA, 646 DispatchKey.MPS, 647 DispatchKey.XPU, 648 DispatchKey.CompositeExplicitAutogradNonFunctional, 649 ) 650 return f"""{maybe_set_guard_line} 651outputs_[output_idx] = create_out(sizes, strides, options);""" 652 elif k is SchemaKind.inplace: 653 return f"""{maybe_set_guard_line} 654const auto& out = outputs_[output_idx].get(); 655check_inplace(out, sizes, options); 656{create_proxy}""" 657 elif k is SchemaKind.out: 658 return f"""{maybe_set_guard_line} 659const auto& out = outputs_[output_idx].get(); 660resize_out(out, sizes, strides, options); 661{create_proxy}""" 662 elif k is SchemaKind.mutable or k is SchemaKind.scratch: 663 raise AssertionError( 664 f"{k} structured operators are currently not supported" 665 ) 666 else: 667 assert_never(k) 668 669 # returns the definition of a ctor, as well as how to construct 670 # this class to a variable named op 671 def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str: 672 if k is SchemaKind.functional: 673 return "" 674 elif k is SchemaKind.inplace: 675 # TODO: Make sure out argument is guaranteed to be self 676 return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}" 677 elif k is SchemaKind.out: 678 out_args = ", ".join(f"Tensor& out{i}" for i in range(returns)) 679 out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns)) 680 return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}" 681 elif k is SchemaKind.mutable or k is SchemaKind.scratch: 682 raise AssertionError( 683 f"{k} structured operators are currently not supported" 684 ) 685 else: 686 assert_never(k) 687 688 def gen_class( 689 self, 690 f: NativeFunction, 691 k: SchemaKind, 692 *, 693 class_name: str, 694 parent_class: str, 695 generate_super: bool, 696 ) -> str: 697 if k is SchemaKind.functional: 698 output_type = "Tensor" 699 output_value = "outputs_[output_idx]" 700 proxy_field = "" 701 elif k is SchemaKind.inplace: 702 output_type = "std::reference_wrapper<Tensor>" 703 output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" 704 proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;" 705 elif k is SchemaKind.out: 706 output_type = "std::reference_wrapper<Tensor>" 707 output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()" 708 proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;" 709 else: 710 raise RuntimeError(f"Unsupported SchemaKind {k}") 711 712 if self.backend_index.dispatch_key == DispatchKey.CUDA: 713 if self.rocm: 714 guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;" 715 else: 716 guard_field = "c10::cuda::OptionalCUDAGuard guard_;" 717 elif ( 718 self.backend_index.dispatch_key 719 == DispatchKey.CompositeExplicitAutogradNonFunctional 720 ): 721 guard_field = "c10::OptionalDeviceGuard guard_;" 722 elif self.backend_index.dispatch_key == DispatchKey.MPS: 723 # TODO: Move to OptionalMPSGuard. 724 guard_field = "c10::OptionalDeviceGuard guard_;" 725 else: 726 guard_field = "" 727 728 indent = " " * 4 729 class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns)) 730 lines = ( 731 f"struct {class_name} final : public {parent_class} {{", 732 f"{textwrap.indent(class_ctor_str, indent)}", 733 f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}", 734 " const Tensor& maybe_get_output(int64_t output_idx) override {", 735 f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit 736 " }", 737 # type: ignore[possibly-undefined] # TODO: audit 738 f" std::array<{output_type}, {len(f.func.returns)}> outputs_;", 739 f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit 740 f"{textwrap.indent(guard_field, indent)}", 741 "};", 742 ) 743 return "\n".join(line for line in lines if line) 744 745 @method_with_native_function 746 def gen_one(self, f: NativeFunction) -> str | None: 747 assert not f.manual_kernel_registration 748 749 if ( 750 self.target is Target.REGISTRATION 751 and not self.selector.is_native_function_selected(f) 752 ): 753 return None 754 755 # TODO: Now, there is something interesting going on here. In the code below, 756 # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace 757 # based on the out implementation. But in fact, out is definable by 758 # functional too (just not very efficiently), and this is honestly the 759 # MORE likely situation for a backend implementor. How do we pick? 760 # Well, taking a page from Haskell type classes and default methods, 761 # we could conceivably register a circular definition (out in terms 762 # of functional, and functional in terms of out) and just require 763 # someone to implement one or the other. We'd have to do a little bit 764 # of work to not register one of these "weak" definitions unless there 765 # is a strong definition somewhere in the DAG! So it's not implemented yet. 766 if ( 767 self.backend_index.dispatch_key 768 == DispatchKey.CompositeExplicitAutogradNonFunctional 769 and f.func.kind() is SchemaKind.out 770 ): 771 # Never generate a default implementation for out, that's what you 772 # have to define as a backend implementor 773 return None 774 775 # Note [Direct dispatch bindings] 776 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 777 # Signature of the non-dispatched function we'll expose in a header 778 # (e.g., at::cpu::add). We don't generate methods (TODO: do this 779 # when CPUTensor class is a thing); nor do we generate fallback 780 # bindings for manual_cpp_binding functions. 781 cpp_sig_group = CppSignatureGroup.from_native_function( 782 f, method=False, fallback_binding=False 783 ) 784 785 # Signature of the wrapper function we'll register to the dispatcher 786 kern = self.backend_index.get_kernel(f) 787 sig = NativeSignature( 788 f.func, 789 prefix=f"wrapper_{self.backend_index.dispatch_key}_", 790 symint=kern is not None and kern.supports_symint(), 791 ) 792 793 if self.target is Target.NAMESPACED_DECLARATION: 794 result = "" 795 for cpp_sig in cpp_sig_group.signatures(symint=self.symint): 796 result += f"TORCH_API {cpp_sig.decl()};\n" 797 return result 798 799 elif self.target is Target.NAMESPACED_DEFINITION: 800 801 def generate_defn(cpp_sig: CppSignature) -> str: 802 return f""" 803{cpp_sig.defn()} {{ 804return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); 805}} 806""" 807 808 result = "" 809 for cpp_sig in cpp_sig_group.signatures(symint=self.symint): 810 result += generate_defn(cpp_sig) 811 return result 812 813 elif self.target is Target.ANONYMOUS_DEFINITION: 814 k = f.func.kind() 815 816 # Construct the body of the wrapper function with signature sig 817 sig_body = [] 818 # We'll use context to keep track of any variables we've brought 819 # into scope while generating code 820 context: list[Binding | Expr] = list(sig.arguments()) 821 822 # Initialize the class corresponding to this structured 823 # operator; feeding it the output argument(s) if it is known 824 if self.backend_index.dispatch_key is DispatchKey.Meta: 825 class_name = f"structured_{meta.name(self.g)}_meta_{k.name}" 826 parent_class = f"at::meta::structured_{meta.name(self.g)}" 827 elif ( 828 self.backend_index.dispatch_key 829 is DispatchKey.CompositeExplicitAutogradNonFunctional 830 ): 831 # TODO: dedup this branch 832 class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}" 833 parent_class = f"at::meta::structured_{meta.name(self.g)}" 834 else: 835 metadata = self.backend_index.get_kernel(self.g) 836 assert metadata is not None 837 class_name = f"structured_{metadata.kernel}_{k.name}" 838 parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}" 839 840 if self.backend_index.device_guard: 841 device_check_args = itertools.chain( 842 f.func.arguments.out, f.func.arguments.flat_positional 843 ) 844 sig_body.append( 845 RegisterDispatchKey.gen_device_check( 846 f.device_check, list(device_check_args), sig.name() 847 ) 848 ) 849 850 if k is SchemaKind.functional: 851 sig_body.append(f"{class_name} op;") 852 elif k is SchemaKind.inplace: 853 sig_body.append(f"{class_name} op(self);") 854 elif k is SchemaKind.out: 855 out_args_str = ", ".join(a.name for a in f.func.arguments.out) 856 sig_body.append(f"{class_name} op({out_args_str});") 857 858 # Translate the input native arguments into structured 859 # arguments for the meta call 860 meta_exprs = ", ".join( 861 e.expr 862 for e in translate( 863 context, structured.meta_arguments(self.g), method=False 864 ) 865 ) 866 867 if self.g.out.precomputed: 868 # If this function group has precomputed elements, the meta function 869 # returns a struct containing them which must be saved so that it 870 # can be unpacked when generating code to call the impl. 871 sig_body.append(f"auto precompute = op.meta({meta_exprs});") 872 873 # Put all of the contents of the precompute struct into the context 874 # so that translate will be able to return the correct args for the 875 # call to the impl. 876 precomputed_values = [ 877 *self.g.out.precomputed.replace.values(), 878 self.g.out.precomputed.add, 879 ] 880 for precomputed_elems in precomputed_values: 881 for arg in precomputed_elems: 882 context.append( 883 Expr( 884 expr=f"precompute.{arg.name}", 885 type=structured.argument_type(arg, binds=arg.name), 886 ) 887 ) 888 889 # Add a use of the precompute struct so FB internal compilers don't 890 # complain that there is an unused variable. 891 sig_body.append("(void)precompute;") 892 else: 893 sig_body.append(f"op.meta({meta_exprs});") 894 895 # After running meta, op.outputs_ is guaranteed to be valid; 896 # add it to the context 897 out_args = structured.out_arguments(self.g) 898 for i, out_arg in enumerate(out_args): 899 assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type 900 901 if k is SchemaKind.out: 902 expr = f"op.maybe_get_output({i})" 903 else: 904 expr = f"op.outputs_[{i}]" 905 906 context.append( 907 Expr( 908 expr=expr, 909 # TODO: Stop hardcoding that the output type is a Tensor. Note 910 # that for the codegen here this is fine because outputs_ is 911 # hardcoded to be tensor already 912 type=NamedCType( 913 out_arg.nctype.name, MutRefCType(BaseCType(tensorT)) 914 ), 915 ) 916 ) 917 918 # With the expanded context, do the impl call (if not a meta 919 # function) 920 if ( 921 self.backend_index.dispatch_key 922 == DispatchKey.CompositeExplicitAutogradNonFunctional 923 ): 924 # TODO: https://github.com/pytorch/pytorch/issues/53023 925 out_sig_group = CppSignatureGroup.from_native_function( 926 self.g.out, method=False, fallback_binding=f.manual_cpp_binding 927 ) 928 out_sig = out_sig_group.most_faithful_signature() 929 api_name = out_sig.name() 930 out_exprs = ", ".join( 931 e.expr 932 for e in translate(context, out_sig.arguments(), method=False) 933 ) 934 # TODO: I think this means structured won't work with method 935 # only functions (but maybe you're saved by faithful? iunno.) 936 # NB: Originally I wrote this as an at::redispatch call, but 937 # I got in trouble because that meant I needed a DispatchKeySet 938 # in the wrapper function, which meant I needed a DispatchKeySet 939 # in the DispatchKeyFunctions declarations, but the defined API 940 # there does NOT permit a dispatch key set. I think you can 941 # probably unwind this by calling some function to do the TLS 942 # fetch and get the DispatchKeySet when you don't have it, but 943 # I didn't do it for this version 944 sig_body.append(f"at::{api_name}({out_exprs});") 945 elif self.backend_index.dispatch_key != DispatchKey.Meta: 946 impl_exprs = ", ".join( 947 e.expr 948 for e in translate( 949 context, structured.impl_arguments(self.g), method=False 950 ) 951 ) 952 sig_body.append(f"op.impl({impl_exprs});") 953 954 # Go over each output, and check if there is a proxy created for it. 955 # If so, copy it over to the original output. 956 if k is SchemaKind.out or k is SchemaKind.inplace: 957 for i in range(len(f.func.returns)): 958 sig_body.append( 959 f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);" 960 ) 961 962 # Destructively return the final tensors 963 # TODO: Do this in translate instead 964 if k is SchemaKind.functional: 965 if len(f.func.returns) == 1: 966 ret_expr = "std::move(op.outputs_[0])" # small optimization 967 else: 968 moved = ", ".join( 969 f"std::move(op.outputs_[{i}])" 970 for i in range(len(f.func.returns)) 971 ) 972 ret_expr = f"std::make_tuple({moved})" 973 elif k is SchemaKind.inplace: 974 ret_expr = "self" 975 elif k is SchemaKind.out: 976 if len(f.func.returns) == 1: 977 ret_expr = f.func.arguments.out[0].name 978 else: 979 refs = ", ".join(a.name for a in f.func.arguments.out) 980 ret_expr = f"std::forward_as_tuple({refs})" 981 sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit 982 983 sig_body_str = "\n".join(sig_body) 984 985 # For an overview of what this template code looks like, see 986 # https://github.com/pytorch/rfcs/pull/9 987 return f"""\ 988{self.gen_class( 989f, k, 990class_name=class_name, 991parent_class=parent_class, 992generate_super=self.g.out.structured_inherits is not None 993)} 994 995{sig.defn()} {{ 996{sig_body_str} 997}} 998""" 999 1000 elif self.target is Target.REGISTRATION: 1001 return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));' 1002 else: 1003 assert_never(self.target) 1004 # Silence mypy's "Missing return statement" error 1005 return None 1006