1 #include <vector>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/ForeachUtils.h>
5 #include <c10/util/irange.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #include <ATen/Operators.h>
11 #else
12 #include <ATen/ops/_foreach_abs_native.h>
13 #include <ATen/ops/_foreach_acos_native.h>
14 #include <ATen/ops/_foreach_add_native.h>
15 #include <ATen/ops/_foreach_addcdiv_native.h>
16 #include <ATen/ops/_foreach_addcmul_native.h>
17 #include <ATen/ops/_foreach_asin_native.h>
18 #include <ATen/ops/_foreach_atan_native.h>
19 #include <ATen/ops/_foreach_ceil_native.h>
20 #include <ATen/ops/_foreach_clamp_max_native.h>
21 #include <ATen/ops/_foreach_clamp_min_native.h>
22 #include <ATen/ops/_foreach_copy_native.h>
23 #include <ATen/ops/_foreach_cos_native.h>
24 #include <ATen/ops/_foreach_cosh_native.h>
25 #include <ATen/ops/_foreach_div_native.h>
26 #include <ATen/ops/_foreach_erf_native.h>
27 #include <ATen/ops/_foreach_erfc_native.h>
28 #include <ATen/ops/_foreach_exp_native.h>
29 #include <ATen/ops/_foreach_expm1_native.h>
30 #include <ATen/ops/_foreach_floor_native.h>
31 #include <ATen/ops/_foreach_frac_native.h>
32 #include <ATen/ops/_foreach_lerp_native.h>
33 #include <ATen/ops/_foreach_lgamma_native.h>
34 #include <ATen/ops/_foreach_log10_native.h>
35 #include <ATen/ops/_foreach_log1p_native.h>
36 #include <ATen/ops/_foreach_log2_native.h>
37 #include <ATen/ops/_foreach_log_native.h>
38 #include <ATen/ops/_foreach_max_native.h>
39 #include <ATen/ops/_foreach_maximum_native.h>
40 #include <ATen/ops/_foreach_minimum_native.h>
41 #include <ATen/ops/_foreach_mul_native.h>
42 #include <ATen/ops/_foreach_neg_native.h>
43 #include <ATen/ops/_foreach_norm_native.h>
44 #include <ATen/ops/_foreach_pow_native.h>
45 #include <ATen/ops/_foreach_reciprocal_native.h>
46 #include <ATen/ops/_foreach_round_native.h>
47 #include <ATen/ops/_foreach_sigmoid_native.h>
48 #include <ATen/ops/_foreach_sign_native.h>
49 #include <ATen/ops/_foreach_sin_native.h>
50 #include <ATen/ops/_foreach_sinh_native.h>
51 #include <ATen/ops/_foreach_sqrt_native.h>
52 #include <ATen/ops/_foreach_sub_native.h>
53 #include <ATen/ops/_foreach_tan_native.h>
54 #include <ATen/ops/_foreach_tanh_native.h>
55 #include <ATen/ops/_foreach_trunc_native.h>
56 #include <ATen/ops/_foreach_zero_native.h>
57 #include <ATen/ops/copy.h>
58 #include <ATen/ops/linalg_vector_norm.h>
59 #include <ATen/ops/max.h>
60 #include <ATen/ops/maximum.h>
61 #include <ATen/ops/minimum.h>
62 #include <ATen/ops/pow.h>
63 #endif
64
65 namespace at::native {
66
67 #define FOREACH_BINARY_OP_TENSOR(OP) \
68 void foreach_tensor_##OP##_tensor_kernel_slow_( \
69 TensorList tensors, const Tensor& scalar) { \
70 TORCH_CHECK( \
71 scalar.dim() == 0 && scalar.numel() == 1, \
72 "scalar tensor expected to be 0 dim but it has ", \
73 scalar.dim(), \
74 " dimensions and ", \
75 scalar.numel(), \
76 " elements."); \
77 check_foreach_api_restrictions(tensors); \
78 \
79 for (auto& t : tensors) { \
80 t.OP##_(scalar); \
81 } \
82 } \
83 \
84 std::vector<Tensor> foreach_tensor_##OP##_tensor_kernel_slow( \
85 TensorList tensors, const Tensor& scalar) { \
86 TORCH_CHECK( \
87 scalar.dim() == 0 && scalar.numel() == 1, \
88 "scalar tensor expected to be 0 dim but it has ", \
89 scalar.dim(), \
90 " dimensions and ", \
91 scalar.numel(), \
92 " elements."); \
93 check_foreach_api_restrictions(tensors); \
94 \
95 std::vector<Tensor> result; \
96 result.reserve(tensors.size()); \
97 for (const auto& t : tensors) { \
98 result.emplace_back(t.OP(scalar)); \
99 } \
100 \
101 return result; \
102 }
103
104 #define FOREACH_BINARY_OP_TENSOR_ALPHA(OP) \
105 void foreach_tensor_##OP##_tensor_kernel_slow_( \
106 TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
107 TORCH_CHECK( \
108 scalar.dim() == 0 && scalar.numel() == 1, \
109 "scalar tensor expected to be 0 dim but it has ", \
110 scalar.dim(), \
111 " dimensions and ", \
112 scalar.numel(), \
113 " elements."); \
114 check_foreach_api_restrictions(tensors); \
115 \
116 for (auto& t : tensors) { \
117 t.OP##_(scalar, alpha); \
118 } \
119 } \
120 \
121 std::vector<Tensor> foreach_tensor_##OP##_tensor_kernel_slow( \
122 TensorList tensors, const Tensor& scalar, const Scalar& alpha) { \
123 TORCH_CHECK( \
124 scalar.dim() == 0 && scalar.numel() == 1, \
125 "scalar tensor expected to be 0 dim but it has ", \
126 scalar.dim(), \
127 " dimensions and ", \
128 scalar.numel(), \
129 " elements."); \
130 check_foreach_api_restrictions(tensors); \
131 \
132 std::vector<Tensor> result; \
133 result.reserve(tensors.size()); \
134 for (const auto& t : tensors) { \
135 result.emplace_back(t.OP(scalar, alpha)); \
136 } \
137 \
138 return result; \
139 }
140
141 #define FOREACH_BINARY_OP_SCALAR(OP) \
142 void foreach_tensor_##OP##_scalar_kernel_slow_( \
143 TensorList tensors, const Scalar& scalar) { \
144 check_foreach_api_restrictions(tensors); \
145 \
146 for (auto& t : tensors) { \
147 t.OP##_(scalar); \
148 } \
149 } \
150 \
151 std::vector<Tensor> foreach_tensor_##OP##_scalar_kernel_slow( \
152 TensorList tensors, const Scalar& scalar) { \
153 check_foreach_api_restrictions(tensors); \
154 \
155 std::vector<Tensor> result; \
156 result.reserve(tensors.size()); \
157 for (const auto& t : tensors) { \
158 result.emplace_back(t.OP(scalar)); \
159 } \
160 \
161 return result; \
162 }
163
164 #define FOREACH_BINARY_OP_SCALARLIST(OP) \
165 void foreach_tensor_##OP##_scalarlist_kernel_slow_( \
166 TensorList tensors, at::ArrayRef<Scalar> scalars) { \
167 check_foreach_api_restrictions(tensors, scalars); \
168 \
169 for (const auto i : c10::irange(tensors.size())) { \
170 tensors[i].OP##_(scalars[i]); \
171 } \
172 } \
173 \
174 std::vector<Tensor> foreach_tensor_##OP##_scalarlist_kernel_slow( \
175 TensorList tensors, at::ArrayRef<Scalar> scalars) { \
176 check_foreach_api_restrictions(tensors, scalars); \
177 std::vector<Tensor> result; \
178 result.reserve(tensors.size()); \
179 for (const auto i : c10::irange(tensors.size())) { \
180 result.emplace_back(tensors[i].OP(scalars[i])); \
181 } \
182 \
183 return result; \
184 }
185
186 #define FOREACH_BINARY_OP_LIST(OP) \
187 std::vector<Tensor> foreach_tensor_##OP##_list_kernel_slow( \
188 TensorList tensors1, TensorList tensors2) { \
189 check_foreach_api_restrictions(tensors1, tensors2); \
190 \
191 std::vector<Tensor> result; \
192 result.reserve(tensors1.size()); \
193 for (const auto i : c10::irange(tensors1.size())) { \
194 result.emplace_back(tensors1[i].OP(tensors2[i])); \
195 } \
196 \
197 return result; \
198 } \
199 \
200 void foreach_tensor_##OP##_list_kernel_slow_( \
201 TensorList tensors1, TensorList tensors2) { \
202 check_foreach_api_restrictions(tensors1, tensors2); \
203 \
204 for (const auto i : c10::irange(tensors1.size())) { \
205 tensors1[i].OP##_(tensors2[i]); \
206 } \
207 }
208
209 #define FOREACH_BINARY_OP_LIST_ALPHA(OP) \
210 std::vector<Tensor> foreach_tensor_##OP##_list_kernel_slow( \
211 TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
212 check_foreach_api_restrictions(tensors1, tensors2); \
213 \
214 std::vector<Tensor> result; \
215 result.reserve(tensors1.size()); \
216 for (const auto i : c10::irange(tensors1.size())) { \
217 result.emplace_back(tensors1[i].OP(tensors2[i], alpha)); \
218 } \
219 \
220 return result; \
221 } \
222 \
223 void foreach_tensor_##OP##_list_kernel_slow_( \
224 TensorList tensors1, TensorList tensors2, const Scalar& alpha) { \
225 check_foreach_api_restrictions(tensors1, tensors2); \
226 \
227 for (const auto i : c10::irange(tensors1.size())) { \
228 tensors1[i].OP##_(tensors2[i], alpha); \
229 } \
230 }
231
232 #define FOREACH_UNARY_OP(OP) \
233 std::vector<Tensor> foreach_tensor_##OP##_slow(TensorList tensors) { \
234 check_foreach_api_restrictions(tensors); \
235 \
236 std::vector<Tensor> result; \
237 result.reserve(tensors.size()); \
238 for (const auto& t : tensors) { \
239 result.emplace_back(t.OP()); \
240 } \
241 \
242 return result; \
243 } \
244 \
245 void foreach_tensor_##OP##_slow_(TensorList tensors) { \
246 check_foreach_api_restrictions(tensors); \
247 \
248 for (auto& t : tensors) { \
249 t.OP##_(); \
250 } \
251 }
252
253 #define FOREACH_POINTWISE_OP_SCALAR(OP) \
254 std::vector<Tensor> foreach_tensor_##OP##_scalar_slow( \
255 TensorList input, \
256 TensorList tensors1, \
257 TensorList tensors2, \
258 const Scalar& scalar) { \
259 check_foreach_api_restrictions(input, tensors1, tensors2); \
260 \
261 std::vector<Tensor> result; \
262 for (const auto i : c10::irange(input.size())) { \
263 result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalar)); \
264 } \
265 \
266 return result; \
267 } \
268 \
269 void foreach_tensor_##OP##_scalar_slow_( \
270 TensorList input, \
271 TensorList tensors1, \
272 TensorList tensors2, \
273 const Scalar& scalar) { \
274 check_foreach_api_restrictions(input, tensors1, tensors2); \
275 \
276 for (const auto i : c10::irange(input.size())) { \
277 input[i].OP##_(tensors1[i], tensors2[i], scalar); \
278 } \
279 }
280
281 #define FOREACH_POINTWISE_OP_SCALARLIST(OP) \
282 std::vector<Tensor> foreach_tensor_##OP##_scalarlist_slow( \
283 TensorList input, \
284 TensorList tensors1, \
285 TensorList tensors2, \
286 at::ArrayRef<Scalar> scalars) { \
287 check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
288 \
289 std::vector<Tensor> result; \
290 for (const auto i : c10::irange(input.size())) { \
291 result.emplace_back(input[i].OP(tensors1[i], tensors2[i], scalars[i])); \
292 } \
293 \
294 return result; \
295 } \
296 \
297 void foreach_tensor_##OP##_scalarlist_slow_( \
298 TensorList input, \
299 TensorList tensors1, \
300 TensorList tensors2, \
301 at::ArrayRef<Scalar> scalars) { \
302 check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
303 \
304 for (const auto i : c10::irange(input.size())) { \
305 input[i].OP##_(tensors1[i], tensors2[i], scalars[i]); \
306 } \
307 }
308
309 #define FOREACH_POINTWISE_OP_TENSOR(OP) \
310 std::vector<Tensor> foreach_tensor_##OP##_tensor_slow( \
311 TensorList input, \
312 TensorList tensors1, \
313 TensorList tensors2, \
314 const Tensor& scalars_) { \
315 auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \
316 check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
317 return foreach_tensor_##OP##_scalarlist_slow( \
318 input, tensors1, tensors2, scalars); \
319 } \
320 \
321 void foreach_tensor_##OP##_tensor_slow_( \
322 TensorList input, \
323 TensorList tensors1, \
324 TensorList tensors2, \
325 const Tensor& scalars_) { \
326 auto scalars = convert_tensor_to_scalar_list(scalars_, input.size()); \
327 check_foreach_api_restrictions(input, tensors1, tensors2, scalars); \
328 foreach_tensor_##OP##_scalarlist_slow_( \
329 input, tensors1, tensors2, scalars); \
330 }
331
332 FOREACH_BINARY_OP_LIST_ALPHA(add);
333 FOREACH_BINARY_OP_LIST_ALPHA(sub);
334 FOREACH_BINARY_OP_LIST_ALPHA(lerp);
335
336 FOREACH_BINARY_OP_TENSOR_ALPHA(add);
337 FOREACH_BINARY_OP_TENSOR(mul);
338 FOREACH_BINARY_OP_TENSOR(div);
339
340 FOREACH_BINARY_OP_SCALAR(add);
341 FOREACH_BINARY_OP_SCALAR(sub);
342 FOREACH_BINARY_OP_SCALAR(mul);
343 FOREACH_BINARY_OP_SCALAR(div);
344 FOREACH_BINARY_OP_SCALAR(clamp_min);
345 FOREACH_BINARY_OP_SCALAR(clamp_max);
346 FOREACH_BINARY_OP_SCALAR(pow);
347
348 FOREACH_BINARY_OP_SCALARLIST(add);
349 FOREACH_BINARY_OP_SCALARLIST(sub);
350 FOREACH_BINARY_OP_SCALARLIST(mul);
351 FOREACH_BINARY_OP_SCALARLIST(div);
352 FOREACH_BINARY_OP_SCALARLIST(clamp_min);
353 FOREACH_BINARY_OP_SCALARLIST(clamp_max);
354 FOREACH_BINARY_OP_SCALARLIST(pow);
355
356 FOREACH_BINARY_OP_LIST(mul);
357 FOREACH_BINARY_OP_LIST(div);
358 FOREACH_BINARY_OP_LIST(clamp_min);
359 FOREACH_BINARY_OP_LIST(clamp_max);
360 FOREACH_BINARY_OP_LIST(pow);
361 // _foreach_copy_
foreach_tensor_copy_list_kernel_slow_(TensorList self,TensorList src,const bool non_blocking)362 void foreach_tensor_copy_list_kernel_slow_(
363 TensorList self,
364 TensorList src,
365 const bool non_blocking) {
366 check_foreach_api_restrictions(self, src);
367
368 for (const auto i : c10::irange(self.size())) {
369 self[i].copy_(src[i], non_blocking);
370 }
371 }
372
373 FOREACH_UNARY_OP(sqrt);
374 FOREACH_UNARY_OP(exp);
375 FOREACH_UNARY_OP(abs);
376 FOREACH_UNARY_OP(acos);
377 FOREACH_UNARY_OP(asin);
378 FOREACH_UNARY_OP(atan);
379 FOREACH_UNARY_OP(ceil);
380 FOREACH_UNARY_OP(cos);
381 FOREACH_UNARY_OP(cosh);
382 FOREACH_UNARY_OP(erf);
383 FOREACH_UNARY_OP(erfc);
384 FOREACH_UNARY_OP(expm1);
385 FOREACH_UNARY_OP(floor);
386 FOREACH_UNARY_OP(log);
387 FOREACH_UNARY_OP(log10);
388 FOREACH_UNARY_OP(log1p);
389 FOREACH_UNARY_OP(log2);
390 FOREACH_UNARY_OP(neg);
391 FOREACH_UNARY_OP(tan);
392 FOREACH_UNARY_OP(tanh);
393 FOREACH_UNARY_OP(sin);
394 FOREACH_UNARY_OP(sinh);
395 FOREACH_UNARY_OP(round);
396 FOREACH_UNARY_OP(lgamma);
397 FOREACH_UNARY_OP(frac);
398 FOREACH_UNARY_OP(trunc);
399 FOREACH_UNARY_OP(reciprocal);
400 FOREACH_UNARY_OP(sigmoid);
401 FOREACH_UNARY_OP(sign);
402
403 FOREACH_POINTWISE_OP_SCALAR(addcdiv);
404 FOREACH_POINTWISE_OP_SCALAR(addcmul);
405
406 FOREACH_POINTWISE_OP_SCALARLIST(addcdiv);
407 FOREACH_POINTWISE_OP_SCALARLIST(addcmul);
408
409 FOREACH_POINTWISE_OP_TENSOR(addcdiv);
410 FOREACH_POINTWISE_OP_TENSOR(addcmul);
411
412 #define FOREACH_TERNARY_OP(OP) \
413 std::vector<Tensor> foreach_tensor_ternary_##OP##_slow( \
414 TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
415 check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
416 std::vector<Tensor> result; \
417 for (const auto i : c10::irange(tensors1.size())) { \
418 result.emplace_back(tensors1[i].OP(tensors2[i], tensors3[i])); \
419 } \
420 return result; \
421 } \
422 \
423 void foreach_tensor_ternary_##OP##_slow_( \
424 TensorList tensors1, TensorList tensors2, TensorList tensors3) { \
425 check_foreach_api_restrictions(tensors1, tensors2, tensors3); \
426 for (const auto i : c10::irange(tensors1.size())) { \
427 tensors1[i].OP##_(tensors2[i], tensors3[i]); \
428 } \
429 }
430
431 FOREACH_TERNARY_OP(lerp);
432
foreach_tensor_zero_slow_(TensorList tensors)433 void foreach_tensor_zero_slow_(TensorList tensors) {
434 check_foreach_api_restrictions(tensors);
435
436 for (auto& t : tensors) {
437 t.zero_();
438 }
439 }
440
foreach_tensor_norm_slow(TensorList tensors,const Scalar & ord,std::optional<ScalarType> dtype)441 std::vector<Tensor> foreach_tensor_norm_slow(
442 TensorList tensors,
443 const Scalar& ord,
444 std::optional<ScalarType> dtype) {
445 check_foreach_api_restrictions(tensors);
446 std::vector<Tensor> result;
447 for (const auto& t : tensors) {
448 result.emplace_back(at::linalg_vector_norm(t, ord, {}, false, dtype));
449 }
450 return result;
451 }
452
foreach_tensor_max_slow(TensorList tensors)453 std::vector<Tensor> foreach_tensor_max_slow(TensorList tensors) {
454 check_foreach_api_restrictions(tensors);
455 std::vector<Tensor> result;
456 for (const auto& t : tensors) {
457 result.emplace_back(at::max(t));
458 }
459 return result;
460 }
461
foreach_scalar_pow_list_kernel_slow(const Scalar & self,TensorList exponent)462 std::vector<Tensor> foreach_scalar_pow_list_kernel_slow(
463 const Scalar& self,
464 TensorList exponent) {
465 check_foreach_api_restrictions(exponent);
466 std::vector<Tensor> result;
467 result.reserve(exponent.size());
468 for (const auto& t : exponent) {
469 result.emplace_back(at::pow(self, t));
470 }
471 return result;
472 }
473
474 } // namespace at::native
475