xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_constant_pad_nd.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <cstring>
11 
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 
14 #include <executorch/kernels/portable/cpu/scalar_utils.h>
15 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
16 
17 namespace torch {
18 namespace executor {
19 namespace native {
20 
21 namespace {
22 
23 template <typename CTYPE>
set_all_to_value(CTYPE * out_data,size_t step_len,CTYPE value)24 void set_all_to_value(CTYPE* out_data, size_t step_len, CTYPE value) {
25   for (size_t i = 0; i < step_len; ++i) {
26     out_data[i] = value;
27   }
28 }
29 
30 template <typename CTYPE>
apply_padding_to_dim(size_t ndim,const CTYPE * self_data,IntArrayRef self_sizes,IntArrayRef self_strides,CTYPE * out_data,IntArrayRef out_sizes,IntArrayRef out_strides,IntArrayRef pad,const CTYPE value,size_t last_padded_dim,size_t dim)31 void apply_padding_to_dim(
32     size_t ndim,
33     const CTYPE* self_data,
34     IntArrayRef self_sizes,
35     IntArrayRef self_strides,
36     CTYPE* out_data,
37     IntArrayRef out_sizes,
38     IntArrayRef out_strides,
39     IntArrayRef pad,
40     const CTYPE value,
41     size_t last_padded_dim,
42     size_t dim) {
43   if (dim >= ndim) {
44     return;
45   }
46 
47   size_t pad_i = ndim - 1 - dim;
48 
49   size_t pad_before = 0;
50   size_t pad_after = 0;
51   if (pad_i >= 0 && pad_i < pad.size() / 2) {
52     pad_before = pad[2 * pad_i];
53     pad_after = pad[2 * pad_i + 1];
54   }
55 
56   size_t out_step_len = out_strides[dim];
57   size_t in_step_len = self_strides[dim];
58 
59   for (size_t i = 0; i < pad_before; ++i) {
60     set_all_to_value(out_data, out_step_len, value);
61     out_data += out_step_len;
62   }
63 
64   // If subsequent dims are not padded, then the whole block of memory can be
65   // copied.
66   if (dim >= last_padded_dim) {
67     size_t copy_len = in_step_len * self_sizes[dim];
68     size_t copy_nbytes = copy_len * sizeof(CTYPE);
69 
70     if (copy_nbytes > 0) {
71       memcpy(out_data, self_data, copy_nbytes);
72       out_data += copy_len;
73       self_data += copy_len;
74     }
75   }
76   // Otherwise, call this function recursively
77   else {
78     for (size_t i = 0; i < self_sizes[dim]; ++i) {
79       apply_padding_to_dim(
80           ndim,
81           self_data,
82           self_sizes,
83           self_strides,
84           out_data,
85           out_sizes,
86           out_strides,
87           pad,
88           value,
89           last_padded_dim,
90           dim + 1);
91 
92       out_data += out_step_len;
93       self_data += in_step_len;
94     }
95   }
96 
97   for (int i = 0; i < pad_after; ++i) {
98     set_all_to_value(out_data, out_step_len, value);
99     out_data += out_step_len;
100   }
101 }
102 
103 template <typename CTYPE>
constant_pad_nd_out_impl(const Tensor & self,IntArrayRef pad,CTYPE value_v,Tensor & out)104 void constant_pad_nd_out_impl(
105     const Tensor& self,
106     IntArrayRef pad,
107     CTYPE value_v,
108     Tensor& out) {
109   const CTYPE* self_data = self.const_data_ptr<CTYPE>();
110   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
111 
112   size_t ndim = self.dim();
113 
114   if (ndim == 0) {
115     out_data[0] = self_data[0];
116     return;
117   }
118 
119   int64_t self_sizes[kTensorDimensionLimit];
120   int64_t self_strides[kTensorDimensionLimit];
121   int64_t out_sizes[kTensorDimensionLimit];
122   int64_t out_strides[kTensorDimensionLimit];
123 
124   // Collect sizes and strides of input and output tensors and determine the
125   // last padded dimension
126   size_t last_padded_dim = 0;
127   for (size_t i = 0; i < ndim; ++i) {
128     self_sizes[i] = self.size(i);
129     self_strides[i] = getTrailingDims(self, static_cast<int64_t>(i));
130     out_sizes[i] = out.size(i);
131     out_strides[i] = getTrailingDims(out, static_cast<int64_t>(i));
132 
133     size_t pad_i = ndim - 1 - i;
134     if (pad_i >= 0 && pad_i < pad.size() / 2) {
135       if (pad[2 * pad_i] + pad[2 * pad_i + 1] > 0) {
136         last_padded_dim = i;
137       }
138     }
139   }
140 
141   IntArrayRef self_sizes_ref(self_sizes, ndim);
142   IntArrayRef self_strides_ref(self_strides, ndim);
143   IntArrayRef out_sizes_ref(out_sizes, ndim);
144   IntArrayRef out_strides_ref(out_strides, ndim);
145 
146   apply_padding_to_dim(
147       ndim,
148       self_data,
149       self_sizes_ref,
150       self_strides_ref,
151       out_data,
152       out_sizes_ref,
153       out_strides_ref,
154       pad,
155       value_v,
156       last_padded_dim,
157       0);
158 }
159 
160 } // namespace
161 
constant_pad_nd_out(KernelRuntimeContext & ctx,const Tensor & in,IntArrayRef pad,const Scalar & value,Tensor & out)162 Tensor& constant_pad_nd_out(
163     KernelRuntimeContext& ctx,
164     const Tensor& in,
165     IntArrayRef pad,
166     const Scalar& value,
167     Tensor& out) {
168   (void)ctx;
169 
170   ET_KERNEL_CHECK(
171       ctx, check_constant_pad_args(in, pad, value, out), InvalidArgument, out);
172 
173   ET_KERNEL_CHECK(
174       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
175 
176   // resize out tensor for dynamic shapes
177   ET_KERNEL_CHECK_MSG(
178       ctx,
179       resize_constant_pad_output(in, pad, out) == Error::Ok,
180       InvalidArgument,
181       out,
182       "Failed to resize output tensor.");
183 
184   ScalarType in_type = in.scalar_type();
185   ScalarType value_type = utils::get_scalar_dtype(value);
186 
187   ET_SWITCH_REAL_TYPES_AND(
188       Bool, in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
189         CTYPE value_v;
190         ET_SWITCH_SCALAR_OBJ_TYPES(
191             value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
192               CTYPE_VALUE val;
193               utils::extract_scalar(value, &val);
194               value_v = static_cast<CTYPE>(val);
195             });
196         constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
197       });
198 
199   return out;
200 }
201 
202 } // namespace native
203 } // namespace executor
204 } // namespace torch
205