1 #pragma once
2
3 #include <vector>
4 #include <c10/util/ArrayRef.h>
5
6 namespace c10 {
7
8 // The passed in function must take T by value (T), or by
9 // const reference (const T&); taking T by non-const reference
10 // will result in an error like:
11 //
12 // error: no type named 'type' in 'class std::invoke_result<foobar::__lambda, T>'
13 //
14 // No explicit template parameters are required.
15
16 // Overload for explicit function and ArrayRef
17 template<class F, class T>
18 inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
19 std::vector<decltype(fn(*inputs.begin()))> r;
20 r.reserve(inputs.size());
21 for(const auto & input : inputs)
22 r.push_back(fn(input));
23 return r;
24 }
25
26 // C++ forbids taking an address of a constructor, so here's a workaround...
27 // Overload for constructor (R) application
28 template<typename R, typename T>
fmap(const T & inputs)29 inline std::vector<R> fmap(const T& inputs) {
30 std::vector<R> r;
31 r.reserve(inputs.size());
32 for(auto & input : inputs)
33 r.push_back(R(input));
34 return r;
35 }
36
37 template<typename F, typename T>
filter(at::ArrayRef<T> inputs,const F & fn)38 inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
39 std::vector<T> r;
40 r.reserve(inputs.size());
41 for(auto & input : inputs) {
42 if (fn(input)) {
43 r.push_back(input);
44 }
45 }
46 return r;
47 }
48
49 template<typename F, typename T>
filter(const std::vector<T> & inputs,const F & fn)50 inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
51 return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
52 }
53
54 } // namespace c10
55