1 #pragma once
2
3 #include <ATen/Device.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/ScalarType.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/utils/ParamsHash.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/result_type_native.h>
15 #endif
16
17 #include <unordered_map>
18 #include <vector>
19
20 namespace at::native {
21 namespace {
22 // Check if tensor list has either a boolean tensor or a integer tensor
has_integral_tensor(TensorList tensors,const bool includeBool)23 inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
24 return std::any_of(
25 tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
26 return at::isIntegralType(t.scalar_type(), includeBool);
27 });
28 }
29 // check if tensor list has bool tensors
has_bool_tensor(TensorList tensors)30 inline bool has_bool_tensor(TensorList tensors) {
31 return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
32 return t.scalar_type() == ScalarType::Bool;
33 });
34 }
35
36 // Check foreach API restrictions
37 // - Tensor lists must be non-empty.
38 // - All TensorLists and ScalarLists must have the same number of elements.
39 // - Corresponding tensors must have the same size.
check_foreach_api_restrictions(TensorList tensors)40 inline void check_foreach_api_restrictions(TensorList tensors) {
41 TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
42 }
43
check_foreach_api_restrictions(TensorList tensors,ArrayRef<Scalar> scalars)44 inline void check_foreach_api_restrictions(
45 TensorList tensors,
46 ArrayRef<Scalar> scalars) {
47 check_foreach_api_restrictions(tensors);
48 TORCH_CHECK(
49 tensors.size() == scalars.size(),
50 "Tensor list must have same number of elements as scalar list.");
51 }
52
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2)53 inline void check_foreach_api_restrictions(
54 TensorList tensors1,
55 TensorList tensors2) {
56 TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
57 TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
58 TORCH_CHECK(
59 tensors1.size() == tensors2.size(),
60 "Tensor lists must have the same number of tensors, got ",
61 tensors1.size(),
62 " and ",
63 tensors2.size());
64 }
65
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2,TensorList tensors3)66 inline void check_foreach_api_restrictions(
67 TensorList tensors1,
68 TensorList tensors2,
69 TensorList tensors3) {
70 TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
71 TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
72 TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
73 TORCH_CHECK(
74 tensors1.size() == tensors2.size(),
75 "Tensor lists must have the same number of tensors, got ",
76 tensors1.size(),
77 " and ",
78 tensors2.size());
79 TORCH_CHECK(
80 tensors1.size() == tensors3.size(),
81 "Tensor lists must have the same number of tensors, got ",
82 tensors1.size(),
83 " and ",
84 tensors3.size());
85 }
86
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2,TensorList tensors3,ArrayRef<Scalar> scalars)87 inline void check_foreach_api_restrictions(
88 TensorList tensors1,
89 TensorList tensors2,
90 TensorList tensors3,
91 ArrayRef<Scalar> scalars) {
92 check_foreach_api_restrictions(tensors1, tensors2, tensors3);
93 TORCH_CHECK(
94 tensors1.size() == scalars.size(),
95 "Tensor list must have same number of elements as scalar list, got ",
96 tensors1.size(),
97 " and ",
98 scalars.size());
99 }
100
101 // Helper function called in check_fast_path_restrictions to check whether all
102 // corresponding tensors (aligning in index across the tensorLists) share the
103 // same device and dtype.
104 inline bool _check_tensors_share_device_and_dtype(
105 ArrayRef<TensorList> tensorLists,
106 const bool skip_dtype_check = false) {
107 const auto expected_dtype = tensorLists[0][0].dtype();
108 const auto expected_device = tensorLists[0][0].device();
109
110 auto is_tensor_okay = [&](const Tensor& tensor) {
111 return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
112 tensor.device() == expected_device && tensor.layout() == at::kStrided &&
113 tensor.is_non_overlapping_and_dense();
114 };
115
116 for (const auto& tensorList : tensorLists) {
117 for (const auto& tensor : tensorList) {
118 if (!is_tensor_okay(tensor)) {
119 return false;
120 }
121 }
122 }
123
124 return true;
125 }
126
127 // Helper function called in check_fast_path_restrictions to check if
128 // corresponding tensors in tensor lists have the same sizes and strides.
_check_tensors_share_sizes_and_strides(ArrayRef<TensorList> tensorLists)129 inline bool _check_tensors_share_sizes_and_strides(
130 ArrayRef<TensorList> tensorLists) {
131 auto is_diff_stride = [](const IntArrayRef& size,
132 const IntArrayRef& left_stride,
133 const IntArrayRef& right_stride) -> bool {
134 const size_t size_size = size.size();
135 for (const auto dim : c10::irange(size_size)) {
136 if (size[dim] == 1)
137 continue;
138 if (left_stride[dim] != right_stride[dim]) {
139 return true;
140 }
141 }
142 return false;
143 };
144 for (const auto i : c10::irange(1, tensorLists.size())) {
145 for (const auto j : c10::irange(tensorLists[0].size())) {
146 if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
147 is_diff_stride(
148 tensorLists[0][j].sizes(),
149 tensorLists[0][j].strides(),
150 tensorLists[i][j].strides())) {
151 return false;
152 }
153 }
154 }
155
156 return true;
157 }
158
159 // Helper function called in check_fast_path_restrictions to check whether
160 // all tensors type promote properly with the scalars in scalarList. This
161 // function assumes that _check_tensors_share_device_and_dtype has already been
162 // called so that all corresponding tensors in tensorLists have the same dtype.
163 // Then, it is sufficient to check the type promotion with just one tensorList.
164 inline bool _check_tensors_do_type_promotion_with_scalars(
165 TensorList tensorList,
166 ArrayRef<Scalar> scalarList = {},
167 bool does_op_promote_integer_inputs_to_float = false) {
168 for (const auto i : c10::irange(tensorList.size())) {
169 // For division, integer inputs will result in float.
170 if (does_op_promote_integer_inputs_to_float) {
171 if (at::isIntegralType(
172 tensorList[i].scalar_type(), /*includeBool*/ true)) {
173 return false;
174 }
175 }
176 if (!scalarList.empty()) {
177 const auto& scalar =
178 scalarList.size() == 1 ? scalarList[0] : scalarList[i];
179 const auto& tensor = tensorList[i];
180 // note(mkozuki): This check might be responsible for
181 // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
182 if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
183 return false;
184 }
185 }
186 }
187
188 return true;
189 }
190
191 // To go via 'fast' path, several conditions must be satisfied
192 // - All tensors in all lists must have the same dtype.
193 // - All tensors must be on the same device
194 // - All tensors must have strided layout
195 // - All tensors must be non-overlapping and dense
196 // - Resulting tensor must have the same dtype as the input one
197
198 // [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
199 // ``does_op_promote_integer_inputs_to_float=true`` means that the result of
200 // the op will be float even if inputs are integer or boolean, which
201 // currently fast path does not support. In short, this flag, when
202 // turned on, gatekeeps the op from going down the fastpath.
203
204 // Please, make sure to call check_foreach_api_restrictions before calling this
205 // method. There is a set of preconditions that have to be satisfied.
206 inline bool check_fast_path_restrictions(
207 ArrayRef<TensorList> tensorLists,
208 ArrayRef<Scalar> scalarList = {},
209 bool does_op_promote_integer_inputs_to_float = false) {
210 return _check_tensors_share_device_and_dtype(tensorLists) &&
211 _check_tensors_share_sizes_and_strides(tensorLists) &&
212 _check_tensors_do_type_promotion_with_scalars(
213 tensorLists[0],
214 scalarList,
215 does_op_promote_integer_inputs_to_float);
216 }
217
convert_tensor_to_scalar_list(const Tensor & scalarList_,int64_t expect_length)218 inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
219 const Tensor& scalarList_,
220 int64_t expect_length) {
221 std::vector<c10::Scalar> scalarList;
222 TORCH_CHECK(
223 scalarList_.device() == c10::kCPU,
224 "Expected scalars to be on CPU, got ",
225 scalarList_.device(),
226 " instead.");
227 TORCH_CHECK(
228 scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
229 TORCH_CHECK(
230 scalarList_.dim() == 1,
231 "Expected packed scalar Tensor to be of dimension 1. Got ",
232 scalarList_.dim(),
233 " instead.");
234 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
235 kComplexHalf,
236 kHalf,
237 kBool,
238 kBFloat16,
239 scalarList_.scalar_type(),
240 "convert_tensor_to_scalar_list",
241 [&]() {
242 const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
243 TORCH_CHECK(
244 (expect_length == scalarList_.size(0)),
245 "Expected length of scalars to match input of length ",
246 expect_length,
247 " but got ",
248 scalarList_.size(0),
249 " instead.");
250 for (int64_t i = 0; i < scalarList_.size(0); i++) {
251 scalarList.emplace_back(scalar_data[i]);
252 }
253 });
254 return scalarList;
255 }
256
257 // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
258 inline bool can_use_fast_route(
259 ArrayRef<TensorList> tensorLists,
260 ArrayRef<Scalar> scalarList = {},
261 bool does_op_promote_integer_inputs_to_float = false) {
262 return check_fast_path_restrictions(
263 tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
264 }
265
266 // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
267 inline bool can_use_fast_route(
268 TensorList tensors1,
269 TensorList tensors2,
270 bool does_op_promote_integer_inputs_to_float = false) {
271 return can_use_fast_route(
272 {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
273 }
274
275 using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
276 using IndicesT = std::vector<size_t>;
277 using nested_optional_tensorvec_t =
278 std::vector<std::vector<std::optional<at::Tensor>>>;
279 using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
280 using FlatMap = std::unordered_map<
281 DeviceDtypeKey,
282 TensorsAndIndicesT,
283 ParamsHash<DeviceDtypeKey>>;
284
_group_tensors_by_first_tensors_device_and_dtype(const nested_optional_tensorvec_t & nested_tensorlist,const bool with_indices)285 inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
286 const nested_optional_tensorvec_t& nested_tensorlist,
287 const bool with_indices) {
288 FlatMap grouped_tensors_with_indices;
289
290 TORCH_CHECK(!nested_tensorlist.empty());
291 TORCH_CHECK(!nested_tensorlist[0].empty());
292 const auto num_lists = nested_tensorlist.size();
293 const auto num_tensors = nested_tensorlist[0].size();
294
295 TORCH_CHECK(std::all_of(
296 nested_tensorlist.cbegin(),
297 nested_tensorlist.cend(),
298 [&](const auto& tensorlist) -> bool {
299 // note(crcrpar): Allow empty tensorlists following
300 // ref:
301 // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
302 return tensorlist.size() == num_tensors || tensorlist.size() == 0;
303 }));
304
305 for (const auto& tensor_index : c10::irange(num_tensors)) {
306 const auto key = [&]() -> DeviceDtypeKey {
307 const auto t = nested_tensorlist[0][tensor_index];
308 TORCH_CHECK(
309 t.has_value(),
310 "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
311 "the ",
312 tensor_index,
313 "-th Tensor is not.");
314 return {t->device(), t->scalar_type()};
315 }();
316 TORCH_CHECK(
317 std::all_of(
318 nested_tensorlist.cbegin(),
319 nested_tensorlist.cend(),
320 [&](const auto& tensorlist) -> bool {
321 if (tensorlist.size() == 0) {
322 return true;
323 }
324 const auto& tensor = tensorlist[tensor_index];
325 // note(crcrpar): Currently the scope of this function is
326 // optimizers so there could be `state_steps` and other scalars
327 // whose elements are float tensors no matter what the parameter's
328 // dtype is.
329 if (!tensor.has_value()) {
330 return true;
331 } else {
332 const auto s = tensor->scalar_type();
333 const auto d = tensor->device();
334 // Note: `step` or `state_step` is float32 by default.
335 if (key.first == d) {
336 return key.second == s || s == at::ScalarType::Float ||
337 s == at::ScalarType::Double;
338 } else if (d.is_cpu()) {
339 // note(crcrpar): There are some test cases (e.g.
340 // TestOptim::test_adam) where state_steps are on CPU and the
341 // others are on CUDA. Currently a state_step Tensor has the
342 // dtype of float.
343 return s == at::ScalarType::Float ||
344 s == at::ScalarType::Double;
345 } else {
346 return false;
347 }
348 }
349 }),
350 "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
351 if (!grouped_tensors_with_indices.count(key)) {
352 grouped_tensors_with_indices.insert(
353 {key,
354 TensorsAndIndicesT{
355 [&]() -> nested_optional_tensorvec_t {
356 nested_optional_tensorvec_t nested_tensorvec;
357 nested_tensorvec.reserve(num_lists);
358 for (const auto& i : c10::irange(num_lists)) {
359 std::vector<std::optional<at::Tensor>> tensors;
360 if (!nested_tensorlist[i].empty()) {
361 // NB: num_tensors is the max possible length for any of
362 // the inner lists of tensor references. Reserving the max
363 // trades memory for perf. This should not have significant
364 // impact.
365 tensors.reserve(num_tensors);
366 }
367 nested_tensorvec.emplace_back(tensors);
368 }
369 return nested_tensorvec;
370 }(),
371 [&]() -> IndicesT {
372 if (!with_indices) {
373 return {};
374 } else {
375 IndicesT indices;
376 indices.reserve(num_tensors);
377 return indices;
378 }
379 }()}});
380 }
381 for (const auto& list_index : c10::irange(num_lists)) {
382 if (!nested_tensorlist[list_index].empty()) {
383 grouped_tensors_with_indices[key].first[list_index].emplace_back(
384 nested_tensorlist[list_index][tensor_index]);
385 }
386 }
387 if (with_indices) {
388 grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
389 }
390 }
391
392 return grouped_tensors_with_indices;
393 }
394
395 } // namespace
396 } // namespace at::native
397