1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace internal {
16 // NOTE: we bake ArrayRef iterators being pointers into the return
17 // type here because we assume that iterators are portable across
18 // ArrayRef copies.
arrayref_begin_ignoring_leading_1s(ArrayRef<Tensor::SizesType> arr)19 inline const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
20 ArrayRef<Tensor::SizesType> arr) {
21 return std::find_if(
22 arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
23 }
24
sizes_match_ignoring_leading_1s(ArrayRef<Tensor::SizesType> lhs,ArrayRef<Tensor::SizesType> rhs)25 inline bool sizes_match_ignoring_leading_1s(
26 ArrayRef<Tensor::SizesType> lhs,
27 ArrayRef<Tensor::SizesType> rhs) {
28 auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
29 auto lhs_end = lhs.end();
30
31 auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
32 auto rhs_end = rhs.end();
33
34 return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
35 std::equal(lhs_begin, lhs_end, rhs_begin);
36 }
37 } // namespace internal
38
39 enum class ElementwiseOptimizedPath {
40 kNone,
41 kTreatAs1d,
42 kBroadcast2dBy1d,
43 kBroadcast2dBy1dReverseArguments,
44 kBroadcastNdByNd,
45 kBroadcastNdByNdReverseArguments,
46 kBroadcastLastDim,
47 kBroadcastLastDimReverseArguments,
48 };
49
50 namespace internal {
51
52 /*
53 Given two tensors, this function returns the broadcast dim if it exists.
54 Returns 0 if no broadcast dim is found.
55 Else negative index is used to indicate broadcast dim
56 e.g. if size = [a, b, c, 1, e, f] then broadcast dim is -3
57
58 This path aims to handle broadcast of the following form
59 A = [a1, a2,., 1, .., an]
60 B = [b1, b2,., bm, .., bn]
61 OR
62 A = [a1, a2,., am, .., an]
63 B = [b1, b2,., 1, .., bn]
64 Note that this way of determining broadcast dim also works
65 when broadcast dim is the last dim.
66 */
get_broadcast_dim(const Tensor & lhs,const Tensor & rhs)67 int32_t inline get_broadcast_dim(const Tensor& lhs, const Tensor& rhs) {
68 auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
69 auto lhs_end = lhs.sizes().end();
70
71 auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
72 auto rhs_end = rhs.sizes().end();
73
74 const auto lhs_size = lhs_end - lhs_begin;
75 const auto rhs_size = rhs_end - rhs_begin;
76
77 // Following example is not handled at the moment
78 // [1, 3, 4, 5]
79 // [2, 3, 4, 5]
80 if (lhs_size != rhs_size) {
81 return 0;
82 }
83
84 int32_t broadcast_dim = 0;
85 // Check
86 // 1. if any dim value is 1 (it constitutes a broadcast dim)
87 // 2. If more than one dim value is 1 (we cannot handle)
88 // 3. If non-1 dim values are equal
89 lhs_end--;
90 rhs_end--;
91 while (lhs_end != lhs_begin) {
92 if (*lhs_end == 1 || *rhs_end == 1) {
93 // If more than one broadcast dim is found, return 0.
94 if (broadcast_dim != 0) {
95 return 0;
96 }
97 // negative index is used
98 broadcast_dim = lhs_end - lhs.sizes().end();
99 } else if (*lhs_end != *rhs_end) {
100 // If non-1 dim values are not equal, return 0.
101 return 0;
102 }
103 lhs_end--;
104 rhs_end--;
105 }
106 return broadcast_dim;
107 }
108
select_broadcast_optimized_path(const Tensor & lhs,const Tensor & rhs)109 inline ElementwiseOptimizedPath select_broadcast_optimized_path(
110 const Tensor& lhs,
111 const Tensor& rhs) {
112 auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
113 auto lhs_end = lhs.sizes().end();
114
115 auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
116 auto rhs_end = rhs.sizes().end();
117
118 const auto lhs_size = lhs_end - lhs_begin;
119 const auto rhs_size = rhs_end - rhs_begin;
120 if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1] == rhs_begin[0]) {
121 return ElementwiseOptimizedPath::kBroadcast2dBy1d;
122 }
123
124 if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1] == lhs_begin[0]) {
125 return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
126 }
127
128 int32_t broadcast_dim = get_broadcast_dim(lhs, rhs);
129 // Right now we dont handle last dim broadcast
130 if (broadcast_dim < -1) {
131 if (std::count_if(rhs_begin, rhs_end, [](Tensor::SizesType x) {
132 return x == 1;
133 }) == 1) {
134 return ElementwiseOptimizedPath::kBroadcastNdByNd;
135 } else {
136 return ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments;
137 }
138 } else if (broadcast_dim == -1) {
139 if (std::count_if(lhs_begin, lhs_end, [](Tensor::SizesType x) {
140 return x == 1;
141 }) == 1) {
142 return ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments;
143 } else {
144 return ElementwiseOptimizedPath::kBroadcastLastDim;
145 }
146 }
147 return ElementwiseOptimizedPath::kNone;
148 }
149 } // namespace internal
150
select_optimized_path(const Tensor & a,const Tensor & b,const Tensor & out)151 ElementwiseOptimizedPath inline select_optimized_path(
152 const Tensor& a,
153 const Tensor& b,
154 const Tensor& out) {
155 ScalarType a_type = a.scalar_type();
156 ScalarType b_type = b.scalar_type();
157 ScalarType out_type = out.scalar_type();
158
159 if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half ||
160 a_type == ScalarType::BFloat16) {
161 return ElementwiseOptimizedPath::kNone;
162 }
163 if (a.sizes().equals(b.sizes()) ||
164 (a.numel() == b.numel() &&
165 (a.numel() == out.numel() ||
166 internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
167 return ElementwiseOptimizedPath::kTreatAs1d;
168 }
169 return internal::select_broadcast_optimized_path(a, b);
170 }
171
get_normalized_tensor_size(const Tensor & a,const int32_t broadcast_dim)172 std::array<int32_t, 3> inline get_normalized_tensor_size(
173 const Tensor& a,
174 const int32_t broadcast_dim) {
175 ET_CHECK_MSG(
176 a.dim() > broadcast_dim,
177 "Size of tensor: %zd, must be larger than broadcast_dim: %d",
178 a.dim(),
179 broadcast_dim);
180 std::array<int32_t, 3> normalized_tensor_size;
181 normalized_tensor_size[0] = 1;
182 normalized_tensor_size[1] = a.size(broadcast_dim);
183 normalized_tensor_size[2] = 1;
184 for (size_t i = 0; i < broadcast_dim; i++) {
185 normalized_tensor_size[0] *= a.size(i);
186 }
187 for (size_t i = broadcast_dim + 1; i < a.dim(); i++) {
188 normalized_tensor_size[2] *= a.size(i);
189 }
190 return normalized_tensor_size;
191 }
192
193 } // namespace executor
194 } // namespace torch
195