1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_join.h"
26 #include "absl/time/time.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/statusor.h"
30 
31 namespace tensorflow {
32 
33 // Static registration for preemption notifiers.
34 #define REGISTER_PREEMPTION_NOTIFIER(notifier_type_name, factory_fn)        \
35   REGISTER_PREEMPTION_NOTIFIER_UNIQ_HELPER(__COUNTER__, notifier_type_name, \
36                                            factory_fn)
37 #define REGISTER_PREEMPTION_NOTIFIER_UNIQ_HELPER(counter, notifier_type_name, \
38                                                  factory_fn)                  \
39   static bool static_preemption_notifier_##counter TF_ATTRIBUTE_UNUSED =      \
40       []() {                                                                  \
41         ::tensorflow::PreemptionNotifier::RegisterPreemptionNotifier(         \
42             notifier_type_name, factory_fn);                                  \
43         return true;                                                          \
44       }()
45 
46 // Base class for listening and propagating task preemption notices.
47 //
48 // This class provides common mechanism to block on waiting for preemption
49 // signals, or register callbacks that will be triggered upon preemption.
50 //
51 // Example:
52 //
53 //    // Monitors the SIGTERM preemption signal
54 //    notifier = PreemptionNotifier::CreatePreemptionNotifier("sigterm", env);
55 //
56 //    // Register callback that will be invoked once preempted
57 //    notifier->WillBePreemptedAtAsync(
58 //      [](StatusOr<absl::Time> status_or_time) {
59 //        if (status_or_time.ok()) {
60 //          LOG(INFO) << "Preempted at time: " << status_or_time.value();
61 //        } else {
62 //          LOG(ERROR) << "Received error: " << status_or_time.status();
63 //        }
64 //      });
65 //
66 //    // Block current thread until preemption
67 //    absl::Time preempt_time = notifier->WillBePreemptedAt().ValueOrDie();
68 //
69 // Users can extend this class to support custom preemption signals, by subclass
70 // `PreemptionNotifier` with a custom constructor, register its creator (factory
71 // function) with `REGISTER_PREEMPTION_NOTIFIER`. The custom constructor should
72 // set up the communication with the cluster scheduler, and invoke the
73 // `NotifyRegisteredListeners` method once a preemption signal is received.
74 // See `SigtermNotifier` as an example.
75 
76 class PreemptionNotifier {
77  public:
78   typedef std::function<void(StatusOr<absl::Time>)> PreemptTimeCallback;
79   using PreemptionNotifierFactory =
80       std::function<std::unique_ptr<PreemptionNotifier>(Env* env)>;
81 
PreemptionNotifier(Env * env)82   explicit PreemptionNotifier(Env* env) : env_(env) {}
83   virtual ~PreemptionNotifier() = default;
84 
RegisterPreemptionNotifier(const std::string & notifier_type_name,PreemptionNotifierFactory factory_fn)85   static void RegisterPreemptionNotifier(const std::string& notifier_type_name,
86                                          PreemptionNotifierFactory factory_fn) {
87     GetPreemptionNotifierFactories()->emplace(notifier_type_name,
88                                               std::move(factory_fn));
89   }
90 
CreatePreemptionNotifier(const std::string & notifier_type,Env * env)91   static std::unique_ptr<PreemptionNotifier> CreatePreemptionNotifier(
92       const std::string& notifier_type, Env* env) {
93     const auto* factories = GetPreemptionNotifierFactories();
94     auto it = factories->find(notifier_type);
95     if (it == factories->end()) {
96       std::vector<std::string> registered_types;
97       registered_types.reserve(factories->size());
98       for (auto& kv : *factories) {
99         registered_types.push_back(kv.first);
100       }
101       LOG(ERROR) << "No preemption notifier factory found for notifier type "
102                  << notifier_type
103                  << ". All registered preemption notifier types are: "
104                  << absl::StrJoin(registered_types, ", ")
105                  << ". Make sure the library is loaded to the program.";
106       return nullptr;
107     }
108     return it->second(env);
109   }
110 
111   // This is a blocking call that returns a death time when preemption /
112   // termination will occur once the listener receives the preemption
113   // notification. If no death time is specified, absl::Now() is returned.
114   // Returns error::Cancelled if UnregisterListeners() is called.
115   StatusOr<absl::Time> WillBePreemptedAt();
116 
117   // Registers a callback that takes the death time as input once the listener
118   // receives the preemption notification.
119   // If no death time is specified, absl::Now() is specified as input.
120   // Note: callback should be kept as simple and fast as possible (e.g. simply
121   // retrieve result). It should not wait for work done by another callback, and
122   // invoke ahy PreemptionNotifier method (e.g. Reset(), destructor).
123   void WillBePreemptedAtAsync(PreemptTimeCallback callback);
124 
125  protected:
GetEnv()126   Env* GetEnv() { return env_; }
127   // Invokes all pending callbacks upon receipt of preemption notice with death
128   // time or errors (e.g. cancellation during shutdown).
129   void NotifyRegisteredListeners(StatusOr<absl::Time> death_time);
130 
131  private:
132   static std::unordered_map<std::string, PreemptionNotifierFactory>*
GetPreemptionNotifierFactories()133   GetPreemptionNotifierFactories() {
134     static auto* preemption_notifier_factories =
135         new std::unordered_map<std::string, PreemptionNotifierFactory>();
136     return preemption_notifier_factories;
137   }
138 
139   Env* env_;  // Not owned.
140   mutex mu_;
141   absl::Time death_time_ TF_GUARDED_BY(mu_) = absl::InfinitePast();
142   std::vector<PreemptTimeCallback> callbacks_ TF_GUARDED_BY(mu_);
143 };
144 
145 }  // namespace tensorflow
146 
147 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_NOTIFIER_H_
148