1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/strings/str_cat.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
27 #include "tensorflow/core/distributed_runtime/call_options.h"
28 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
29 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/protobuf/coordination_service.pb.h"
34 
35 
36 namespace tensorflow {
37 namespace {
38 constexpr int64_t kPreemptionSyncUnsetCounter = -1;
39 constexpr char kPreemptionNoticeKey[] = "RECEIVED_PREEMPTION_NOTICE";
40 constexpr char kPreemptionCounterDirKey[] = "PREEMPTION_CURRENT_COUNTER/";
41 constexpr char kPreemptionBarrier[] = "PREEMPTION_SYNC_BARRIER";
42 constexpr absl::Duration kPreemptionBarrierTimeout = absl::Minutes(3);
43 
44 // Only start protocol if death time is within `kProtocolDuration`, so that we
45 // don't synchronize too early.
46 // TODO(b/230630494): Make this configurable so that users can extend this to
47 // accommodate higher checkpoint durations.
48 constexpr absl::Duration kProtocolDuration = absl::Minutes(15);
49 
50 class PreemptionSyncManagerImpl : public PreemptionSyncManager {
51  public:
52   PreemptionSyncManagerImpl() = default;
~PreemptionSyncManagerImpl()53   ~PreemptionSyncManagerImpl() override {
54     shutdown_.Notify();
55   }
56   Status Initialize(CoordinationServiceAgent* agent) override;
57   Status Initialize(xla::DistributedRuntimeClient* client) override;
58   Status Initialize(CoordinationServiceAgent* agent,
59                     const std::string& preemption_notifier_type) override;
60   Status Initialize(CoordinationServiceAgent* agent,
61                     std::unique_ptr<PreemptionNotifier> notifier) override;
62   bool ReachedSyncPoint(int step_counter) override;
63 
64  private:
65   // Determine the sync point upon receipt of preemption notice (death time).
66   void ComputeSyncCallCounter(absl::Time death_time);
67   // Notify other tasks to not wait at the barrier if the sync protocol failed
68   // midway.
69   void CancelPreemptionBarrier();
70 
71   mutex mu_;
72   // Tracks the last step_counter passed into ReachedSyncPoint();
73   int64_t call_counter_ TF_GUARDED_BY(mu_) = 0;
74   // If set, determines the sync point.
75   int64_t preemption_sync_counter_ TF_GUARDED_BY(mu_) =
76       kPreemptionSyncUnsetCounter;
77   std::string current_call_counter_key_;
78 
79   Env* env_;                         // Not owned;
80   CoordinationServiceAgent* agent_;  // Not owned.
81   absl::Notification shutdown_;
82   std::unique_ptr<Thread> sync_protocol_thread_;
83   std::unique_ptr<PreemptionNotifier> preemption_notifier_;
84   std::shared_ptr<CallOptions> call_opts_;
85 };
86 
Initialize(xla::DistributedRuntimeClient * client)87 Status PreemptionSyncManagerImpl::Initialize(
88     xla::DistributedRuntimeClient* client) {
89   TF_ASSIGN_OR_RETURN(CoordinationServiceAgent * coord_agent,
90                       client->GetCoordinationServiceAgent());
91   return Initialize(coord_agent);
92 }
93 
Initialize(CoordinationServiceAgent * agent)94 Status PreemptionSyncManagerImpl::Initialize(CoordinationServiceAgent* agent) {
95   return Initialize(agent, "sigterm");
96 }
97 
Initialize(CoordinationServiceAgent * agent,const std::string & preemption_notifier_type)98 Status PreemptionSyncManagerImpl::Initialize(
99     CoordinationServiceAgent* agent,
100     const std::string& preemption_notifier_type) {
101   TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv());
102   return Initialize(agent, PreemptionNotifier::CreatePreemptionNotifier(
103                                preemption_notifier_type, env));
104 }
105 
Initialize(CoordinationServiceAgent * agent,std::unique_ptr<PreemptionNotifier> notifier)106 Status PreemptionSyncManagerImpl::Initialize(
107     CoordinationServiceAgent* agent,
108     std::unique_ptr<PreemptionNotifier> notifier) {
109   TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv());
110   env_ = env;
111   agent_ = agent;
112   preemption_notifier_ = std::move(notifier);
113   TF_ASSIGN_OR_RETURN(CoordinatedTask own_task, agent->GetOwnTask());
114   const std::string task_name =
115       absl::StrCat("/job:", own_task.job_name(), "/task:", own_task.task_id());
116   current_call_counter_key_ = absl::StrCat(kPreemptionCounterDirKey, task_name);
117 
118   /* Listen for preemption notice within this task, then notify coordination
119    * service when death time is within kProtocolDuration.
120    */
121   preemption_notifier_->WillBePreemptedAtAsync(
122       [agent = agent_, task_name](StatusOr<absl::Time> death_time) {
123         if (!death_time.ok()) {
124           // The preemption notifier invokes callback with Cancelled error when
125           // its being destructed.
126           if (errors::IsCancelled(death_time.status())) {
127             LOG(INFO) << "Preemption sync protocol cancelled by notifier: "
128                       << death_time.status();
129           } else {
130             LOG(ERROR) << "Error from preemption notifier: "
131                        << death_time.status();
132           }
133           return;
134         }
135         // Notify coordination service about preemption notice.
136         const Status s = agent->InsertKeyValue(kPreemptionNoticeKey,
137                                                absl::FormatTime(*death_time));
138         LOG(INFO) << "Notified coordination service that this task will "
139                      "be preempted at "
140                   << *death_time << ". Status: " << s;
141       });
142 
143   /* Listen for preemption notice (death time) from coordination service, which
144    * triggers the sync protocol.
145    */
146   call_opts_ = agent_->GetKeyValueAsync(
147       kPreemptionNoticeKey,
148       [this, agent = agent_](StatusOr<std::string> status_or_death_time) {
149         if (errors::IsCancelled(status_or_death_time.status())) {
150           // The agent cancels pending GetKeyValue RPCs because of shutdown,
151           // so simply log and return.
152           LOG(INFO) << "Cancelled call to retrive preemption notice.";
153           return;
154         } else if (!status_or_death_time.ok()) {
155           LOG(ERROR) << "Failed to retrieve preemption notice from "
156                         "coordination service: "
157                      << status_or_death_time.status();
158           // Notify other tasks to not wait at the barrier. Note:
159           // CancelPreemptionBarrier() cannot be used because this may be
160           // triggered after preemption sync manager has been destroyed.
161           agent->CancelBarrierAsync(
162               kPreemptionBarrier, [](const Status& status) {
163                 if (!status.ok()) {
164                   LOG(ERROR)
165                       << "Failed to cancel preemption barrier: " << status;
166                 }
167               });
168           return;
169         }
170         std::string err;
171         absl::Time death_time;
172         if (absl::ParseTime(absl::RFC3339_full, *status_or_death_time,
173                             &death_time, &err)) {
174           LOG(INFO) << "Received preemption notice with death_time "
175                     << death_time;
176         } else {
177           LOG(ERROR) << "Unable to parse preemption notice's death time: "
178                      << err;
179           CancelPreemptionBarrier();
180           return;
181         }
182 
183         LOG(INFO) << "Received preemption notice with death time: "
184                   << death_time;
185 
186         // Trigger protocol in a separate thread: compute max call counter.
187         sync_protocol_thread_ = absl::WrapUnique(env_->StartThread(
188             {}, "PreemptionSyncManager_SyncProtocol",
189             std::bind(&PreemptionSyncManagerImpl::ComputeSyncCallCounter, this,
190                       death_time)));
191       });
192 
193   return Status::OK();
194 }
195 
ComputeSyncCallCounter(absl::Time death_time)196 void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) {
197   // 1. If death time is in the distant future, sleep until there's
198   // `kProtocolDuration` left until death time before we begin the protocol.
199   const absl::Duration remaining_time = death_time - absl::Now();
200   if (remaining_time > kProtocolDuration) {
201     LOG(INFO) << "Will begin preemption sync protocol in " << remaining_time;
202     const absl::Duration sleep_time = remaining_time - kProtocolDuration;
203 
204     if (shutdown_.WaitForNotificationWithTimeout(sleep_time)) {
205       // If shutdown is triggered midway, exit thread immediately.
206       LOG(WARNING)
207           << "Shutdown is triggered before preemption sync protocol has begun.";
208       CancelPreemptionBarrier();
209       return;
210     }
211   }
212 
213   // 2. Send coordination service the task's current call counter. Hold the lock
214   // to prevent updates to `call_counter_` until the protocol completes and this
215   // function exits, implying that we have decided on a new
216   // `preemption_sync_counter_` or the protocol failed. This ensures correctness
217   // of the preemption sync protocol.
218   mutex_lock l(mu_);
219   const Status notified_status = agent_->InsertKeyValue(
220       current_call_counter_key_, std::to_string(call_counter_));
221   if (!notified_status.ok()) {
222     LOG(ERROR) << "Preemption sync failed - could not inform service of "
223                   "current call counter: "
224                << notified_status;
225     CancelPreemptionBarrier();
226     return;
227   }
228 
229   // 3. Impose a barrier to wait until everybody sends their current call
230   // counter.
231   const Status barrier_status =
232       agent_->WaitAtBarrier(kPreemptionBarrier, kPreemptionBarrierTimeout, {});
233   if (!barrier_status.ok()) {
234     LOG(ERROR) << "Preemption sync barrier failed: " << barrier_status;
235     return;
236   }
237 
238   // 4. Retrieve every task's current call counter.
239   StatusOr<std::vector<KeyValueEntry>> all_counters =
240       agent_->GetKeyValueDir(kPreemptionCounterDirKey);
241   if (!all_counters.ok()) {
242     LOG(ERROR) << "Preemption sync failed - unable to retrieve call counters : "
243                << all_counters.status();
244     return;
245   }
246 
247   // 5. Compute the fastest task's call counter.
248   // Note: Each task should retrieve the same set of call counters and arrive at
249   // the same maximum. We have to calculate this max within each task because
250   // coordination service does not provide GetMaxKeyValue().
251   int64_t max_counter = kPreemptionSyncUnsetCounter;
252   for (const auto& kv : *all_counters) {
253     int64_t call_counter;
254     if (!absl::SimpleAtoi(kv.value(), &call_counter)) {
255       LOG(ERROR) << "Preemption sync failed - failed to parse preemption call "
256                     "counter: "
257                  << kv.DebugString();
258       return;
259     }
260     max_counter = std::max(max_counter, call_counter);
261   }
262 
263   if (max_counter == kPreemptionSyncUnsetCounter) {
264     LOG(ERROR) << "Preemption sync failed - no call counters found.";
265     return;
266   }
267 
268   // 6. Set sync point to be the next possible call counter of the fastest task.
269   preemption_sync_counter_ = max_counter + 1;
270   LOG(INFO) << "Preemption sync counter is set: " << preemption_sync_counter_;
271 }
272 
CancelPreemptionBarrier()273 void PreemptionSyncManagerImpl::CancelPreemptionBarrier() {
274   agent_->CancelBarrierAsync(kPreemptionBarrier, [](const Status& status) {
275     if (!status.ok()) {
276       LOG(ERROR) << "Failed to cancel preemption barrier: " << status;
277     }
278   });
279 }
280 
ReachedSyncPoint(int step_counter)281 bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) {
282   // Note: if a preemption notice has been received and ComputeSyncCallCounter()
283   // is ongoing , this method will be blocked until it acquires the lock. This
284   // prevents updates to `call_counter_` while `preemption_sync_counter_` is
285   // being computed, which ensures correctness of the preemption sync protocol.
286   mutex_lock l(mu_);
287   // Track current call.
288   call_counter_ = step_counter;
289   VLOG(3) << "Current call counter: " << call_counter_
290           << ", Preemption sync point: " << preemption_sync_counter_;
291 
292   // Check if we have reached the sync point.
293   return preemption_sync_counter_ == call_counter_;
294 }
295 }  // namespace
CreatePreemptionSyncManager()296 std::unique_ptr<PreemptionSyncManager> CreatePreemptionSyncManager() {
297   return std::make_unique<PreemptionSyncManagerImpl>();
298 }
299 }  // namespace tensorflow
300