1 #pragma once
2 #include <ATen/jit_macros.h>
3
4 #if AT_USE_JITERATOR()
5
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/cuda/detail/OffsetCalculator.cuh>
8 #include <ATen/native/cuda/jit_utils.h>
9 #include <ATen/native/cuda/MemoryAccess.cuh>
10 #include <ATen/native/cuda/JitLoops.cuh>
11
12 #include <string>
13 #include <variant>
14 #include <vector>
15
16 namespace at::native {
17
18
19 #define AT_FOR_8_CASES(_) \
20 _(1) \
21 _(2) \
22 _(3) \
23 _(4) \
24 _(5) \
25 _(6) \
26 _(7) \
27 _(8)
28
29 #define AT_FOR_8_CASES_WITH_COMMA(_) \
30 _(1) , \
31 _(2) , \
32 _(3) , \
33 _(4) , \
34 _(5) , \
35 _(6) , \
36 _(7) , \
37 _(8)
38
get_extra_args_typenames(const c10::SmallVector<at::Scalar> & extra_args)39 c10::SmallVector<std::string> get_extra_args_typenames(const c10::SmallVector<at::Scalar>& extra_args) {
40 c10::SmallVector<std::string> args_typenames(extra_args.size());
41 for (const auto i : c10::irange(extra_args.size())) {
42 args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type());
43 }
44 return args_typenames;
45 }
46
can_vectorize_up_to(at::ScalarType type,char * pointer)47 int can_vectorize_up_to(at::ScalarType type, char* pointer) {
48 switch(type) {
49 #define DEFINE_CASE(ctype, scalartype) \
50 case ScalarType::scalartype : return memory::can_vectorize_up_to<ctype>(pointer);
51
52 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
53 #undef DEFINE_CASE
54
55 default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
56 }
57 }
58
59 // jitted version of the above
60 // See Note [Jiterator], this relies on the assumptions enumerated there
jitted_can_vectorize_up_to(const TensorIteratorBase & iter)61 int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) {
62 const at::ScalarType common_dtype = iter.common_dtype();
63 const at::ScalarType result_dtype = common_dtype;
64
65 // Deals with output
66 int result = can_vectorize_up_to(result_dtype, static_cast<char*>(iter.data_ptr(0)));
67
68 // Incorporates input(s)
69 for (auto i = 1; i < iter.ntensors(); ++i) {
70 result = std::min<int>(result, can_vectorize_up_to(common_dtype, static_cast<char*>(iter.data_ptr(i))));
71 }
72
73 return result;
74 }
75
76 template<bool IS_INPUT, int N>
make_unique_offset_calculator(const TensorIteratorBase & iter)77 static std::unique_ptr<OffsetCalculator<N>> make_unique_offset_calculator(
78 const TensorIteratorBase& iter) {
79 // array size can not be 0, this happens when N == 0
80 constexpr int array_size = std::max<int>(N, 1);
81 TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs()));
82
83 std::array<const int64_t*, array_size> strides;
84 int64_t element_sizes[array_size];
85 for (int i = 0; i < N; i++) {
86 int index = IS_INPUT ? i + iter.noutputs() : i;
87 strides[i] = iter.strides(index).data();
88 element_sizes[i] = iter.element_size(index);
89 }
90 return std::make_unique<OffsetCalculator<N>>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
91 }
92
93 template <bool IS_INPUT>
94 struct OffsetCalculatorVariant {
95 #define DEFINE_CASE(index) std::unique_ptr<OffsetCalculator<index>>
96 using OffsetCalculatorTypes = std::variant<
97 AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
98 >;
99 #undef DEFINE_CASE
100
OffsetCalculatorVariantOffsetCalculatorVariant101 OffsetCalculatorVariant(const TensorIteratorBase& iter) {
102 int num = IS_INPUT ? iter.ninputs() : iter.noutputs();
103
104 switch(num) {
105 #define DEFINE_CASE(index) \
106 case index : v = make_unique_offset_calculator<IS_INPUT, index>(iter); break;
107
108 AT_FOR_8_CASES(DEFINE_CASE)
109 #undef DEFINE_CASE
110 default:
111 TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num);
112 }
113 }
114
data_ptrOffsetCalculatorVariant115 void* data_ptr() {
116 return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
117 }
118
119 private:
120 OffsetCalculatorTypes v{};
121 };
122
123 struct ArrayVariant {
124 // works for up to 8 input + 8 outputs
125 #define DEFINE_CASE(index) at::detail::Array<char*, index>, at::detail::Array<char*, index+8>
126 using ArrayTypes = std::variant<
127 AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
128 >;
129 #undef DEFINE_CASE
130
ArrayVariantArrayVariant131 ArrayVariant(const TensorIteratorBase& iter) {
132 int ntensors = iter.ntensors();
133 switch(ntensors) {
134 #define DEFINE_CASE(index) \
135 case index: array = at::detail::Array<char*, index>{}; break; \
136 case index+8: array = at::detail::Array<char*, index+8>{}; break;
137
138 AT_FOR_8_CASES(DEFINE_CASE)
139 #undef DEFINE_CASE
140
141 default:
142 TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors);
143 }
144
145 std::visit([&](auto& a) {
146 for (auto i = 0; i < ntensors; ++i) {
147 a[i] = (char*)iter.data_ptr(i);
148 }
149 }, array);
150 }
151
data_ptrArrayVariant152 void* data_ptr() {
153 return std::visit([](auto & a){ return static_cast<void*>(&a); }, array);
154 }
155
156 private:
157 ArrayTypes array;
158 };
159
160 struct TrivialOffsetCalculatorVariant {
161 #define DEFINE_CASE(index) TrivialOffsetCalculator<index>
162 using TrivialOffsetCalculatorTypes = std::variant<
163 AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
164 >;
165 #undef DEFINE_CASE
166
TrivialOffsetCalculatorVariantTrivialOffsetCalculatorVariant167 TrivialOffsetCalculatorVariant(int num) {
168 switch(num) {
169 #define DEFINE_CASE(index) \
170 case index: v = TrivialOffsetCalculator<index>(); break;
171
172 AT_FOR_8_CASES(DEFINE_CASE)
173 #undef DEFINE_CASE
174
175 default:
176 TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num);
177 }
178 }
179
data_ptrTrivialOffsetCalculatorVariant180 void* data_ptr() {
181 return std::visit([](auto & v){ return static_cast<void*>(&v); }, v);
182 }
183
184 private:
185 TrivialOffsetCalculatorTypes v{};
186 };
187
188 struct LoadWithCastVariant {
189 #define DEFINE_CASE(index) std::unique_ptr<memory::LoadWithCast<index>>
190 using LoadWithCastPtr = std::variant<
191 AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
192 >;
193 #undef DEFINE_CASE
194
LoadWithCastVariantLoadWithCastVariant195 LoadWithCastVariant(const TensorIteratorBase& iter) {
196 int arity = iter.ninputs();
197 switch(arity) {
198 #define DEFINE_CASE(index) \
199 case index: v = std::make_unique<memory::LoadWithCast<index>>(iter); break;
200
201 AT_FOR_8_CASES(DEFINE_CASE)
202 #undef DEFINE_CASE
203
204 default:
205 TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity);
206 }
207 }
208
data_ptrLoadWithCastVariant209 void* data_ptr() {
210 return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
211 }
212
213 private:
214 LoadWithCastPtr v{};
215 };
216
217 struct StoreWithCastVariant {
218 #define DEFINE_CASE(index) std::unique_ptr<memory::StoreWithCast<index>>
219 using StoreWithCastPtr = std::variant<
220 AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE)
221 >;
222 #undef DEFINE_CASE
223
StoreWithCastVariantStoreWithCastVariant224 StoreWithCastVariant(const TensorIteratorBase& iter) {
225 int num = iter.noutputs();
226 switch(num) {
227 #define DEFINE_CASE(index) \
228 case index: v = std::make_unique<memory::StoreWithCast<index>>(iter); break;
229
230 AT_FOR_8_CASES(DEFINE_CASE)
231 #undef DEFINE_CASE
232
233 default:
234 TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num);
235 }
236 }
237
data_ptrStoreWithCastVariant238 void* data_ptr() {
239 return std::visit([](auto & v){ return static_cast<void*>(v.get()); }, v);
240 }
241
242 private:
243 StoreWithCastPtr v{};
244 };
245
246 } // namespace at::native
247
248
249 #endif // AT_USE_JITERATOR()
250