1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10 #include <cstring>
11
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 #include <executorch/kernels/portable/cpu/scalar_utils.h>
15 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
16
17 namespace torch {
18 namespace executor {
19 namespace native {
20
21 namespace {
22
23 template <typename CTYPE>
set_all_to_value(CTYPE * out_data,size_t step_len,CTYPE value)24 void set_all_to_value(CTYPE* out_data, size_t step_len, CTYPE value) {
25 for (size_t i = 0; i < step_len; ++i) {
26 out_data[i] = value;
27 }
28 }
29
30 template <typename CTYPE>
apply_padding_to_dim(size_t ndim,const CTYPE * self_data,IntArrayRef self_sizes,IntArrayRef self_strides,CTYPE * out_data,IntArrayRef out_sizes,IntArrayRef out_strides,IntArrayRef pad,const CTYPE value,size_t last_padded_dim,size_t dim)31 void apply_padding_to_dim(
32 size_t ndim,
33 const CTYPE* self_data,
34 IntArrayRef self_sizes,
35 IntArrayRef self_strides,
36 CTYPE* out_data,
37 IntArrayRef out_sizes,
38 IntArrayRef out_strides,
39 IntArrayRef pad,
40 const CTYPE value,
41 size_t last_padded_dim,
42 size_t dim) {
43 if (dim >= ndim) {
44 return;
45 }
46
47 size_t pad_i = ndim - 1 - dim;
48
49 size_t pad_before = 0;
50 size_t pad_after = 0;
51 if (pad_i >= 0 && pad_i < pad.size() / 2) {
52 pad_before = pad[2 * pad_i];
53 pad_after = pad[2 * pad_i + 1];
54 }
55
56 size_t out_step_len = out_strides[dim];
57 size_t in_step_len = self_strides[dim];
58
59 for (size_t i = 0; i < pad_before; ++i) {
60 set_all_to_value(out_data, out_step_len, value);
61 out_data += out_step_len;
62 }
63
64 // If subsequent dims are not padded, then the whole block of memory can be
65 // copied.
66 if (dim >= last_padded_dim) {
67 size_t copy_len = in_step_len * self_sizes[dim];
68 size_t copy_nbytes = copy_len * sizeof(CTYPE);
69
70 if (copy_nbytes > 0) {
71 memcpy(out_data, self_data, copy_nbytes);
72 out_data += copy_len;
73 self_data += copy_len;
74 }
75 }
76 // Otherwise, call this function recursively
77 else {
78 for (size_t i = 0; i < self_sizes[dim]; ++i) {
79 apply_padding_to_dim(
80 ndim,
81 self_data,
82 self_sizes,
83 self_strides,
84 out_data,
85 out_sizes,
86 out_strides,
87 pad,
88 value,
89 last_padded_dim,
90 dim + 1);
91
92 out_data += out_step_len;
93 self_data += in_step_len;
94 }
95 }
96
97 for (int i = 0; i < pad_after; ++i) {
98 set_all_to_value(out_data, out_step_len, value);
99 out_data += out_step_len;
100 }
101 }
102
103 template <typename CTYPE>
constant_pad_nd_out_impl(const Tensor & self,IntArrayRef pad,CTYPE value_v,Tensor & out)104 void constant_pad_nd_out_impl(
105 const Tensor& self,
106 IntArrayRef pad,
107 CTYPE value_v,
108 Tensor& out) {
109 const CTYPE* self_data = self.const_data_ptr<CTYPE>();
110 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
111
112 size_t ndim = self.dim();
113
114 if (ndim == 0) {
115 out_data[0] = self_data[0];
116 return;
117 }
118
119 int64_t self_sizes[kTensorDimensionLimit];
120 int64_t self_strides[kTensorDimensionLimit];
121 int64_t out_sizes[kTensorDimensionLimit];
122 int64_t out_strides[kTensorDimensionLimit];
123
124 // Collect sizes and strides of input and output tensors and determine the
125 // last padded dimension
126 size_t last_padded_dim = 0;
127 for (size_t i = 0; i < ndim; ++i) {
128 self_sizes[i] = self.size(i);
129 self_strides[i] = getTrailingDims(self, static_cast<int64_t>(i));
130 out_sizes[i] = out.size(i);
131 out_strides[i] = getTrailingDims(out, static_cast<int64_t>(i));
132
133 size_t pad_i = ndim - 1 - i;
134 if (pad_i >= 0 && pad_i < pad.size() / 2) {
135 if (pad[2 * pad_i] + pad[2 * pad_i + 1] > 0) {
136 last_padded_dim = i;
137 }
138 }
139 }
140
141 IntArrayRef self_sizes_ref(self_sizes, ndim);
142 IntArrayRef self_strides_ref(self_strides, ndim);
143 IntArrayRef out_sizes_ref(out_sizes, ndim);
144 IntArrayRef out_strides_ref(out_strides, ndim);
145
146 apply_padding_to_dim(
147 ndim,
148 self_data,
149 self_sizes_ref,
150 self_strides_ref,
151 out_data,
152 out_sizes_ref,
153 out_strides_ref,
154 pad,
155 value_v,
156 last_padded_dim,
157 0);
158 }
159
160 } // namespace
161
constant_pad_nd_out(KernelRuntimeContext & ctx,const Tensor & in,IntArrayRef pad,const Scalar & value,Tensor & out)162 Tensor& constant_pad_nd_out(
163 KernelRuntimeContext& ctx,
164 const Tensor& in,
165 IntArrayRef pad,
166 const Scalar& value,
167 Tensor& out) {
168 (void)ctx;
169
170 ET_KERNEL_CHECK(
171 ctx, check_constant_pad_args(in, pad, value, out), InvalidArgument, out);
172
173 ET_KERNEL_CHECK(
174 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
175
176 // resize out tensor for dynamic shapes
177 ET_KERNEL_CHECK_MSG(
178 ctx,
179 resize_constant_pad_output(in, pad, out) == Error::Ok,
180 InvalidArgument,
181 out,
182 "Failed to resize output tensor.");
183
184 ScalarType in_type = in.scalar_type();
185 ScalarType value_type = utils::get_scalar_dtype(value);
186
187 ET_SWITCH_REAL_TYPES_AND(
188 Bool, in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
189 CTYPE value_v;
190 ET_SWITCH_SCALAR_OBJ_TYPES(
191 value_type, ctx, "constant_pad_nd.out", CTYPE_VALUE, [&]() {
192 CTYPE_VALUE val;
193 utils::extract_scalar(value, &val);
194 value_v = static_cast<CTYPE>(val);
195 });
196 constant_pad_nd_out_impl<CTYPE>(in, pad, value_v, out);
197 });
198
199 return out;
200 }
201
202 } // namespace native
203 } // namespace executor
204 } // namespace torch
205