xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorMeta.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/DimVector.h>
4 #include <ATen/core/Dimname.h>
5 #include <c10/core/TensorOptions.h>
6 #include <c10/util/strides.h>
7 
8 namespace at {
9 
10 class Tensor;
11 
12 namespace impl {
13 
14 // Use this to define the prototype for a meta function.  There are two
15 // versions; one that takes one argument (just the operator name), or FUNC2
16 // variant that takes two arguments (operator name and overload name).
17 //
18 // Example usage:
19 //
20 //    TORCH_META_FUNC2(add, Tensor) (
21 //      const Tensor& self, const Tensor& other
22 //    ) {
23 //      ... compute sizes and options ...
24 //      set_output(sizes, options);
25 //    }
26 //
27 #define TORCH_META_FUNC(name) void structured_##name::meta
28 #define TORCH_META_FUNC2(name, overload) \
29   void structured_##name##_##overload::meta
30 
31 // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
32 // as a return value. They should be used when the kernel in question has
33 // precomputed values declared in native_functions.yaml and the corresponding
34 // implementation should return an instance of the aforementioned struct.
35 #define TORCH_PRECOMPUTE_META_FUNC(name) \
36   structured_##name::meta_return_ty structured_##name::meta
37 #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
38   structured_##name##_##overload::meta_return_ty    \
39       structured_##name##_##overload::meta
40 
41 // Use this to create a precompute struct in a meta function.
42 #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
43 #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
44   structured_##name##_##overload::precompute_out<>
45 
46 // Use this to define the prototype for an implementation.  This takes only
47 // one argument, which is the name of the dispatch key entry you're
48 // implementing.
49 //
50 // Example usage:
51 //
52 //    TORCH_IMPL_FUNC(add_cpu) (
53 //      Tensor& result, const Tensor& self, const Tensor& other
54 //    ) {
55 //      ... do the actual implementation ...
56 //    }
57 //
58 #define TORCH_IMPL_FUNC(name) void structured_##name::impl
59 
60 // Base class for all structured kernel classes.  The set_output virtual
61 // method is varied depending whether or not the operator is
62 // functional/out/inplace, and could also be specialized for CPU/CUDA/etc
63 // (although presently it isn't).
64 //
65 // A notable subclass of this interface is TensorIteratorBase.
66 struct TORCH_API MetaBase {
67   MetaBase() = default;
68   MetaBase(const MetaBase&) = default;
69   MetaBase& operator=(const MetaBase&) = default;
70   MetaBase(MetaBase&&) noexcept = default;
71   MetaBase& operator=(MetaBase&&) noexcept = default;
72   virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
73 
74   // Note: [set_output_*]
75   // See: https://github.com/pytorch/pytorch/issues/69813
76   // Whenever defining the output properties in the META function of a
77   // structured kernel (what was usually done with `set_output`), use one of
78   // these 3 variants, instead. In order to decide which variant to use, check
79   // the following decision tree:
80   //
81   // - Can the kernel you are going to implement support output tensors
82   //   with arbitrary strides?
83   //     |
84   //     -- YES: `set_output_raw_strided`
85   //     |
86   //     -- NO: Should the output tensor strides be contiguous?
87   //         |
88   //         -- YES: `set_output_contiguous`
89   //         |
90   //         -- NO: `set_output_strided`
91   //
92   // Use this function whenever the kernel requires specific strides for the
93   // output. If `strides` does not match the given output strides, proxy outputs
94   // will be created and passed to the IMPL function.
95   virtual void set_output_strided(
96       int64_t output_idx [[maybe_unused]],
97       IntArrayRef sizes [[maybe_unused]],
98       IntArrayRef strides [[maybe_unused]],
99       TensorOptions options [[maybe_unused]],
100       DimnameList names [[maybe_unused]] = {}) {
101     TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
102   }
103 
104   // Use this function whenever the kernel knows how to handle arbitrary strided
105   // outputs. This function has the same behavior as the old `set_output`: it
106   // will only re-stride if the given output was resized.
107   virtual void set_output_raw_strided(
108       int64_t output_idx [[maybe_unused]],
109       IntArrayRef sizes [[maybe_unused]],
110       IntArrayRef strides_hint [[maybe_unused]],
111       TensorOptions options [[maybe_unused]],
112       DimnameList names [[maybe_unused]] = {}) {
113     TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
114   }
115 
116   // Use this function if the kernel requires contiguous strides.
117   // Alias for `set_output_strided`, but with contiguous strides.
118   void set_output_contiguous(
119       int64_t output_idx,
120       IntArrayRef sizes,
121       TensorOptions options,
122       DimnameList names = {}) {
123     auto strides = c10::contiguous_strides(sizes);
124     set_output_strided(output_idx, sizes, strides, options, names);
125   }
126 
127   // Returns a reference to an undefined tensor if there is no presupplied
128   // output
maybe_get_outputMetaBase129   const Tensor& maybe_get_output() {
130     return maybe_get_output(0);
131   }
132   virtual ~MetaBase() = default;
133 };
134 
135 } // namespace impl
136 
137 } // namespace at
138