xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_clone.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 
17 using Tensor = exec_aten::Tensor;
18 
19 // clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out)
20 // -> Tensor(a!)
clone_out(KernelRuntimeContext & context,const Tensor & self,exec_aten::optional<exec_aten::MemoryFormat> memory_format,Tensor & out)21 Tensor& clone_out(
22     KernelRuntimeContext& context,
23     const Tensor& self,
24     exec_aten::optional<exec_aten::MemoryFormat> memory_format,
25     Tensor& out) {
26   (void)context;
27 
28   ET_KERNEL_CHECK(
29       context,
30       resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
31       InvalidArgument,
32       out);
33 
34   // The input and out shall share same dtype and size
35   ET_KERNEL_CHECK(
36       context,
37       tensors_have_same_shape_and_dtype(self, out),
38       InvalidArgument,
39       out);
40 
41   ET_KERNEL_CHECK(
42       context, tensors_have_same_dim_order(self, out), InvalidArgument, out);
43 
44   // Right now we only focus on contiguous memory, memory_format shall always
45   // either a nullopt or exec::aten::MemoryFormat::Contiguous
46   ET_KERNEL_CHECK(
47       context,
48       !memory_format.has_value() ||
49           memory_format.value() == MemoryFormat::Contiguous,
50       InvalidArgument,
51       out);
52 
53   if (self.nbytes() > 0) {
54     // Note that this check is important. It's valid for a tensor with numel 0
55     // to have a null data pointer, but in some environments it's invalid to
56     // pass a null pointer to memcpy() even when the size is zero.
57     memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
58   }
59   return out;
60 }
61 
62 } // namespace native
63 } // namespace executor
64 } // namespace torch
65