xref: /aosp_15_r20/external/pytorch/torchgen/dest/register_dispatch_key.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import itertools
4import textwrap
5from dataclasses import dataclass
6from typing import Literal, TYPE_CHECKING
7
8import torchgen.api.cpp as cpp
9import torchgen.api.meta as meta
10import torchgen.api.structured as structured
11from torchgen.api.translate import translate
12from torchgen.api.types import (
13    BaseCType,
14    Binding,
15    ConstRefCType,
16    CppSignature,
17    CppSignatureGroup,
18    DispatcherSignature,
19    Expr,
20    kernel_signature,
21    MutRefCType,
22    NamedCType,
23    NativeSignature,
24    tensorT,
25)
26from torchgen.context import method_with_native_function, native_function_manager
27from torchgen.model import (
28    Argument,
29    BackendIndex,
30    DeviceCheckType,
31    DispatchKey,
32    gets_generated_out_inplace_wrapper,
33    is_cuda_dispatch_key,
34    NativeFunction,
35    NativeFunctionsGroup,
36    SchemaKind,
37    TensorOptionsArguments,
38)
39from torchgen.utils import assert_never, mapMaybe, Target
40
41
42if TYPE_CHECKING:
43    from torchgen.selective_build.selector import SelectiveBuilder
44
45
46def gen_registration_headers(
47    backend_index: BackendIndex,
48    per_operator_headers: bool,
49    rocm: bool,
50) -> list[str]:
51    if per_operator_headers:
52        headers = ["#include <ATen/ops/as_strided_native.h>"]
53    else:
54        headers = ["#include <ATen/NativeFunctions.h>"]
55
56    if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
57        headers.append("#include <ATen/EmptyTensor.h>")
58    elif backend_index.dispatch_key == DispatchKey.CUDA:
59        if rocm:
60            headers.append("#include <ATen/hip/EmptyTensor.h>")
61        else:
62            headers.append("#include <ATen/cuda/EmptyTensor.h>")
63    elif backend_index.dispatch_key == DispatchKey.MPS:
64        headers.append("#include <ATen/mps/EmptyTensor.h>")
65    elif backend_index.dispatch_key == DispatchKey.XPU:
66        # XPU specific, this header resides in third_party/torch-xpu-ops
67        headers.append("#include <ATen/xpu/EmptyTensor.h>")
68    elif per_operator_headers:
69        headers += [
70            "#include <ATen/ops/empty.h>",
71            "#include <ATen/ops/empty_strided.h>",
72            "#include <ATen/ops/_copy_from_and_resize.h>",
73            "#include <ATen/ops/_copy_from.h>",
74        ]
75    else:
76        headers.append("#include <ATen/Functions.h>")
77
78    headers.append("#include <c10/macros/Macros.h>")
79    return headers
80
81
82def gen_empty_impl_names(
83    backend_index: BackendIndex,
84) -> tuple[str | None, str | None]:
85    empty_impl = None
86    empty_strided_impl = None
87
88    if backend_index.dispatch_key in (
89        DispatchKey.Meta,
90        DispatchKey.CPU,
91        DispatchKey.CUDA,
92        DispatchKey.MPS,
93        DispatchKey.XPU,
94    ):
95        dispatch = str(backend_index.dispatch_key).lower()
96        empty_impl = f"at::detail::empty_{dispatch}"
97        empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
98    elif backend_index.dispatch_key in (
99        DispatchKey.CompositeExplicitAutogradNonFunctional,
100        DispatchKey.QuantizedCPU,
101        DispatchKey.QuantizedCUDA,
102        DispatchKey.XPU,
103    ):
104        empty_impl = "at::empty"
105        empty_strided_impl = "at::empty_strided"
106
107    return empty_impl, empty_strided_impl
108
109
110def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
111    if backend_index.dispatch_key == DispatchKey.Meta:
112        empty_options = "options.device(at::kMeta)"
113    else:
114        empty_options = "options"
115
116    empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
117    if empty_impl is None:
118        return []
119
120    return [
121        f"""
122Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
123  if (strides.empty()) {{
124      return {empty_impl}(sizes, {empty_options});
125  }} else {{
126      return {empty_strided_impl}(sizes, strides, {empty_options});
127  }}
128}}
129"""
130    ]
131
132
133def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
134    _, empty_strided_impl = gen_empty_impl_names(backend_index)
135    return (
136        []
137        if empty_strided_impl is None
138        else [
139            f"""
140std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
141  if (out.strides() != strides) {{
142    return {empty_strided_impl}(sizes, strides, options);
143  }}
144  return std::nullopt;
145}}
146"""
147        ]
148    )
149
150
151def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
152    if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
153        # The function isn't used by this key (since only functional ops have a kernel for this key),
154        # so we need to not include it to avoid a defined-but-not-used error.
155        return []
156    return [
157        """
158void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
159  TORCH_CHECK(options.dtype() == out.dtype(),
160      "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
161  TORCH_CHECK(options.device() == out.device(),
162      "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
163  const bool resized = at::native::resize_output(out, sizes);
164  // Only restride if a resize occurred; otherwise we ignore the (advisory)
165  // strides from the meta function and directly use the output tensor's
166  // preexisting strides
167  if (resized) {
168    if (!strides.empty()) {
169      TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
170      // TODO: avoid the redispatch here
171      out.as_strided_(sizes, strides);
172    } else if (options.memory_format_opt().has_value()) {
173      out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
174    }
175  }
176}
177"""
178    ]
179
180
181def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
182    return [
183        """
184void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
185  // These checks are needed on those operators that:
186  //   1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
187  //   2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
188  // For other operators (e.g. 'add'), 'TensorIterator' already checks
189  // these things separately.
190  TORCH_CHECK(options.dtype() == self.dtype(),
191      "Bad in-place call: ",
192      "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
193  TORCH_CHECK(options.device() == self.device(),
194      "Bad in-place call: ",
195      "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
196  TORCH_CHECK(sizes == self.sizes(),
197      "Bad in-place call: ",
198      "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
199}
200"""
201    ]
202
203
204def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
205    return [
206        'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
207        *gen_create_out_helper(backend_index),
208        *gen_resize_out_helper(backend_index),
209        *gen_check_inplace_helper(backend_index),
210        *gen_maybe_create_proxy_helper(backend_index),
211        "C10_DIAGNOSTIC_POP()",
212    ]
213
214
215# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
216#
217#   - The primary function of this file is to register all of the
218#     implementations for the given dispatch key to the dispatcher,
219#     so they are available for use in PyTorch.  If dispatch is
220#     None, we generate schema (def) registrations and catchall
221#     registrations.
222#   - The secondary function of this file is to generate a wrapper
223#     around functions.  In CPUType these wrappers do nothing
224#     (and should be removed), but in other cases they handle
225#     DeviceGuard. A small extra benefit of wrappers is they
226#     are not overloaded, so they can be used in the registration
227#     API without having to disambiguate which overload you want
228#     (as would be the case if you directly registered native::
229#     functions).
230#   - The tertiary function of this file is to generate *static*
231#     cpp API bindings which can be used to bypass dispatcher
232#     directly to kernels, but with user-friendly cpp-style API
233@dataclass(frozen=True)
234class RegisterDispatchKey:
235    backend_index: BackendIndex
236
237    target: Literal[
238        Target.ANONYMOUS_DEFINITION,
239        Target.NAMESPACED_DEFINITION,
240        Target.NAMESPACED_DECLARATION,
241        Target.REGISTRATION,
242    ]
243
244    # Selector object to determine which operators to generate
245    # registration code for.
246    selector: SelectiveBuilder
247
248    # Whether or not we are actually code-genning for ROCm
249    rocm: bool
250
251    # Whether or not to generate symint registrations or not.  External users
252    # of codegen who don't care about symints can set this to false to get
253    # non-SymInt codegen
254    symint: bool
255
256    # The class that all unstructured native functions live under. This is used to improve
257    # compiler error messages when a kernel writer adds a native function with the wrong signature.
258    # This is only used in unstructured kernels, since structured kernels already live in a class.
259    # Finally, this field is currently Optional because it is only used by external backends.
260    # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
261    # all of the existing kernel signatures scattered across aten/src/ATen/native.
262    class_method_name: str | None
263
264    # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
265    # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
266    skip_dispatcher_op_registration: bool
267
268    @staticmethod
269    def gen_device_check(
270        type: DeviceCheckType, args: list[Argument], method_name: str
271    ) -> str:
272        if type == DeviceCheckType.NoCheck:
273            return "  // No device check\n"
274
275        device_check = "std::optional<Device> common_device = std::nullopt;\n"
276        device_check += "(void)common_device; // Suppress unused variable warning\n"
277        for arg in args:
278            # Only tensor like arguments are eligible
279            if arg.type.is_tensor_like():
280                device_check += f"""
281  c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
282        return device_check
283
284    @method_with_native_function
285    def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
286        if isinstance(f, NativeFunctionsGroup):
287            g: NativeFunctionsGroup = f
288            # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
289            # gen_structured() has special logic to handle auto-generated kernels.
290            if g.structured:
291                return self.gen_structured(g)
292            else:
293                return list(
294                    mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
295                )
296        elif isinstance(f, NativeFunction):
297            r = self.gen_unstructured(f)
298            return [] if r is None else [r]
299        else:
300            assert_never(f)
301
302    def wrapper_kernel_sig(
303        self, f: NativeFunction
304    ) -> NativeSignature | DispatcherSignature:
305        # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
306        return DispatcherSignature.from_schema(
307            f.func,
308            prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
309            symint=self.symint,
310        )
311
312    def gen_out_inplace_wrapper(
313        self, f: NativeFunction, g: NativeFunctionsGroup | None
314    ) -> str | None:
315        if g is None:
316            return None
317        k = f.func.kind()
318        if k is SchemaKind.inplace:
319            copy_op = "at::_copy_from"
320        elif k is SchemaKind.out:
321            copy_op = "at::_copy_from_and_resize"
322        else:
323            raise AssertionError("gen_out_inplace_wrapper called on a functional op")
324
325        sig = self.wrapper_kernel_sig(f)
326        name = sig.name()
327
328        func_res = f"{name}_tmp"
329        return_names = cpp.return_names(f)
330        if len(return_names) > 1:
331            updates = "\n  ".join(
332                f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
333                for i, ret_name in enumerate(return_names)
334            )
335            returns = f'{sig.returns_type().cpp_type()}({", ".join(return_names)})'
336        elif len(return_names) == 1:
337            ret_name = return_names[0]
338            updates = f"{copy_op}({func_res}, {ret_name});"
339            returns = ret_name
340        else:
341            assert len(f.func.arguments.out) == 1
342            returns = ""
343            out_arg = f.func.arguments.out[0]
344            if out_arg.type.is_list_like():
345                updates = f"""\
346    for (int64_t i = 0; i < {func_res}.size(); ++i) {{
347        {copy_op}({func_res}[i], {out_arg.name}[i]);
348    }}"""
349            else:
350                updates = f"{copy_op}({func_res}, {out_arg.name});"
351
352        functional_sig = self.wrapper_kernel_sig(g.functional)
353        wrapper_name = sig.name()
354
355        return f"""\
356{sig.defn(name=wrapper_name)} {{
357  auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
358  {updates}
359  return {returns};
360}}
361"""
362
363    def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
364        metadata = self.backend_index.get_kernel(g)
365        if self.backend_index.dispatch_key == DispatchKey.Meta:
366            assert not self.backend_index.has_kernel(g.out), (
367                "Do not explicitly specify Meta dispatch key on structured "
368                "functions, they will be automatically generated for you"
369            )
370        elif (
371            self.backend_index.dispatch_key
372            == DispatchKey.CompositeExplicitAutogradNonFunctional
373        ):
374            assert not self.backend_index.has_kernel(g.out), (
375                "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
376                "functions, they will be automatically generated for you"
377            )
378        elif metadata is None or not metadata.structured:
379            return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
380        structured_gen = StructuredRegisterDispatchKey(
381            self.backend_index,
382            self.target,
383            self.selector,
384            self.rocm,
385            self.symint,
386            self.class_method_name,
387            self.skip_dispatcher_op_registration,
388            g,
389        )
390        return list(mapMaybe(structured_gen.gen_one, g.functions()))
391
392    def gen_unstructured(
393        self, f: NativeFunction, g: NativeFunctionsGroup | None = None
394    ) -> str | None:
395        with native_function_manager(f):
396            inplace_meta = False
397            gets_out_inplace_wrapper = False
398            if not self.backend_index.has_kernel(f):
399                if (
400                    self.backend_index.dispatch_key == DispatchKey.Meta
401                    and f.func.kind() is SchemaKind.inplace
402                    and
403                    # Defer to composites for meta implementation
404                    not f.has_composite_kernel
405                    and
406                    # Inplace list operations are not supported
407                    len(f.func.returns) == 1
408                ):
409                    inplace_meta = True
410                elif (
411                    not self.backend_index.use_out_as_primary
412                    and g is not None
413                    and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
414                ):
415                    # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
416                    gets_out_inplace_wrapper = True
417                else:
418                    return None
419            if f.manual_kernel_registration:
420                return None
421
422            if (
423                self.target is Target.REGISTRATION
424                and not self.selector.is_native_function_selected(f)
425            ):
426                return None
427
428            sig = self.wrapper_kernel_sig(f)
429
430            name = sig.name()
431            returns_type = sig.returns_type().cpp_type()
432            args = sig.arguments()
433            args_str = ", ".join(a.defn() for a in args)
434
435            # See Note [Direct dispatch bindings]
436            cpp_sig_group = CppSignatureGroup.from_native_function(
437                f, method=False, fallback_binding=False
438            )
439
440            # TODO: dedupe this with the structured codegen
441            if self.target is Target.NAMESPACED_DECLARATION:
442                result = ""
443                for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
444                    result += f"TORCH_API {cpp_sig.decl()};\n"
445                return result
446            elif self.target is Target.NAMESPACED_DEFINITION:
447
448                def generate_defn(cpp_sig: CppSignature) -> str:
449                    return f"""
450{cpp_sig.defn()} {{
451return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
452}}
453"""
454
455                result = ""
456                for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
457                    result += generate_defn(cpp_sig)
458                return result
459
460            elif self.target is Target.ANONYMOUS_DEFINITION:
461                # short circuit for inplace_meta
462                if inplace_meta:
463                    assert f.func.arguments.self_arg is not None
464                    self_arg_name = f.func.arguments.self_arg.argument.name
465                    # TODO: handle in place on tensor list
466                    return f"""
467{returns_type} {name}({args_str}) {{
468  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
469    "Cannot inplace into non-meta tensor with meta tensor argument");
470  return {self_arg_name};
471}}
472"""
473
474                # short circuit for generated inplace/out wrappers
475                if gets_out_inplace_wrapper:
476                    return self.gen_out_inplace_wrapper(f, g)
477
478                metadata = self.backend_index.get_kernel(f)
479                if metadata is None:
480                    return None
481                if self.class_method_name is None:
482                    impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
483                else:
484                    impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
485
486                kernel_sig = kernel_signature(f, self.backend_index)
487
488                args_exprs_str = ", ".join(
489                    e.expr
490                    for e in translate(
491                        sig.arguments(), kernel_sig.arguments(), method=False
492                    )
493                )
494
495                device_check = "  // No device check\n"
496                # Backends that require device guards presumably also require device checks.
497                if self.backend_index.device_guard:
498                    device_check_args = itertools.chain(
499                        f.func.arguments.out, f.func.arguments.flat_positional
500                    )
501                    device_check = RegisterDispatchKey.gen_device_check(
502                        f.device_check, list(device_check_args), name
503                    )
504
505                device_guard = "// DeviceGuard omitted"  # default
506                if f.device_guard and self.backend_index.device_guard:
507                    has_tensor_options = any(
508                        isinstance(a, TensorOptionsArguments)
509                        for a in f.func.arguments.non_out
510                    )
511                    if has_tensor_options:
512                        # kernel is creating a tensor
513                        device_guard = """
514  const DeviceGuard device_guard(device_or_default(device));"""
515
516                        # CUDA requires special handling
517                        if is_cuda_dispatch_key(self.backend_index.dispatch_key):
518                            device_guard = (
519                                f"globalContext().lazyInitCUDA();\n{device_guard}"
520                            )
521                    else:
522                        # kernel is operating on existing tensors
523
524                        # There is precedence for which argument we use to do
525                        # device guard.  This describes the precedence order.
526                        self_arg = (
527                            [f.func.arguments.self_arg.argument]
528                            if f.func.arguments.self_arg is not None
529                            else []
530                        )
531                        candidate_args = itertools.chain(
532                            self_arg,
533                            f.func.arguments.out,
534                            f.func.arguments.flat_positional,
535                        )
536
537                        # Only tensor like arguments are eligible
538                        device_of = next(
539                            (
540                                f"{a.name}"
541                                for a in candidate_args
542                                if a.type.is_tensor_like()
543                            ),
544                            None,
545                        )
546                        if device_of is not None:
547                            device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
548
549                return f"""\
550namespace {{
551
552{returns_type} {name}({args_str}) {{
553  {device_check}
554
555  {device_guard}
556  return {impl_name}({args_exprs_str});
557}}
558
559}} // anonymous namespace
560"""
561
562            elif self.target is Target.REGISTRATION:
563                if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
564                    return None
565                else:
566                    payload = f"TORCH_FN({name})"
567                    return f'm.impl("{f.func.name}",\n{payload});\n'
568            else:
569                assert_never(self.target)
570
571
572# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
573#
574#                           STRUCTURED
575#
576# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
577
578
579@dataclass(frozen=True)
580class StructuredRegisterDispatchKey(RegisterDispatchKey):
581    g: NativeFunctionsGroup
582
583    def gen_class_set_output_functions(
584        self, k: SchemaKind, parent_class: str, generate_super: bool
585    ) -> str:
586        if generate_super:
587            set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
588        else:
589            set_output_super = ""
590
591        def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
592            return f"""
593void set_output_{name}(
594    int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
595    TensorOptions options, DimnameList names
596) override {{
597{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), "    ")}
598    if (!names.empty()) {{
599      namedinference::propagate_names(outputs_[output_idx], names);
600    }}
601    // super must happen after, so that downstream can use maybe_get_output
602    // to retrieve the output
603{textwrap.indent(set_output_super, "    ")}
604}}
605"""
606
607        return f"""
608{gen_set_output_function("strided", maybe_create_proxy=True)}
609{gen_set_output_function("raw_strided", maybe_create_proxy=False)}
610"""
611
612    def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
613        if self.backend_index.dispatch_key in [
614            DispatchKey.CUDA,
615            DispatchKey.MPS,
616            DispatchKey.CompositeExplicitAutogradNonFunctional,
617        ]:
618            maybe_set_guard = """
619auto current_device = guard_.current_device();
620if (C10_UNLIKELY(current_device.has_value())) {
621  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
622    "structured kernels don't support multi-device outputs");
623} else {
624  guard_.reset_device(options.device());
625}
626"""
627            maybe_set_guard_line = maybe_set_guard + "\n"
628        else:
629            maybe_set_guard_line = maybe_set_guard = ""
630
631        if maybe_create_proxy:
632            create_proxy = """
633auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
634if (C10_UNLIKELY(maybe_proxy.has_value())) {
635    proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
636}
637"""
638        else:
639            create_proxy = ""
640
641        if k is SchemaKind.functional:
642            assert self.backend_index.dispatch_key in (
643                DispatchKey.Meta,
644                DispatchKey.CPU,
645                DispatchKey.CUDA,
646                DispatchKey.MPS,
647                DispatchKey.XPU,
648                DispatchKey.CompositeExplicitAutogradNonFunctional,
649            )
650            return f"""{maybe_set_guard_line}
651outputs_[output_idx] = create_out(sizes, strides, options);"""
652        elif k is SchemaKind.inplace:
653            return f"""{maybe_set_guard_line}
654const auto& out = outputs_[output_idx].get();
655check_inplace(out, sizes, options);
656{create_proxy}"""
657        elif k is SchemaKind.out:
658            return f"""{maybe_set_guard_line}
659const auto& out = outputs_[output_idx].get();
660resize_out(out, sizes, strides, options);
661{create_proxy}"""
662        elif k is SchemaKind.mutable or k is SchemaKind.scratch:
663            raise AssertionError(
664                f"{k} structured operators are currently not supported"
665            )
666        else:
667            assert_never(k)
668
669    # returns the definition of a ctor, as well as how to construct
670    # this class to a variable named op
671    def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
672        if k is SchemaKind.functional:
673            return ""
674        elif k is SchemaKind.inplace:
675            # TODO: Make sure out argument is guaranteed to be self
676            return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
677        elif k is SchemaKind.out:
678            out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
679            out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
680            return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
681        elif k is SchemaKind.mutable or k is SchemaKind.scratch:
682            raise AssertionError(
683                f"{k} structured operators are currently not supported"
684            )
685        else:
686            assert_never(k)
687
688    def gen_class(
689        self,
690        f: NativeFunction,
691        k: SchemaKind,
692        *,
693        class_name: str,
694        parent_class: str,
695        generate_super: bool,
696    ) -> str:
697        if k is SchemaKind.functional:
698            output_type = "Tensor"
699            output_value = "outputs_[output_idx]"
700            proxy_field = ""
701        elif k is SchemaKind.inplace:
702            output_type = "std::reference_wrapper<Tensor>"
703            output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
704            proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
705        elif k is SchemaKind.out:
706            output_type = "std::reference_wrapper<Tensor>"
707            output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
708            proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
709        else:
710            raise RuntimeError(f"Unsupported SchemaKind {k}")
711
712        if self.backend_index.dispatch_key == DispatchKey.CUDA:
713            if self.rocm:
714                guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
715            else:
716                guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
717        elif (
718            self.backend_index.dispatch_key
719            == DispatchKey.CompositeExplicitAutogradNonFunctional
720        ):
721            guard_field = "c10::OptionalDeviceGuard guard_;"
722        elif self.backend_index.dispatch_key == DispatchKey.MPS:
723            # TODO: Move to OptionalMPSGuard.
724            guard_field = "c10::OptionalDeviceGuard guard_;"
725        else:
726            guard_field = ""
727
728        indent = " " * 4
729        class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
730        lines = (
731            f"struct {class_name} final : public {parent_class} {{",
732            f"{textwrap.indent(class_ctor_str, indent)}",
733            f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
734            "    const Tensor& maybe_get_output(int64_t output_idx) override {",
735            f"      return {output_value};\n",  # type: ignore[possibly-undefined]  # TODO: audit
736            "    }",
737            # type: ignore[possibly-undefined]  # TODO: audit
738            f"    std::array<{output_type}, {len(f.func.returns)}> outputs_;",
739            f"{textwrap.indent(proxy_field, indent)}",  # type: ignore[possibly-undefined]  # TODO: audit
740            f"{textwrap.indent(guard_field, indent)}",
741            "};",
742        )
743        return "\n".join(line for line in lines if line)
744
745    @method_with_native_function
746    def gen_one(self, f: NativeFunction) -> str | None:
747        assert not f.manual_kernel_registration
748
749        if (
750            self.target is Target.REGISTRATION
751            and not self.selector.is_native_function_selected(f)
752        ):
753            return None
754
755        # TODO: Now, there is something interesting going on here.  In the code below,
756        # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
757        # based on the out implementation.  But in fact, out is definable by
758        # functional too (just not very efficiently), and this is honestly the
759        # MORE likely situation for a backend implementor.  How do we pick?
760        # Well, taking a page from Haskell type classes and default methods,
761        # we could conceivably register a circular definition (out in terms
762        # of functional, and functional in terms of out) and just require
763        # someone to implement one or the other.  We'd have to do a little bit
764        # of work to not register one of these "weak" definitions unless there
765        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
766        if (
767            self.backend_index.dispatch_key
768            == DispatchKey.CompositeExplicitAutogradNonFunctional
769            and f.func.kind() is SchemaKind.out
770        ):
771            # Never generate a default implementation for out, that's what you
772            # have to define as a backend implementor
773            return None
774
775        # Note [Direct dispatch bindings]
776        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
777        # Signature of the non-dispatched function we'll expose in a header
778        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
779        # when CPUTensor class is a thing); nor do we generate fallback
780        # bindings for manual_cpp_binding functions.
781        cpp_sig_group = CppSignatureGroup.from_native_function(
782            f, method=False, fallback_binding=False
783        )
784
785        # Signature of the wrapper function we'll register to the dispatcher
786        kern = self.backend_index.get_kernel(f)
787        sig = NativeSignature(
788            f.func,
789            prefix=f"wrapper_{self.backend_index.dispatch_key}_",
790            symint=kern is not None and kern.supports_symint(),
791        )
792
793        if self.target is Target.NAMESPACED_DECLARATION:
794            result = ""
795            for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
796                result += f"TORCH_API {cpp_sig.decl()};\n"
797            return result
798
799        elif self.target is Target.NAMESPACED_DEFINITION:
800
801            def generate_defn(cpp_sig: CppSignature) -> str:
802                return f"""
803{cpp_sig.defn()} {{
804return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
805}}
806"""
807
808            result = ""
809            for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
810                result += generate_defn(cpp_sig)
811            return result
812
813        elif self.target is Target.ANONYMOUS_DEFINITION:
814            k = f.func.kind()
815
816            # Construct the body of the wrapper function with signature sig
817            sig_body = []
818            # We'll use context to keep track of any variables we've brought
819            # into scope while generating code
820            context: list[Binding | Expr] = list(sig.arguments())
821
822            # Initialize the class corresponding to this structured
823            # operator; feeding it the output argument(s) if it is known
824            if self.backend_index.dispatch_key is DispatchKey.Meta:
825                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
826                parent_class = f"at::meta::structured_{meta.name(self.g)}"
827            elif (
828                self.backend_index.dispatch_key
829                is DispatchKey.CompositeExplicitAutogradNonFunctional
830            ):
831                # TODO: dedup this branch
832                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
833                parent_class = f"at::meta::structured_{meta.name(self.g)}"
834            else:
835                metadata = self.backend_index.get_kernel(self.g)
836                assert metadata is not None
837                class_name = f"structured_{metadata.kernel}_{k.name}"
838                parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
839
840            if self.backend_index.device_guard:
841                device_check_args = itertools.chain(
842                    f.func.arguments.out, f.func.arguments.flat_positional
843                )
844                sig_body.append(
845                    RegisterDispatchKey.gen_device_check(
846                        f.device_check, list(device_check_args), sig.name()
847                    )
848                )
849
850            if k is SchemaKind.functional:
851                sig_body.append(f"{class_name} op;")
852            elif k is SchemaKind.inplace:
853                sig_body.append(f"{class_name} op(self);")
854            elif k is SchemaKind.out:
855                out_args_str = ", ".join(a.name for a in f.func.arguments.out)
856                sig_body.append(f"{class_name} op({out_args_str});")
857
858            # Translate the input native arguments into structured
859            # arguments for the meta call
860            meta_exprs = ", ".join(
861                e.expr
862                for e in translate(
863                    context, structured.meta_arguments(self.g), method=False
864                )
865            )
866
867            if self.g.out.precomputed:
868                # If this function group has precomputed elements, the meta function
869                # returns a struct containing them which must be saved so that it
870                # can be unpacked when generating code to call the impl.
871                sig_body.append(f"auto precompute = op.meta({meta_exprs});")
872
873                # Put all of the contents of the precompute struct into the context
874                # so that translate will be able to return the correct args for the
875                # call to the impl.
876                precomputed_values = [
877                    *self.g.out.precomputed.replace.values(),
878                    self.g.out.precomputed.add,
879                ]
880                for precomputed_elems in precomputed_values:
881                    for arg in precomputed_elems:
882                        context.append(
883                            Expr(
884                                expr=f"precompute.{arg.name}",
885                                type=structured.argument_type(arg, binds=arg.name),
886                            )
887                        )
888
889                # Add a use of the precompute struct so FB internal compilers don't
890                # complain that there is an unused variable.
891                sig_body.append("(void)precompute;")
892            else:
893                sig_body.append(f"op.meta({meta_exprs});")
894
895            # After running meta, op.outputs_ is guaranteed to be valid;
896            # add it to the context
897            out_args = structured.out_arguments(self.g)
898            for i, out_arg in enumerate(out_args):
899                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
900
901                if k is SchemaKind.out:
902                    expr = f"op.maybe_get_output({i})"
903                else:
904                    expr = f"op.outputs_[{i}]"
905
906                context.append(
907                    Expr(
908                        expr=expr,
909                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
910                        # that for the codegen here this is fine because outputs_ is
911                        # hardcoded to be tensor already
912                        type=NamedCType(
913                            out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
914                        ),
915                    )
916                )
917
918            # With the expanded context, do the impl call (if not a meta
919            # function)
920            if (
921                self.backend_index.dispatch_key
922                == DispatchKey.CompositeExplicitAutogradNonFunctional
923            ):
924                # TODO: https://github.com/pytorch/pytorch/issues/53023
925                out_sig_group = CppSignatureGroup.from_native_function(
926                    self.g.out, method=False, fallback_binding=f.manual_cpp_binding
927                )
928                out_sig = out_sig_group.most_faithful_signature()
929                api_name = out_sig.name()
930                out_exprs = ", ".join(
931                    e.expr
932                    for e in translate(context, out_sig.arguments(), method=False)
933                )
934                # TODO: I think this means structured won't work with method
935                # only functions (but maybe you're saved by faithful? iunno.)
936                # NB: Originally I wrote this as an at::redispatch call, but
937                # I got in trouble because that meant I needed a DispatchKeySet
938                # in the wrapper function, which meant I needed a DispatchKeySet
939                # in the DispatchKeyFunctions declarations, but the defined API
940                # there does NOT permit a dispatch key set.  I think you can
941                # probably unwind this by calling some function to do the TLS
942                # fetch and get the DispatchKeySet when you don't have it, but
943                # I didn't do it for this version
944                sig_body.append(f"at::{api_name}({out_exprs});")
945            elif self.backend_index.dispatch_key != DispatchKey.Meta:
946                impl_exprs = ", ".join(
947                    e.expr
948                    for e in translate(
949                        context, structured.impl_arguments(self.g), method=False
950                    )
951                )
952                sig_body.append(f"op.impl({impl_exprs});")
953
954            # Go over each output, and check if there is a proxy created for it.
955            # If so, copy it over to the original output.
956            if k is SchemaKind.out or k is SchemaKind.inplace:
957                for i in range(len(f.func.returns)):
958                    sig_body.append(
959                        f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
960                    )
961
962            # Destructively return the final tensors
963            # TODO: Do this in translate instead
964            if k is SchemaKind.functional:
965                if len(f.func.returns) == 1:
966                    ret_expr = "std::move(op.outputs_[0])"  # small optimization
967                else:
968                    moved = ", ".join(
969                        f"std::move(op.outputs_[{i}])"
970                        for i in range(len(f.func.returns))
971                    )
972                    ret_expr = f"std::make_tuple({moved})"
973            elif k is SchemaKind.inplace:
974                ret_expr = "self"
975            elif k is SchemaKind.out:
976                if len(f.func.returns) == 1:
977                    ret_expr = f.func.arguments.out[0].name
978                else:
979                    refs = ", ".join(a.name for a in f.func.arguments.out)
980                    ret_expr = f"std::forward_as_tuple({refs})"
981            sig_body.append(f"return {ret_expr};")  # type: ignore[possibly-undefined]  # TODO: audit
982
983            sig_body_str = "\n".join(sig_body)
984
985            # For an overview of what this template code looks like, see
986            # https://github.com/pytorch/rfcs/pull/9
987            return f"""\
988{self.gen_class(
989f, k,
990class_name=class_name,
991parent_class=parent_class,
992generate_super=self.g.out.structured_inherits is not None
993)}
994
995{sig.defn()} {{
996{sig_body_str}
997}}
998"""
999
1000        elif self.target is Target.REGISTRATION:
1001            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
1002        else:
1003            assert_never(self.target)
1004            # Silence mypy's "Missing return statement" error
1005            return None
1006