1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/transpose_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cstring>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
17 using SizesType = exec_aten::SizesType;
18 using StridesType = exec_aten::StridesType;
19 using Tensor = exec_aten::Tensor;
20
21 /**
22 * Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.
23 * 0-D and 1-D tensors are returned as is. When input is a 2-D tensor this
24 * is equivalent to transpose(input, 0, 1).
25 * t_copy.out(Tensor self, Tensor(a!) out)
26 */
t_copy_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)27 Tensor& t_copy_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
28 (void)ctx;
29
30 ET_KERNEL_CHECK(ctx, check_t_copy_args(in, out), InvalidArgument, out);
31
32 ScalarType in_type = in.scalar_type();
33
34 if (in.dim() < 2) {
35 // Resize for dynamic shape
36 ET_KERNEL_CHECK(
37 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
38
39 if (in.numel() > 0) {
40 ET_SWITCH_ALL_TYPES(in_type, ctx, __func__, CTYPE, [&]() {
41 const CTYPE* in_data = in.const_data_ptr<CTYPE>();
42 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
43 memcpy(out_data, in_data, in.nbytes());
44 });
45 }
46
47 return out;
48 }
49
50 ET_KERNEL_CHECK(
51 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
52
53 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
54
55 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
56 size_t expected_out_dim = 0;
57 get_transpose_out_target_size(in, 1, 0, expected_out_size, &expected_out_dim);
58
59 // Resize for dynamic shape
60 ET_KERNEL_CHECK(
61 ctx,
62 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
63 InvalidArgument,
64 out);
65
66 ET_SWITCH_ALL_TYPES(in_type, ctx, __func__, CTYPE, [&] {
67 transpose_tensors<CTYPE>(in, 1, 0, out);
68 });
69
70 return out;
71 }
72
73 } // namespace native
74 } // namespace executor
75 } // namespace torch
76