xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/jiterator_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/jit_macros.h>
3 
4 #if AT_USE_JITERATOR()
5 
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/cuda/detail/OffsetCalculator.cuh>
8 #include <ATen/native/cuda/jit_utils.h>
9 #include <ATen/native/cuda/MemoryAccess.cuh>
10 #include <ATen/native/cuda/JitLoops.cuh>
11 
12 #include <string>
13 #include <variant>
14 #include <vector>
15 
16 namespace at::native {
17 
18 
19 #define AT_FOR_8_CASES(_)  \
20   _(1)                      \
21   _(2)                      \
22   _(3)                      \
23   _(4)                      \
24   _(5)                      \
25   _(6)                      \
26   _(7)                      \
27   _(8)
28 
29 #define AT_FOR_8_CASES_WITH_COMMA(_)  \
30   _(1)     ,                           \
31   _(2)     ,                           \
32   _(3)     ,                           \
33   _(4)     ,                           \
34   _(5)     ,                           \
35   _(6)     ,                           \
36   _(7)     ,                           \
37   _(8)
38 
get_extra_args_typenames(const c10::SmallVector<at::Scalar> & extra_args)39 c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
40   c10::SmallVector<std::string> args_typenames(extra_args.size());
41   for (const auto i : c10::irange(extra_args.size())) {
42     args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
43   }
44   return args_typenames;
45 }
46 
can_vectorize_up_to(at::ScalarType type,char * pointer)47 int can_vectorize_up_to(at::ScalarType type, char* pointer) {
48   switch(type) {
49 #define DEFINE_CASE(ctype, scalartype)                                   \
50     case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);
51 
52     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
53 #undef DEFINE_CASE
54 
55     default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
56   }
57 }
58 
59 // jitted version of the above
60 // See Note [Jiterator], this relies on the assumptions enumerated there
jitted_can_vectorize_up_to(const TensorIteratorBase & iter)61 int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
62   const at::ScalarType common_dtype = iter.common_dtype();
63   const at::ScalarType result_dtype = common_dtype;
64 
65   // Deals with output
66   int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));
67 
68   // Incorporates input(s)
69   for (auto i = 1; i < iter.ntensors(); ++i) {
70     result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
71   }
72 
73   return result;
74 }
75 
76 template<bool IS_INPUT, int N>
make_unique_offset_calculator(const TensorIteratorBase & iter)77 static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
78           const TensorIteratorBase& iter) {
79   // array size can not be 0, this happens when N == 0
80   constexpr int array_size = std::max<int>(N, 1);
81   TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs()));
82 
83   std::array<const int64_t*, array_size> strides;
84   int64_t element_sizes[array_size];
85   for (int i = 0; i < N; i++) {
86     int index = IS_INPUT ? i + iter.noutputs() : i;
87     strides[i] = iter.strides(index).data();
88     element_sizes[i] = iter.element_size(index);
89   }
90   return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
91 }
92 
93 template <bool IS_INPUT>
94 struct OffsetCalculatorVariant {
95 #define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
96   using OffsetCalculatorTypes = std::variant<
97     AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
98   >;
99 #undef DEFINE_CASE
100 
OffsetCalculatorVariantOffsetCalculatorVariant101   OffsetCalculatorVariant(const TensorIteratorBase& iter) {
102     int num = IS_INPUT ? iter.ninputs() : iter.noutputs();
103 
104     switch(num) {
105 #define DEFINE_CASE(index)        \
106       case index : v = make_unique_offset_calculator<IS_INPUT, index>(iter); break;
107 
108       AT_FOR_8_CASES(DEFINE_CASE)
109 #undef DEFINE_CASE
110       default:
111         TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num);
112     }
113   }
114 
data_ptrOffsetCalculatorVariant115   void* data_ptr() {
116     return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
117   }
118 
119  private:
120   OffsetCalculatorTypes v{};
121 };
122 
123 struct ArrayVariant {
124 // works for up to 8 input + 8 outputs
125 #define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
126   using ArrayTypes = std::variant<
127     AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
128   >;
129 #undef DEFINE_CASE
130 
ArrayVariantArrayVariant131   ArrayVariant(const TensorIteratorBase& iter) {
132     int ntensors = iter.ntensors();
133     switch(ntensors) {
134 #define DEFINE_CASE(index)                                            \
135       case index: array = at::detail::Array<char*, index>{}; break;   \
136       case index+8: array = at::detail::Array<char*, index+8>{}; break;
137 
138       AT_FOR_8_CASES(DEFINE_CASE)
139 #undef DEFINE_CASE
140 
141       default:
142         TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
143     }
144 
145     std::visit([&](auto& a) {
146       for (auto i = 0; i < ntensors; ++i) {
147         a[i] = (char*)iter.data_ptr(i);
148       }
149     }, array);
150   }
151 
data_ptrArrayVariant152   void* data_ptr() {
153     return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
154   }
155 
156 private:
157   ArrayTypes array;
158 };
159 
160 struct TrivialOffsetCalculatorVariant {
161 #define DEFINE_CASE(index) TrivialOffsetCalculator<index>
162   using TrivialOffsetCalculatorTypes = std::variant<
163     AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
164   >;
165 #undef DEFINE_CASE
166 
TrivialOffsetCalculatorVariantTrivialOffsetCalculatorVariant167   TrivialOffsetCalculatorVariant(int num) {
168     switch(num) {
169 #define DEFINE_CASE(index)      \
170       case index: v = TrivialOffsetCalculator<index>(); break;
171 
172       AT_FOR_8_CASES(DEFINE_CASE)
173 #undef DEFINE_CASE
174 
175       default:
176         TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num);
177     }
178   }
179 
data_ptrTrivialOffsetCalculatorVariant180   void* data_ptr() {
181     return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
182   }
183 
184 private:
185   TrivialOffsetCalculatorTypes v{};
186 };
187 
188 struct LoadWithCastVariant {
189 #define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
190   using LoadWithCastPtr = std::variant<
191     AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
192   >;
193 #undef DEFINE_CASE
194 
LoadWithCastVariantLoadWithCastVariant195   LoadWithCastVariant(const TensorIteratorBase& iter) {
196     int arity = iter.ninputs();
197     switch(arity) {
198 #define DEFINE_CASE(index)      \
199       case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;
200 
201       AT_FOR_8_CASES(DEFINE_CASE)
202 #undef DEFINE_CASE
203 
204       default:
205         TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
206     }
207   }
208 
data_ptrLoadWithCastVariant209   void* data_ptr() {
210     return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
211   }
212 
213 private:
214   LoadWithCastPtr v{};
215 };
216 
217 struct StoreWithCastVariant {
218 #define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
219   using StoreWithCastPtr = std::variant<
220     AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
221   >;
222 #undef DEFINE_CASE
223 
StoreWithCastVariantStoreWithCastVariant224   StoreWithCastVariant(const TensorIteratorBase& iter) {
225     int num = iter.noutputs();
226     switch(num) {
227 #define DEFINE_CASE(index)      \
228       case index: v = std::make_unique<memory::StoreWithCast<index>>(iter); break;
229 
230       AT_FOR_8_CASES(DEFINE_CASE)
231 #undef DEFINE_CASE
232 
233       default:
234         TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num);
235     }
236   }
237 
data_ptrStoreWithCastVariant238   void* data_ptr() {
239     return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
240   }
241 
242 private:
243   StoreWithCastPtr v{};
244 };
245 
246 } // namespace at::native
247 
248 
249 #endif // AT_USE_JITERATOR()
250