xref: /aosp_15_r20/external/pytorch/aten/src/ATen/templates/CompositeViewCopyKernels.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 // ${generated_comment}
3 
4 #include <ATen/InferSize.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/native/Resize.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Operators.h>
10 #else
11 #include <ATen/ops/clone.h>
12 $ops_headers
13 #endif
14 
15 namespace at {
16 namespace native {
17 
18 // This file contains a number of kernels for aten functions that are fully code-generated.
19 // TODO: rename this file to something more generic.
20 
21 namespace {
clone_arg(const at::Tensor & t)22 at::Tensor clone_arg(const at::Tensor& t) {
23     return t.clone();
24 }
25 
clone_arg(const at::TensorList & t_list)26 std::vector<at::Tensor> clone_arg(const at::TensorList& t_list) {
27     std::vector<at::Tensor> out(t_list.size());
28     for (const auto& i : c10::irange(t_list.size())) {
29         out[i] = t_list[i].clone();
30     }
31     return out;
32 }
33 
34 // duped with gen_resize_out_helper from structured kernels
copy_arg(const at::Tensor & dst,const at::Tensor & src)35 void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
36     TORCH_CHECK(src.dtype() == dst.dtype(),
37         "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
38     TORCH_CHECK(src.device() == dst.device(),
39         "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
40     dst.copy_(src);
41 }
42 
copy_arg(const at::TensorList & dst,const at::TensorList & src)43 void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
44     TORCH_INTERNAL_ASSERT(dst.size() == src.size());
45     for (const auto& i : c10::irange(dst.size())) {
46         copy_arg(dst[i], src[i]);
47     }
48 }
49 
50 // TODO: this doesn't handle restriding empty tensors correctly; see
51 // gen_resize_out_helper for the correct algorithm
52 
resize_out_helper(const at::Tensor & dst,const at::Tensor & src)53 void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
54     at::native::resize_output(dst, src.sizes());
55 }
56 
resize_out_helper(const at::TensorList & dst,const at::TensorList & src)57 void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
58     TORCH_INTERNAL_ASSERT(dst.size() == src.size());
59     for (const auto& i : c10::irange(dst.size())) {
60         at::native::resize_output(dst[i], src[i].sizes());
61     }
62 }
63 }
64 
65 
66 ${CompositeViewCopyKernel_Definitions}
67 
68 ${GeneratedCompositeFunctional_Definitions}
69 
70 ${GeneratedCompositeOut_Definitions}
71 
72 } // namespace native
73 } // namespace at
74