xref: /aosp_15_r20/external/pytorch/c10/util/overloaded.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 namespace c10 {
5 namespace detail {
6 
7 template <class... Ts>
8 struct overloaded_t {};
9 
10 template <class T0>
11 struct overloaded_t<T0> : T0 {
12   using T0::operator();
13   overloaded_t(T0 t0) : T0(std::move(t0)) {}
14 };
15 template <class T0, class... Ts>
16 struct overloaded_t<T0, Ts...> : T0, overloaded_t<Ts...> {
17   using T0::operator();
18   using overloaded_t<Ts...>::operator();
19   overloaded_t(T0 t0, Ts... ts)
20       : T0(std::move(t0)), overloaded_t<Ts...>(std::move(ts)...) {}
21 };
22 
23 } // namespace detail
24 
25 // Construct an overloaded callable combining multiple callables, e.g. lambdas
26 template <class... Ts>
27 detail::overloaded_t<Ts...> overloaded(Ts... ts) {
28   return {std::move(ts)...};
29 }
30 
31 } // namespace c10
32