1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
19 using Tensor = exec_aten::Tensor;
20
21 // copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22 // out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23 // TODO: We actually shouldn't see this op with the proper functionalization,
24 // and this op needs to be deleted
copy_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & src,bool non_blocking,Tensor & out)25 Tensor& copy_out(
26 KernelRuntimeContext& ctx,
27 const Tensor& in,
28 const Tensor& src,
29 bool non_blocking,
30 Tensor& out) {
31 (void)ctx;
32 // Right now we only support blocking data transfer
33 ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
34
35 ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
36
37 ET_KERNEL_CHECK(
38 ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
39
40 ET_KERNEL_CHECK(
41 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
42
43 ET_KERNEL_CHECK(
44 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45
46 // @lint-ignore CLANGTIDY facebook-hte-CArray
47 static constexpr const char op_name[] = "copy.out";
48
49 ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50 utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51 [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52 ctx,
53 in,
54 utils::SupportedTensorDtypes::REALHBBF16,
55 src,
56 utils::SupportedTensorDtypes::REALHBBF16,
57 out,
58 utils::SupportedTensorDtypes::REALHBBF16);
59 });
60
61 return out;
62 }
63
copy_(KernelRuntimeContext & ctx,Tensor & in,const Tensor & src,bool non_blocking)64 Tensor& copy_(
65 KernelRuntimeContext& ctx,
66 Tensor& in,
67 const Tensor& src,
68 bool non_blocking) {
69 (void)ctx;
70 // Right now we only support blocking data transfer
71 ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
72
73 ET_KERNEL_CHECK(
74 ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
75
76 ET_KERNEL_CHECK(
77 ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
78
79 // @lint-ignore CLANGTIDY facebook-hte-CArray
80 static constexpr const char op_name[] = "copy_";
81
82 ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83 utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84 [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85 ctx,
86 in,
87 utils::SupportedTensorDtypes::REALHBBF16,
88 src,
89 utils::SupportedTensorDtypes::REALHBBF16,
90 in,
91 utils::SupportedTensorDtypes::REALHBBF16);
92 });
93
94 return in;
95 }
96
97 } // namespace native
98 } // namespace executor
99 } // namespace torch
100