xref: /aosp_15_r20/external/grpc-grpc/src/core/load_balancing/ring_hash/ring_hash.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 // Copyright 2018 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include <grpc/support/port_platform.h>
18 
19 #include "src/core/load_balancing/ring_hash/ring_hash.h"
20 
21 #include <inttypes.h>
22 #include <stdlib.h>
23 
24 #include <algorithm>
25 #include <cmath>
26 #include <map>
27 #include <memory>
28 #include <string>
29 #include <utility>
30 #include <vector>
31 
32 #include "absl/base/attributes.h"
33 #include "absl/container/inlined_vector.h"
34 #include "absl/status/status.h"
35 #include "absl/status/statusor.h"
36 #include "absl/strings/str_cat.h"
37 #include "absl/strings/string_view.h"
38 #include "absl/types/optional.h"
39 
40 #include <grpc/impl/channel_arg_names.h>
41 #include <grpc/impl/connectivity_state.h>
42 #include <grpc/support/json.h>
43 #include <grpc/support/log.h>
44 
45 #include "src/core/client_channel/client_channel_internal.h"
46 #include "src/core/load_balancing/pick_first/pick_first.h"
47 #include "src/core/lib/address_utils/sockaddr_utils.h"
48 #include "src/core/lib/channel/channel_args.h"
49 #include "src/core/lib/config/core_configuration.h"
50 #include "src/core/lib/debug/trace.h"
51 #include "src/core/lib/gprpp/crash.h"
52 #include "src/core/lib/gprpp/debug_location.h"
53 #include "src/core/lib/gprpp/orphanable.h"
54 #include "src/core/lib/gprpp/ref_counted.h"
55 #include "src/core/lib/gprpp/ref_counted_ptr.h"
56 #include "src/core/lib/gprpp/unique_type_name.h"
57 #include "src/core/lib/gprpp/work_serializer.h"
58 #include "src/core/lib/gprpp/xxhash_inline.h"
59 #include "src/core/lib/iomgr/closure.h"
60 #include "src/core/lib/iomgr/error.h"
61 #include "src/core/lib/iomgr/exec_ctx.h"
62 #include "src/core/lib/iomgr/pollset_set.h"
63 #include "src/core/lib/iomgr/resolved_address.h"
64 #include "src/core/lib/json/json.h"
65 #include "src/core/lib/transport/connectivity_state.h"
66 #include "src/core/load_balancing/delegating_helper.h"
67 #include "src/core/load_balancing/lb_policy.h"
68 #include "src/core/load_balancing/lb_policy_factory.h"
69 #include "src/core/load_balancing/lb_policy_registry.h"
70 #include "src/core/resolver/endpoint_addresses.h"
71 
72 namespace grpc_core {
73 
74 TraceFlag grpc_lb_ring_hash_trace(false, "ring_hash_lb");
75 
TypeName()76 UniqueTypeName RequestHashAttribute::TypeName() {
77   static UniqueTypeName::Factory kFactory("request_hash");
78   return kFactory.Create();
79 }
80 
81 // Helper Parser method
82 
JsonLoader(const JsonArgs &)83 const JsonLoaderInterface* RingHashConfig::JsonLoader(const JsonArgs&) {
84   static const auto* loader =
85       JsonObjectLoader<RingHashConfig>()
86           .OptionalField("minRingSize", &RingHashConfig::min_ring_size)
87           .OptionalField("maxRingSize", &RingHashConfig::max_ring_size)
88           .Finish();
89   return loader;
90 }
91 
JsonPostLoad(const Json &,const JsonArgs &,ValidationErrors * errors)92 void RingHashConfig::JsonPostLoad(const Json&, const JsonArgs&,
93                                   ValidationErrors* errors) {
94   {
95     ValidationErrors::ScopedField field(errors, ".minRingSize");
96     if (!errors->FieldHasErrors() &&
97         (min_ring_size == 0 || min_ring_size > 8388608)) {
98       errors->AddError("must be in the range [1, 8388608]");
99     }
100   }
101   {
102     ValidationErrors::ScopedField field(errors, ".maxRingSize");
103     if (!errors->FieldHasErrors() &&
104         (max_ring_size == 0 || max_ring_size > 8388608)) {
105       errors->AddError("must be in the range [1, 8388608]");
106     }
107   }
108   if (min_ring_size > max_ring_size) {
109     errors->AddError("max_ring_size cannot be smaller than min_ring_size");
110   }
111 }
112 
113 namespace {
114 
115 constexpr absl::string_view kRingHash = "ring_hash_experimental";
116 
117 class RingHashLbConfig final : public LoadBalancingPolicy::Config {
118  public:
RingHashLbConfig(size_t min_ring_size,size_t max_ring_size)119   RingHashLbConfig(size_t min_ring_size, size_t max_ring_size)
120       : min_ring_size_(min_ring_size), max_ring_size_(max_ring_size) {}
name() const121   absl::string_view name() const override { return kRingHash; }
min_ring_size() const122   size_t min_ring_size() const { return min_ring_size_; }
max_ring_size() const123   size_t max_ring_size() const { return max_ring_size_; }
124 
125  private:
126   size_t min_ring_size_;
127   size_t max_ring_size_;
128 };
129 
130 //
131 // ring_hash LB policy
132 //
133 
134 constexpr size_t kRingSizeCapDefault = 4096;
135 
136 class RingHash final : public LoadBalancingPolicy {
137  public:
138   explicit RingHash(Args args);
139 
name() const140   absl::string_view name() const override { return kRingHash; }
141 
142   absl::Status UpdateLocked(UpdateArgs args) override;
143   void ResetBackoffLocked() override;
144 
145  private:
146   // A ring computed based on a config and address list.
147   class Ring final : public RefCounted<Ring> {
148    public:
149     struct RingEntry {
150       uint64_t hash;
151       size_t endpoint_index;  // Index into RingHash::endpoints_.
152     };
153 
154     Ring(RingHash* ring_hash, RingHashLbConfig* config);
155 
ring() const156     const std::vector<RingEntry>& ring() const { return ring_; }
157 
158    private:
159     std::vector<RingEntry> ring_;
160   };
161 
162   // State for a particular endpoint.  Delegates to a pick_first child policy.
163   class RingHashEndpoint final : public InternallyRefCounted<RingHashEndpoint> {
164    public:
165     // index is the index into RingHash::endpoints_ of this endpoint.
RingHashEndpoint(RefCountedPtr<RingHash> ring_hash,size_t index)166     RingHashEndpoint(RefCountedPtr<RingHash> ring_hash, size_t index)
167         : ring_hash_(std::move(ring_hash)), index_(index) {}
168 
169     void Orphan() override;
170 
index() const171     size_t index() const { return index_; }
172 
173     void UpdateLocked(size_t index);
174 
connectivity_state() const175     grpc_connectivity_state connectivity_state() const {
176       return connectivity_state_;
177     }
178 
179     // Returns info about the endpoint to be stored in the picker.
180     struct EndpointInfo {
181       RefCountedPtr<RingHashEndpoint> endpoint;
182       RefCountedPtr<SubchannelPicker> picker;
183       grpc_connectivity_state state;
184       absl::Status status;
185     };
GetInfoForPicker()186     EndpointInfo GetInfoForPicker() {
187       return {Ref(), picker_, connectivity_state_, status_};
188     }
189 
190     void ResetBackoffLocked();
191 
192     // If the child policy does not yet exist, creates it; otherwise,
193     // asks the child to exit IDLE.
194     void RequestConnectionLocked();
195 
196    private:
197     class Helper;
198 
199     void CreateChildPolicy();
200     void UpdateChildPolicyLocked();
201 
202     // Called when the child policy reports a connectivity state update.
203     void OnStateUpdate(grpc_connectivity_state new_state,
204                        const absl::Status& status,
205                        RefCountedPtr<SubchannelPicker> picker);
206 
207     // Ref to our parent.
208     RefCountedPtr<RingHash> ring_hash_;
209     size_t index_;  // Index into RingHash::endpoints_ of this endpoint.
210 
211     // The pick_first child policy.
212     OrphanablePtr<LoadBalancingPolicy> child_policy_;
213 
214     grpc_connectivity_state connectivity_state_ = GRPC_CHANNEL_IDLE;
215     absl::Status status_;
216     RefCountedPtr<SubchannelPicker> picker_;
217   };
218 
219   class Picker final : public SubchannelPicker {
220    public:
Picker(RefCountedPtr<RingHash> ring_hash)221     explicit Picker(RefCountedPtr<RingHash> ring_hash)
222         : ring_hash_(std::move(ring_hash)),
223           ring_(ring_hash_->ring_),
224           endpoints_(ring_hash_->endpoints_.size()) {
225       for (const auto& p : ring_hash_->endpoint_map_) {
226         endpoints_[p.second->index()] = p.second->GetInfoForPicker();
227       }
228     }
229 
230     PickResult Pick(PickArgs args) override;
231 
232    private:
233     // A fire-and-forget class that schedules endpoint connection attempts
234     // on the control plane WorkSerializer.
235     class EndpointConnectionAttempter final {
236      public:
EndpointConnectionAttempter(RefCountedPtr<RingHash> ring_hash,RefCountedPtr<RingHashEndpoint> endpoint)237       EndpointConnectionAttempter(RefCountedPtr<RingHash> ring_hash,
238                                   RefCountedPtr<RingHashEndpoint> endpoint)
239           : ring_hash_(std::move(ring_hash)), endpoint_(std::move(endpoint)) {
240         // Hop into ExecCtx, so that we're not holding the data plane mutex
241         // while we run control-plane code.
242         GRPC_CLOSURE_INIT(&closure_, RunInExecCtx, this, nullptr);
243         ExecCtx::Run(DEBUG_LOCATION, &closure_, absl::OkStatus());
244       }
245 
246      private:
RunInExecCtx(void * arg,grpc_error_handle)247       static void RunInExecCtx(void* arg, grpc_error_handle /*error*/) {
248         auto* self = static_cast<EndpointConnectionAttempter*>(arg);
249         self->ring_hash_->work_serializer()->Run(
250             [self]() {
251               if (!self->ring_hash_->shutdown_) {
252                 self->endpoint_->RequestConnectionLocked();
253               }
254               delete self;
255             },
256             DEBUG_LOCATION);
257       }
258 
259       RefCountedPtr<RingHash> ring_hash_;
260       RefCountedPtr<RingHashEndpoint> endpoint_;
261       grpc_closure closure_;
262     };
263 
264     RefCountedPtr<RingHash> ring_hash_;
265     RefCountedPtr<Ring> ring_;
266     std::vector<RingHashEndpoint::EndpointInfo> endpoints_;
267   };
268 
269   ~RingHash() override;
270 
271   void ShutdownLocked() override;
272 
273   // Updates the aggregate policy's connectivity state based on the
274   // endpoint list's state counters, creating a new picker.
275   // entered_transient_failure is true if the endpoint has just
276   // entered TRANSIENT_FAILURE state.
277   // If the call to this method is triggered by an endpoint entering
278   // TRANSIENT_FAILURE, then status is the status reported by the endpoint.
279   void UpdateAggregatedConnectivityStateLocked(bool entered_transient_failure,
280                                                absl::Status status);
281 
282   // Current endpoint list, channel args, and ring.
283   EndpointAddressesList endpoints_;
284   ChannelArgs args_;
285   RefCountedPtr<Ring> ring_;
286 
287   std::map<EndpointAddressSet, OrphanablePtr<RingHashEndpoint>> endpoint_map_;
288 
289   // TODO(roth): If we ever change the helper UpdateState() API to not
290   // need the status reported for TRANSIENT_FAILURE state (because
291   // it's not currently actually used for anything outside of the picker),
292   // then we will no longer need this data member.
293   absl::Status last_failure_;
294 
295   // indicating if we are shutting down.
296   bool shutdown_ = false;
297 };
298 
299 //
300 // RingHash::Picker
301 //
302 
Pick(PickArgs args)303 RingHash::PickResult RingHash::Picker::Pick(PickArgs args) {
304   auto* call_state = static_cast<ClientChannelLbCallState*>(args.call_state);
305   auto* hash_attribute = static_cast<RequestHashAttribute*>(
306       call_state->GetCallAttribute(RequestHashAttribute::TypeName()));
307   if (hash_attribute == nullptr) {
308     return PickResult::Fail(absl::InternalError("hash attribute not present"));
309   }
310   uint64_t request_hash = hash_attribute->request_hash();
311   const auto& ring = ring_->ring();
312   // Find the index in the ring to use for this RPC.
313   // Ported from https://github.com/RJ/ketama/blob/master/libketama/ketama.c
314   // (ketama_get_server) NOTE: The algorithm depends on using signed integers
315   // for lowp, highp, and index. Do not change them!
316   int64_t lowp = 0;
317   int64_t highp = ring.size();
318   int64_t index = 0;
319   while (true) {
320     index = (lowp + highp) / 2;
321     if (index == static_cast<int64_t>(ring.size())) {
322       index = 0;
323       break;
324     }
325     uint64_t midval = ring[index].hash;
326     uint64_t midval1 = index == 0 ? 0 : ring[index - 1].hash;
327     if (request_hash <= midval && request_hash > midval1) {
328       break;
329     }
330     if (midval < request_hash) {
331       lowp = index + 1;
332     } else {
333       highp = index - 1;
334     }
335     if (lowp > highp) {
336       index = 0;
337       break;
338     }
339   }
340   // Find the first endpoint we can use from the selected index.
341   for (size_t i = 0; i < ring.size(); ++i) {
342     const auto& entry = ring[(index + i) % ring.size()];
343     const auto& endpoint_info = endpoints_[entry.endpoint_index];
344     switch (endpoint_info.state) {
345       case GRPC_CHANNEL_READY:
346         return endpoint_info.picker->Pick(args);
347       case GRPC_CHANNEL_IDLE:
348         new EndpointConnectionAttempter(
349             ring_hash_.Ref(DEBUG_LOCATION, "EndpointConnectionAttempter"),
350             endpoint_info.endpoint);
351         ABSL_FALLTHROUGH_INTENDED;
352       case GRPC_CHANNEL_CONNECTING:
353         return PickResult::Queue();
354       default:
355         break;
356     }
357   }
358   return PickResult::Fail(absl::UnavailableError(absl::StrCat(
359       "ring hash cannot find a connected endpoint; first failure: ",
360       endpoints_[ring[index].endpoint_index].status.message())));
361 }
362 
363 //
364 // RingHash::Ring
365 //
366 
Ring(RingHash * ring_hash,RingHashLbConfig * config)367 RingHash::Ring::Ring(RingHash* ring_hash, RingHashLbConfig* config) {
368   // Store the weights while finding the sum.
369   struct EndpointWeight {
370     std::string address;  // Key by endpoint's first address.
371     // Default weight is 1 for the cases where a weight is not provided,
372     // each occurrence of the address will be counted a weight value of 1.
373     uint32_t weight = 1;
374     double normalized_weight;
375   };
376   std::vector<EndpointWeight> endpoint_weights;
377   size_t sum = 0;
378   const EndpointAddressesList& endpoints = ring_hash->endpoints_;
379   endpoint_weights.reserve(endpoints.size());
380   for (const auto& endpoint : endpoints) {
381     EndpointWeight endpoint_weight;
382     endpoint_weight.address =
383         grpc_sockaddr_to_string(&endpoint.addresses().front(), false).value();
384     // Weight should never be zero, but ignore it just in case, since
385     // that value would screw up the ring-building algorithm.
386     auto weight_arg = endpoint.args().GetInt(GRPC_ARG_ADDRESS_WEIGHT);
387     if (weight_arg.value_or(0) > 0) {
388       endpoint_weight.weight = *weight_arg;
389     }
390     sum += endpoint_weight.weight;
391     endpoint_weights.push_back(std::move(endpoint_weight));
392   }
393   // Calculating normalized weights and find min and max.
394   double min_normalized_weight = 1.0;
395   double max_normalized_weight = 0.0;
396   for (auto& endpoint_weight : endpoint_weights) {
397     endpoint_weight.normalized_weight =
398         static_cast<double>(endpoint_weight.weight) / sum;
399     min_normalized_weight =
400         std::min(endpoint_weight.normalized_weight, min_normalized_weight);
401     max_normalized_weight =
402         std::max(endpoint_weight.normalized_weight, max_normalized_weight);
403   }
404   // Scale up the number of hashes per host such that the least-weighted host
405   // gets a whole number of hashes on the ring. Other hosts might not end up
406   // with whole numbers, and that's fine (the ring-building algorithm below can
407   // handle this). This preserves the original implementation's behavior: when
408   // weights aren't provided, all hosts should get an equal number of hashes. In
409   // the case where this number exceeds the max_ring_size, it's scaled back down
410   // to fit.
411   const size_t ring_size_cap =
412       ring_hash->args_.GetInt(GRPC_ARG_RING_HASH_LB_RING_SIZE_CAP)
413           .value_or(kRingSizeCapDefault);
414   const size_t min_ring_size = std::min(config->min_ring_size(), ring_size_cap);
415   const size_t max_ring_size = std::min(config->max_ring_size(), ring_size_cap);
416   const double scale = std::min(
417       std::ceil(min_normalized_weight * min_ring_size) / min_normalized_weight,
418       static_cast<double>(max_ring_size));
419   // Reserve memory for the entire ring up front.
420   const uint64_t ring_size = std::ceil(scale);
421   ring_.reserve(ring_size);
422   // Populate the hash ring by walking through the (host, weight) pairs in
423   // normalized_host_weights, and generating (scale * weight) hashes for each
424   // host. Since these aren't necessarily whole numbers, we maintain running
425   // sums -- current_hashes and target_hashes -- which allows us to populate the
426   // ring in a mostly stable way.
427   absl::InlinedVector<char, 196> hash_key_buffer;
428   double current_hashes = 0.0;
429   double target_hashes = 0.0;
430   uint64_t min_hashes_per_host = ring_size;
431   uint64_t max_hashes_per_host = 0;
432   for (size_t i = 0; i < endpoints.size(); ++i) {
433     const std::string& address_string = endpoint_weights[i].address;
434     hash_key_buffer.assign(address_string.begin(), address_string.end());
435     hash_key_buffer.emplace_back('_');
436     auto offset_start = hash_key_buffer.end();
437     target_hashes += scale * endpoint_weights[i].normalized_weight;
438     size_t count = 0;
439     while (current_hashes < target_hashes) {
440       const std::string count_str = absl::StrCat(count);
441       hash_key_buffer.insert(offset_start, count_str.begin(), count_str.end());
442       absl::string_view hash_key(hash_key_buffer.data(),
443                                  hash_key_buffer.size());
444       const uint64_t hash = XXH64(hash_key.data(), hash_key.size(), 0);
445       ring_.push_back({hash, i});
446       ++count;
447       ++current_hashes;
448       hash_key_buffer.erase(offset_start, hash_key_buffer.end());
449     }
450     min_hashes_per_host =
451         std::min(static_cast<uint64_t>(i), min_hashes_per_host);
452     max_hashes_per_host =
453         std::max(static_cast<uint64_t>(i), max_hashes_per_host);
454   }
455   std::sort(ring_.begin(), ring_.end(),
456             [](const RingEntry& lhs, const RingEntry& rhs) -> bool {
457               return lhs.hash < rhs.hash;
458             });
459 }
460 
461 //
462 // RingHash::RingHashEndpoint::Helper
463 //
464 
465 class RingHash::RingHashEndpoint::Helper final
466     : public LoadBalancingPolicy::DelegatingChannelControlHelper {
467  public:
Helper(RefCountedPtr<RingHashEndpoint> endpoint)468   explicit Helper(RefCountedPtr<RingHashEndpoint> endpoint)
469       : endpoint_(std::move(endpoint)) {}
470 
~Helper()471   ~Helper() override { endpoint_.reset(DEBUG_LOCATION, "Helper"); }
472 
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)473   void UpdateState(
474       grpc_connectivity_state state, const absl::Status& status,
475       RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) override {
476     endpoint_->OnStateUpdate(state, status, std::move(picker));
477   }
478 
479  private:
parent_helper() const480   LoadBalancingPolicy::ChannelControlHelper* parent_helper() const override {
481     return endpoint_->ring_hash_->channel_control_helper();
482   }
483 
484   RefCountedPtr<RingHashEndpoint> endpoint_;
485 };
486 
487 //
488 // RingHash::RingHashEndpoint
489 //
490 
Orphan()491 void RingHash::RingHashEndpoint::Orphan() {
492   if (child_policy_ != nullptr) {
493     // Remove pollset_set linkage.
494     grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(),
495                                      ring_hash_->interested_parties());
496     child_policy_.reset();
497     picker_.reset();
498   }
499   Unref();
500 }
501 
UpdateLocked(size_t index)502 void RingHash::RingHashEndpoint::UpdateLocked(size_t index) {
503   index_ = index;
504   if (child_policy_ != nullptr) UpdateChildPolicyLocked();
505 }
506 
ResetBackoffLocked()507 void RingHash::RingHashEndpoint::ResetBackoffLocked() {
508   if (child_policy_ != nullptr) child_policy_->ResetBackoffLocked();
509 }
510 
RequestConnectionLocked()511 void RingHash::RingHashEndpoint::RequestConnectionLocked() {
512   if (child_policy_ == nullptr) {
513     CreateChildPolicy();
514   } else {
515     child_policy_->ExitIdleLocked();
516   }
517 }
518 
CreateChildPolicy()519 void RingHash::RingHashEndpoint::CreateChildPolicy() {
520   GPR_ASSERT(child_policy_ == nullptr);
521   LoadBalancingPolicy::Args lb_policy_args;
522   lb_policy_args.work_serializer = ring_hash_->work_serializer();
523   lb_policy_args.args =
524       ring_hash_->args_
525           .Set(GRPC_ARG_INTERNAL_PICK_FIRST_ENABLE_HEALTH_CHECKING, true)
526           .Set(GRPC_ARG_INTERNAL_PICK_FIRST_OMIT_STATUS_MESSAGE_PREFIX, true);
527   lb_policy_args.channel_control_helper =
528       std::make_unique<Helper>(Ref(DEBUG_LOCATION, "Helper"));
529   child_policy_ =
530       CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
531           "pick_first", std::move(lb_policy_args));
532   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
533     const EndpointAddresses& endpoint = ring_hash_->endpoints_[index_];
534     gpr_log(GPR_INFO,
535             "[RH %p] endpoint %p (index %" PRIuPTR " of %" PRIuPTR
536             ", %s): created child policy %p",
537             ring_hash_.get(), this, index_, ring_hash_->endpoints_.size(),
538             endpoint.ToString().c_str(), child_policy_.get());
539   }
540   // Add our interested_parties pollset_set to that of the newly created
541   // child policy. This will make the child policy progress upon activity on
542   // this policy, which in turn is tied to the application's call.
543   grpc_pollset_set_add_pollset_set(child_policy_->interested_parties(),
544                                    ring_hash_->interested_parties());
545   UpdateChildPolicyLocked();
546 }
547 
UpdateChildPolicyLocked()548 void RingHash::RingHashEndpoint::UpdateChildPolicyLocked() {
549   // Construct pick_first config.
550   auto config =
551       CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig(
552           Json::FromArray(
553               {Json::FromObject({{"pick_first", Json::FromObject({})}})}));
554   GPR_ASSERT(config.ok());
555   // Update child policy.
556   LoadBalancingPolicy::UpdateArgs update_args;
557   update_args.addresses =
558       std::make_shared<SingleEndpointIterator>(ring_hash_->endpoints_[index_]);
559   update_args.args = ring_hash_->args_;
560   update_args.config = std::move(*config);
561   // TODO(roth): If the child reports a non-OK status with the update,
562   // we need to propagate that back to the resolver somehow.
563   (void)child_policy_->UpdateLocked(std::move(update_args));
564 }
565 
OnStateUpdate(grpc_connectivity_state new_state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)566 void RingHash::RingHashEndpoint::OnStateUpdate(
567     grpc_connectivity_state new_state, const absl::Status& status,
568     RefCountedPtr<SubchannelPicker> picker) {
569   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
570     gpr_log(
571         GPR_INFO,
572         "[RH %p] connectivity changed for endpoint %p (%s, child_policy=%p): "
573         "prev_state=%s new_state=%s (%s)",
574         ring_hash_.get(), this,
575         ring_hash_->endpoints_[index_].ToString().c_str(), child_policy_.get(),
576         ConnectivityStateName(connectivity_state_),
577         ConnectivityStateName(new_state), status.ToString().c_str());
578   }
579   if (child_policy_ == nullptr) return;  // Already orphaned.
580   // Update state.
581   const bool entered_transient_failure =
582       connectivity_state_ != GRPC_CHANNEL_TRANSIENT_FAILURE &&
583       new_state == GRPC_CHANNEL_TRANSIENT_FAILURE;
584   connectivity_state_ = new_state;
585   status_ = status;
586   picker_ = std::move(picker);
587   // Update the aggregated connectivity state.
588   ring_hash_->UpdateAggregatedConnectivityStateLocked(entered_transient_failure,
589                                                       status);
590 }
591 
592 //
593 // RingHash
594 //
595 
RingHash(Args args)596 RingHash::RingHash(Args args) : LoadBalancingPolicy(std::move(args)) {
597   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
598     gpr_log(GPR_INFO, "[RH %p] Created", this);
599   }
600 }
601 
~RingHash()602 RingHash::~RingHash() {
603   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
604     gpr_log(GPR_INFO, "[RH %p] Destroying Ring Hash policy", this);
605   }
606 }
607 
ShutdownLocked()608 void RingHash::ShutdownLocked() {
609   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
610     gpr_log(GPR_INFO, "[RH %p] Shutting down", this);
611   }
612   shutdown_ = true;
613   endpoint_map_.clear();
614 }
615 
ResetBackoffLocked()616 void RingHash::ResetBackoffLocked() {
617   for (const auto& p : endpoint_map_) {
618     p.second->ResetBackoffLocked();
619   }
620 }
621 
UpdateLocked(UpdateArgs args)622 absl::Status RingHash::UpdateLocked(UpdateArgs args) {
623   // Check address list.
624   if (args.addresses.ok()) {
625     if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
626       gpr_log(GPR_INFO, "[RH %p] received update", this);
627     }
628     // De-dup endpoints, taking weight into account.
629     endpoints_.clear();
630     std::map<EndpointAddressSet, size_t> endpoint_indices;
631     (*args.addresses)->ForEach([&](const EndpointAddresses& endpoint) {
632       const EndpointAddressSet key(endpoint.addresses());
633       auto p = endpoint_indices.emplace(key, endpoints_.size());
634       if (!p.second) {
635         // Duplicate endpoint.  Combine weights and skip the dup.
636         EndpointAddresses& prev_endpoint = endpoints_[p.first->second];
637         int weight_arg =
638             endpoint.args().GetInt(GRPC_ARG_ADDRESS_WEIGHT).value_or(1);
639         int prev_weight_arg =
640             prev_endpoint.args().GetInt(GRPC_ARG_ADDRESS_WEIGHT).value_or(1);
641         if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
642           gpr_log(GPR_INFO,
643                   "[RH %p] merging duplicate endpoint for %s, combined "
644                   "weight %d",
645                   this, key.ToString().c_str(), weight_arg + prev_weight_arg);
646         }
647         prev_endpoint = EndpointAddresses(
648             prev_endpoint.addresses(),
649             prev_endpoint.args().Set(GRPC_ARG_ADDRESS_WEIGHT,
650                                      weight_arg + prev_weight_arg));
651       } else {
652         endpoints_.push_back(endpoint);
653       }
654     });
655   } else {
656     if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
657       gpr_log(GPR_INFO, "[RH %p] received update with addresses error: %s",
658               this, args.addresses.status().ToString().c_str());
659     }
660     // If we already have an endpoint list, then keep using the existing
661     // list, but still report back that the update was not accepted.
662     if (!endpoints_.empty()) return args.addresses.status();
663   }
664   // Save channel args.
665   args_ = std::move(args.args);
666   // Build new ring.
667   ring_ = MakeRefCounted<Ring>(
668       this, static_cast<RingHashLbConfig*>(args.config.get()));
669   // Update endpoint map.
670   std::map<EndpointAddressSet, OrphanablePtr<RingHashEndpoint>> endpoint_map;
671   for (size_t i = 0; i < endpoints_.size(); ++i) {
672     const EndpointAddresses& addresses = endpoints_[i];
673     const EndpointAddressSet address_set(addresses.addresses());
674     // If present in old map, retain it; otherwise, create a new one.
675     auto it = endpoint_map_.find(address_set);
676     if (it != endpoint_map_.end()) {
677       it->second->UpdateLocked(i);
678       endpoint_map.emplace(address_set, std::move(it->second));
679     } else {
680       endpoint_map.emplace(address_set, MakeOrphanable<RingHashEndpoint>(
681                                             RefAsSubclass<RingHash>(), i));
682     }
683   }
684   endpoint_map_ = std::move(endpoint_map);
685   // If the address list is empty, report TRANSIENT_FAILURE.
686   if (endpoints_.empty()) {
687     absl::Status status =
688         args.addresses.ok() ? absl::UnavailableError(absl::StrCat(
689                                   "empty address list: ", args.resolution_note))
690                             : args.addresses.status();
691     channel_control_helper()->UpdateState(
692         GRPC_CHANNEL_TRANSIENT_FAILURE, status,
693         MakeRefCounted<TransientFailurePicker>(status));
694     return status;
695   }
696   // Return a new picker.
697   UpdateAggregatedConnectivityStateLocked(/*entered_transient_failure=*/false,
698                                           absl::OkStatus());
699   return absl::OkStatus();
700 }
701 
UpdateAggregatedConnectivityStateLocked(bool entered_transient_failure,absl::Status status)702 void RingHash::UpdateAggregatedConnectivityStateLocked(
703     bool entered_transient_failure, absl::Status status) {
704   // Count the number of endpoints in each state.
705   size_t num_idle = 0;
706   size_t num_connecting = 0;
707   size_t num_ready = 0;
708   size_t num_transient_failure = 0;
709   for (const auto& p : endpoint_map_) {
710     switch (p.second->connectivity_state()) {
711       case GRPC_CHANNEL_READY:
712         ++num_ready;
713         break;
714       case GRPC_CHANNEL_IDLE:
715         ++num_idle;
716         break;
717       case GRPC_CHANNEL_CONNECTING:
718         ++num_connecting;
719         break;
720       case GRPC_CHANNEL_TRANSIENT_FAILURE:
721         ++num_transient_failure;
722         break;
723       default:
724         Crash("child policy should never report SHUTDOWN");
725     }
726   }
727   // The overall aggregation rules here are:
728   // 1. If there is at least one endpoint in READY state, report READY.
729   // 2. If there are 2 or more endpoints in TRANSIENT_FAILURE state, report
730   //    TRANSIENT_FAILURE.
731   // 3. If there is at least one endpoint in CONNECTING state, report
732   //    CONNECTING.
733   // 4. If there is one endpoint in TRANSIENT_FAILURE state and there is
734   //    more than one endpoint, report CONNECTING.
735   // 5. If there is at least one endpoint in IDLE state, report IDLE.
736   // 6. Otherwise, report TRANSIENT_FAILURE.
737   //
738   // We set start_connection_attempt to true if we match rules 2, 4, or 6.
739   grpc_connectivity_state state;
740   bool start_connection_attempt = false;
741   if (num_ready > 0) {
742     state = GRPC_CHANNEL_READY;
743   } else if (num_transient_failure >= 2) {
744     state = GRPC_CHANNEL_TRANSIENT_FAILURE;
745     start_connection_attempt = true;
746   } else if (num_connecting > 0) {
747     state = GRPC_CHANNEL_CONNECTING;
748   } else if (num_transient_failure == 1 && endpoints_.size() > 1) {
749     state = GRPC_CHANNEL_CONNECTING;
750     start_connection_attempt = true;
751   } else if (num_idle > 0) {
752     state = GRPC_CHANNEL_IDLE;
753   } else {
754     state = GRPC_CHANNEL_TRANSIENT_FAILURE;
755     start_connection_attempt = true;
756   }
757   if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
758     gpr_log(GPR_INFO,
759             "[RH %p] setting connectivity state to %s (num_idle=%" PRIuPTR
760             ", num_connecting=%" PRIuPTR ", num_ready=%" PRIuPTR
761             ", num_transient_failure=%" PRIuPTR ", size=%" PRIuPTR
762             ") -- start_connection_attempt=%d",
763             this, ConnectivityStateName(state), num_idle, num_connecting,
764             num_ready, num_transient_failure, endpoints_.size(),
765             start_connection_attempt);
766   }
767   // In TRANSIENT_FAILURE, report the last reported failure.
768   // Otherwise, report OK.
769   if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) {
770     if (!status.ok()) {
771       last_failure_ = absl::UnavailableError(absl::StrCat(
772           "no reachable endpoints; last error: ", status.message()));
773     }
774     status = last_failure_;
775   } else {
776     status = absl::OkStatus();
777   }
778   // Generate new picker and return it to the channel.
779   // Note that we use our own picker regardless of connectivity state.
780   channel_control_helper()->UpdateState(
781       state, status,
782       MakeRefCounted<Picker>(
783           RefAsSubclass<RingHash>(DEBUG_LOCATION, "RingHashPicker")));
784   // While the ring_hash policy is reporting TRANSIENT_FAILURE, it will
785   // not be getting any pick requests from the priority policy.
786   // However, because the ring_hash policy does not attempt to
787   // reconnect to endpoints unless it is getting pick requests,
788   // it will need special handling to ensure that it will eventually
789   // recover from TRANSIENT_FAILURE state once the problem is resolved.
790   // Specifically, it will make sure that it is attempting to connect to
791   // at least one endpoint at any given time.  But we don't want to just
792   // try to connect to only one endpoint, because if that particular
793   // endpoint happens to be down but the rest are reachable, we would
794   // incorrectly fail to recover.
795   //
796   // So, to handle this, whenever an endpoint initially enters
797   // TRANSIENT_FAILURE state (i.e., its initial connection attempt has
798   // failed), if there are no endpoints currently in CONNECTING state
799   // (i.e., they are still trying their initial connection attempt),
800   // then we will trigger a connection attempt for the first endpoint
801   // that is currently in state IDLE, if any.
802   //
803   // Note that once an endpoint enters TRANSIENT_FAILURE state, it will
804   // stay in that state and automatically retry after appropriate backoff,
805   // never stopping until it establishes a connection.  This means that
806   // if we stay in TRANSIENT_FAILURE for a long period of time, we will
807   // eventually be trying *all* endpoints, which probably isn't ideal.
808   // But it's no different than what can happen if ring_hash is the root
809   // LB policy and we keep getting picks, so it's not really a new
810   // problem.  If/when it becomes an issue, we can figure out how to
811   // address it.
812   //
813   // Note that we do the same thing when the policy is in state
814   // CONNECTING, just to ensure that we don't remain in CONNECTING state
815   // indefinitely if there are no new picks coming in.
816   if (start_connection_attempt && entered_transient_failure) {
817     size_t first_idle_index = endpoints_.size();
818     for (size_t i = 0; i < endpoints_.size(); ++i) {
819       auto it =
820           endpoint_map_.find(EndpointAddressSet(endpoints_[i].addresses()));
821       GPR_ASSERT(it != endpoint_map_.end());
822       if (it->second->connectivity_state() == GRPC_CHANNEL_CONNECTING) {
823         first_idle_index = endpoints_.size();
824         break;
825       }
826       if (first_idle_index == endpoints_.size() &&
827           it->second->connectivity_state() == GRPC_CHANNEL_IDLE) {
828         first_idle_index = i;
829       }
830     }
831     if (first_idle_index != endpoints_.size()) {
832       auto it = endpoint_map_.find(
833           EndpointAddressSet(endpoints_[first_idle_index].addresses()));
834       GPR_ASSERT(it != endpoint_map_.end());
835       if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_ring_hash_trace)) {
836         gpr_log(GPR_INFO,
837                 "[RH %p] triggering internal connection attempt for endpoint "
838                 "%p (%s) (index %" PRIuPTR " of %" PRIuPTR ")",
839                 this, it->second.get(),
840                 endpoints_[first_idle_index].ToString().c_str(),
841                 first_idle_index, endpoints_.size());
842       }
843       it->second->RequestConnectionLocked();
844     }
845   }
846 }
847 
848 //
849 // factory
850 //
851 
852 class RingHashFactory final : public LoadBalancingPolicyFactory {
853  public:
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const854   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
855       LoadBalancingPolicy::Args args) const override {
856     return MakeOrphanable<RingHash>(std::move(args));
857   }
858 
name() const859   absl::string_view name() const override { return kRingHash; }
860 
861   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json & json) const862   ParseLoadBalancingConfig(const Json& json) const override {
863     auto config = LoadFromJson<RingHashConfig>(
864         json, JsonArgs(), "errors validating ring_hash LB policy config");
865     if (!config.ok()) return config.status();
866     return MakeRefCounted<RingHashLbConfig>(config->min_ring_size,
867                                             config->max_ring_size);
868   }
869 };
870 
871 }  // namespace
872 
RegisterRingHashLbPolicy(CoreConfiguration::Builder * builder)873 void RegisterRingHashLbPolicy(CoreConfiguration::Builder* builder) {
874   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
875       std::make_unique<RingHashFactory>());
876 }
877 
878 }  // namespace grpc_core
879