1 // Copyright 2021 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H
16 #define GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <assert.h>
21 #include <stddef.h>
22 
23 #include <array>
24 #include <tuple>
25 #include <type_traits>
26 #include <utility>
27 
28 #include "absl/utility/utility.h"
29 
30 #include "src/core/lib/gprpp/bitset.h"
31 #include "src/core/lib/gprpp/construct_destruct.h"
32 #include "src/core/lib/promise/detail/promise_like.h"
33 #include "src/core/lib/promise/poll.h"
34 
35 namespace grpc_core {
36 namespace promise_detail {
37 
38 // This union can either be a functor, or the result of the functor (after
39 // mapping via a trait). Allows us to remember the result of one joined functor
40 // until the rest are ready.
41 template <typename Traits, typename F>
42 union Fused {
Fused(F && f)43   explicit Fused(F&& f) : f(std::forward<F>(f)) {}
Fused(PromiseLike<F> && f)44   explicit Fused(PromiseLike<F>&& f) : f(std::forward<PromiseLike<F>>(f)) {}
~Fused()45   ~Fused() {}
46   // Wrap the functor in a PromiseLike to handle immediately returning functors
47   // and the like.
48   using Promise = PromiseLike<F>;
49   GPR_NO_UNIQUE_ADDRESS Promise f;
50   // Compute the result type: We take the result of the promise, and pass it via
51   // our traits, so that, for example, TryJoin and take a StatusOr<T> and just
52   // store a T.
53   using Result = typename Traits::template ResultType<typename Promise::Result>;
54   GPR_NO_UNIQUE_ADDRESS Result result;
55 };
56 
57 // A join gets composed of joints... these are just wrappers around a Fused for
58 // their data, with some machinery as methods to get the system working.
59 template <typename Traits, size_t kRemaining, typename... Fs>
60 struct Joint : public Joint<Traits, kRemaining - 1, Fs...> {
61   // The index into Fs for this Joint
62   static constexpr size_t kIdx = sizeof...(Fs) - kRemaining;
63   // The next join (the one we derive from)
64   using NextJoint = Joint<Traits, kRemaining - 1, Fs...>;
65   // From Fs, extract the functor for this joint.
66   using F = typename std::tuple_element<kIdx, std::tuple<Fs...>>::type;
67   // Generate the Fused type for this functor.
68   using Fsd = Fused<Traits, F>;
69   GPR_NO_UNIQUE_ADDRESS Fsd fused;
70   // Figure out what kind of bitmask will be used by the outer join.
71   using Bits = BitSet<sizeof...(Fs)>;
72   // Initialize from a tuple of pointers to Fs
JointJoint73   explicit Joint(std::tuple<Fs*...> fs)
74       : NextJoint(fs), fused(std::move(*std::get<kIdx>(fs))) {}
75   // Copy: assume that the Fuse is still in the promise state (since it's not
76   // legal to copy after the first poll!)
JointJoint77   Joint(const Joint& j) : NextJoint(j), fused(j.fused.f) {}
78   // Move: assume that the Fuse is still in the promise state (since it's not
79   // legal to move after the first poll!)
JointJoint80   Joint(Joint&& j) noexcept
81       : NextJoint(std::forward<NextJoint>(j)), fused(std::move(j.fused.f)) {}
82   // Destruct: check bits to see if we're in promise or result state, and call
83   // the appropriate destructor. Recursively, call up through the join.
DestructAllJoint84   void DestructAll(const Bits& bits) {
85     if (!bits.is_set(kIdx)) {
86       Destruct(&fused.f);
87     } else {
88       Destruct(&fused.result);
89     }
90     NextJoint::DestructAll(bits);
91   }
92   // Poll all joints up, and then call finally.
93   template <typename F>
94   auto Run(Bits* bits, F finally) -> decltype(finally()) {
95     // If we're still in the promise state...
96     if (!bits->is_set(kIdx)) {
97       // Poll the promise
98       auto r = fused.f();
99       if (auto* p = r.value_if_ready()) {
100         // If it's done, then ask the trait to unwrap it and store that result
101         // in the Fused, and continue the iteration. Note that OnResult could
102         // instead choose to return a value instead of recursing through the
103         // iteration, in that case we continue returning the same result up.
104         // Here is where TryJoin can escape out.
105         return Traits::OnResult(
106             std::move(*p), [this, bits, &finally](typename Fsd::Result result) {
107               bits->set(kIdx);
108               Destruct(&fused.f);
109               Construct(&fused.result, std::move(result));
110               return NextJoint::Run(bits, std::move(finally));
111             });
112       }
113     }
114     // That joint is still pending... we'll still poll the result of the joints.
115     return NextJoint::Run(bits, std::move(finally));
116   }
117 };
118 
119 // Terminating joint... for each of the recursions, do the thing we're supposed
120 // to do at the end.
121 template <typename Traits, typename... Fs>
122 struct Joint<Traits, 0, Fs...> {
123   explicit Joint(std::tuple<Fs*...>) {}
124   Joint(const Joint&) {}
125   Joint(Joint&&) noexcept {}
126   template <typename T>
127   void DestructAll(const T&) {}
128   template <typename F>
129   auto Run(BitSet<sizeof...(Fs)>*, F finally) -> decltype(finally()) {
130     return finally();
131   }
132 };
133 
134 template <typename Traits, typename... Fs>
135 class BasicJoin {
136  private:
137   // How many things are we joining?
138   static constexpr size_t N = sizeof...(Fs);
139   // Bitset: if a bit is 0, that joint is still in promise state. If it's 1,
140   // then the joint has a result.
141   GPR_NO_UNIQUE_ADDRESS BitSet<N> state_;
142   // The actual joints, wrapped in an anonymous union to give us control of
143   // construction/destruction.
144   union {
145     GPR_NO_UNIQUE_ADDRESS Joint<Traits, sizeof...(Fs), Fs...> joints_;
146   };
147 
148   // Access joint index I
149   template <size_t I>
150   Joint<Traits, sizeof...(Fs) - I, Fs...>* GetJoint() {
151     return static_cast<Joint<Traits, sizeof...(Fs) - I, Fs...>*>(&joints_);
152   }
153 
154   // The tuple of results of all our promises
155   using Tuple = std::tuple<typename Fused<Traits, Fs>::Result...>;
156 
157   // Collect up all the results and construct a tuple.
158   template <size_t... I>
159   Tuple Finish(absl::index_sequence<I...>) {
160     return Tuple(std::move(GetJoint<I>()->fused.result)...);
161   }
162 
163  public:
164   explicit BasicJoin(Fs&&... fs) : joints_(std::tuple<Fs*...>(&fs...)) {}
165   BasicJoin& operator=(const BasicJoin&) = delete;
166   // Copy a join - only available before polling.
167   BasicJoin(const BasicJoin& other) {
168     assert(other.state_.none());
169     Construct(&joints_, other.joints_);
170   }
171   // Move a join - only available before polling.
172   BasicJoin(BasicJoin&& other) noexcept {
173     assert(other.state_.none());
174     Construct(&joints_, std::move(other.joints_));
175   }
176   ~BasicJoin() { joints_.DestructAll(state_); }
177   using Result = decltype(Traits::Wrap(std::declval<Tuple>()));
178   // Poll the join
179   Poll<Result> operator()() {
180     // Poll the joints...
181     return joints_.Run(&state_, [this]() -> Poll<Result> {
182       // If all of them are completed, collect the results, and then ask our
183       // traits to wrap them - allowing for example TryJoin to turn tuple<A,B,C>
184       // into StatusOr<tuple<A,B,C>>.
185       if (state_.all()) {
186         return Traits::Wrap(Finish(absl::make_index_sequence<N>()));
187       } else {
188         return Pending();
189       }
190     });
191   }
192 };
193 
194 }  // namespace promise_detail
195 }  // namespace grpc_core
196 
197 #endif  // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H
198