xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/functional.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <vector>
4 #include <c10/util/ArrayRef.h>
5 
6 namespace c10 {
7 
8 // The passed in function must take T by value (T), or by
9 // const reference (const T&); taking T by non-const reference
10 // will result in an error like:
11 //
12 //    error: no type named 'type' in 'class std::invoke_result<foobar::__lambda, T>'
13 //
14 // No explicit template parameters are required.
15 
16 // Overload for explicit function and ArrayRef
17 template<class F, class T>
18 inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
19   std::vector<decltype(fn(*inputs.begin()))> r;
20   r.reserve(inputs.size());
21   for(const auto & input : inputs)
22     r.push_back(fn(input));
23   return r;
24 }
25 
26 // C++ forbids taking an address of a constructor, so here's a workaround...
27 // Overload for constructor (R) application
28 template<typename R, typename T>
fmap(const T & inputs)29 inline std::vector<R> fmap(const T& inputs) {
30   std::vector<R> r;
31   r.reserve(inputs.size());
32   for(auto & input : inputs)
33     r.push_back(R(input));
34   return r;
35 }
36 
37 template<typename F, typename T>
filter(at::ArrayRef<T> inputs,const F & fn)38 inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
39   std::vector<T> r;
40   r.reserve(inputs.size());
41   for(auto & input : inputs) {
42     if (fn(input)) {
43       r.push_back(input);
44     }
45   }
46   return r;
47 }
48 
49 template<typename F, typename T>
filter(const std::vector<T> & inputs,const F & fn)50 inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
51   return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
52 }
53 
54 } // namespace c10
55