1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17
18 using Tensor = exec_aten::Tensor;
19
stack_out(KernelRuntimeContext & ctx,exec_aten::ArrayRef<Tensor> tensors,int64_t dim,Tensor & out)20 Tensor& stack_out(
21 KernelRuntimeContext& ctx,
22 exec_aten::ArrayRef<Tensor> tensors,
23 int64_t dim,
24 Tensor& out) {
25 (void)ctx;
26
27 if (dim < 0) {
28 dim += out.dim();
29 }
30
31 ET_KERNEL_CHECK(
32 ctx, check_stack_args(tensors, dim, out), InvalidArgument, out);
33
34 for (size_t i = 0; i < tensors.size(); ++i) {
35 ET_KERNEL_CHECK(
36 ctx,
37 tensors_have_same_dim_order(tensors[i], out),
38 InvalidArgument,
39 out);
40 }
41
42 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);
43
44 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
45 size_t expected_out_dim = 0;
46 get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
47 ET_KERNEL_CHECK(
48 ctx,
49 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
50 InvalidArgument,
51 out);
52
53 const size_t outer = getLeadingDims(out, dim);
54 const size_t inner = getTrailingDims(out, dim);
55 const size_t ninputs = tensors.size();
56
57 const auto out_type = out.scalar_type();
58 ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "stack.out", CTYPE_OUT, [&] {
59 CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
60 for (size_t i = 0; i < outer; ++i) {
61 for (size_t j = 0; j < ninputs; ++j) {
62 const auto in_type = tensors[j].scalar_type();
63 ET_SWITCH_REAL_TYPES_AND(
64 Bool, in_type, ctx, "stack.out", CTYPE_IN, [&] {
65 const CTYPE_IN* const in_ptr =
66 tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;
67
68 for (size_t k = 0; k < inner; ++k) {
69 out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
70 }
71 out_ptr += inner;
72 });
73 }
74 }
75 });
76
77 return out;
78 }
79
80 } // namespace native
81 } // namespace executor
82 } // namespace torch
83