xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/cancellation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/cancellation.h"
17 
18 #include <forward_list>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/status.h"
24 
25 namespace tensorflow {
26 
27 const CancellationToken CancellationManager::kInvalidToken = -1;
28 
CancellationManager()29 CancellationManager::CancellationManager()
30     : is_cancelling_(false),
31       is_cancelled_(false),
32       next_cancellation_token_(0) {}
33 
CancellationManager(CancellationManager * parent)34 CancellationManager::CancellationManager(CancellationManager* parent)
35     : is_cancelling_(false), next_cancellation_token_(0), parent_(parent) {
36   is_cancelled_ = parent->RegisterChild(this);
37 }
38 
StartCancel()39 void CancellationManager::StartCancel() {
40   // An "OK" status will not be logged by a callback registered by
41   // RegisterCallbackWithErrorLogging.
42   StartCancelWithStatus(OkStatus());
43 }
44 
StartCancelWithStatus(const Status & status)45 void CancellationManager::StartCancelWithStatus(const Status& status) {
46   gtl::FlatMap<CancellationToken, CallbackConfiguration> callbacks_to_run;
47   std::forward_list<CancellationManager*> children_to_cancel;
48   Notification* cancelled_notification = nullptr;
49   {
50     mutex_lock l(mu_);
51     if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
52       return;
53     }
54     is_cancelling_ = true;
55     if (state_) {
56       std::swap(state_->callbacks, callbacks_to_run);
57 
58       // Remove all children from the list of children.
59       CancellationManager* child = state_->first_child;
60       while (child != nullptr) {
61         children_to_cancel.push_front(child);
62         child->is_removed_from_parent_ = true;
63         child = child->next_sibling_;
64       }
65       state_->first_child = nullptr;
66 
67       cancelled_notification = &state_->cancelled_notification;
68     }
69   }
70   // We call these callbacks without holding mu_, so that concurrent
71   // calls to DeregisterCallback, which can happen asynchronously, do
72   // not block. The callbacks remain valid because any concurrent call
73   // to DeregisterCallback will block until the
74   // cancelled_notification_ is notified.
75   for (auto key_and_value : callbacks_to_run) {
76     CallbackConfiguration& config = key_and_value.second;
77     if (!status.ok() && config.log_error) {
78       LOG(WARNING) << "Cancellation callback \"" << config.name
79                    << "\" is triggered due to a "
80                    << (StatusGroup::IsDerived(status) ? "derived" : "root")
81                    << " error: " << status.ToString();
82     }
83     config.callback();
84   }
85   for (CancellationManager* child : children_to_cancel) {
86     child->StartCancelWithStatus(status);
87   }
88   {
89     mutex_lock l(mu_);
90     is_cancelling_ = false;
91     is_cancelled_.store(true, std::memory_order_release);
92   }
93   if (cancelled_notification) {
94     cancelled_notification->Notify();
95   }
96 }
97 
RegisterCallback(CancellationToken token,CancelCallback callback)98 bool CancellationManager::RegisterCallback(CancellationToken token,
99                                            CancelCallback callback) {
100   return RegisterCallbackConfig(
101       token, CallbackConfiguration{callback, "", false});
102 }
103 
RegisterCallbackWithErrorLogging(CancellationToken token,CancelCallback callback,tensorflow::StringPiece callback_name)104 bool CancellationManager::RegisterCallbackWithErrorLogging(
105     CancellationToken token, CancelCallback callback,
106     tensorflow::StringPiece callback_name) {
107   return RegisterCallbackConfig(
108       token, CallbackConfiguration{callback, std::string(callback_name), true});
109 }
110 
RegisterCallbackConfig(CancellationToken token,CallbackConfiguration config)111 bool CancellationManager::RegisterCallbackConfig(CancellationToken token,
112                                                  CallbackConfiguration config) {
113   DCHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
114   mutex_lock l(mu_);
115   bool should_register = !is_cancelled_ && !is_cancelling_;
116   if (should_register) {
117     if (!state_) {
118       state_ = absl::make_unique<State>();
119     }
120     std::swap(state_->callbacks[token], config);
121   }
122   return should_register;
123 }
124 
DeregisterCallback(CancellationToken token)125 bool CancellationManager::DeregisterCallback(CancellationToken token) {
126   mu_.lock();
127   if (is_cancelled_) {
128     mu_.unlock();
129     return false;
130   } else if (is_cancelling_) {
131     Notification* cancelled_notification =
132         state_ ? &state_->cancelled_notification : nullptr;
133     mu_.unlock();
134     // Wait for all of the cancellation callbacks to be called. This
135     // wait ensures that the caller of DeregisterCallback does not
136     // return immediately and free objects that may be used in the
137     // execution of any currently pending callbacks in StartCancel.
138     if (cancelled_notification) {
139       cancelled_notification->WaitForNotification();
140     }
141     return false;
142   } else {
143     if (state_) {
144       state_->callbacks.erase(token);
145     }
146     mu_.unlock();
147     return true;
148   }
149 }
150 
RegisterChild(CancellationManager * child)151 bool CancellationManager::RegisterChild(CancellationManager* child) {
152   mutex_lock l(mu_);
153   if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
154     child->is_removed_from_parent_ = true;
155     return true;
156   }
157 
158   if (!state_) {
159     state_ = absl::make_unique<State>();
160   }
161 
162   // Push `child` onto the front of the list of children.
163   CancellationManager* current_head = state_->first_child;
164   state_->first_child = child;
165   child->prev_sibling_ = nullptr;
166   child->next_sibling_ = current_head;
167   if (current_head) {
168     current_head->prev_sibling_ = child;
169   }
170 
171   return false;
172 }
173 
DeregisterChild(CancellationManager * child)174 void CancellationManager::DeregisterChild(CancellationManager* child) {
175   DCHECK_EQ(child->parent_, this);
176   Notification* cancelled_notification = nullptr;
177   {
178     mutex_lock l(mu_);
179     if (!child->is_removed_from_parent_) {
180       // Remove the child from this manager's list of children.
181       DCHECK(state_);
182 
183       if (child->prev_sibling_ == nullptr) {
184         // The child was at the head of the list.
185         DCHECK_EQ(state_->first_child, child);
186         state_->first_child = child->next_sibling_;
187       } else {
188         child->prev_sibling_->next_sibling_ = child->next_sibling_;
189       }
190 
191       if (child->next_sibling_ != nullptr) {
192         child->next_sibling_->prev_sibling_ = child->prev_sibling_;
193       }
194 
195       child->is_removed_from_parent_ = true;
196     }
197     if (is_cancelling_) {
198       cancelled_notification = &state_->cancelled_notification;
199     }
200   }
201 
202   // Wait for an ongoing call to StartCancel() to finish. This wait ensures that
203   // the caller of DeregisterChild does not return immediately and free a child
204   // that may currently be being cancelled by StartCancel().
205   if (cancelled_notification) {
206     cancelled_notification->WaitForNotification();
207   }
208 }
209 
TryDeregisterCallback(CancellationToken token)210 bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
211   mutex_lock lock(mu_);
212   if (is_cancelled_ || is_cancelling_) {
213     return false;
214   } else {
215     if (state_) {
216       state_->callbacks.erase(token);
217     }
218     return true;
219   }
220 }
221 
~CancellationManager()222 CancellationManager::~CancellationManager() {
223   if (parent_) {
224     parent_->DeregisterChild(this);
225   }
226   if (state_) {
227     StartCancel();
228   }
229 }
230 
IsCancelling()231 bool CancellationManager::IsCancelling() {
232   mutex_lock lock(mu_);
233   return is_cancelling_;
234 }
235 
RegisterCancellationCallback(CancellationManager * cancellation_manager,CancelCallback callback,std::function<void ()> * deregister_fn)236 Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
237                                     CancelCallback callback,
238                                     std::function<void()>* deregister_fn) {
239   if (cancellation_manager) {
240     CancellationToken token = cancellation_manager->get_cancellation_token();
241     if (!cancellation_manager->RegisterCallback(token, std::move(callback))) {
242       return errors::Cancelled("Operation was cancelled");
243     }
244     *deregister_fn = [cancellation_manager, token]() {
245       cancellation_manager->DeregisterCallback(token);
246     };
247   } else {
248     VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
249                "not be registered.";
250     *deregister_fn = []() {};
251   }
252   return OkStatus();
253 }
254 
255 }  // end namespace tensorflow
256