1 #pragma once 2 3 #include <ATen/DimVector.h> 4 #include <ATen/core/Dimname.h> 5 #include <c10/core/TensorOptions.h> 6 #include <c10/util/strides.h> 7 8 namespace at { 9 10 class Tensor; 11 12 namespace impl { 13 14 // Use this to define the prototype for a meta function. There are two 15 // versions; one that takes one argument (just the operator name), or FUNC2 16 // variant that takes two arguments (operator name and overload name). 17 // 18 // Example usage: 19 // 20 // TORCH_META_FUNC2(add, Tensor) ( 21 // const Tensor& self, const Tensor& other 22 // ) { 23 // ... compute sizes and options ... 24 // set_output(sizes, options); 25 // } 26 // 27 #define TORCH_META_FUNC(name) void structured_##name::meta 28 #define TORCH_META_FUNC2(name, overload) \ 29 void structured_##name##_##overload::meta 30 31 // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct 32 // as a return value. They should be used when the kernel in question has 33 // precomputed values declared in native_functions.yaml and the corresponding 34 // implementation should return an instance of the aforementioned struct. 35 #define TORCH_PRECOMPUTE_META_FUNC(name) \ 36 structured_##name::meta_return_ty structured_##name::meta 37 #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \ 38 structured_##name##_##overload::meta_return_ty \ 39 structured_##name##_##overload::meta 40 41 // Use this to create a precompute struct in a meta function. 42 #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<> 43 #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \ 44 structured_##name##_##overload::precompute_out<> 45 46 // Use this to define the prototype for an implementation. This takes only 47 // one argument, which is the name of the dispatch key entry you're 48 // implementing. 49 // 50 // Example usage: 51 // 52 // TORCH_IMPL_FUNC(add_cpu) ( 53 // Tensor& result, const Tensor& self, const Tensor& other 54 // ) { 55 // ... do the actual implementation ... 56 // } 57 // 58 #define TORCH_IMPL_FUNC(name) void structured_##name::impl 59 60 // Base class for all structured kernel classes. The set_output virtual 61 // method is varied depending whether or not the operator is 62 // functional/out/inplace, and could also be specialized for CPU/CUDA/etc 63 // (although presently it isn't). 64 // 65 // A notable subclass of this interface is TensorIteratorBase. 66 struct TORCH_API MetaBase { 67 MetaBase() = default; 68 MetaBase(const MetaBase&) = default; 69 MetaBase& operator=(const MetaBase&) = default; 70 MetaBase(MetaBase&&) noexcept = default; 71 MetaBase& operator=(MetaBase&&) noexcept = default; 72 virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; 73 74 // Note: [set_output_*] 75 // See: https://github.com/pytorch/pytorch/issues/69813 76 // Whenever defining the output properties in the META function of a 77 // structured kernel (what was usually done with `set_output`), use one of 78 // these 3 variants, instead. In order to decide which variant to use, check 79 // the following decision tree: 80 // 81 // - Can the kernel you are going to implement support output tensors 82 // with arbitrary strides? 83 // | 84 // -- YES: `set_output_raw_strided` 85 // | 86 // -- NO: Should the output tensor strides be contiguous? 87 // | 88 // -- YES: `set_output_contiguous` 89 // | 90 // -- NO: `set_output_strided` 91 // 92 // Use this function whenever the kernel requires specific strides for the 93 // output. If `strides` does not match the given output strides, proxy outputs 94 // will be created and passed to the IMPL function. 95 virtual void set_output_strided( 96 int64_t output_idx [[maybe_unused]], 97 IntArrayRef sizes [[maybe_unused]], 98 IntArrayRef strides [[maybe_unused]], 99 TensorOptions options [[maybe_unused]], 100 DimnameList names [[maybe_unused]] = {}) { 101 TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); 102 } 103 104 // Use this function whenever the kernel knows how to handle arbitrary strided 105 // outputs. This function has the same behavior as the old `set_output`: it 106 // will only re-stride if the given output was resized. 107 virtual void set_output_raw_strided( 108 int64_t output_idx [[maybe_unused]], 109 IntArrayRef sizes [[maybe_unused]], 110 IntArrayRef strides_hint [[maybe_unused]], 111 TensorOptions options [[maybe_unused]], 112 DimnameList names [[maybe_unused]] = {}) { 113 TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); 114 } 115 116 // Use this function if the kernel requires contiguous strides. 117 // Alias for `set_output_strided`, but with contiguous strides. 118 void set_output_contiguous( 119 int64_t output_idx, 120 IntArrayRef sizes, 121 TensorOptions options, 122 DimnameList names = {}) { 123 auto strides = c10::contiguous_strides(sizes); 124 set_output_strided(output_idx, sizes, strides, options, names); 125 } 126 127 // Returns a reference to an undefined tensor if there is no presupplied 128 // output maybe_get_outputMetaBase129 const Tensor& maybe_get_output() { 130 return maybe_get_output(0); 131 } 132 virtual ~MetaBase() = default; 133 }; 134 135 } // namespace impl 136 137 } // namespace at 138