1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13
14 namespace torch {
15 namespace executor {
16
17 using Tensor = exec_aten::Tensor;
18
19 namespace {
20
param_array_is_valid(const char * name,IntArrayRef array,int64_t min_val,size_t length,bool allow_empty)21 bool param_array_is_valid(
22 const char* name,
23 IntArrayRef array,
24 int64_t min_val,
25 size_t length,
26 bool allow_empty) {
27 auto size = array.size();
28 if (allow_empty) {
29 ET_LOG_MSG_AND_RETURN_IF_FALSE(
30 size == 0 || size == 1 || size == length,
31 "Expected %s to have size 0, 1 or %zu but got %zd",
32 name,
33 length,
34 size);
35 } else {
36 ET_LOG_MSG_AND_RETURN_IF_FALSE(
37 size == 1 || size == length,
38 "Expected %s to have size 1 or %zu but got %zd",
39 name,
40 length,
41 size);
42 }
43 ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(array, min_val));
44 return true;
45 }
46
47 } // namespace
48
int_array_all_ge(IntArrayRef array,int64_t val)49 bool int_array_all_ge(IntArrayRef array, int64_t val) {
50 for (size_t i = 0; i < array.size(); ++i) {
51 if (array[i] < val) {
52 ET_LOG(
53 Error,
54 "Expected array[%zu] > %" PRId64 ", found %" PRId64,
55 i,
56 val,
57 array[i]);
58 return false;
59 }
60 }
61 return true;
62 }
63
kernel_size_is_valid(IntArrayRef kernel_size,size_t kernel_ndim)64 bool kernel_size_is_valid(IntArrayRef kernel_size, size_t kernel_ndim) {
65 return param_array_is_valid(
66 "kernel_size",
67 kernel_size,
68 /*min_val=*/1,
69 kernel_ndim,
70 /*allow_empty=*/false);
71 }
72
stride_is_valid(IntArrayRef stride,size_t kernel_ndim,bool allow_empty)73 bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim, bool allow_empty) {
74 return param_array_is_valid(
75 "stride", stride, /*min_val=*/1, kernel_ndim, allow_empty);
76 }
77
padding_is_valid(IntArrayRef padding,IntArrayRef kernel_size,size_t kernel_ndim,bool enforce_half_kernel)78 bool padding_is_valid(
79 IntArrayRef padding,
80 IntArrayRef kernel_size,
81 size_t kernel_ndim,
82 bool enforce_half_kernel) {
83 bool valid = param_array_is_valid(
84 "padding", padding, /*min_val=*/0, kernel_ndim, /*allow_empty=*/false);
85 if (!valid) {
86 return false;
87 }
88
89 if (enforce_half_kernel) {
90 // Padding must be at most half of kernel size.
91 for (size_t i = 0; i < padding.size(); i++) {
92 if (padding[i] > val_at(kernel_size, i) / 2) {
93 ET_LOG(
94 Error,
95 "Padding should be at most half of kernel size, "
96 "but got padding[%zu] = %" PRId64 " > kernel_size[%zu] = %" PRId64,
97 i,
98 padding[i],
99 i,
100 val_at(kernel_size, i));
101 return false;
102 }
103 }
104 }
105 return true;
106 }
107
dilation_is_valid(IntArrayRef dilation,size_t kernel_ndim)108 bool dilation_is_valid(IntArrayRef dilation, size_t kernel_ndim) {
109 return param_array_is_valid(
110 "dilation", dilation, /*min_val=*/1, kernel_ndim, /*allow_empty=*/false);
111 }
112
output_padding_is_valid(IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,size_t kernel_ndim)113 bool output_padding_is_valid(
114 IntArrayRef output_padding,
115 IntArrayRef stride,
116 IntArrayRef dilation,
117 size_t kernel_ndim) {
118 ET_LOG_AND_RETURN_IF_FALSE(param_array_is_valid(
119 "output_padding",
120 output_padding,
121 /*min_val=*/0,
122 kernel_ndim,
123 /*allow_empty=*/false));
124
125 for (size_t i = 0; i < kernel_ndim; i++) {
126 const int64_t op_i = val_at(output_padding, i);
127 const int64_t s_i = val_at(stride, i);
128 const int64_t d_i = val_at(dilation, i);
129 ET_LOG_MSG_AND_RETURN_IF_FALSE(
130 op_i < s_i || op_i < d_i,
131 "output padding must be smaller than either stride or dilation");
132 }
133 return true;
134 }
135
output_size_is_valid(exec_aten::ArrayRef<exec_aten::SizesType> output_size,size_t kernel_ndim)136 bool output_size_is_valid(
137 exec_aten::ArrayRef<exec_aten::SizesType> output_size,
138 size_t kernel_ndim) {
139 bool valid = true;
140 size_t out_dim = output_size.size();
141 for (size_t i = 0; i < out_dim - kernel_ndim; i++) {
142 if (output_size[i] < 0) {
143 valid = false;
144 }
145 }
146 for (size_t i = out_dim - kernel_ndim; i < out_dim; i++) {
147 if (output_size[i] <= 0) {
148 valid = false;
149 }
150 }
151 if (!valid) {
152 ET_LOG(
153 Error,
154 "The provided combination of input and kernel parameters "
155 "produces an invalid output size:");
156 for (size_t d = 0; d < output_size.size(); ++d) {
157 ET_LOG(
158 Error, " size(%zu): %zu", d, static_cast<size_t>(output_size[d]));
159 }
160 }
161 return valid;
162 }
163
get_unsqueezed_sizes(const Tensor & t,int64_t unsqueeze_dim,exec_aten::SizesType * sizes_arr,size_t & ndim)164 void get_unsqueezed_sizes(
165 const Tensor& t,
166 int64_t unsqueeze_dim,
167 exec_aten::SizesType* sizes_arr,
168 size_t& ndim) {
169 ndim = t.dim() + 1;
170 for (int d = 0; d < unsqueeze_dim; ++d) {
171 sizes_arr[d] = t.size(d);
172 }
173 sizes_arr[unsqueeze_dim] = 1;
174 for (int d = (unsqueeze_dim + 1); d < ndim; d++) {
175 sizes_arr[d] = t.size(d - 1);
176 }
177 }
178
get_unsqueezed_dim_order(const Tensor & t,exec_aten::DimOrderType unsqueeze_dim,exec_aten::DimOrderType * dim_order_arr)179 void get_unsqueezed_dim_order(
180 const Tensor& t,
181 exec_aten::DimOrderType unsqueeze_dim,
182 exec_aten::DimOrderType* dim_order_arr) {
183 int offset = 0;
184 for (int i = 0; i < t.dim(); ++i) {
185 exec_aten::DimOrderType dim = t.dim_order()[i];
186 if (dim == unsqueeze_dim) {
187 dim_order_arr[i] = dim;
188 dim_order_arr[i + 1] = dim + 1;
189 offset = 1;
190 } else {
191 dim_order_arr[i + offset] = dim > unsqueeze_dim ? dim + 1 : dim;
192 }
193 }
194 return;
195 }
196
_kernel_output_size_helper(size_t inputSize,int64_t kernelSize,int64_t pad,int64_t stride,int64_t dilation,bool ceil_mode,bool transposed,int64_t output_padding)197 int64_t _kernel_output_size_helper(
198 size_t inputSize,
199 int64_t kernelSize,
200 int64_t pad,
201 int64_t stride,
202 int64_t dilation,
203 bool ceil_mode,
204 bool transposed,
205 int64_t output_padding) {
206 if (transposed) {
207 return (inputSize - 1) * stride - 2 * pad + dilation * (kernelSize - 1) +
208 output_padding + 1;
209 }
210 int64_t numerator = inputSize + 2 * pad - dilation * (kernelSize - 1) - 1 +
211 (ceil_mode ? stride - 1 : 0);
212 int64_t outputSize = numerator / stride + 1;
213 if (ceil_mode) {
214 // ensure that the last pooling starts inside the image
215 // needed to avoid problems in ceil mode
216 if ((outputSize - 1) * stride >= inputSize + pad) {
217 --outputSize;
218 }
219 }
220 return outputSize;
221 }
222
calculate_kernel_output_sizes(const Tensor & in,size_t kernel_ndim,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,exec_aten::SizesType * out_sizes,bool ceil_mode,bool transposed,IntArrayRef output_padding)223 void calculate_kernel_output_sizes(
224 const Tensor& in,
225 size_t kernel_ndim,
226 IntArrayRef kernel_size,
227 IntArrayRef stride,
228 IntArrayRef padding,
229 IntArrayRef dilation,
230 exec_aten::SizesType* out_sizes,
231 bool ceil_mode,
232 bool transposed,
233 IntArrayRef output_padding) {
234 for (size_t i = 0; i < kernel_ndim; ++i) {
235 auto dim = in.dim() - (kernel_ndim - i);
236 int64_t k = val_at(kernel_size, i);
237 int64_t s = val_at(stride, i, /*default_value=*/k);
238 int64_t d = val_at(dilation, i, /*default_value=*/1);
239 int64_t p = val_at(padding, i, /*default_value=*/0);
240 int64_t op =
241 transposed ? val_at(output_padding, i, /*default_value=*/0) : 0;
242
243 out_sizes[dim] = _kernel_output_size_helper(
244 in.size(dim), k, p, s, d, ceil_mode, transposed, op);
245 }
246 }
247
check_arange_args(double start,double end,double step,Tensor & out)248 bool check_arange_args(double start, double end, double step, Tensor& out) {
249 ET_LOG_MSG_AND_RETURN_IF_FALSE(
250 out.dim() == 1,
251 "out should be a 1-d tensor, but got a %zu-d tensor",
252 out.dim());
253
254 ET_LOG_MSG_AND_RETURN_IF_FALSE(
255 (step > 0 && (end >= start)) || (step < 0 && (end <= start)),
256 "upper bound and larger bound inconsistent with step sign");
257
258 return true;
259 }
260
check_avg_pool2d_args(const Tensor & in,const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const bool ceil_mode,const bool count_include_pad,const exec_aten::optional<int64_t> & divisor_override,const Tensor & out)261 bool check_avg_pool2d_args(
262 const Tensor& in,
263 const IntArrayRef kernel_size,
264 const IntArrayRef stride,
265 const IntArrayRef padding,
266 const bool ceil_mode,
267 const bool count_include_pad,
268 const exec_aten::optional<int64_t>& divisor_override,
269 const Tensor& out) {
270 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
271
272 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
273 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
274
275 ET_LOG_MSG_AND_RETURN_IF_FALSE(
276 (in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
277 (in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
278 "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
279
280 ET_LOG_AND_RETURN_IF_FALSE(
281 kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
282 ET_LOG_AND_RETURN_IF_FALSE(
283 stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
284 ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
285 padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
286
287 if (divisor_override.has_value()) {
288 ET_LOG_MSG_AND_RETURN_IF_FALSE(
289 divisor_override.value() != 0,
290 "divisor_override must be non-zero, but found %" PRId64,
291 divisor_override.value());
292 }
293
294 return true;
295 }
296
get_avg_pool2d_out_target_size(const Tensor & in,const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const bool ceil_mode,exec_aten::SizesType * const out_sizes,size_t * const out_ndim)297 void get_avg_pool2d_out_target_size(
298 const Tensor& in,
299 const IntArrayRef kernel_size,
300 const IntArrayRef stride,
301 const IntArrayRef padding,
302 const bool ceil_mode,
303 exec_aten::SizesType* const out_sizes,
304 size_t* const out_ndim) {
305 *out_ndim = in.dim();
306
307 // Batch dim is optional, so in can be either 3 or 4 dim.
308 if (in.dim() == 4) {
309 out_sizes[0] = in.size(0);
310 out_sizes[1] = in.size(1);
311 } else {
312 out_sizes[0] = in.size(0);
313 }
314
315 calculate_kernel_output_sizes(
316 in, 2, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
317 }
318
check_convolution_args(const Tensor & in,const Tensor & weight,const exec_aten::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,const Tensor & out)319 bool check_convolution_args(
320 const Tensor& in,
321 const Tensor& weight,
322 const exec_aten::optional<Tensor>& bias,
323 IntArrayRef stride,
324 IntArrayRef padding,
325 IntArrayRef dilation,
326 bool transposed,
327 IntArrayRef output_padding,
328 int64_t groups,
329 const Tensor& out) {
330 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight, out));
331
332 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
333 ET_LOG_AND_RETURN_IF_FALSE(
334 tensor_is_default_or_channels_last_dim_order(weight));
335 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
336
337 ET_LOG_MSG_AND_RETURN_IF_FALSE(
338 in.dim() == 3 || in.dim() == 4,
339 "Expect input tensor to be 3-D or 4-D, but got, %zu.",
340 static_cast<size_t>(in.dim()));
341 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, in.dim()));
342 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, in.dim()));
343
344 if (bias.has_value()) {
345 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(bias.value(), 1));
346 ET_LOG_MSG_AND_RETURN_IF_FALSE(
347 bias.value().size(0) == transposed ? groups * weight.size(1)
348 : weight.size(0),
349 "bias length must equal number of output channels, but got %zd",
350 bias.value().size(0));
351 }
352
353 int64_t kernel_size[2];
354 size_t kernel_ndim = 2;
355 if (weight.dim() == 3) {
356 kernel_size[0] = weight.size(2);
357 kernel_ndim = 1;
358 } else {
359 kernel_size[0] = weight.size(2);
360 kernel_size[1] = weight.size(3);
361 }
362 ET_LOG_AND_RETURN_IF_FALSE(
363 stride_is_valid(stride, kernel_ndim, /*allow_empty=*/false));
364 ET_LOG_AND_RETURN_IF_FALSE(
365 padding_is_valid(padding, {kernel_size, kernel_ndim}, kernel_ndim));
366 ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, kernel_ndim));
367 if (transposed) {
368 ET_LOG_AND_RETURN_IF_FALSE(
369 output_padding_is_valid(output_padding, stride, dilation, kernel_ndim));
370 }
371
372 ET_LOG_MSG_AND_RETURN_IF_FALSE(
373 weight.size(0) >= groups,
374 "Given groups=%" PRId64 ", expected weight to be at least %" PRId64
375 " at dimension 0, but got weight.size(0) = %zd instead",
376 groups,
377 groups,
378 weight.size(0));
379 ET_LOG_MSG_AND_RETURN_IF_FALSE(
380 weight.size(0) % groups == 0,
381 "Given groups=%" PRId64 ", expected weight to be divisible by %" PRId64
382 " at dimension 0, but got weight.size(0) = %zd instead",
383 groups,
384 groups,
385 weight.size(0));
386
387 if (!transposed) {
388 ET_LOG_MSG_AND_RETURN_IF_FALSE(
389 in.size(1) == groups * weight.size(1),
390 "Given groups=%" PRId64
391 " and weight.size(1) = %zd, expected input to have %" PRId64
392 " channels, but got %zd",
393 groups,
394 weight.size(1),
395 groups * weight.size(1),
396 in.size(1));
397 } else {
398 ET_LOG_MSG_AND_RETURN_IF_FALSE(
399 in.size(1) == weight.size(0),
400 "input channels must match weight.size(0) in transposed convolution");
401 }
402
403 return true;
404 }
405
get_convolution_out_target_size(const Tensor & in,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,exec_aten::SizesType * out_sizes,size_t * out_ndim)406 void get_convolution_out_target_size(
407 const Tensor& in,
408 const Tensor& weight,
409 IntArrayRef stride,
410 IntArrayRef padding,
411 IntArrayRef dilation,
412 bool transposed,
413 IntArrayRef output_padding,
414 int64_t groups,
415 exec_aten::SizesType* out_sizes,
416 size_t* out_ndim) {
417 *out_ndim = in.dim();
418
419 // batch dim
420 out_sizes[0] = in.size(0);
421
422 // channel dim
423 if (!transposed) {
424 out_sizes[1] = in.size(1) == 0 ? 0 : weight.size(0);
425 } else {
426 out_sizes[1] = in.size(1) == 0 ? 0 : groups * weight.size(1);
427 }
428
429 int64_t kernel_size[2];
430 size_t kernel_ndim = 2;
431 if (weight.dim() == 3) {
432 kernel_size[0] = weight.size(2);
433 kernel_ndim = 1;
434 } else {
435 kernel_size[0] = weight.size(2);
436 kernel_size[1] = weight.size(3);
437 }
438 calculate_kernel_output_sizes(
439 in,
440 kernel_ndim,
441 {kernel_size, kernel_ndim},
442 stride,
443 padding,
444 dilation,
445 out_sizes,
446 false,
447 transposed,
448 output_padding);
449 }
450
check_cumsum_args(const Tensor & in,int64_t dim,optional<ScalarType> dtype,Tensor & out)451 bool check_cumsum_args(
452 const Tensor& in,
453 int64_t dim,
454 optional<ScalarType> dtype,
455 Tensor& out) {
456 ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, in.dim()));
457
458 if (dtype.has_value()) {
459 ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
460 }
461
462 return true;
463 }
464
check_max_pool2d_with_indices_args(const Tensor & in,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & out,Tensor & indices)465 bool check_max_pool2d_with_indices_args(
466 const Tensor& in,
467 IntArrayRef kernel_size,
468 IntArrayRef stride,
469 IntArrayRef padding,
470 IntArrayRef dilation,
471 bool ceil_mode,
472 Tensor& out,
473 Tensor& indices) {
474 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
475 ET_LOG_MSG_AND_RETURN_IF_FALSE(
476 indices.scalar_type() == ScalarType::Long,
477 "Expected indices to have type of Long, but found %s",
478 toString(indices.scalar_type()));
479
480 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
481 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
482
483 ET_LOG_MSG_AND_RETURN_IF_FALSE(
484 (in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
485 (in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
486 "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
487
488 ET_LOG_AND_RETURN_IF_FALSE(
489 kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
490 ET_LOG_AND_RETURN_IF_FALSE(
491 stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
492 ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
493 padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
494 ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(kernel_size, /*kernel_ndim=*/2));
495
496 return true;
497 }
498
get_max_pool2d_with_indices_out_target_size(const Tensor & in,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,exec_aten::SizesType * out_sizes,size_t * out_ndim)499 void get_max_pool2d_with_indices_out_target_size(
500 const Tensor& in,
501 IntArrayRef kernel_size,
502 IntArrayRef stride,
503 IntArrayRef padding,
504 IntArrayRef dilation,
505 bool ceil_mode,
506 exec_aten::SizesType* out_sizes,
507 size_t* out_ndim) {
508 *out_ndim = in.dim();
509
510 // Batch dim is optional, so in can be either 3 or 4 dim.
511 if (in.dim() == 4) {
512 out_sizes[0] = in.size(0);
513 out_sizes[1] = in.size(1);
514 } else {
515 out_sizes[0] = in.size(0);
516 }
517
518 calculate_kernel_output_sizes(
519 in, 2, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
520 }
521
check_masked_fill_args(const Tensor & in,const Tensor & mask,const Scalar & value,Tensor & out)522 bool check_masked_fill_args(
523 const Tensor& in,
524 const Tensor& mask,
525 const Scalar& value,
526 Tensor& out) {
527 (void)value;
528
529 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
530 ET_LOG_AND_RETURN_IF_FALSE(mask.scalar_type() == ScalarType::Bool);
531
532 return true;
533 }
534
check_constant_pad_args(const Tensor & in,IntArrayRef pad,const Scalar & value,Tensor & out)535 bool check_constant_pad_args(
536 const Tensor& in,
537 IntArrayRef pad,
538 const Scalar& value,
539 Tensor& out) {
540 (void)value;
541
542 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
543
544 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(in, out));
545
546 ET_LOG_MSG_AND_RETURN_IF_FALSE(
547 pad.size() % 2 == 0, "Padding array must be a multiple of 2");
548
549 ET_LOG_MSG_AND_RETURN_IF_FALSE(
550 pad.size() / 2 <= in.dim(), "Padding array contains too many elements");
551
552 return true;
553 }
554
resize_constant_pad_output(const Tensor & in,IntArrayRef pad,Tensor & out)555 Error resize_constant_pad_output(
556 const Tensor& in,
557 IntArrayRef pad,
558 Tensor& out) {
559 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
560
561 int pad_i = in.dim() - 1;
562 for (size_t i = 0; i < in.dim(); ++i, --pad_i) {
563 expected_output_size[i] = in.size(i);
564 if (pad_i >= 0 && pad_i < pad.size() / 2) {
565 expected_output_size[i] += pad[2 * pad_i] + pad[2 * pad_i + 1];
566 }
567 }
568
569 ArrayRef<Tensor::SizesType> output_size{
570 expected_output_size, static_cast<size_t>(in.dim())};
571 auto error = resize_tensor(out, output_size);
572
573 return error;
574 }
575
check_embedding_args(const Tensor & weight,const Tensor & indices,const Tensor & out)576 bool check_embedding_args(
577 const Tensor& weight,
578 const Tensor& indices,
579 const Tensor& out) {
580 // Ensure weight is 2-D. It could be empty.
581 ET_LOG_MSG_AND_RETURN_IF_FALSE(
582 weight.dim() == 2, "weight.dim() %zd != 2", weight.dim());
583
584 // Ensure out is k+1 dimension tensor where k is the indices.dim()
585 // out's first k dimension shall be same as indices, and the last dim shall
586 // equal weight's last dim
587 ET_LOG_MSG_AND_RETURN_IF_FALSE(
588 out.dim() == indices.dim() + 1,
589 "out.dim() %zd != indices.dim() %zd + 1",
590 out.dim(),
591 indices.dim());
592
593 // Ensure dtype is the same for out and weight
594 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, out));
595
596 return true;
597 }
598
resize_embedding_output(const Tensor & weight,const Tensor & indices,const Tensor & out)599 Error resize_embedding_output(
600 const Tensor& weight,
601 const Tensor& indices,
602 const Tensor& out) {
603 Tensor::SizesType expected_output_size[kTensorDimensionLimit];
604 for (size_t i = 0; i < indices.dim(); i++) {
605 expected_output_size[i] = indices.size(i);
606 }
607 const size_t embedding_dim = weight.size(1);
608 expected_output_size[out.dim() - 1] = embedding_dim;
609
610 ArrayRef<Tensor::SizesType> output_size{
611 expected_output_size, static_cast<size_t>(out.dim())};
612
613 return resize_tensor(out, output_size);
614 }
615
check_alpha_type(const ScalarType alpha_type,const ScalarType common_type)616 bool check_alpha_type(
617 const ScalarType alpha_type,
618 const ScalarType common_type) {
619 // Verify that alpha type is compatible with common type,
620 // as used by ops such as add and sub.
621 ET_LOG_AND_RETURN_IF_FALSE(
622 canCast(alpha_type, common_type) ||
623 (common_type == ScalarType::Bool && isIntegralType(alpha_type, true)));
624
625 return true;
626 }
627
628 } // namespace executor
629 } // namespace torch
630