1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
17 using Tensor = exec_aten::Tensor;
18
19 // clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out)
20 // -> Tensor(a!)
clone_out(KernelRuntimeContext & context,const Tensor & self,exec_aten::optional<exec_aten::MemoryFormat> memory_format,Tensor & out)21 Tensor& clone_out(
22 KernelRuntimeContext& context,
23 const Tensor& self,
24 exec_aten::optional<exec_aten::MemoryFormat> memory_format,
25 Tensor& out) {
26 (void)context;
27
28 ET_KERNEL_CHECK(
29 context,
30 resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
31 InvalidArgument,
32 out);
33
34 // The input and out shall share same dtype and size
35 ET_KERNEL_CHECK(
36 context,
37 tensors_have_same_shape_and_dtype(self, out),
38 InvalidArgument,
39 out);
40
41 ET_KERNEL_CHECK(
42 context, tensors_have_same_dim_order(self, out), InvalidArgument, out);
43
44 // Right now we only focus on contiguous memory, memory_format shall always
45 // either a nullopt or exec::aten::MemoryFormat::Contiguous
46 ET_KERNEL_CHECK(
47 context,
48 !memory_format.has_value() ||
49 memory_format.value() == MemoryFormat::Contiguous,
50 InvalidArgument,
51 out);
52
53 if (self.nbytes() > 0) {
54 // Note that this check is important. It's valid for a tensor with numel 0
55 // to have a null data pointer, but in some environments it's invalid to
56 // pass a null pointer to memcpy() even when the size is zero.
57 memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes());
58 }
59 return out;
60 }
61
62 } // namespace native
63 } // namespace executor
64 } // namespace torch
65