1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/strings/str_cat.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
27 #include "tensorflow/core/distributed_runtime/call_options.h"
28 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
29 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/protobuf/coordination_service.pb.h"
34
35
36 namespace tensorflow {
37 namespace {
38 constexpr int64_t kPreemptionSyncUnsetCounter = -1;
39 constexpr char kPreemptionNoticeKey[] = "RECEIVED_PREEMPTION_NOTICE";
40 constexpr char kPreemptionCounterDirKey[] = "PREEMPTION_CURRENT_COUNTER/";
41 constexpr char kPreemptionBarrier[] = "PREEMPTION_SYNC_BARRIER";
42 constexpr absl::Duration kPreemptionBarrierTimeout = absl::Minutes(3);
43
44 // Only start protocol if death time is within `kProtocolDuration`, so that we
45 // don't synchronize too early.
46 // TODO(b/230630494): Make this configurable so that users can extend this to
47 // accommodate higher checkpoint durations.
48 constexpr absl::Duration kProtocolDuration = absl::Minutes(15);
49
50 class PreemptionSyncManagerImpl : public PreemptionSyncManager {
51 public:
52 PreemptionSyncManagerImpl() = default;
~PreemptionSyncManagerImpl()53 ~PreemptionSyncManagerImpl() override {
54 shutdown_.Notify();
55 }
56 Status Initialize(CoordinationServiceAgent* agent) override;
57 Status Initialize(xla::DistributedRuntimeClient* client) override;
58 Status Initialize(CoordinationServiceAgent* agent,
59 const std::string& preemption_notifier_type) override;
60 Status Initialize(CoordinationServiceAgent* agent,
61 std::unique_ptr<PreemptionNotifier> notifier) override;
62 bool ReachedSyncPoint(int step_counter) override;
63
64 private:
65 // Determine the sync point upon receipt of preemption notice (death time).
66 void ComputeSyncCallCounter(absl::Time death_time);
67 // Notify other tasks to not wait at the barrier if the sync protocol failed
68 // midway.
69 void CancelPreemptionBarrier();
70
71 mutex mu_;
72 // Tracks the last step_counter passed into ReachedSyncPoint();
73 int64_t call_counter_ TF_GUARDED_BY(mu_) = 0;
74 // If set, determines the sync point.
75 int64_t preemption_sync_counter_ TF_GUARDED_BY(mu_) =
76 kPreemptionSyncUnsetCounter;
77 std::string current_call_counter_key_;
78
79 Env* env_; // Not owned;
80 CoordinationServiceAgent* agent_; // Not owned.
81 absl::Notification shutdown_;
82 std::unique_ptr<Thread> sync_protocol_thread_;
83 std::unique_ptr<PreemptionNotifier> preemption_notifier_;
84 std::shared_ptr<CallOptions> call_opts_;
85 };
86
Initialize(xla::DistributedRuntimeClient * client)87 Status PreemptionSyncManagerImpl::Initialize(
88 xla::DistributedRuntimeClient* client) {
89 TF_ASSIGN_OR_RETURN(CoordinationServiceAgent * coord_agent,
90 client->GetCoordinationServiceAgent());
91 return Initialize(coord_agent);
92 }
93
Initialize(CoordinationServiceAgent * agent)94 Status PreemptionSyncManagerImpl::Initialize(CoordinationServiceAgent* agent) {
95 return Initialize(agent, "sigterm");
96 }
97
Initialize(CoordinationServiceAgent * agent,const std::string & preemption_notifier_type)98 Status PreemptionSyncManagerImpl::Initialize(
99 CoordinationServiceAgent* agent,
100 const std::string& preemption_notifier_type) {
101 TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv());
102 return Initialize(agent, PreemptionNotifier::CreatePreemptionNotifier(
103 preemption_notifier_type, env));
104 }
105
Initialize(CoordinationServiceAgent * agent,std::unique_ptr<PreemptionNotifier> notifier)106 Status PreemptionSyncManagerImpl::Initialize(
107 CoordinationServiceAgent* agent,
108 std::unique_ptr<PreemptionNotifier> notifier) {
109 TF_ASSIGN_OR_RETURN(Env * env, agent->GetEnv());
110 env_ = env;
111 agent_ = agent;
112 preemption_notifier_ = std::move(notifier);
113 TF_ASSIGN_OR_RETURN(CoordinatedTask own_task, agent->GetOwnTask());
114 const std::string task_name =
115 absl::StrCat("/job:", own_task.job_name(), "/task:", own_task.task_id());
116 current_call_counter_key_ = absl::StrCat(kPreemptionCounterDirKey, task_name);
117
118 /* Listen for preemption notice within this task, then notify coordination
119 * service when death time is within kProtocolDuration.
120 */
121 preemption_notifier_->WillBePreemptedAtAsync(
122 [agent = agent_, task_name](StatusOr<absl::Time> death_time) {
123 if (!death_time.ok()) {
124 // The preemption notifier invokes callback with Cancelled error when
125 // its being destructed.
126 if (errors::IsCancelled(death_time.status())) {
127 LOG(INFO) << "Preemption sync protocol cancelled by notifier: "
128 << death_time.status();
129 } else {
130 LOG(ERROR) << "Error from preemption notifier: "
131 << death_time.status();
132 }
133 return;
134 }
135 // Notify coordination service about preemption notice.
136 const Status s = agent->InsertKeyValue(kPreemptionNoticeKey,
137 absl::FormatTime(*death_time));
138 LOG(INFO) << "Notified coordination service that this task will "
139 "be preempted at "
140 << *death_time << ". Status: " << s;
141 });
142
143 /* Listen for preemption notice (death time) from coordination service, which
144 * triggers the sync protocol.
145 */
146 call_opts_ = agent_->GetKeyValueAsync(
147 kPreemptionNoticeKey,
148 [this, agent = agent_](StatusOr<std::string> status_or_death_time) {
149 if (errors::IsCancelled(status_or_death_time.status())) {
150 // The agent cancels pending GetKeyValue RPCs because of shutdown,
151 // so simply log and return.
152 LOG(INFO) << "Cancelled call to retrive preemption notice.";
153 return;
154 } else if (!status_or_death_time.ok()) {
155 LOG(ERROR) << "Failed to retrieve preemption notice from "
156 "coordination service: "
157 << status_or_death_time.status();
158 // Notify other tasks to not wait at the barrier. Note:
159 // CancelPreemptionBarrier() cannot be used because this may be
160 // triggered after preemption sync manager has been destroyed.
161 agent->CancelBarrierAsync(
162 kPreemptionBarrier, [](const Status& status) {
163 if (!status.ok()) {
164 LOG(ERROR)
165 << "Failed to cancel preemption barrier: " << status;
166 }
167 });
168 return;
169 }
170 std::string err;
171 absl::Time death_time;
172 if (absl::ParseTime(absl::RFC3339_full, *status_or_death_time,
173 &death_time, &err)) {
174 LOG(INFO) << "Received preemption notice with death_time "
175 << death_time;
176 } else {
177 LOG(ERROR) << "Unable to parse preemption notice's death time: "
178 << err;
179 CancelPreemptionBarrier();
180 return;
181 }
182
183 LOG(INFO) << "Received preemption notice with death time: "
184 << death_time;
185
186 // Trigger protocol in a separate thread: compute max call counter.
187 sync_protocol_thread_ = absl::WrapUnique(env_->StartThread(
188 {}, "PreemptionSyncManager_SyncProtocol",
189 std::bind(&PreemptionSyncManagerImpl::ComputeSyncCallCounter, this,
190 death_time)));
191 });
192
193 return Status::OK();
194 }
195
ComputeSyncCallCounter(absl::Time death_time)196 void PreemptionSyncManagerImpl::ComputeSyncCallCounter(absl::Time death_time) {
197 // 1. If death time is in the distant future, sleep until there's
198 // `kProtocolDuration` left until death time before we begin the protocol.
199 const absl::Duration remaining_time = death_time - absl::Now();
200 if (remaining_time > kProtocolDuration) {
201 LOG(INFO) << "Will begin preemption sync protocol in " << remaining_time;
202 const absl::Duration sleep_time = remaining_time - kProtocolDuration;
203
204 if (shutdown_.WaitForNotificationWithTimeout(sleep_time)) {
205 // If shutdown is triggered midway, exit thread immediately.
206 LOG(WARNING)
207 << "Shutdown is triggered before preemption sync protocol has begun.";
208 CancelPreemptionBarrier();
209 return;
210 }
211 }
212
213 // 2. Send coordination service the task's current call counter. Hold the lock
214 // to prevent updates to `call_counter_` until the protocol completes and this
215 // function exits, implying that we have decided on a new
216 // `preemption_sync_counter_` or the protocol failed. This ensures correctness
217 // of the preemption sync protocol.
218 mutex_lock l(mu_);
219 const Status notified_status = agent_->InsertKeyValue(
220 current_call_counter_key_, std::to_string(call_counter_));
221 if (!notified_status.ok()) {
222 LOG(ERROR) << "Preemption sync failed - could not inform service of "
223 "current call counter: "
224 << notified_status;
225 CancelPreemptionBarrier();
226 return;
227 }
228
229 // 3. Impose a barrier to wait until everybody sends their current call
230 // counter.
231 const Status barrier_status =
232 agent_->WaitAtBarrier(kPreemptionBarrier, kPreemptionBarrierTimeout, {});
233 if (!barrier_status.ok()) {
234 LOG(ERROR) << "Preemption sync barrier failed: " << barrier_status;
235 return;
236 }
237
238 // 4. Retrieve every task's current call counter.
239 StatusOr<std::vector<KeyValueEntry>> all_counters =
240 agent_->GetKeyValueDir(kPreemptionCounterDirKey);
241 if (!all_counters.ok()) {
242 LOG(ERROR) << "Preemption sync failed - unable to retrieve call counters : "
243 << all_counters.status();
244 return;
245 }
246
247 // 5. Compute the fastest task's call counter.
248 // Note: Each task should retrieve the same set of call counters and arrive at
249 // the same maximum. We have to calculate this max within each task because
250 // coordination service does not provide GetMaxKeyValue().
251 int64_t max_counter = kPreemptionSyncUnsetCounter;
252 for (const auto& kv : *all_counters) {
253 int64_t call_counter;
254 if (!absl::SimpleAtoi(kv.value(), &call_counter)) {
255 LOG(ERROR) << "Preemption sync failed - failed to parse preemption call "
256 "counter: "
257 << kv.DebugString();
258 return;
259 }
260 max_counter = std::max(max_counter, call_counter);
261 }
262
263 if (max_counter == kPreemptionSyncUnsetCounter) {
264 LOG(ERROR) << "Preemption sync failed - no call counters found.";
265 return;
266 }
267
268 // 6. Set sync point to be the next possible call counter of the fastest task.
269 preemption_sync_counter_ = max_counter + 1;
270 LOG(INFO) << "Preemption sync counter is set: " << preemption_sync_counter_;
271 }
272
CancelPreemptionBarrier()273 void PreemptionSyncManagerImpl::CancelPreemptionBarrier() {
274 agent_->CancelBarrierAsync(kPreemptionBarrier, [](const Status& status) {
275 if (!status.ok()) {
276 LOG(ERROR) << "Failed to cancel preemption barrier: " << status;
277 }
278 });
279 }
280
ReachedSyncPoint(int step_counter)281 bool PreemptionSyncManagerImpl::ReachedSyncPoint(int step_counter) {
282 // Note: if a preemption notice has been received and ComputeSyncCallCounter()
283 // is ongoing , this method will be blocked until it acquires the lock. This
284 // prevents updates to `call_counter_` while `preemption_sync_counter_` is
285 // being computed, which ensures correctness of the preemption sync protocol.
286 mutex_lock l(mu_);
287 // Track current call.
288 call_counter_ = step_counter;
289 VLOG(3) << "Current call counter: " << call_counter_
290 << ", Preemption sync point: " << preemption_sync_counter_;
291
292 // Check if we have reached the sync point.
293 return preemption_sync_counter_ == call_counter_;
294 }
295 } // namespace
CreatePreemptionSyncManager()296 std::unique_ptr<PreemptionSyncManager> CreatePreemptionSyncManager() {
297 return std::make_unique<PreemptionSyncManagerImpl>();
298 }
299 } // namespace tensorflow
300