xref: /aosp_15_r20/external/executorch/kernels/prim_ops/et_view.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/prim_ops/et_view.h>
10 
11 #include <cstring>
12 
13 #include <executorch/runtime/core/array_ref.h>
14 #include <executorch/runtime/core/exec_aten/exec_aten.h>
15 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
16 #include <executorch/runtime/platform/assert.h>
17 
18 using exec_aten::SizesType;
19 using exec_aten::Tensor;
20 using torch::executor::Error;
21 using torch::executor::resize_tensor;
22 
23 namespace torch {
24 namespace executor {
25 namespace function {
26 
27 constexpr size_t kTensorDimensionLimit = 16;
28 
29 namespace {
get_view_target_size(const exec_aten::Tensor self,exec_aten::ArrayRef<int64_t> size,int64_t dim,exec_aten::SizesType * out_size)30 bool get_view_target_size(
31     const exec_aten::Tensor self,
32     exec_aten::ArrayRef<int64_t> size,
33     int64_t dim,
34     exec_aten::SizesType* out_size) {
35   ET_LOG_AND_RETURN_IF_FALSE(size.size() == dim);
36   int minus1_dim = -1;
37   int n_zero = 0;
38   int64_t numel_without_minus_1 = 1;
39   for (int i = 0; i < dim; i++) {
40     if (size[i] == -1) {
41       ET_LOG_MSG_AND_RETURN_IF_FALSE(
42           minus1_dim == -1, "At most one view dim can be -1.");
43       minus1_dim = i;
44     } else {
45       // The size[i] must be non-negative now, but we check size[i] >= -1
46       // in case code is reordered in the future.
47       ET_LOG_MSG_AND_RETURN_IF_FALSE(
48           size[i] >= -1, "Negative sizes are not allowed.");
49 
50       numel_without_minus_1 *= size[i];
51       out_size[i] = static_cast<exec_aten::SizesType>(size[i]);
52 
53       if (size[i] == 0) {
54         n_zero++;
55       }
56     }
57   }
58   if (minus1_dim >= 0) {
59     ET_LOG_MSG_AND_RETURN_IF_FALSE(
60         n_zero == 0, "Cannot infer dimension size if there is a zero dim.");
61     out_size[minus1_dim] = self.numel() / numel_without_minus_1;
62   }
63   return true;
64 }
65 } // namespace
66 
et_view(KernelRuntimeContext & context,EValue ** stack)67 void et_view(KernelRuntimeContext& context, EValue** stack) {
68   (void)context;
69 
70   auto self = (*stack[0]).toTensor();
71   auto size = (*stack[1]).toIntList();
72   auto out = (*stack[2]).toTensor();
73 
74   ET_CHECK(tensors_have_same_dtype(self, out));
75 
76   // Compute output size
77   SizesType expected_output_size[kTensorDimensionLimit];
78   ET_CHECK(get_view_target_size(self, size, out.dim(), expected_output_size));
79 
80   // Resize for dynamic shape
81   ET_CHECK_MSG(
82       resize_tensor(
83           out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
84           Error::Ok,
85       "Failed to resize output tensor.");
86 
87   // Do some checks
88   ET_CHECK(self.numel() == out.numel());
89 
90   // Update data ptr
91   ET_CHECK_MSG(
92       internal::set_tensor_data(
93           out,
94           /*buffer=*/self.mutable_data_ptr(),
95           /*buffer_size=*/out.nbytes()) == Error::Ok,
96       "Failed to set data_ptr for out to self.");
97 }
98 
99 } // namespace function
100 } // namespace executor
101 } // namespace torch
102