xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ForeachOpsKernels.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <vector>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <c10/util/irange.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #include <ATen/Operators.h>
11 #else
12 #include <ATen/ops/_foreach_abs_native.h>
13 #include <ATen/ops/_foreach_acos_native.h>
14 #include <ATen/ops/_foreach_add_native.h>
15 #include <ATen/ops/_foreach_addcdiv_native.h>
16 #include <ATen/ops/_foreach_addcmul_native.h>
17 #include <ATen/ops/_foreach_asin_native.h>
18 #include <ATen/ops/_foreach_atan_native.h>
19 #include <ATen/ops/_foreach_ceil_native.h>
20 #include <ATen/ops/_foreach_clamp_max_native.h>
21 #include <ATen/ops/_foreach_clamp_min_native.h>
22 #include <ATen/ops/_foreach_copy_native.h>
23 #include <ATen/ops/_foreach_cos_native.h>
24 #include <ATen/ops/_foreach_cosh_native.h>
25 #include <ATen/ops/_foreach_div_native.h>
26 #include <ATen/ops/_foreach_erf_native.h>
27 #include <ATen/ops/_foreach_erfc_native.h>
28 #include <ATen/ops/_foreach_exp_native.h>
29 #include <ATen/ops/_foreach_expm1_native.h>
30 #include <ATen/ops/_foreach_floor_native.h>
31 #include <ATen/ops/_foreach_frac_native.h>
32 #include <ATen/ops/_foreach_lerp_native.h>
33 #include <ATen/ops/_foreach_lgamma_native.h>
34 #include <ATen/ops/_foreach_log10_native.h>
35 #include <ATen/ops/_foreach_log1p_native.h>
36 #include <ATen/ops/_foreach_log2_native.h>
37 #include <ATen/ops/_foreach_log_native.h>
38 #include <ATen/ops/_foreach_max_native.h>
39 #include <ATen/ops/_foreach_maximum_native.h>
40 #include <ATen/ops/_foreach_minimum_native.h>
41 #include <ATen/ops/_foreach_mul_native.h>
42 #include <ATen/ops/_foreach_neg_native.h>
43 #include <ATen/ops/_foreach_norm_native.h>
44 #include <ATen/ops/_foreach_pow_native.h>
45 #include <ATen/ops/_foreach_reciprocal_native.h>
46 #include <ATen/ops/_foreach_round_native.h>
47 #include <ATen/ops/_foreach_sigmoid_native.h>
48 #include <ATen/ops/_foreach_sign_native.h>
49 #include <ATen/ops/_foreach_sin_native.h>
50 #include <ATen/ops/_foreach_sinh_native.h>
51 #include <ATen/ops/_foreach_sqrt_native.h>
52 #include <ATen/ops/_foreach_sub_native.h>
53 #include <ATen/ops/_foreach_tan_native.h>
54 #include <ATen/ops/_foreach_tanh_native.h>
55 #include <ATen/ops/_foreach_trunc_native.h>
56 #include <ATen/ops/_foreach_zero_native.h>
57 #include <ATen/ops/copy.h>
58 #include <ATen/ops/linalg_vector_norm.h>
59 #include <ATen/ops/max.h>
60 #include <ATen/ops/maximum.h>
61 #include <ATen/ops/minimum.h>
62 #include <ATen/ops/pow.h>
63 #endif
64 
65 namespace at::native {
66 
67 #define FOREACH_BINARY_OP_TENSOR(OP)                            \
68   void foreach_tensor_##OP##_tensor_kernel_slow_(               \
69       TensorList tensors, const Tensor& scalar) {               \
70     TORCH_CHECK(                                                \
71         scalar.dim() == 0 && scalar.numel() == 1,               \
72         "scalar tensor expected to be 0 dim but it has ",       \
73         scalar.dim(),                                           \
74         " dimensions and ",                                     \
75         scalar.numel(),                                         \
76         " elements.");                                          \
77     check_foreach_api_restrictions(tensors);                    \
78                                                                 \
79     for (auto& t : tensors) {                                   \
80       t.OP##_(scalar);                                          \
81     }                                                           \
82   }                                                             \
83                                                                 \
84   std::vector<Tensor> foreach_tensor_##OP##_tensor_kernel_slow( \
85       TensorList tensors, const Tensor& scalar) {               \
86     TORCH_CHECK(                                                \
87         scalar.dim() == 0 && scalar.numel() == 1,               \
88         "scalar tensor expected to be 0 dim but it has ",       \
89         scalar.dim(),                                           \
90         " dimensions and ",                                     \
91         scalar.numel(),                                         \
92         " elements.");                                          \
93     check_foreach_api_restrictions(tensors);                    \
94                                                                 \
95     std::vector<Tensor> result;                                 \
96     result.reserve(tensors.size());                             \
97     for (const auto& t : tensors) {                             \
98       result.emplace_back(t.OP(scalar));                        \
99     }                                                           \
100                                                                 \
101     return result;                                              \
102   }
103 
104 #define FOREACH_BINARY_OP_TENSOR_ALPHA(OP)                             \
105   void foreach_tensor_##OP##_tensor_kernel_slow_(                      \
106       TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
107     TORCH_CHECK(                                                       \
108         scalar.dim() == 0 && scalar.numel() == 1,                      \
109         "scalar tensor expected to be 0 dim but it has ",              \
110         scalar.dim(),                                                  \
111         " dimensions and ",                                            \
112         scalar.numel(),                                                \
113         " elements.");                                                 \
114     check_foreach_api_restrictions(tensors);                           \
115                                                                        \
116     for (auto& t : tensors) {                                          \
117       t.OP##_(scalar, alpha);                                          \
118     }                                                                  \
119   }                                                                    \
120                                                                        \
121   std::vector<Tensor> foreach_tensor_##OP##_tensor_kernel_slow(        \
122       TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
123     TORCH_CHECK(                                                       \
124         scalar.dim() == 0 && scalar.numel() == 1,                      \
125         "scalar tensor expected to be 0 dim but it has ",              \
126         scalar.dim(),                                                  \
127         " dimensions and ",                                            \
128         scalar.numel(),                                                \
129         " elements.");                                                 \
130     check_foreach_api_restrictions(tensors);                           \
131                                                                        \
132     std::vector<Tensor> result;                                        \
133     result.reserve(tensors.size());                                    \
134     for (const auto& t : tensors) {                                    \
135       result.emplace_back(t.OP(scalar, alpha));                        \
136     }                                                                  \
137                                                                        \
138     return result;                                                     \
139   }
140 
141 #define FOREACH_BINARY_OP_SCALAR(OP)                            \
142   void foreach_tensor_##OP##_scalar_kernel_slow_(               \
143       TensorList tensors, const Scalar& scalar) {               \
144     check_foreach_api_restrictions(tensors);                    \
145                                                                 \
146     for (auto& t : tensors) {                                   \
147       t.OP##_(scalar);                                          \
148     }                                                           \
149   }                                                             \
150                                                                 \
151   std::vector<Tensor> foreach_tensor_##OP##_scalar_kernel_slow( \
152       TensorList tensors, const Scalar& scalar) {               \
153     check_foreach_api_restrictions(tensors);                    \
154                                                                 \
155     std::vector<Tensor> result;                                 \
156     result.reserve(tensors.size());                             \
157     for (const auto& t : tensors) {                             \
158       result.emplace_back(t.OP(scalar));                        \
159     }                                                           \
160                                                                 \
161     return result;                                              \
162   }
163 
164 #define FOREACH_BINARY_OP_SCALARLIST(OP)                            \
165   void foreach_tensor_##OP##_scalarlist_kernel_slow_(               \
166       TensorList tensors, at::ArrayRef<Scalar> scalars) {           \
167     check_foreach_api_restrictions(tensors, scalars);               \
168                                                                     \
169     for (const auto i : c10::irange(tensors.size())) {              \
170       tensors[i].OP##_(scalars[i]);                                 \
171     }                                                               \
172   }                                                                 \
173                                                                     \
174   std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow( \
175       TensorList tensors, at::ArrayRef<Scalar> scalars) {           \
176     check_foreach_api_restrictions(tensors, scalars);               \
177     std::vector<Tensor> result;                                     \
178     result.reserve(tensors.size());                                 \
179     for (const auto i : c10::irange(tensors.size())) {              \
180       result.emplace_back(tensors[i].OP(scalars[i]));               \
181     }                                                               \
182                                                                     \
183     return result;                                                  \
184   }
185 
186 #define FOREACH_BINARY_OP_LIST(OP)                            \
187   std::vector<Tensor> foreach_tensor_##OP##_list_kernel_slow( \
188       TensorList tensors1, TensorList tensors2) {             \
189     check_foreach_api_restrictions(tensors1, tensors2);       \
190                                                               \
191     std::vector<Tensor> result;                               \
192     result.reserve(tensors1.size());                          \
193     for (const auto i : c10::irange(tensors1.size())) {       \
194       result.emplace_back(tensors1[i].OP(tensors2[i]));       \
195     }                                                         \
196                                                               \
197     return result;                                            \
198   }                                                           \
199                                                               \
200   void foreach_tensor_##OP##_list_kernel_slow_(               \
201       TensorList tensors1, TensorList tensors2) {             \
202     check_foreach_api_restrictions(tensors1, tensors2);       \
203                                                               \
204     for (const auto i : c10::irange(tensors1.size())) {       \
205       tensors1[i].OP##_(tensors2[i]);                         \
206     }                                                         \
207   }
208 
209 #define FOREACH_BINARY_OP_LIST_ALPHA(OP)                               \
210   std::vector<Tensor> foreach_tensor_##OP##_list_kernel_slow(          \
211       TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
212     check_foreach_api_restrictions(tensors1, tensors2);                \
213                                                                        \
214     std::vector<Tensor> result;                                        \
215     result.reserve(tensors1.size());                                   \
216     for (const auto i : c10::irange(tensors1.size())) {                \
217       result.emplace_back(tensors1[i].OP(tensors2[i], alpha));         \
218     }                                                                  \
219                                                                        \
220     return result;                                                     \
221   }                                                                    \
222                                                                        \
223   void foreach_tensor_##OP##_list_kernel_slow_(                        \
224       TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
225     check_foreach_api_restrictions(tensors1, tensors2);                \
226                                                                        \
227     for (const auto i : c10::irange(tensors1.size())) {                \
228       tensors1[i].OP##_(tensors2[i], alpha);                           \
229     }                                                                  \
230   }
231 
232 #define FOREACH_UNARY_OP(OP)                                           \
233   std::vector<Tensor> foreach_tensor_##OP##_slow(TensorList tensors) { \
234     check_foreach_api_restrictions(tensors);                           \
235                                                                        \
236     std::vector<Tensor> result;                                        \
237     result.reserve(tensors.size());                                    \
238     for (const auto& t : tensors) {                                    \
239       result.emplace_back(t.OP());                                     \
240     }                                                                  \
241                                                                        \
242     return result;                                                     \
243   }                                                                    \
244                                                                        \
245   void foreach_tensor_##OP##_slow_(TensorList tensors) {               \
246     check_foreach_api_restrictions(tensors);                           \
247                                                                        \
248     for (auto& t : tensors) {                                          \
249       t.OP##_();                                                       \
250     }                                                                  \
251   }
252 
253 #define FOREACH_POINTWISE_OP_SCALAR(OP)                                   \
254   std::vector<Tensor> foreach_tensor_##OP##_scalar_slow(                  \
255       TensorList input,                                                   \
256       TensorList tensors1,                                                \
257       TensorList tensors2,                                                \
258       const Scalar& scalar) {                                             \
259     check_foreach_api_restrictions(input, tensors1, tensors2);            \
260                                                                           \
261     std::vector<Tensor> result;                                           \
262     for (const auto i : c10::irange(input.size())) {                      \
263       result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalar)); \
264     }                                                                     \
265                                                                           \
266     return result;                                                        \
267   }                                                                       \
268                                                                           \
269   void foreach_tensor_##OP##_scalar_slow_(                                \
270       TensorList input,                                                   \
271       TensorList tensors1,                                                \
272       TensorList tensors2,                                                \
273       const Scalar& scalar) {                                             \
274     check_foreach_api_restrictions(input, tensors1, tensors2);            \
275                                                                           \
276     for (const auto i : c10::irange(input.size())) {                      \
277       input[i].OP##_(tensors1[i], tensors2[i], scalar);                   \
278     }                                                                     \
279   }
280 
281 #define FOREACH_POINTWISE_OP_SCALARLIST(OP)                                   \
282   std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow(                  \
283       TensorList input,                                                       \
284       TensorList tensors1,                                                    \
285       TensorList tensors2,                                                    \
286       at::ArrayRef<Scalar> scalars) {                                         \
287     check_foreach_api_restrictions(input, tensors1, tensors2, scalars);       \
288                                                                               \
289     std::vector<Tensor> result;                                               \
290     for (const auto i : c10::irange(input.size())) {                          \
291       result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalars[i])); \
292     }                                                                         \
293                                                                               \
294     return result;                                                            \
295   }                                                                           \
296                                                                               \
297   void foreach_tensor_##OP##_scalarlist_slow_(                                \
298       TensorList input,                                                       \
299       TensorList tensors1,                                                    \
300       TensorList tensors2,                                                    \
301       at::ArrayRef<Scalar> scalars) {                                         \
302     check_foreach_api_restrictions(input, tensors1, tensors2, scalars);       \
303                                                                               \
304     for (const auto i : c10::irange(input.size())) {                          \
305       input[i].OP##_(tensors1[i], tensors2[i], scalars[i]);                   \
306     }                                                                         \
307   }
308 
309 #define FOREACH_POINTWISE_OP_TENSOR(OP)                                   \
310   std::vector<Tensor> foreach_tensor_##OP##_tensor_slow(                  \
311       TensorList input,                                                   \
312       TensorList tensors1,                                                \
313       TensorList tensors2,                                                \
314       const Tensor& scalars_) {                                           \
315     auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \
316     check_foreach_api_restrictions(input, tensors1, tensors2, scalars);   \
317     return foreach_tensor_##OP##_scalarlist_slow(                         \
318         input, tensors1, tensors2, scalars);                              \
319   }                                                                       \
320                                                                           \
321   void foreach_tensor_##OP##_tensor_slow_(                                \
322       TensorList input,                                                   \
323       TensorList tensors1,                                                \
324       TensorList tensors2,                                                \
325       const Tensor& scalars_) {                                           \
326     auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \
327     check_foreach_api_restrictions(input, tensors1, tensors2, scalars);   \
328     foreach_tensor_##OP##_scalarlist_slow_(                               \
329         input, tensors1, tensors2, scalars);                              \
330   }
331 
332 FOREACH_BINARY_OP_LIST_ALPHA(add);
333 FOREACH_BINARY_OP_LIST_ALPHA(sub);
334 FOREACH_BINARY_OP_LIST_ALPHA(lerp);
335 
336 FOREACH_BINARY_OP_TENSOR_ALPHA(add);
337 FOREACH_BINARY_OP_TENSOR(mul);
338 FOREACH_BINARY_OP_TENSOR(div);
339 
340 FOREACH_BINARY_OP_SCALAR(add);
341 FOREACH_BINARY_OP_SCALAR(sub);
342 FOREACH_BINARY_OP_SCALAR(mul);
343 FOREACH_BINARY_OP_SCALAR(div);
344 FOREACH_BINARY_OP_SCALAR(clamp_min);
345 FOREACH_BINARY_OP_SCALAR(clamp_max);
346 FOREACH_BINARY_OP_SCALAR(pow);
347 
348 FOREACH_BINARY_OP_SCALARLIST(add);
349 FOREACH_BINARY_OP_SCALARLIST(sub);
350 FOREACH_BINARY_OP_SCALARLIST(mul);
351 FOREACH_BINARY_OP_SCALARLIST(div);
352 FOREACH_BINARY_OP_SCALARLIST(clamp_min);
353 FOREACH_BINARY_OP_SCALARLIST(clamp_max);
354 FOREACH_BINARY_OP_SCALARLIST(pow);
355 
356 FOREACH_BINARY_OP_LIST(mul);
357 FOREACH_BINARY_OP_LIST(div);
358 FOREACH_BINARY_OP_LIST(clamp_min);
359 FOREACH_BINARY_OP_LIST(clamp_max);
360 FOREACH_BINARY_OP_LIST(pow);
361 // _foreach_copy_
foreach_tensor_copy_list_kernel_slow_(TensorList self,TensorList src,const bool non_blocking)362 void foreach_tensor_copy_list_kernel_slow_(
363     TensorList self,
364     TensorList src,
365     const bool non_blocking) {
366   check_foreach_api_restrictions(self, src);
367 
368   for (const auto i : c10::irange(self.size())) {
369     self[i].copy_(src[i], non_blocking);
370   }
371 }
372 
373 FOREACH_UNARY_OP(sqrt);
374 FOREACH_UNARY_OP(exp);
375 FOREACH_UNARY_OP(abs);
376 FOREACH_UNARY_OP(acos);
377 FOREACH_UNARY_OP(asin);
378 FOREACH_UNARY_OP(atan);
379 FOREACH_UNARY_OP(ceil);
380 FOREACH_UNARY_OP(cos);
381 FOREACH_UNARY_OP(cosh);
382 FOREACH_UNARY_OP(erf);
383 FOREACH_UNARY_OP(erfc);
384 FOREACH_UNARY_OP(expm1);
385 FOREACH_UNARY_OP(floor);
386 FOREACH_UNARY_OP(log);
387 FOREACH_UNARY_OP(log10);
388 FOREACH_UNARY_OP(log1p);
389 FOREACH_UNARY_OP(log2);
390 FOREACH_UNARY_OP(neg);
391 FOREACH_UNARY_OP(tan);
392 FOREACH_UNARY_OP(tanh);
393 FOREACH_UNARY_OP(sin);
394 FOREACH_UNARY_OP(sinh);
395 FOREACH_UNARY_OP(round);
396 FOREACH_UNARY_OP(lgamma);
397 FOREACH_UNARY_OP(frac);
398 FOREACH_UNARY_OP(trunc);
399 FOREACH_UNARY_OP(reciprocal);
400 FOREACH_UNARY_OP(sigmoid);
401 FOREACH_UNARY_OP(sign);
402 
403 FOREACH_POINTWISE_OP_SCALAR(addcdiv);
404 FOREACH_POINTWISE_OP_SCALAR(addcmul);
405 
406 FOREACH_POINTWISE_OP_SCALARLIST(addcdiv);
407 FOREACH_POINTWISE_OP_SCALARLIST(addcmul);
408 
409 FOREACH_POINTWISE_OP_TENSOR(addcdiv);
410 FOREACH_POINTWISE_OP_TENSOR(addcmul);
411 
412 #define FOREACH_TERNARY_OP(OP)                                         \
413   std::vector<Tensor> foreach_tensor_ternary_##OP##_slow(              \
414       TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
415     check_foreach_api_restrictions(tensors1, tensors2, tensors3);      \
416     std::vector<Tensor> result;                                        \
417     for (const auto i : c10::irange(tensors1.size())) {                \
418       result.emplace_back(tensors1[i].OP(tensors2[i], tensors3[i]));   \
419     }                                                                  \
420     return result;                                                     \
421   }                                                                    \
422                                                                        \
423   void foreach_tensor_ternary_##OP##_slow_(                            \
424       TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
425     check_foreach_api_restrictions(tensors1, tensors2, tensors3);      \
426     for (const auto i : c10::irange(tensors1.size())) {                \
427       tensors1[i].OP##_(tensors2[i], tensors3[i]);                     \
428     }                                                                  \
429   }
430 
431 FOREACH_TERNARY_OP(lerp);
432 
foreach_tensor_zero_slow_(TensorList tensors)433 void foreach_tensor_zero_slow_(TensorList tensors) {
434   check_foreach_api_restrictions(tensors);
435 
436   for (auto& t : tensors) {
437     t.zero_();
438   }
439 }
440 
foreach_tensor_norm_slow(TensorList tensors,const Scalar & ord,std::optional<ScalarType> dtype)441 std::vector<Tensor> foreach_tensor_norm_slow(
442     TensorList tensors,
443     const Scalar& ord,
444     std::optional<ScalarType> dtype) {
445   check_foreach_api_restrictions(tensors);
446   std::vector<Tensor> result;
447   for (const auto& t : tensors) {
448     result.emplace_back(at::linalg_vector_norm(t, ord, {}, false, dtype));
449   }
450   return result;
451 }
452 
foreach_tensor_max_slow(TensorList tensors)453 std::vector<Tensor> foreach_tensor_max_slow(TensorList tensors) {
454   check_foreach_api_restrictions(tensors);
455   std::vector<Tensor> result;
456   for (const auto& t : tensors) {
457     result.emplace_back(at::max(t));
458   }
459   return result;
460 }
461 
foreach_scalar_pow_list_kernel_slow(const Scalar & self,TensorList exponent)462 std::vector<Tensor> foreach_scalar_pow_list_kernel_slow(
463     const Scalar& self,
464     TensorList exponent) {
465   check_foreach_api_restrictions(exponent);
466   std::vector<Tensor> result;
467   result.reserve(exponent.size());
468   for (const auto& t : exponent) {
469     result.emplace_back(at::pow(self, t));
470   }
471   return result;
472 }
473 
474 } // namespace at::native
475