1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ 16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/compiler/xla/pjrt/distributed/client.h" 22 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h" 23 #include "tensorflow/core/distributed_runtime/preemption/preemption_notifier.h" 24 #include "tensorflow/core/platform/status.h" 25 26 namespace tensorflow { 27 28 // Enables multiple tasks to coordinate on a safe sync point if any of the tasks 29 // receive a preemption notice. Example: tasks agree on a safe checkpointing 30 // step after a preemption notice so that training can resume with minimal 31 // disruption after the preemption. 32 // Note: the sync point can only be set once whenever the first preemption 33 // occurs. 34 // TODO(b/230630494): Add Reset() to allow multiple sync points to be set. 35 class PreemptionSyncManager { 36 public: 37 virtual ~PreemptionSyncManager() = default; 38 39 virtual Status Initialize(xla::DistributedRuntimeClient* client) = 0; 40 virtual Status Initialize(CoordinationServiceAgent* agent) = 0; 41 virtual Status Initialize(CoordinationServiceAgent* agent, 42 const std::string& preemption_notifier_type) = 0; 43 virtual Status Initialize(CoordinationServiceAgent* agent, 44 std::unique_ptr<PreemptionNotifier> notifier) = 0; 45 46 // Check if the synchronized point has been reached. When a task has been 47 // preempted, a safe sync point will be determined by using the fastest task's 48 // next possible sync point, which is then propagated to all tasks via this 49 // method. 50 // Notes: 51 // 1) This must be called during every possible sync point so that the library 52 // is aware of each task's progress. 53 // 2) This assumes that each task begins from the same point. 54 // Internally, it updates a counter to track the last `step_counter` passed 55 // in as argument to record each task's current progress. 56 // Example use case: this can be called during every training step for every 57 // task. Once a preemption notice is received, all tasks will agree on a safe 58 // step to pause training and handle the preemption (e.g. save checkpoint and 59 // exit, or wait for preempted task to restart, then resume training). 60 virtual bool ReachedSyncPoint(int step_counter) = 0; 61 }; 62 63 std::unique_ptr<PreemptionSyncManager> CreatePreemptionSyncManager(); 64 65 } // namespace tensorflow 66 67 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ 68