1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/prim_ops/et_view.h>
10
11 #include <cstring>
12
13 #include <executorch/runtime/core/array_ref.h>
14 #include <executorch/runtime/core/exec_aten/exec_aten.h>
15 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
16 #include <executorch/runtime/platform/assert.h>
17
18 using exec_aten::SizesType;
19 using exec_aten::Tensor;
20 using torch::executor::Error;
21 using torch::executor::resize_tensor;
22
23 namespace torch {
24 namespace executor {
25 namespace function {
26
27 constexpr size_t kTensorDimensionLimit = 16;
28
29 namespace {
get_view_target_size(const exec_aten::Tensor self,exec_aten::ArrayRef<int64_t> size,int64_t dim,exec_aten::SizesType * out_size)30 bool get_view_target_size(
31 const exec_aten::Tensor self,
32 exec_aten::ArrayRef<int64_t> size,
33 int64_t dim,
34 exec_aten::SizesType* out_size) {
35 ET_LOG_AND_RETURN_IF_FALSE(size.size() == dim);
36 int minus1_dim = -1;
37 int n_zero = 0;
38 int64_t numel_without_minus_1 = 1;
39 for (int i = 0; i < dim; i++) {
40 if (size[i] == -1) {
41 ET_LOG_MSG_AND_RETURN_IF_FALSE(
42 minus1_dim == -1, "At most one view dim can be -1.");
43 minus1_dim = i;
44 } else {
45 // The size[i] must be non-negative now, but we check size[i] >= -1
46 // in case code is reordered in the future.
47 ET_LOG_MSG_AND_RETURN_IF_FALSE(
48 size[i] >= -1, "Negative sizes are not allowed.");
49
50 numel_without_minus_1 *= size[i];
51 out_size[i] = static_cast<exec_aten::SizesType>(size[i]);
52
53 if (size[i] == 0) {
54 n_zero++;
55 }
56 }
57 }
58 if (minus1_dim >= 0) {
59 ET_LOG_MSG_AND_RETURN_IF_FALSE(
60 n_zero == 0, "Cannot infer dimension size if there is a zero dim.");
61 out_size[minus1_dim] = self.numel() / numel_without_minus_1;
62 }
63 return true;
64 }
65 } // namespace
66
et_view(KernelRuntimeContext & context,EValue ** stack)67 void et_view(KernelRuntimeContext& context, EValue** stack) {
68 (void)context;
69
70 auto self = (*stack[0]).toTensor();
71 auto size = (*stack[1]).toIntList();
72 auto out = (*stack[2]).toTensor();
73
74 ET_CHECK(tensors_have_same_dtype(self, out));
75
76 // Compute output size
77 SizesType expected_output_size[kTensorDimensionLimit];
78 ET_CHECK(get_view_target_size(self, size, out.dim(), expected_output_size));
79
80 // Resize for dynamic shape
81 ET_CHECK_MSG(
82 resize_tensor(
83 out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
84 Error::Ok,
85 "Failed to resize output tensor.");
86
87 // Do some checks
88 ET_CHECK(self.numel() == out.numel());
89
90 // Update data ptr
91 ET_CHECK_MSG(
92 internal::set_tensor_data(
93 out,
94 /*buffer=*/self.mutable_data_ptr(),
95 /*buffer_size=*/out.nbytes()) == Error::Ok,
96 "Failed to set data_ptr for out to self.");
97 }
98
99 } // namespace function
100 } // namespace executor
101 } // namespace torch
102