1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/strings/str_join.h" 26 #include "absl/time/time.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/platform/mutex.h" 29 #include "tensorflow/core/platform/statusor.h" 30 31 namespace tensorflow { 32 33 // Static registration for preemption notifiers. 34 #define REGISTER_PREEMPTION_NOTIFIER(notifier_type_name, factory_fn) \ 35 REGISTER_PREEMPTION_NOTIFIER_UNIQ_HELPER(__COUNTER__, notifier_type_name, \ 36 factory_fn) 37 #define REGISTER_PREEMPTION_NOTIFIER_UNIQ_HELPER(counter, notifier_type_name, \ 38 factory_fn) \ 39 static bool static_preemption_notifier_##counter TF_ATTRIBUTE_UNUSED = \ 40 []() { \ 41 ::tensorflow::PreemptionNotifier::RegisterPreemptionNotifier( \ 42 notifier_type_name, factory_fn); \ 43 return true; \ 44 }() 45 46 // Base class for listening and propagating task preemption notices. 47 // 48 // This class provides common mechanism to block on waiting for preemption 49 // signals, or register callbacks that will be triggered upon preemption. 50 // 51 // Example: 52 // 53 // // Monitors the SIGTERM preemption signal 54 // notifier = PreemptionNotifier::CreatePreemptionNotifier("sigterm", env); 55 // 56 // // Register callback that will be invoked once preempted 57 // notifier->WillBePreemptedAtAsync( 58 // [](StatusOr<absl::Time> status_or_time) { 59 // if (status_or_time.ok()) { 60 // LOG(INFO) << "Preempted at time: " << status_or_time.value(); 61 // } else { 62 // LOG(ERROR) << "Received error: " << status_or_time.status(); 63 // } 64 // }); 65 // 66 // // Block current thread until preemption 67 // absl::Time preempt_time = notifier->WillBePreemptedAt().ValueOrDie(); 68 // 69 // Users can extend this class to support custom preemption signals, by subclass 70 // `PreemptionNotifier` with a custom constructor, register its creator (factory 71 // function) with `REGISTER_PREEMPTION_NOTIFIER`. The custom constructor should 72 // set up the communication with the cluster scheduler, and invoke the 73 // `NotifyRegisteredListeners` method once a preemption signal is received. 74 // See `SigtermNotifier` as an example. 75 76 class PreemptionNotifier { 77 public: 78 typedef std::function<void(StatusOr<absl::Time>)> PreemptTimeCallback; 79 using PreemptionNotifierFactory = 80 std::function<std::unique_ptr<PreemptionNotifier>(Env* env)>; 81 PreemptionNotifier(Env * env)82 explicit PreemptionNotifier(Env* env) : env_(env) {} 83 virtual ~PreemptionNotifier() = default; 84 RegisterPreemptionNotifier(const std::string & notifier_type_name,PreemptionNotifierFactory factory_fn)85 static void RegisterPreemptionNotifier(const std::string& notifier_type_name, 86 PreemptionNotifierFactory factory_fn) { 87 GetPreemptionNotifierFactories()->emplace(notifier_type_name, 88 std::move(factory_fn)); 89 } 90 CreatePreemptionNotifier(const std::string & notifier_type,Env * env)91 static std::unique_ptr<PreemptionNotifier> CreatePreemptionNotifier( 92 const std::string& notifier_type, Env* env) { 93 const auto* factories = GetPreemptionNotifierFactories(); 94 auto it = factories->find(notifier_type); 95 if (it == factories->end()) { 96 std::vector<std::string> registered_types; 97 registered_types.reserve(factories->size()); 98 for (auto& kv : *factories) { 99 registered_types.push_back(kv.first); 100 } 101 LOG(ERROR) << "No preemption notifier factory found for notifier type " 102 << notifier_type 103 << ". All registered preemption notifier types are: " 104 << absl::StrJoin(registered_types, ", ") 105 << ". Make sure the library is loaded to the program."; 106 return nullptr; 107 } 108 return it->second(env); 109 } 110 111 // This is a blocking call that returns a death time when preemption / 112 // termination will occur once the listener receives the preemption 113 // notification. If no death time is specified, absl::Now() is returned. 114 // Returns error::Cancelled if UnregisterListeners() is called. 115 StatusOr<absl::Time> WillBePreemptedAt(); 116 117 // Registers a callback that takes the death time as input once the listener 118 // receives the preemption notification. 119 // If no death time is specified, absl::Now() is specified as input. 120 // Note: callback should be kept as simple and fast as possible (e.g. simply 121 // retrieve result). It should not wait for work done by another callback, and 122 // invoke ahy PreemptionNotifier method (e.g. Reset(), destructor). 123 void WillBePreemptedAtAsync(PreemptTimeCallback callback); 124 125 protected: GetEnv()126 Env* GetEnv() { return env_; } 127 // Invokes all pending callbacks upon receipt of preemption notice with death 128 // time or errors (e.g. cancellation during shutdown). 129 void NotifyRegisteredListeners(StatusOr<absl::Time> death_time); 130 131 private: 132 static std::unordered_map<std::string, PreemptionNotifierFactory>* GetPreemptionNotifierFactories()133 GetPreemptionNotifierFactories() { 134 static auto* preemption_notifier_factories = 135 new std::unordered_map<std::string, PreemptionNotifierFactory>(); 136 return preemption_notifier_factories; 137 } 138 139 Env* env_; // Not owned. 140 mutex mu_; 141 absl::Time death_time_ TF_GUARDED_BY(mu_) = absl::InfinitePast(); 142 std::vector<PreemptTimeCallback> callbacks_ TF_GUARDED_BY(mu_); 143 }; 144 145 } // namespace tensorflow 146 147 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_ 148