xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ForeachUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Device.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/ScalarType.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/utils/ParamsHash.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/result_type_native.h>
15 #endif
16 
17 #include <unordered_map>
18 #include <vector>
19 
20 namespace at::native {
21 namespace {
22 // Check if tensor list has either a boolean tensor or a integer tensor
has_integral_tensor(TensorList tensors,const bool includeBool)23 inline bool has_integral_tensor(TensorList tensors, const bool includeBool) {
24   return std::any_of(
25       tensors.begin(), tensors.end(), [&includeBool](const auto& t) {
26         return at::isIntegralType(t.scalar_type(), includeBool);
27       });
28 }
29 // check if tensor list has bool tensors
has_bool_tensor(TensorList tensors)30 inline bool has_bool_tensor(TensorList tensors) {
31   return std::any_of(tensors.begin(), tensors.end(), [](const auto& t) -> bool {
32     return t.scalar_type() == ScalarType::Bool;
33   });
34 }
35 
36 // Check foreach API restrictions
37 // - Tensor lists must be non-empty.
38 // - All TensorLists and ScalarLists must have the same number of elements.
39 // - Corresponding tensors must have the same size.
check_foreach_api_restrictions(TensorList tensors)40 inline void check_foreach_api_restrictions(TensorList tensors) {
41   TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
42 }
43 
check_foreach_api_restrictions(TensorList tensors,ArrayRef<Scalar> scalars)44 inline void check_foreach_api_restrictions(
45     TensorList tensors,
46     ArrayRef<Scalar> scalars) {
47   check_foreach_api_restrictions(tensors);
48   TORCH_CHECK(
49       tensors.size() == scalars.size(),
50       "Tensor list must have same number of elements as scalar list.");
51 }
52 
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2)53 inline void check_foreach_api_restrictions(
54     TensorList tensors1,
55     TensorList tensors2) {
56   TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
57   TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
58   TORCH_CHECK(
59       tensors1.size() == tensors2.size(),
60       "Tensor lists must have the same number of tensors, got ",
61       tensors1.size(),
62       " and ",
63       tensors2.size());
64 }
65 
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2,TensorList tensors3)66 inline void check_foreach_api_restrictions(
67     TensorList tensors1,
68     TensorList tensors2,
69     TensorList tensors3) {
70   TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
71   TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
72   TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
73   TORCH_CHECK(
74       tensors1.size() == tensors2.size(),
75       "Tensor lists must have the same number of tensors, got ",
76       tensors1.size(),
77       " and ",
78       tensors2.size());
79   TORCH_CHECK(
80       tensors1.size() == tensors3.size(),
81       "Tensor lists must have the same number of tensors, got ",
82       tensors1.size(),
83       " and ",
84       tensors3.size());
85 }
86 
check_foreach_api_restrictions(TensorList tensors1,TensorList tensors2,TensorList tensors3,ArrayRef<Scalar> scalars)87 inline void check_foreach_api_restrictions(
88     TensorList tensors1,
89     TensorList tensors2,
90     TensorList tensors3,
91     ArrayRef<Scalar> scalars) {
92   check_foreach_api_restrictions(tensors1, tensors2, tensors3);
93   TORCH_CHECK(
94       tensors1.size() == scalars.size(),
95       "Tensor list must have same number of elements as scalar list, got ",
96       tensors1.size(),
97       " and ",
98       scalars.size());
99 }
100 
101 // Helper function called in check_fast_path_restrictions to check whether all
102 // corresponding tensors (aligning in index across the tensorLists) share the
103 // same device and dtype.
104 inline bool _check_tensors_share_device_and_dtype(
105     ArrayRef<TensorList> tensorLists,
106     const bool skip_dtype_check = false) {
107   const auto expected_dtype = tensorLists[0][0].dtype();
108   const auto expected_device = tensorLists[0][0].device();
109 
110   auto is_tensor_okay = [&](const Tensor& tensor) {
111     return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
112         tensor.device() == expected_device && tensor.layout() == at::kStrided &&
113         tensor.is_non_overlapping_and_dense();
114   };
115 
116   for (const auto& tensorList : tensorLists) {
117     for (const auto& tensor : tensorList) {
118       if (!is_tensor_okay(tensor)) {
119         return false;
120       }
121     }
122   }
123 
124   return true;
125 }
126 
127 // Helper function called in check_fast_path_restrictions to check if
128 // corresponding tensors in tensor lists have the same sizes and strides.
_check_tensors_share_sizes_and_strides(ArrayRef<TensorList> tensorLists)129 inline bool _check_tensors_share_sizes_and_strides(
130     ArrayRef<TensorList> tensorLists) {
131   auto is_diff_stride = [](const IntArrayRef& size,
132                            const IntArrayRef& left_stride,
133                            const IntArrayRef& right_stride) -> bool {
134     const size_t size_size = size.size();
135     for (const auto dim : c10::irange(size_size)) {
136       if (size[dim] == 1)
137         continue;
138       if (left_stride[dim] != right_stride[dim]) {
139         return true;
140       }
141     }
142     return false;
143   };
144   for (const auto i : c10::irange(1, tensorLists.size())) {
145     for (const auto j : c10::irange(tensorLists[0].size())) {
146       if (tensorLists[0][j].sizes() != tensorLists[i][j].sizes() ||
147           is_diff_stride(
148               tensorLists[0][j].sizes(),
149               tensorLists[0][j].strides(),
150               tensorLists[i][j].strides())) {
151         return false;
152       }
153     }
154   }
155 
156   return true;
157 }
158 
159 // Helper function called in check_fast_path_restrictions to check whether
160 // all tensors type promote properly with the scalars in scalarList. This
161 // function assumes that _check_tensors_share_device_and_dtype has already been
162 // called so that all corresponding tensors in tensorLists have the same dtype.
163 // Then, it is sufficient to check the type promotion with just one tensorList.
164 inline bool _check_tensors_do_type_promotion_with_scalars(
165     TensorList tensorList,
166     ArrayRef<Scalar> scalarList = {},
167     bool does_op_promote_integer_inputs_to_float = false) {
168   for (const auto i : c10::irange(tensorList.size())) {
169     // For division, integer inputs will result in float.
170     if (does_op_promote_integer_inputs_to_float) {
171       if (at::isIntegralType(
172               tensorList[i].scalar_type(), /*includeBool*/ true)) {
173         return false;
174       }
175     }
176     if (!scalarList.empty()) {
177       const auto& scalar =
178           scalarList.size() == 1 ? scalarList[0] : scalarList[i];
179       const auto& tensor = tensorList[i];
180       // note(mkozuki): This check might be responsible for
181       // `_foreach_add(bool_tensors, bool_tensors)` being pushed to slow path.
182       if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
183         return false;
184       }
185     }
186   }
187 
188   return true;
189 }
190 
191 // To go via 'fast' path, several conditions must be satisfied
192 // - All tensors in all lists must have the same dtype.
193 // - All tensors must be on the same device
194 // - All tensors must have strided layout
195 // - All tensors must be non-overlapping and dense
196 // - Resulting tensor must have the same dtype as the input one
197 
198 // [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
199 //     ``does_op_promote_integer_inputs_to_float=true`` means that the result of
200 //     the op will be float even if inputs are integer or boolean, which
201 //     currently fast path does not support. In short, this flag, when
202 //     turned on, gatekeeps the op from going down the fastpath.
203 
204 // Please, make sure to call check_foreach_api_restrictions before calling this
205 // method. There is a set of preconditions that have to be satisfied.
206 inline bool check_fast_path_restrictions(
207     ArrayRef<TensorList> tensorLists,
208     ArrayRef<Scalar> scalarList = {},
209     bool does_op_promote_integer_inputs_to_float = false) {
210   return _check_tensors_share_device_and_dtype(tensorLists) &&
211       _check_tensors_share_sizes_and_strides(tensorLists) &&
212       _check_tensors_do_type_promotion_with_scalars(
213              tensorLists[0],
214              scalarList,
215              does_op_promote_integer_inputs_to_float);
216 }
217 
convert_tensor_to_scalar_list(const Tensor & scalarList_,int64_t expect_length)218 inline std::vector<c10::Scalar> convert_tensor_to_scalar_list(
219     const Tensor& scalarList_,
220     int64_t expect_length) {
221   std::vector<c10::Scalar> scalarList;
222   TORCH_CHECK(
223       scalarList_.device() == c10::kCPU,
224       "Expected scalars to be on CPU, got ",
225       scalarList_.device(),
226       " instead.");
227   TORCH_CHECK(
228       scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
229   TORCH_CHECK(
230       scalarList_.dim() == 1,
231       "Expected packed scalar Tensor to be of dimension 1. Got ",
232       scalarList_.dim(),
233       " instead.");
234   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
235       kComplexHalf,
236       kHalf,
237       kBool,
238       kBFloat16,
239       scalarList_.scalar_type(),
240       "convert_tensor_to_scalar_list",
241       [&]() {
242         const scalar_t* scalar_data = scalarList_.const_data_ptr<scalar_t>();
243         TORCH_CHECK(
244             (expect_length == scalarList_.size(0)),
245             "Expected length of scalars to match input of length ",
246             expect_length,
247             " but got ",
248             scalarList_.size(0),
249             " instead.");
250         for (int64_t i = 0; i < scalarList_.size(0); i++) {
251           scalarList.emplace_back(scalar_data[i]);
252         }
253       });
254   return scalarList;
255 }
256 
257 // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
258 inline bool can_use_fast_route(
259     ArrayRef<TensorList> tensorLists,
260     ArrayRef<Scalar> scalarList = {},
261     bool does_op_promote_integer_inputs_to_float = false) {
262   return check_fast_path_restrictions(
263       tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
264 }
265 
266 // see: [note: what's ``does_op_promote_integer_inputs_to_float=true``?]
267 inline bool can_use_fast_route(
268     TensorList tensors1,
269     TensorList tensors2,
270     bool does_op_promote_integer_inputs_to_float = false) {
271   return can_use_fast_route(
272       {tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
273 }
274 
275 using DeviceDtypeKey = std::pair<at::Device, at::ScalarType>;
276 using IndicesT = std::vector<size_t>;
277 using nested_optional_tensorvec_t =
278     std::vector<std::vector<std::optional<at::Tensor>>>;
279 using TensorsAndIndicesT = std::pair<nested_optional_tensorvec_t, IndicesT>;
280 using FlatMap = std::unordered_map<
281     DeviceDtypeKey,
282     TensorsAndIndicesT,
283     ParamsHash<DeviceDtypeKey>>;
284 
_group_tensors_by_first_tensors_device_and_dtype(const nested_optional_tensorvec_t & nested_tensorlist,const bool with_indices)285 inline FlatMap _group_tensors_by_first_tensors_device_and_dtype(
286     const nested_optional_tensorvec_t& nested_tensorlist,
287     const bool with_indices) {
288   FlatMap grouped_tensors_with_indices;
289 
290   TORCH_CHECK(!nested_tensorlist.empty());
291   TORCH_CHECK(!nested_tensorlist[0].empty());
292   const auto num_lists = nested_tensorlist.size();
293   const auto num_tensors = nested_tensorlist[0].size();
294 
295   TORCH_CHECK(std::all_of(
296       nested_tensorlist.cbegin(),
297       nested_tensorlist.cend(),
298       [&](const auto& tensorlist) -> bool {
299         // note(crcrpar): Allow empty tensorlists following
300         // ref:
301         // https://github.com/pytorch/pytorch/blob/85885301fd3c6adb8b9dc3cf7afadf6945566684/torch/utils/_foreach_utils.py#L21-L24
302         return tensorlist.size() == num_tensors || tensorlist.size() == 0;
303       }));
304 
305   for (const auto& tensor_index : c10::irange(num_tensors)) {
306     const auto key = [&]() -> DeviceDtypeKey {
307       const auto t = nested_tensorlist[0][tensor_index];
308       TORCH_CHECK(
309           t.has_value(),
310           "Tensors of the first list of nested Tensor lists are supposed to be defined but ",
311           "the ",
312           tensor_index,
313           "-th Tensor is not.");
314       return {t->device(), t->scalar_type()};
315     }();
316     TORCH_CHECK(
317         std::all_of(
318             nested_tensorlist.cbegin(),
319             nested_tensorlist.cend(),
320             [&](const auto& tensorlist) -> bool {
321               if (tensorlist.size() == 0) {
322                 return true;
323               }
324               const auto& tensor = tensorlist[tensor_index];
325               // note(crcrpar): Currently the scope of this function is
326               // optimizers so there could be `state_steps` and other scalars
327               // whose elements are float tensors no matter what the parameter's
328               // dtype is.
329               if (!tensor.has_value()) {
330                 return true;
331               } else {
332                 const auto s = tensor->scalar_type();
333                 const auto d = tensor->device();
334                 // Note: `step` or `state_step` is float32 by default.
335                 if (key.first == d) {
336                   return key.second == s || s == at::ScalarType::Float ||
337                       s == at::ScalarType::Double;
338                 } else if (d.is_cpu()) {
339                   // note(crcrpar): There are some test cases (e.g.
340                   // TestOptim::test_adam) where state_steps are on CPU and the
341                   // others are on CUDA. Currently a state_step Tensor has the
342                   // dtype of float.
343                   return s == at::ScalarType::Float ||
344                       s == at::ScalarType::Double;
345                 } else {
346                   return false;
347                 }
348               }
349             }),
350         "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding");
351     if (!grouped_tensors_with_indices.count(key)) {
352       grouped_tensors_with_indices.insert(
353           {key,
354            TensorsAndIndicesT{
355                [&]() -> nested_optional_tensorvec_t {
356                  nested_optional_tensorvec_t nested_tensorvec;
357                  nested_tensorvec.reserve(num_lists);
358                  for (const auto& i : c10::irange(num_lists)) {
359                    std::vector<std::optional<at::Tensor>> tensors;
360                    if (!nested_tensorlist[i].empty()) {
361                      // NB: num_tensors is the max possible length for any of
362                      // the inner lists of tensor references. Reserving the max
363                      // trades memory for perf. This should not have significant
364                      // impact.
365                      tensors.reserve(num_tensors);
366                    }
367                    nested_tensorvec.emplace_back(tensors);
368                  }
369                  return nested_tensorvec;
370                }(),
371                [&]() -> IndicesT {
372                  if (!with_indices) {
373                    return {};
374                  } else {
375                    IndicesT indices;
376                    indices.reserve(num_tensors);
377                    return indices;
378                  }
379                }()}});
380     }
381     for (const auto& list_index : c10::irange(num_lists)) {
382       if (!nested_tensorlist[list_index].empty()) {
383         grouped_tensors_with_indices[key].first[list_index].emplace_back(
384             nested_tensorlist[list_index][tensor_index]);
385       }
386     }
387     if (with_indices) {
388       grouped_tensors_with_indices[key].second.emplace_back(tensor_index);
389     }
390   }
391 
392   return grouped_tensors_with_indices;
393 }
394 
395 } // namespace
396 } // namespace at::native
397