xref: /aosp_15_r20/external/pytorch/aten/src/ATen/templates/RegisterFunctionalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3 
4 #include <ATen/core/LegacyTypeDispatch.h>
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/FunctionalTensorWrapper.h>
7 #include <ATen/FunctionalInverses.h>
8 #include <ATen/MemoryOverlap.h>
9 #include <torch/library.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Operators.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 // needed for the meta tensor calls to get stride info in functionalization
16 #include <ATen/ops/empty_strided_native.h>
17 // needed for special handling of copy_().
18 // See Note [functionalizating copy_() and not preserving strides]
19 #include <ATen/ops/to_ops.h>
20 #include <ATen/ops/expand_copy_ops.h>
21 
22 $ops_headers
23 #endif
24 
25 namespace at {
26 namespace functionalization {
27 
28 // This keyset is used by functionalization when it calls into meta kernels
29 // to accurately propagate stride metadata.
30 // Exclude any modes: the purpose of calling into meta kernels is only as an implementation
31 // detail to perform shape inference, and we don't want any modal keys to run.
32 // Specifically, we want to prevent functionalization and Python modes from running.
33 constexpr auto exclude_keys_for_meta_dispatch =
34     c10::functorch_transforms_ks |
35     c10::DispatchKeySet({
36         c10::DispatchKey::FuncTorchDynamicLayerBackMode,
37         c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
38         c10::DispatchKey::Python,
39         c10::DispatchKey::PreDispatch,
40 
41     });
42 
43 // Helper around at::has_internal_overlap.
44 // The ATen util is used in hot-path eager mode: it's always fast,
45 // but might return TOO_HARD sometimes.
46 // During functionalization, we're ok taking a bit longer
47 // to detect memory overlap.
has_internal_overlap_helper(const at::Tensor t)48 inline bool has_internal_overlap_helper(const at::Tensor t) {
49   auto has_overlap = at::has_internal_overlap(t);
50   if (has_overlap == at::MemOverlap::Yes) return true;
51   if (has_overlap == at::MemOverlap::No) return false;
52   return false;
53 }
54 
55 
to_meta(const Tensor & t)56 inline Tensor to_meta(const Tensor& t) {
57     if (!t.defined()) return t;
58     return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
59 /*dtype=*/std::make_optional(t.scalar_type()), /*layout=*/std::make_optional(t.layout()),
60 /*device=*/std::make_optional(c10::Device(kMeta)), /*pin_memory=*/std::nullopt);
61 }
62 
to_meta(const std::optional<Tensor> & t)63 inline std::optional<Tensor> to_meta(const std::optional<Tensor>& t) {
64   if (t.has_value()) {
65     return std::make_optional<Tensor>(to_meta(*t));
66   }
67   return std::nullopt;
68 }
69 
to_meta(at::ITensorListRef t_list)70 inline std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
71   std::vector<Tensor> outputs;
72   outputs.reserve(t_list.size());
73   for (const auto& tensor : t_list) {
74     outputs.push_back(to_meta(tensor));
75   }
76   return outputs;
77 }
78 
to_meta(const c10::List<Tensor> & t_list)79 inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
80   c10::List<Tensor> outputs;
81   outputs.reserve(t_list.size());
82   for (const auto i : c10::irange(t_list.size())) {
83     outputs.push_back(to_meta(t_list[i]));
84   }
85   return outputs;
86 }
87 
to_meta(const c10::List<::std::optional<Tensor>> & t_list)88 inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
89   c10::List<::std::optional<Tensor>> outputs;
90   outputs.reserve(t_list.size());
91   for (const auto i : c10::irange(t_list.size())) {
92     outputs.push_back(to_meta(t_list[i]));
93   }
94   return outputs;
95 }
96 
97 
98 ${func_definitions}
99 
100 }  // namespace functionalization
101 
102 namespace {
103 
TORCH_LIBRARY_IMPL(aten,Functionalize,m)104 TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
105   ${func_registrations};
106 }
107 
108 }  // namespace
109 
110 } // namespace at
111