xref: /aosp_15_r20/system/media/audio/include/system/elementwise_op.h (revision b9df5ad1c9ac98a7fefaac271a55f7ae3db05414)
1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #ifdef __cplusplus
20 
21 #include <algorithm>
22 #include <optional>
23 #include <type_traits>
24 #include <vector>
25 
26 #include <audio_utils/template_utils.h>
27 #include <android/binder_enums.h>
28 
29 namespace android::audio_utils {
30 
31 using android::audio_utils::has_tag_and_get_tag_v;
32 using android::audio_utils::is_specialization_v;
33 using android::audio_utils::op_aggregate;
34 
35 /**
36  * Type of elements needs custom comparison for the elementwise ops.
37  * When `CustomOpElementTypes` evaluated to true, custom comparison is implemented in this header.
38  * When `CustomOpElementTypes` evaluated to false, fallback to std implementation.
39  */
40 template <typename T>
41 concept CustomOpElementTypes =
42     (std::is_class_v<T> && std::is_aggregate_v<T>) ||
43     is_specialization_v<T, std::vector> || has_tag_and_get_tag_v<T>;
44 
45 /**
46  * Find the underlying value of AIDL union objects, and run an `op` with the underlying values
47  */
48 template <typename Op, typename T, std::size_t... Is, typename... Ts>
aidl_union_op_helper(Op && op,std::index_sequence<Is...>,const T & first,const Ts &...rest)49 void aidl_union_op_helper(Op&& op, std::index_sequence<Is...>, const T& first, const Ts&... rest) {
50   (([&]() -> bool {
51      const typename T::Tag TAG = static_cast<typename T::Tag>(Is);
52      if (((first.getTag() == TAG) && ... && (rest.getTag() == TAG))) {
53        // handle the case of a sub union class inside another union
54        using FieldType = decltype(first.template get<TAG>());
55        if constexpr (has_tag_and_get_tag_v<FieldType>) {
56          static constexpr std::size_t tagSize = std::ranges::distance(
57              ndk::enum_range<typename FieldType::Tag>().begin(),
58              ndk::enum_range<typename FieldType::Tag>().end());
59          return aidl_union_op_helper(op, std::make_index_sequence<tagSize>{},
60                                      first.template get<TAG>(),
61                                      rest.template get<TAG>()...);
62        } else {
63          op.template operator()<TAG>(first.template get<TAG>(), rest.template get<TAG>()...);
64          // exit the index sequence
65          return true;
66        }
67      } else {
68        return false;
69      }
70    }()) ||
71    ...);
72 }
73 
74 // check if the class `T` is an AIDL union with `has_tag_and_get_tag_v`
75 template <typename Op, typename T, typename... Ts>
requires(has_tag_and_get_tag_v<T> &&...&& has_tag_and_get_tag_v<Ts>)76   requires(has_tag_and_get_tag_v<T> && ... && has_tag_and_get_tag_v<Ts>)
77 void aidl_union_op(Op&& op, const T& first, const Ts&... rest) {
78   static constexpr std::size_t tagSize =
79       std::ranges::distance(ndk::enum_range<typename T::Tag>().begin(),
80                             ndk::enum_range<typename T::Tag>().end());
81   aidl_union_op_helper(op, std::make_index_sequence<tagSize>{}, first, rest...);
82 }
83 
84 /**
85  * Utility functions for clamping values of different types within a specified
86  * range of [min, max]. Supported types are evaluated with
87  * `CustomOpElementTypes`.
88  *
89  * - For **structures**, each member is clamped individually and reassembled
90  *   after clamping.
91  * - For **vectors**, the `min` and `max` ranges (if defined) may have either
92  *   one element or match the size of the target vector. If `min`/`max` have
93  *   only one element, each target vector element is clamped within that range.
94  *   If `min`/`max` match the target's size, each target element is clamped
95  *   within the corresponding `min`/`max` elements.
96  * - For **AIDL union** class, `aidl_union_op` is used to find the underlying
97  *   value automatically first, and then do `elementwise_clamp` on the
98  *   underlying value.
99  * - For all other types, `std::clamp` is used directly, std::string
100  *   comparison and clamp is performed lexicographically.
101  *
102  * The maximum number of members supported in a structure is `kMaxStructMember`
103  * as defined in the template_utils.h header.
104  */
105 
106 /**
107  * @brief Clamp function for aggregate types (structs).
108  */
109 template <typename T>
110   requires std::is_class_v<T> && std::is_aggregate_v<T>
111 [[nodiscard]]
112 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max);
113 
114 template <typename T>
115   requires has_tag_and_get_tag_v<T>
116 [[nodiscard]]
117 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max);
118 
119 /**
120  * @brief Clamp function for all other types, `std::clamp` is used.
121  */
122 template <typename T>
123   requires(!CustomOpElementTypes<T>)
124 [[nodiscard]]
elementwise_clamp(const T & target,const T & min,const T & max)125 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max) {
126   if (min > max) {
127     return std::nullopt;
128   }
129   return std::clamp(target, min, max);
130 }
131 
132 /**
133  * @brief Clamp function for vectors.
134  *
135  * Clamping each vector element within a specified range. The `min` and `max`
136  * vectors may have either one element or the same number of elements as the
137  * `target` vector.
138  *
139  * - If `min` or `max` contain only one element, each element in `target` is
140  *   clamped by this single value.
141  * - If `min` or `max` match `target` in size, each element in `target` is
142  *   clamped by the corresponding elements in `min` and `max`.
143  * - If size of `min` or `max` vector is neither 1 nor same size as `target`,
144  *   the range will be considered as invalid, and `std::nullopt` will be
145  *   returned.
146  *
147  * Some examples:
148  * std::vector<int> target({3, 0, 5, 2});
149  * std::vector<int> min({1});
150  * std::vector<int> max({3});
151  * elementwise_clamp(target, min, max) result will be std::vector({3, 1, 3, 2})
152  *
153  * std::vector<int> target({3, 0, 5, 2});
154  * std::vector<int> min({1, 2, 3, 4});
155  * std::vector<int> max({3, 4, 5, 6});
156  * elementwise_clamp(target, min, max) result will be std::vector({3, 2, 5, 4})
157  *
158  * std::vector<int> target({3, 0, 5, 2});
159  * std::vector<int> min({});
160  * std::vector<int> max({3, 4});
161  * elementwise_clamp(target, min, max) result will be std::nullopt
162  */
163 template <typename T>
164   requires is_specialization_v<T, std::vector>
165 [[nodiscard]]
elementwise_clamp(const T & target,const T & min,const T & max)166 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max) {
167   using ElemType = typename T::value_type;
168 
169   const size_t min_size = min.size(), max_size = max.size(),
170                target_size = target.size();
171   if (min_size == 0 || max_size == 0 || target_size == 0) {
172     return std::nullopt;
173   }
174 
175   T result;
176   result.reserve(target_size);
177 
178   if (min_size == 1 && max_size == 1) {
179     const ElemType clamp_min = min[0], clamp_max = max[0];
180     for (size_t i = 0; i < target_size; ++i) {
181       auto clamped_elem = elementwise_clamp(target[i], clamp_min, clamp_max);
182       if (clamped_elem) {
183         result.emplace_back(*clamped_elem);
184       } else {
185         return std::nullopt;
186       }
187     }
188   } else if (min_size == target_size && max_size == target_size) {
189     for (size_t i = 0; i < target_size; ++i) {
190       auto clamped_elem = elementwise_clamp(target[i], min[i], max[i]);
191       if (clamped_elem) {
192         result.emplace_back(*clamped_elem);
193       } else {
194         return std::nullopt;
195       }
196     }
197   } else if (min_size == 1 && max_size == target_size) {
198     const ElemType clamp_min = min[0];
199     for (size_t i = 0; i < target_size; ++i) {
200       auto clamped_elem = elementwise_clamp(target[i], clamp_min, max[i]);
201       if (clamped_elem) {
202         result.emplace_back(*clamped_elem);
203       } else {
204         return std::nullopt;
205       }
206     }
207   } else if (min_size == target_size && max_size == 1) {
208     const ElemType clamp_max = max[0];
209     for (size_t i = 0; i < target_size; ++i) {
210       auto clamped_elem = elementwise_clamp(target[i], min[i], clamp_max);
211       if (clamped_elem) {
212         result.emplace_back(*clamped_elem);
213       } else {
214         return std::nullopt;
215       }
216     }
217   } else {
218     // incompatible size
219     return std::nullopt;
220   }
221 
222   return result;
223 }
224 
225 /**
226  * @brief Clamp function for class and aggregate type (structs).
227  *
228  * Uses `opAggregate` with elementwise_clamp_op to perform clamping on each member of the
229  * aggregate type `T`, the max number of supported members in `T` is
230  * `kMaxStructMember`.
231  */
232 template <typename T>
233   requires std::is_class_v<T> && std::is_aggregate_v<T>
234 [[nodiscard]]
elementwise_clamp(const T & target,const T & min,const T & max)235 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max) {
236   const auto elementwise_clamp_op = [](const auto& a, const auto& b, const auto& c) {
237     return elementwise_clamp(a, b, c);
238   };
239   return op_aggregate(elementwise_clamp_op, target, min, max);
240 }
241 
242 template <typename T>
243   requires has_tag_and_get_tag_v<T>
244 [[nodiscard]]
elementwise_clamp(const T & target,const T & min,const T & max)245 std::optional<T> elementwise_clamp(const T& target, const T& min, const T& max) {
246   std::optional<T> ret = std::nullopt;
247 
248   auto elementwise_clamp_op = [&]<typename T::Tag TAG, typename P>(
249                           const P& p1, const P& p2, const P& p3) {
250     auto p = elementwise_clamp(p1, p2, p3);
251     if (!p) return;
252     ret = T::template make<TAG>(*p);
253   };
254 
255   aidl_union_op(elementwise_clamp_op, target, min, max);
256   return ret;
257 }
258 
259 /**
260  * Utility functions to determine the element-wise min/max of two values with
261  * same type. The `elementwise_min` function accepts two inputs and return
262  * the element-wise min of them, while the `elementwise_max` function
263  * calculates the element-wise max.
264  *
265  * - For **vectors**, the two input vectors may have either `0`, `1`, or `n`
266  *   elements. If both input vectors have more than one element, their sizes
267  *   must match. If either input vector has only one element, it is compared
268  *   with each element of the other input vector.
269  * - For **structures (aggregate types)**, each element field is compared
270  *   individually, and the final result is reassembled from the element field
271  *   comparison result.
272  * - For **AIDL union** class, `aidl_union_op` is used to find the underlying
273  *   value automatically first, and then do elementwise min/max on the
274  *   underlying value.
275  * - For all other types, `std::min`/`std::max` is used directly, std::string
276  *   comparison and clamp is performed lexicographically.
277  *
278  * The maximum number of element fields supported in a structure is defined by
279  * `android::audio_utils::kMaxStructMember` as defined in the `template_utils.h`
280  * header.
281  */
282 
283 template <typename T>
284   requires std::is_class_v<T> && std::is_aggregate_v<T>
285 [[nodiscard]]
286 std::optional<T> elementwise_min(const T& a, const T& b);
287 
288 template <typename T>
289   requires has_tag_and_get_tag_v<T>
290 [[nodiscard]]
291 std::optional<T> elementwise_min(const T& a, const T& b);
292 
293 template <typename T>
294   requires std::is_class_v<T> && std::is_aggregate_v<T>
295 [[nodiscard]]
296 std::optional<T> elementwise_max(const T& a, const T& b);
297 
298 template <typename T>
299   requires has_tag_and_get_tag_v<T>
300 [[nodiscard]]
301 std::optional<T> elementwise_max(const T& a, const T& b);
302 
303 /**
304  * @brief Determines the min/max for all other type values.
305  *
306  * @tparam T The target type.
307  * @param a The first value.
308  * @param b The second value.
309  * @return The min/max of the two inputs.
310  *
311  * Example:
312  * int a = 3;
313  * int b = 5;
314  * auto result = elementwise_min(a, b);  // result will be 3
315  * auto result = elementwise_max(a, b);  // result will be 5
316  */
317 template <typename T>
318   requires(!CustomOpElementTypes<T>)
319 [[nodiscard]]
elementwise_min(const T & a,const T & b)320 std::optional<T> elementwise_min(const T& a, const T& b) {
321   return std::min(a, b);
322 }
323 
324 template <typename T>
325   requires(!CustomOpElementTypes<T>)
326 [[nodiscard]]
elementwise_max(const T & a,const T & b)327 std::optional<T> elementwise_max(const T& a, const T& b) {
328   return std::max(a, b);
329 }
330 
331 /**
332  * @brief Determines the element-wise min/max of two vectors by comparing
333  * each corresponding element.
334  *
335  * This function calculates the element-wise min/max of two input vectors. The
336  * valid sizes for input vectors `a` and `b` can be 0, 1, or `n` (where `n >
337  * 1`). If both `a` and `b` contain more than one element, their sizes must be
338  * equal. If either vector has only one element, that value will be compared
339  * with each element of the other vector.
340  *
341  * Some examples:
342  * std::vector<int> a({1, 2, 3, 4});
343  * std::vector<int> b({3, 4, 5, 0});
344  * elementwise_min(a, b) result will be std::vector({1, 2, 3, 0})
345  * elementwise_max(a, b) result will be std::vector({3, 4, 5, 4})
346  *
347  * std::vector<int> a({1});
348  * std::vector<int> b({3, 4, 5, 0});
349  * elementwise_min(a, b) result will be std::vector({1, 1, 1, 0})
350  * elementwise_max(a, b) result will be std::vector({3, 4, 5, 1})
351  *
352  * std::vector<int> a({1, 2, 3});
353  * std::vector<int> b({});
354  * elementwise_min(a, b) result will be std::vector({})
355  * elementwise_max(a, b) result will be std::vector({1, 2, 3})
356  *
357  * std::vector<int> a({1, 2, 3, 4});
358  * std::vector<int> b({3, 4, 0});
359  * elementwise_min(a, b) and elementwise_max(a, b) result will be std::nullopt
360  *
361  * @tparam T The vector type.
362  * @param a The first vector.
363  * @param b The second vector.
364  * @return A vector representing the element-wise min/max, or `std::nullopt` if
365  * sizes are incompatible.
366  */
367 template <typename T>
368   requires is_specialization_v<T, std::vector>
369 [[nodiscard]]
elementwise_min(const T & a,const T & b)370 std::optional<T> elementwise_min(const T& a, const T& b) {
371   T result;
372   const size_t a_size = a.size(), b_size = b.size();
373   if (a_size == 0 || b_size == 0) {
374     return result;
375   }
376 
377   if (a_size == b_size) {
378     for (size_t i = 0; i < a_size; ++i) {
379       auto lower_elem = elementwise_min(a[i], b[i]);
380       if (lower_elem) {
381         result.emplace_back(*lower_elem);
382       }
383     }
384   } else if (a_size == 1) {
385     for (size_t i = 0; i < b_size; ++i) {
386       auto lower_elem = elementwise_min(a[0], b[i]);
387       if (lower_elem) {
388         result.emplace_back(*lower_elem);
389       }
390     }
391   } else if (b_size == 1) {
392     for (size_t i = 0; i < a_size; ++i) {
393       auto lower_elem = elementwise_min(a[i], b[0]);
394       if (lower_elem) {
395         result.emplace_back(*lower_elem);
396       }
397     }
398   } else {
399     // incompatible size
400     return std::nullopt;
401   }
402 
403   return result;
404 }
405 
406 template <typename T>
407   requires is_specialization_v<T, std::vector>
408 [[nodiscard]]
elementwise_max(const T & a,const T & b)409 std::optional<T> elementwise_max(const T& a, const T& b) {
410   T result;
411   const size_t a_size = a.size(), b_size = b.size();
412   if (a_size == 0) {
413     result = b;
414   } else if (b_size == 0) {
415     result = a;
416   } else if (a_size == b_size) {
417     for (size_t i = 0; i < a_size; ++i) {
418       auto upper_elem = elementwise_max(a[i], b[i]);
419       if (upper_elem) result.emplace_back(*upper_elem);
420     }
421   } else if (a_size == 1) {
422     for (size_t i = 0; i < b_size; ++i) {
423       auto upper_elem = elementwise_max(a[0], b[i]);
424       if (upper_elem) result.emplace_back(*upper_elem);
425     }
426   } else if (b_size == 1) {
427     for (size_t i = 0; i < a_size; ++i) {
428       auto upper_elem = elementwise_max(a[i], b[0]);
429       if (upper_elem) result.emplace_back(*upper_elem);
430     }
431   } else {
432     // incompatible size
433     return std::nullopt;
434   }
435 
436   return result;
437 }
438 
439 /**
440  * @brief Determines the element-wise min/max of two aggregate type values
441  * by comparing each corresponding element.
442  *
443  * @tparam T The type of the aggregate values.
444  * @param a The first aggregate.
445  * @param b The second aggregate.
446  * @return A new aggregate representing the element-wise min/max of the two
447  * inputs, or `std::nullopt` if the element-wise comparison fails.
448  *
449  * Example:
450  * struct Point {
451  *   int x;
452  *   int y;
453  * };
454  * Point p1{3, 5};
455  * Point p2{4, 2};
456  * auto result = elementwise_min(p1, p2);  // result will be Point{3, 2}
457  * auto result = elementwise_max(p1, p2);  // result will be Point{4, 5}
458  */
459 template <typename T>
460   requires std::is_class_v<T> && std::is_aggregate_v<T>
461 [[nodiscard]]
elementwise_min(const T & a,const T & b)462 std::optional<T> elementwise_min(const T& a, const T& b) {
463   const auto elementwise_min_op = [](const auto& a_member, const auto& b_member) {
464     return elementwise_min(a_member, b_member);
465   };
466   return op_aggregate(elementwise_min_op, a, b);
467 }
468 
469 template <typename T>
470   requires std::is_class_v<T> && std::is_aggregate_v<T>
471 [[nodiscard]]
elementwise_max(const T & a,const T & b)472 std::optional<T> elementwise_max(const T& a, const T& b) {
473   const auto elementwise_max_op = [](const auto& a_member, const auto& b_member) {
474     return elementwise_max(a_member, b_member);
475   };
476   return op_aggregate(elementwise_max_op, a, b);
477 }
478 
479 template <typename T>
480   requires has_tag_and_get_tag_v<T>
481 [[nodiscard]]
elementwise_min(const T & a,const T & b)482 std::optional<T> elementwise_min(const T& a, const T& b) {
483   std::optional<T> ret = std::nullopt;
484   auto elementwise_min_op = [&]<typename T::Tag TAG, typename P>(const P& p1, const P& p2) {
485     auto p = elementwise_min(p1, p2);
486     if (!p) return;
487     ret = T::template make<TAG>(*p);
488   };
489   aidl_union_op(elementwise_min_op, a, b);
490   return ret;
491 }
492 
493 // handle the case of a sub union class inside another union
494 template <typename T>
495   requires has_tag_and_get_tag_v<T>
496 [[nodiscard]]
elementwise_max(const T & a,const T & b)497 std::optional<T> elementwise_max(const T& a, const T& b) {
498   std::optional<T> ret = std::nullopt;
499   auto elementwise_max_op = [&]<typename T::Tag TAG, typename P>(const P& p1, const P& p2) {
500     auto p = elementwise_max(p1, p2);
501     if (!p) return;
502     ret = T::template make<TAG>(*p);
503   };
504   aidl_union_op(elementwise_max_op, a, b);
505   return ret;
506 }
507 
508 }  // namespace android::audio_utils
509 
510 #endif  // __cplusplus