xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/kernel_ops_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13 
14 namespace torch {
15 namespace executor {
16 
17 using Tensor = exec_aten::Tensor;
18 
19 namespace {
20 
param_array_is_valid(const char * name,IntArrayRef array,int64_t min_val,size_t length,bool allow_empty)21 bool param_array_is_valid(
22     const char* name,
23     IntArrayRef array,
24     int64_t min_val,
25     size_t length,
26     bool allow_empty) {
27   auto size = array.size();
28   if (allow_empty) {
29     ET_LOG_MSG_AND_RETURN_IF_FALSE(
30         size == 0 || size == 1 || size == length,
31         "Expected %s to have size 0, 1 or %zu but got %zd",
32         name,
33         length,
34         size);
35   } else {
36     ET_LOG_MSG_AND_RETURN_IF_FALSE(
37         size == 1 || size == length,
38         "Expected %s to have size 1 or %zu but got %zd",
39         name,
40         length,
41         size);
42   }
43   ET_LOG_AND_RETURN_IF_FALSE(int_array_all_ge(array, min_val));
44   return true;
45 }
46 
47 } // namespace
48 
int_array_all_ge(IntArrayRef array,int64_t val)49 bool int_array_all_ge(IntArrayRef array, int64_t val) {
50   for (size_t i = 0; i < array.size(); ++i) {
51     if (array[i] < val) {
52       ET_LOG(
53           Error,
54           "Expected array[%zu] > %" PRId64 ", found %" PRId64,
55           i,
56           val,
57           array[i]);
58       return false;
59     }
60   }
61   return true;
62 }
63 
kernel_size_is_valid(IntArrayRef kernel_size,size_t kernel_ndim)64 bool kernel_size_is_valid(IntArrayRef kernel_size, size_t kernel_ndim) {
65   return param_array_is_valid(
66       "kernel_size",
67       kernel_size,
68       /*min_val=*/1,
69       kernel_ndim,
70       /*allow_empty=*/false);
71 }
72 
stride_is_valid(IntArrayRef stride,size_t kernel_ndim,bool allow_empty)73 bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim, bool allow_empty) {
74   return param_array_is_valid(
75       "stride", stride, /*min_val=*/1, kernel_ndim, allow_empty);
76 }
77 
padding_is_valid(IntArrayRef padding,IntArrayRef kernel_size,size_t kernel_ndim,bool enforce_half_kernel)78 bool padding_is_valid(
79     IntArrayRef padding,
80     IntArrayRef kernel_size,
81     size_t kernel_ndim,
82     bool enforce_half_kernel) {
83   bool valid = param_array_is_valid(
84       "padding", padding, /*min_val=*/0, kernel_ndim, /*allow_empty=*/false);
85   if (!valid) {
86     return false;
87   }
88 
89   if (enforce_half_kernel) {
90     // Padding must be at most half of kernel size.
91     for (size_t i = 0; i < padding.size(); i++) {
92       if (padding[i] > val_at(kernel_size, i) / 2) {
93         ET_LOG(
94             Error,
95             "Padding should be at most half of kernel size, "
96             "but got padding[%zu] = %" PRId64 " > kernel_size[%zu] = %" PRId64,
97             i,
98             padding[i],
99             i,
100             val_at(kernel_size, i));
101         return false;
102       }
103     }
104   }
105   return true;
106 }
107 
dilation_is_valid(IntArrayRef dilation,size_t kernel_ndim)108 bool dilation_is_valid(IntArrayRef dilation, size_t kernel_ndim) {
109   return param_array_is_valid(
110       "dilation", dilation, /*min_val=*/1, kernel_ndim, /*allow_empty=*/false);
111 }
112 
output_padding_is_valid(IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation,size_t kernel_ndim)113 bool output_padding_is_valid(
114     IntArrayRef output_padding,
115     IntArrayRef stride,
116     IntArrayRef dilation,
117     size_t kernel_ndim) {
118   ET_LOG_AND_RETURN_IF_FALSE(param_array_is_valid(
119       "output_padding",
120       output_padding,
121       /*min_val=*/0,
122       kernel_ndim,
123       /*allow_empty=*/false));
124 
125   for (size_t i = 0; i < kernel_ndim; i++) {
126     const int64_t op_i = val_at(output_padding, i);
127     const int64_t s_i = val_at(stride, i);
128     const int64_t d_i = val_at(dilation, i);
129     ET_LOG_MSG_AND_RETURN_IF_FALSE(
130         op_i < s_i || op_i < d_i,
131         "output padding must be smaller than either stride or dilation");
132   }
133   return true;
134 }
135 
output_size_is_valid(exec_aten::ArrayRef<exec_aten::SizesType> output_size,size_t kernel_ndim)136 bool output_size_is_valid(
137     exec_aten::ArrayRef<exec_aten::SizesType> output_size,
138     size_t kernel_ndim) {
139   bool valid = true;
140   size_t out_dim = output_size.size();
141   for (size_t i = 0; i < out_dim - kernel_ndim; i++) {
142     if (output_size[i] < 0) {
143       valid = false;
144     }
145   }
146   for (size_t i = out_dim - kernel_ndim; i < out_dim; i++) {
147     if (output_size[i] <= 0) {
148       valid = false;
149     }
150   }
151   if (!valid) {
152     ET_LOG(
153         Error,
154         "The provided combination of input and kernel parameters "
155         "produces an invalid output size:");
156     for (size_t d = 0; d < output_size.size(); ++d) {
157       ET_LOG(
158           Error, "    size(%zu): %zu", d, static_cast<size_t>(output_size[d]));
159     }
160   }
161   return valid;
162 }
163 
get_unsqueezed_sizes(const Tensor & t,int64_t unsqueeze_dim,exec_aten::SizesType * sizes_arr,size_t & ndim)164 void get_unsqueezed_sizes(
165     const Tensor& t,
166     int64_t unsqueeze_dim,
167     exec_aten::SizesType* sizes_arr,
168     size_t& ndim) {
169   ndim = t.dim() + 1;
170   for (int d = 0; d < unsqueeze_dim; ++d) {
171     sizes_arr[d] = t.size(d);
172   }
173   sizes_arr[unsqueeze_dim] = 1;
174   for (int d = (unsqueeze_dim + 1); d < ndim; d++) {
175     sizes_arr[d] = t.size(d - 1);
176   }
177 }
178 
get_unsqueezed_dim_order(const Tensor & t,exec_aten::DimOrderType unsqueeze_dim,exec_aten::DimOrderType * dim_order_arr)179 void get_unsqueezed_dim_order(
180     const Tensor& t,
181     exec_aten::DimOrderType unsqueeze_dim,
182     exec_aten::DimOrderType* dim_order_arr) {
183   int offset = 0;
184   for (int i = 0; i < t.dim(); ++i) {
185     exec_aten::DimOrderType dim = t.dim_order()[i];
186     if (dim == unsqueeze_dim) {
187       dim_order_arr[i] = dim;
188       dim_order_arr[i + 1] = dim + 1;
189       offset = 1;
190     } else {
191       dim_order_arr[i + offset] = dim > unsqueeze_dim ? dim + 1 : dim;
192     }
193   }
194   return;
195 }
196 
_kernel_output_size_helper(size_t inputSize,int64_t kernelSize,int64_t pad,int64_t stride,int64_t dilation,bool ceil_mode,bool transposed,int64_t output_padding)197 int64_t _kernel_output_size_helper(
198     size_t inputSize,
199     int64_t kernelSize,
200     int64_t pad,
201     int64_t stride,
202     int64_t dilation,
203     bool ceil_mode,
204     bool transposed,
205     int64_t output_padding) {
206   if (transposed) {
207     return (inputSize - 1) * stride - 2 * pad + dilation * (kernelSize - 1) +
208         output_padding + 1;
209   }
210   int64_t numerator = inputSize + 2 * pad - dilation * (kernelSize - 1) - 1 +
211       (ceil_mode ? stride - 1 : 0);
212   int64_t outputSize = numerator / stride + 1;
213   if (ceil_mode) {
214     // ensure that the last pooling starts inside the image
215     // needed to avoid problems in ceil mode
216     if ((outputSize - 1) * stride >= inputSize + pad) {
217       --outputSize;
218     }
219   }
220   return outputSize;
221 }
222 
calculate_kernel_output_sizes(const Tensor & in,size_t kernel_ndim,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,exec_aten::SizesType * out_sizes,bool ceil_mode,bool transposed,IntArrayRef output_padding)223 void calculate_kernel_output_sizes(
224     const Tensor& in,
225     size_t kernel_ndim,
226     IntArrayRef kernel_size,
227     IntArrayRef stride,
228     IntArrayRef padding,
229     IntArrayRef dilation,
230     exec_aten::SizesType* out_sizes,
231     bool ceil_mode,
232     bool transposed,
233     IntArrayRef output_padding) {
234   for (size_t i = 0; i < kernel_ndim; ++i) {
235     auto dim = in.dim() - (kernel_ndim - i);
236     int64_t k = val_at(kernel_size, i);
237     int64_t s = val_at(stride, i, /*default_value=*/k);
238     int64_t d = val_at(dilation, i, /*default_value=*/1);
239     int64_t p = val_at(padding, i, /*default_value=*/0);
240     int64_t op =
241         transposed ? val_at(output_padding, i, /*default_value=*/0) : 0;
242 
243     out_sizes[dim] = _kernel_output_size_helper(
244         in.size(dim), k, p, s, d, ceil_mode, transposed, op);
245   }
246 }
247 
check_arange_args(double start,double end,double step,Tensor & out)248 bool check_arange_args(double start, double end, double step, Tensor& out) {
249   ET_LOG_MSG_AND_RETURN_IF_FALSE(
250       out.dim() == 1,
251       "out should be a 1-d tensor, but got a %zu-d tensor",
252       out.dim());
253 
254   ET_LOG_MSG_AND_RETURN_IF_FALSE(
255       (step > 0 && (end >= start)) || (step < 0 && (end <= start)),
256       "upper bound and larger bound inconsistent with step sign");
257 
258   return true;
259 }
260 
check_avg_pool2d_args(const Tensor & in,const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const bool ceil_mode,const bool count_include_pad,const exec_aten::optional<int64_t> & divisor_override,const Tensor & out)261 bool check_avg_pool2d_args(
262     const Tensor& in,
263     const IntArrayRef kernel_size,
264     const IntArrayRef stride,
265     const IntArrayRef padding,
266     const bool ceil_mode,
267     const bool count_include_pad,
268     const exec_aten::optional<int64_t>& divisor_override,
269     const Tensor& out) {
270   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
271 
272   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
273   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
274 
275   ET_LOG_MSG_AND_RETURN_IF_FALSE(
276       (in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
277           (in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
278       "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
279 
280   ET_LOG_AND_RETURN_IF_FALSE(
281       kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
282   ET_LOG_AND_RETURN_IF_FALSE(
283       stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
284   ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
285       padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
286 
287   if (divisor_override.has_value()) {
288     ET_LOG_MSG_AND_RETURN_IF_FALSE(
289         divisor_override.value() != 0,
290         "divisor_override must be non-zero, but found %" PRId64,
291         divisor_override.value());
292   }
293 
294   return true;
295 }
296 
get_avg_pool2d_out_target_size(const Tensor & in,const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const bool ceil_mode,exec_aten::SizesType * const out_sizes,size_t * const out_ndim)297 void get_avg_pool2d_out_target_size(
298     const Tensor& in,
299     const IntArrayRef kernel_size,
300     const IntArrayRef stride,
301     const IntArrayRef padding,
302     const bool ceil_mode,
303     exec_aten::SizesType* const out_sizes,
304     size_t* const out_ndim) {
305   *out_ndim = in.dim();
306 
307   // Batch dim is optional, so in can be either 3 or 4 dim.
308   if (in.dim() == 4) {
309     out_sizes[0] = in.size(0);
310     out_sizes[1] = in.size(1);
311   } else {
312     out_sizes[0] = in.size(0);
313   }
314 
315   calculate_kernel_output_sizes(
316       in, 2, kernel_size, stride, padding, {}, out_sizes, ceil_mode);
317 }
318 
check_convolution_args(const Tensor & in,const Tensor & weight,const exec_aten::optional<Tensor> & bias,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,const Tensor & out)319 bool check_convolution_args(
320     const Tensor& in,
321     const Tensor& weight,
322     const exec_aten::optional<Tensor>& bias,
323     IntArrayRef stride,
324     IntArrayRef padding,
325     IntArrayRef dilation,
326     bool transposed,
327     IntArrayRef output_padding,
328     int64_t groups,
329     const Tensor& out) {
330   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight, out));
331 
332   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
333   ET_LOG_AND_RETURN_IF_FALSE(
334       tensor_is_default_or_channels_last_dim_order(weight));
335   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
336 
337   ET_LOG_MSG_AND_RETURN_IF_FALSE(
338       in.dim() == 3 || in.dim() == 4,
339       "Expect input tensor to be 3-D or 4-D, but got, %zu.",
340       static_cast<size_t>(in.dim()));
341   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, in.dim()));
342   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, in.dim()));
343 
344   if (bias.has_value()) {
345     ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(bias.value(), 1));
346     ET_LOG_MSG_AND_RETURN_IF_FALSE(
347         bias.value().size(0) == transposed ? groups * weight.size(1)
348                                            : weight.size(0),
349         "bias length must equal number of output channels, but got %zd",
350         bias.value().size(0));
351   }
352 
353   int64_t kernel_size[2];
354   size_t kernel_ndim = 2;
355   if (weight.dim() == 3) {
356     kernel_size[0] = weight.size(2);
357     kernel_ndim = 1;
358   } else {
359     kernel_size[0] = weight.size(2);
360     kernel_size[1] = weight.size(3);
361   }
362   ET_LOG_AND_RETURN_IF_FALSE(
363       stride_is_valid(stride, kernel_ndim, /*allow_empty=*/false));
364   ET_LOG_AND_RETURN_IF_FALSE(
365       padding_is_valid(padding, {kernel_size, kernel_ndim}, kernel_ndim));
366   ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(dilation, kernel_ndim));
367   if (transposed) {
368     ET_LOG_AND_RETURN_IF_FALSE(
369         output_padding_is_valid(output_padding, stride, dilation, kernel_ndim));
370   }
371 
372   ET_LOG_MSG_AND_RETURN_IF_FALSE(
373       weight.size(0) >= groups,
374       "Given groups=%" PRId64 ", expected weight to be at least %" PRId64
375       " at dimension 0, but got weight.size(0) = %zd instead",
376       groups,
377       groups,
378       weight.size(0));
379   ET_LOG_MSG_AND_RETURN_IF_FALSE(
380       weight.size(0) % groups == 0,
381       "Given groups=%" PRId64 ", expected weight to be divisible by %" PRId64
382       " at dimension 0, but got weight.size(0) = %zd instead",
383       groups,
384       groups,
385       weight.size(0));
386 
387   if (!transposed) {
388     ET_LOG_MSG_AND_RETURN_IF_FALSE(
389         in.size(1) == groups * weight.size(1),
390         "Given groups=%" PRId64
391         " and weight.size(1) = %zd, expected input to have %" PRId64
392         " channels, but got %zd",
393         groups,
394         weight.size(1),
395         groups * weight.size(1),
396         in.size(1));
397   } else {
398     ET_LOG_MSG_AND_RETURN_IF_FALSE(
399         in.size(1) == weight.size(0),
400         "input channels must match weight.size(0) in transposed convolution");
401   }
402 
403   return true;
404 }
405 
get_convolution_out_target_size(const Tensor & in,const Tensor & weight,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool transposed,IntArrayRef output_padding,int64_t groups,exec_aten::SizesType * out_sizes,size_t * out_ndim)406 void get_convolution_out_target_size(
407     const Tensor& in,
408     const Tensor& weight,
409     IntArrayRef stride,
410     IntArrayRef padding,
411     IntArrayRef dilation,
412     bool transposed,
413     IntArrayRef output_padding,
414     int64_t groups,
415     exec_aten::SizesType* out_sizes,
416     size_t* out_ndim) {
417   *out_ndim = in.dim();
418 
419   // batch dim
420   out_sizes[0] = in.size(0);
421 
422   // channel dim
423   if (!transposed) {
424     out_sizes[1] = in.size(1) == 0 ? 0 : weight.size(0);
425   } else {
426     out_sizes[1] = in.size(1) == 0 ? 0 : groups * weight.size(1);
427   }
428 
429   int64_t kernel_size[2];
430   size_t kernel_ndim = 2;
431   if (weight.dim() == 3) {
432     kernel_size[0] = weight.size(2);
433     kernel_ndim = 1;
434   } else {
435     kernel_size[0] = weight.size(2);
436     kernel_size[1] = weight.size(3);
437   }
438   calculate_kernel_output_sizes(
439       in,
440       kernel_ndim,
441       {kernel_size, kernel_ndim},
442       stride,
443       padding,
444       dilation,
445       out_sizes,
446       false,
447       transposed,
448       output_padding);
449 }
450 
check_cumsum_args(const Tensor & in,int64_t dim,optional<ScalarType> dtype,Tensor & out)451 bool check_cumsum_args(
452     const Tensor& in,
453     int64_t dim,
454     optional<ScalarType> dtype,
455     Tensor& out) {
456   ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, in.dim()));
457 
458   if (dtype.has_value()) {
459     ET_LOG_AND_RETURN_IF_FALSE(dtype.value() == out.scalar_type());
460   }
461 
462   return true;
463 }
464 
check_max_pool2d_with_indices_args(const Tensor & in,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,Tensor & out,Tensor & indices)465 bool check_max_pool2d_with_indices_args(
466     const Tensor& in,
467     IntArrayRef kernel_size,
468     IntArrayRef stride,
469     IntArrayRef padding,
470     IntArrayRef dilation,
471     bool ceil_mode,
472     Tensor& out,
473     Tensor& indices) {
474   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
475   ET_LOG_MSG_AND_RETURN_IF_FALSE(
476       indices.scalar_type() == ScalarType::Long,
477       "Expected indices to have type of Long, but found %s",
478       toString(indices.scalar_type()));
479 
480   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));
481   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(out));
482 
483   ET_LOG_MSG_AND_RETURN_IF_FALSE(
484       (in.dim() == 3 && in.size(0) > 0 && in.size(1) > 0 && in.size(2) > 0) ||
485           (in.dim() == 4 && in.size(1) > 0 && in.size(2) > 0 && in.size(3) > 0),
486       "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input");
487 
488   ET_LOG_AND_RETURN_IF_FALSE(
489       kernel_size_is_valid(kernel_size, /*kernel_ndim=*/2));
490   ET_LOG_AND_RETURN_IF_FALSE(
491       stride_is_valid(kernel_size, /*kernel_ndim=*/2, /*allow_empty=*/true));
492   ET_LOG_AND_RETURN_IF_FALSE(padding_is_valid(
493       padding, kernel_size, /*kernel_ndim=*/2, /*enforce_half_kernel=*/true));
494   ET_LOG_AND_RETURN_IF_FALSE(dilation_is_valid(kernel_size, /*kernel_ndim=*/2));
495 
496   return true;
497 }
498 
get_max_pool2d_with_indices_out_target_size(const Tensor & in,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,bool ceil_mode,exec_aten::SizesType * out_sizes,size_t * out_ndim)499 void get_max_pool2d_with_indices_out_target_size(
500     const Tensor& in,
501     IntArrayRef kernel_size,
502     IntArrayRef stride,
503     IntArrayRef padding,
504     IntArrayRef dilation,
505     bool ceil_mode,
506     exec_aten::SizesType* out_sizes,
507     size_t* out_ndim) {
508   *out_ndim = in.dim();
509 
510   // Batch dim is optional, so in can be either 3 or 4 dim.
511   if (in.dim() == 4) {
512     out_sizes[0] = in.size(0);
513     out_sizes[1] = in.size(1);
514   } else {
515     out_sizes[0] = in.size(0);
516   }
517 
518   calculate_kernel_output_sizes(
519       in, 2, kernel_size, stride, padding, dilation, out_sizes, ceil_mode);
520 }
521 
check_masked_fill_args(const Tensor & in,const Tensor & mask,const Scalar & value,Tensor & out)522 bool check_masked_fill_args(
523     const Tensor& in,
524     const Tensor& mask,
525     const Scalar& value,
526     Tensor& out) {
527   (void)value;
528 
529   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
530   ET_LOG_AND_RETURN_IF_FALSE(mask.scalar_type() == ScalarType::Bool);
531 
532   return true;
533 }
534 
check_constant_pad_args(const Tensor & in,IntArrayRef pad,const Scalar & value,Tensor & out)535 bool check_constant_pad_args(
536     const Tensor& in,
537     IntArrayRef pad,
538     const Scalar& value,
539     Tensor& out) {
540   (void)value;
541 
542   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
543 
544   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(in, out));
545 
546   ET_LOG_MSG_AND_RETURN_IF_FALSE(
547       pad.size() % 2 == 0, "Padding array must be a multiple of 2");
548 
549   ET_LOG_MSG_AND_RETURN_IF_FALSE(
550       pad.size() / 2 <= in.dim(), "Padding array contains too many elements");
551 
552   return true;
553 }
554 
resize_constant_pad_output(const Tensor & in,IntArrayRef pad,Tensor & out)555 Error resize_constant_pad_output(
556     const Tensor& in,
557     IntArrayRef pad,
558     Tensor& out) {
559   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
560 
561   int pad_i = in.dim() - 1;
562   for (size_t i = 0; i < in.dim(); ++i, --pad_i) {
563     expected_output_size[i] = in.size(i);
564     if (pad_i >= 0 && pad_i < pad.size() / 2) {
565       expected_output_size[i] += pad[2 * pad_i] + pad[2 * pad_i + 1];
566     }
567   }
568 
569   ArrayRef<Tensor::SizesType> output_size{
570       expected_output_size, static_cast<size_t>(in.dim())};
571   auto error = resize_tensor(out, output_size);
572 
573   return error;
574 }
575 
check_embedding_args(const Tensor & weight,const Tensor & indices,const Tensor & out)576 bool check_embedding_args(
577     const Tensor& weight,
578     const Tensor& indices,
579     const Tensor& out) {
580   // Ensure weight is 2-D. It could be empty.
581   ET_LOG_MSG_AND_RETURN_IF_FALSE(
582       weight.dim() == 2, "weight.dim() %zd != 2", weight.dim());
583 
584   // Ensure out is k+1 dimension tensor where k is the indices.dim()
585   // out's first k dimension shall be same as indices, and the last dim shall
586   // equal weight's last dim
587   ET_LOG_MSG_AND_RETURN_IF_FALSE(
588       out.dim() == indices.dim() + 1,
589       "out.dim() %zd != indices.dim() %zd + 1",
590       out.dim(),
591       indices.dim());
592 
593   // Ensure dtype is the same for out and weight
594   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, out));
595 
596   return true;
597 }
598 
resize_embedding_output(const Tensor & weight,const Tensor & indices,const Tensor & out)599 Error resize_embedding_output(
600     const Tensor& weight,
601     const Tensor& indices,
602     const Tensor& out) {
603   Tensor::SizesType expected_output_size[kTensorDimensionLimit];
604   for (size_t i = 0; i < indices.dim(); i++) {
605     expected_output_size[i] = indices.size(i);
606   }
607   const size_t embedding_dim = weight.size(1);
608   expected_output_size[out.dim() - 1] = embedding_dim;
609 
610   ArrayRef<Tensor::SizesType> output_size{
611       expected_output_size, static_cast<size_t>(out.dim())};
612 
613   return resize_tensor(out, output_size);
614 }
615 
check_alpha_type(const ScalarType alpha_type,const ScalarType common_type)616 bool check_alpha_type(
617     const ScalarType alpha_type,
618     const ScalarType common_type) {
619   // Verify that alpha type is compatible with common type,
620   // as used by ops such as add and sub.
621   ET_LOG_AND_RETURN_IF_FALSE(
622       canCast(alpha_type, common_type) ||
623       (common_type == ScalarType::Bool && isIntegralType(alpha_type, true)));
624 
625   return true;
626 }
627 
628 } // namespace executor
629 } // namespace torch
630