1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3
4 #include <ATen/core/LegacyTypeDispatch.h>
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/FunctionalTensorWrapper.h>
7 #include <ATen/FunctionalInverses.h>
8 #include <ATen/MemoryOverlap.h>
9 #include <torch/library.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Operators.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 // needed for the meta tensor calls to get stride info in functionalization
16 #include <ATen/ops/empty_strided_native.h>
17 // needed for special handling of copy_().
18 // See Note [functionalizating copy_() and not preserving strides]
19 #include <ATen/ops/to_ops.h>
20 #include <ATen/ops/expand_copy_ops.h>
21
22 $ops_headers
23 #endif
24
25 namespace at {
26 namespace functionalization {
27
28 // This keyset is used by functionalization when it calls into meta kernels
29 // to accurately propagate stride metadata.
30 // Exclude any modes: the purpose of calling into meta kernels is only as an implementation
31 // detail to perform shape inference, and we don't want any modal keys to run.
32 // Specifically, we want to prevent functionalization and Python modes from running.
33 constexpr auto exclude_keys_for_meta_dispatch =
34 c10::functorch_transforms_ks |
35 c10::DispatchKeySet({
36 c10::DispatchKey::FuncTorchDynamicLayerBackMode,
37 c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
38 c10::DispatchKey::Python,
39 c10::DispatchKey::PreDispatch,
40
41 });
42
43 // Helper around at::has_internal_overlap.
44 // The ATen util is used in hot-path eager mode: it's always fast,
45 // but might return TOO_HARD sometimes.
46 // During functionalization, we're ok taking a bit longer
47 // to detect memory overlap.
has_internal_overlap_helper(const at::Tensor t)48 inline bool has_internal_overlap_helper(const at::Tensor t) {
49 auto has_overlap = at::has_internal_overlap(t);
50 if (has_overlap == at::MemOverlap::Yes) return true;
51 if (has_overlap == at::MemOverlap::No) return false;
52 return false;
53 }
54
55
to_meta(const Tensor & t)56 inline Tensor to_meta(const Tensor& t) {
57 if (!t.defined()) return t;
58 return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
59 /*dtype=*/std::make_optional(t.scalar_type()), /*layout=*/std::make_optional(t.layout()),
60 /*device=*/std::make_optional(c10::Device(kMeta)), /*pin_memory=*/std::nullopt);
61 }
62
to_meta(const std::optional<Tensor> & t)63 inline std::optional<Tensor> to_meta(const std::optional<Tensor>& t) {
64 if (t.has_value()) {
65 return std::make_optional<Tensor>(to_meta(*t));
66 }
67 return std::nullopt;
68 }
69
to_meta(at::ITensorListRef t_list)70 inline std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
71 std::vector<Tensor> outputs;
72 outputs.reserve(t_list.size());
73 for (const auto& tensor : t_list) {
74 outputs.push_back(to_meta(tensor));
75 }
76 return outputs;
77 }
78
to_meta(const c10::List<Tensor> & t_list)79 inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
80 c10::List<Tensor> outputs;
81 outputs.reserve(t_list.size());
82 for (const auto i : c10::irange(t_list.size())) {
83 outputs.push_back(to_meta(t_list[i]));
84 }
85 return outputs;
86 }
87
to_meta(const c10::List<::std::optional<Tensor>> & t_list)88 inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
89 c10::List<::std::optional<Tensor>> outputs;
90 outputs.reserve(t_list.size());
91 for (const auto i : c10::irange(t_list.size())) {
92 outputs.push_back(to_meta(t_list[i]));
93 }
94 return outputs;
95 }
96
97
98 ${func_definitions}
99
100 } // namespace functionalization
101
102 namespace {
103
TORCH_LIBRARY_IMPL(aten,Functionalize,m)104 TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
105 ${func_registrations};
106 }
107
108 } // namespace
109
110 } // namespace at
111