xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/binary_ops.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace internal {
16 // NOTE: we bake ArrayRef iterators being pointers into the return
17 // type here because we assume that iterators are portable across
18 // ArrayRef copies.
arrayref_begin_ignoring_leading_1s(ArrayRef<Tensor::SizesType> arr)19 inline const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
20     ArrayRef<Tensor::SizesType> arr) {
21   return std::find_if(
22       arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
23 }
24 
sizes_match_ignoring_leading_1s(ArrayRef<Tensor::SizesType> lhs,ArrayRef<Tensor::SizesType> rhs)25 inline bool sizes_match_ignoring_leading_1s(
26     ArrayRef<Tensor::SizesType> lhs,
27     ArrayRef<Tensor::SizesType> rhs) {
28   auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
29   auto lhs_end = lhs.end();
30 
31   auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
32   auto rhs_end = rhs.end();
33 
34   return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
35       std::equal(lhs_begin, lhs_end, rhs_begin);
36 }
37 } // namespace internal
38 
39 enum class ElementwiseOptimizedPath {
40   kNone,
41   kTreatAs1d,
42   kBroadcast2dBy1d,
43   kBroadcast2dBy1dReverseArguments,
44   kBroadcastNdByNd,
45   kBroadcastNdByNdReverseArguments,
46   kBroadcastLastDim,
47   kBroadcastLastDimReverseArguments,
48 };
49 
50 namespace internal {
51 
52 /*
53   Given two tensors, this function returns the broadcast dim if it exists.
54   Returns 0 if no broadcast dim is found.
55   Else negative index is used to indicate broadcast dim
56   e.g. if size = [a, b, c, 1, e, f] then broadcast dim is -3
57 
58   This path aims to handle broadcast of the following form
59   A = [a1, a2,., 1, .., an]
60   B = [b1, b2,., bm, .., bn]
61   OR
62   A = [a1, a2,., am, .., an]
63   B = [b1, b2,., 1, .., bn]
64   Note that this way of determining broadcast dim also works
65   when broadcast dim is the last dim.
66 */
get_broadcast_dim(const Tensor & lhs,const Tensor & rhs)67 int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
68   auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
69   auto lhs_end = lhs.sizes().end();
70 
71   auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
72   auto rhs_end = rhs.sizes().end();
73 
74   const auto lhs_size = lhs_end - lhs_begin;
75   const auto rhs_size = rhs_end - rhs_begin;
76 
77   // Following example is not handled at the moment
78   // [1, 3, 4, 5]
79   // [2, 3, 4, 5]
80   if (lhs_size != rhs_size) {
81     return 0;
82   }
83 
84   int32_t broadcast_dim = 0;
85   // Check
86   // 1. if any dim value is 1 (it constitutes a broadcast dim)
87   // 2. If more than one dim value is 1 (we cannot handle)
88   // 3. If non-1 dim values are equal
89   lhs_end--;
90   rhs_end--;
91   while (lhs_end != lhs_begin) {
92     if (*lhs_end == 1 || *rhs_end == 1) {
93       // If more than one broadcast dim is found, return 0.
94       if (broadcast_dim != 0) {
95         return 0;
96       }
97       // negative index is used
98       broadcast_dim = lhs_end - lhs.sizes().end();
99     } else if (*lhs_end != *rhs_end) {
100       // If non-1 dim values are not equal, return 0.
101       return 0;
102     }
103     lhs_end--;
104     rhs_end--;
105   }
106   return broadcast_dim;
107 }
108 
select_broadcast_optimized_path(const Tensor & lhs,const Tensor & rhs)109 inline ElementwiseOptimizedPath select_broadcast_optimized_path(
110     const Tensor& lhs,
111     const Tensor& rhs) {
112   auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
113   auto lhs_end = lhs.sizes().end();
114 
115   auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
116   auto rhs_end = rhs.sizes().end();
117 
118   const auto lhs_size = lhs_end - lhs_begin;
119   const auto rhs_size = rhs_end - rhs_begin;
120   if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1] == rhs_begin[0]) {
121     return ElementwiseOptimizedPath::kBroadcast2dBy1d;
122   }
123 
124   if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1] == lhs_begin[0]) {
125     return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
126   }
127 
128   int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
129   // Right now we dont handle last dim broadcast
130   if (broadcast_dim < -1) {
131     if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
132           return x == 1;
133         }) == 1) {
134       return ElementwiseOptimizedPath::kBroadcastNdByNd;
135     } else {
136       return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
137     }
138   } else if (broadcast_dim == -1) {
139     if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) {
140           return x == 1;
141         }) == 1) {
142       return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
143     } else {
144       return ElementwiseOptimizedPath::kBroadcastLastDim;
145     }
146   }
147   return ElementwiseOptimizedPath::kNone;
148 }
149 } // namespace internal
150 
select_optimized_path(const Tensor & a,const Tensor & b,const Tensor & out)151 ElementwiseOptimizedPath inline select_optimized_path(
152     const Tensor& a,
153     const Tensor& b,
154     const Tensor& out) {
155   ScalarType a_type = a.scalar_type();
156   ScalarType b_type = b.scalar_type();
157   ScalarType out_type = out.scalar_type();
158 
159   if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half ||
160       a_type == ScalarType::BFloat16) {
161     return ElementwiseOptimizedPath::kNone;
162   }
163   if (a.sizes().equals(b.sizes()) ||
164       (a.numel() == b.numel() &&
165        (a.numel() == out.numel() ||
166         internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
167     return ElementwiseOptimizedPath::kTreatAs1d;
168   }
169   return internal::select_broadcast_optimized_path(a, b);
170 }
171 
get_normalized_tensor_size(const Tensor & a,const int32_t broadcast_dim)172 std::array<int32_t, 3> inline get_normalized_tensor_size(
173     const Tensor& a,
174     const int32_t broadcast_dim) {
175   ET_CHECK_MSG(
176       a.dim() > broadcast_dim,
177       "Size of tensor: %zd, must be larger than broadcast_dim: %d",
178       a.dim(),
179       broadcast_dim);
180   std::array<int32_t, 3> normalized_tensor_size;
181   normalized_tensor_size[0] = 1;
182   normalized_tensor_size[1] = a.size(broadcast_dim);
183   normalized_tensor_size[2] = 1;
184   for (size_t i = 0; i < broadcast_dim; i++) {
185     normalized_tensor_size[0] *= a.size(i);
186   }
187   for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
188     normalized_tensor_size[2] *= a.size(i);
189   }
190   return normalized_tensor_size;
191 }
192 
193 } // namespace executor
194 } // namespace torch
195