1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/executor.h"
17
18 #include <algorithm>
19 #include <atomic>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/activity_watcher/activity.h"
27 #include "tensorflow/core/common_runtime/costmodel_manager.h"
28 #include "tensorflow/core/common_runtime/entry.h"
29 #include "tensorflow/core/common_runtime/executor_factory.h"
30 #include "tensorflow/core/common_runtime/graph_view.h"
31 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
32 #include "tensorflow/core/common_runtime/pending_counts.h"
33 #include "tensorflow/core/common_runtime/propagator_state.h"
34 #include "tensorflow/core/common_runtime/renamed_device.h"
35 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
36 #include "tensorflow/core/common_runtime/step_stats_collector.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/cancellation.h"
39 #include "tensorflow/core/framework/collective.h"
40 #include "tensorflow/core/framework/control_flow.h"
41 #include "tensorflow/core/framework/device_attributes.pb.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/metrics.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/op_segment.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_reference.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/graph/edgeset.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_node_util.h"
54 #include "tensorflow/core/lib/core/errors.h"
55 #include "tensorflow/core/lib/core/notification.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/lib/core/threadpool.h"
58 #include "tensorflow/core/lib/gtl/flatmap.h"
59 #include "tensorflow/core/lib/gtl/inlined_vector.h"
60 #include "tensorflow/core/lib/gtl/manual_constructor.h"
61 #include "tensorflow/core/lib/hash/hash.h"
62 #include "tensorflow/core/platform/context.h"
63 #include "tensorflow/core/platform/env.h"
64 #include "tensorflow/core/platform/errors.h"
65 #include "tensorflow/core/platform/logging.h"
66 #include "tensorflow/core/platform/macros.h"
67 #include "tensorflow/core/platform/mutex.h"
68 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
69 #include "tensorflow/core/platform/status.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/tracing.h"
72 #include "tensorflow/core/platform/types.h"
73 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
74 #include "tensorflow/core/profiler/lib/connected_traceme.h"
75 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
76 #include "tensorflow/core/profiler/lib/traceme_encode.h"
77 #include "tensorflow/core/protobuf/error_codes.pb.h"
78 #include "tensorflow/core/util/determinism.h"
79 #include "tensorflow/core/util/managed_stack_trace.h"
80 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
81
82 namespace tensorflow {
83
84 namespace {
85
86 // 1-D, 0 element tensor.
87 static const Tensor* const kEmptyTensor = new Tensor;
88
89 // Helper routines for collecting step stats.
90 namespace nodestats {
NowInNsec()91 inline int64_t NowInNsec() { return EnvTime::NowNanos(); }
92
SetScheduled(NodeExecStatsInterface * stats,int64_t micros)93 void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) {
94 if (!stats) return;
95 stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
96 }
97
SetAllStart(NodeExecStatsInterface * stats)98 void SetAllStart(NodeExecStatsInterface* stats) {
99 if (!stats) return;
100 stats->RecordExecutorStarted();
101 }
102
SetOpStart(NodeExecStatsInterface * stats)103 void SetOpStart(NodeExecStatsInterface* stats) {
104 if (!stats) return;
105 stats->RecordComputeStarted();
106 }
107
SetOpEnd(NodeExecStatsInterface * stats)108 void SetOpEnd(NodeExecStatsInterface* stats) {
109 if (!stats) return;
110 stats->RecordComputeEnded();
111 }
112
SetAllEnd(NodeExecStatsInterface * stats)113 void SetAllEnd(NodeExecStatsInterface* stats) {
114 if (!stats) return;
115 stats->RecordExecutorEnded();
116 }
117
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)118 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
119 if (!stats) return;
120 stats->SetOutput(slot, v);
121 }
122
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)123 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
124 if (!stats) return;
125 stats->SetMemory(ctx);
126 }
127
128 } // namespace nodestats
129
130 // Time the execution of kernels (in CPU cycles). Used to dynamically identify
131 // inexpensive kernels which can be dispatched inline.
132 struct KernelTimer {
133 uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
134
ElapsedCyclestensorflow::__anon79fb0f210111::KernelTimer135 uint64 ElapsedCycles() {
136 return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
137 }
138 };
139
140 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
141 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
142 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
143
144 class ExecutorImpl : public Executor {
145 public:
ExecutorImpl(const LocalExecutorParams & p)146 explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
147
Initialize(const Graph & graph)148 Status Initialize(const Graph& graph) {
149 TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
150 kernel_stats_.Initialize(immutable_state_.graph_view());
151 return OkStatus();
152 }
153
154 void RunAsync(const Args& args, DoneCallback done) override;
155
156 private:
157 template <class PropagatorStateType>
158 friend class ExecutorState;
159
160 // Stores execution time information about the kernels in an executor's graph.
161 class KernelStats {
162 public:
163 KernelStats() = default;
164
Initialize(const GraphView & gview)165 void Initialize(const GraphView& gview) {
166 is_expensive_.resize(gview.num_nodes());
167 cost_estimates_ =
168 std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
169 for (int32_t i = 0; i < gview.num_nodes(); ++i) {
170 if (gview.node(i)) {
171 is_expensive_[i] =
172 gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
173 cost_estimates_[i] = kInitialCostEstimateCycles;
174 }
175 }
176 }
177
178 // Returns true iff the given node is considered "expensive". The
179 // executor uses this flag to optimize graph execution, for example
180 // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const181 bool IsExpensive(const NodeItem& node) const {
182 return is_expensive_[node.node_id] &&
183 (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
184 kOpIsExpensiveThresholdCycles);
185 }
186
187 // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const188 bool HasExpensiveMarker(const NodeItem& node) const {
189 return is_expensive_[node.node_id];
190 }
191
192 // Updates the dynamic cost estimate, which is used to determine whether the
193 // given node is expensive. The new cost estimate is a weighted average of
194 // the old cost estimate and the latest cost. We only update cost estimates
195 // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)196 void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
197 // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
198 // updates may result in one or more updates being ignored. This does not
199 // affect correctness but may slow down the update frequency.
200 std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
201 auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
202
203 uint64 new_estimate =
204 ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
205
206 cost_estimate.store(new_estimate, std::memory_order_relaxed);
207 }
208
209 private:
210 // Initial time (in CPU cycles) we expect an operation to take. Used to
211 // determine whether an operation should be place in a threadpool.
212 // Operations start out "expensive".
213 static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
214 static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
215 static constexpr uint64 kCostDecay = 10;
216
217 std::vector<bool> is_expensive_;
218 // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
219 std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
220 };
221
222 ImmutableExecutorState immutable_state_;
223 KernelStats kernel_stats_;
224
225 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
226 };
227
228 // The state associated with one invocation of ExecutorImpl::Run.
229 //
230 // ExecutorState dispatches nodes when they become ready, and delegates to an
231 // instance of `PropagatorStateType` to keep track of how many predecessors of a
232 // are still pending.
233 //
234 // The template argument `class PropagatorStateType` must define the following
235 // public members:
236 // * A type `TaggedNode`, representing a node to be processed, with public
237 // members:
238 // * `const NodeItem& get_node_item() const`
239 // * `bool get_is_dead() const`
240 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
241 // processed, with public members (having the same meanings as in an
242 // `std::vector<TaggedNode>`):
243 // * `void push_back(const TaggedNode& node)`
244 // * `TaggedNode front() const`
245 // * `void pop_front()`
246 // * `bool empty() const`
247 // * A type `TaggedNodeSeq`, representing a list of nodes to be scheduled, with
248 // public members (having the same meanings as in an
249 // `std::vector<TaggedNode>`):
250 // * `size_t size() const`
251 // * `bool empty() const`
252 // * `void clear()`
253 // * `const_iterator begin() const`
254 // * `const_iterator end() const`
255 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
256 // immutable_state, int64 step_id)`.
257 // * The following public methods:
258 // * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
259 // TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
260 // nodes in `roots` and adds them to `*ready`
261 // * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
262 // outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
263 // given `tagged_node` to the destinations of its output edges, and adds
264 // any newly runnable nodes to `*ready`
265 // * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
266 // returns a pointer to the input tensors for the given `tagged_node`
267 // * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
268 // which creates a `FrameAndIter` for the given `tagged_node`
269 // * `void DumpState()`, which dumps the dynamic state of the executing graph
270 // * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
271 // that a node has started
272 // * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
273 // that a node has completed
274 //
275 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
276 // can be used to instantiate `PropagatorStateType`.
277 template <class PropagatorStateType>
278 class ExecutorState {
279 public:
280 ExecutorState(const Executor::Args& args,
281 const ImmutableExecutorState& immutable_state_,
282 ExecutorImpl::KernelStats* kernel_stats_);
283 ~ExecutorState();
284
285 void RunAsync(Executor::DoneCallback done);
286
287 private:
288 // Use `TaggedNode` types defined by `PropagatorStateType`.
289 typedef typename PropagatorStateType::TaggedNode TaggedNode;
290 typedef
291 typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
292 typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
293
294 struct AsyncState;
295
296 // Process a ready node in current thread.
297 void Process(TaggedNode node, int64_t scheduled_nsec);
298
299 void ProcessInline(TaggedNodeReadyQueue* inline_ready,
300 int64_t scheduled_nsec);
301
302 Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
303 EntryVector* outputs, NodeExecStatsInterface* stats);
304 void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
305 const TaggedNode& tagged_node, Entry* first_input,
306 NodeExecStatsInterface* stats,
307 activity_watcher::ActivityId activity_id);
308 void ProcessNoop(NodeExecStatsInterface* stats);
309 void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
310 NodeExecStatsInterface* stats);
311
312 // Before invoking item->kernel, fills in its "inputs".
313 Status PrepareInputs(const NodeItem& item, Entry* first_input,
314 TensorValueVec* inputs,
315 AllocatorAttributeVec* input_alloc_attrs,
316 bool* is_input_dead);
317
318 // After item->kernel computation is done, processes its outputs.
319 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
320 Entry* outputs, NodeExecStatsInterface* stats);
321
322 // Called after each node finishes. Takes ownership of "stats". Returns true
323 // if execution has completed.
324 //
325 // This method will clear `*ready` before returning.
326 bool NodeDone(const Status& s, TaggedNodeSeq* ready,
327 NodeExecStatsInterface* stats,
328 TaggedNodeReadyQueue* inline_ready);
329
330 // Schedule all the expensive nodes in '*ready', and put all the inexpensive
331 // nodes in 'ready' into 'inline_ready'.
332 //
333 // This method will clear `*ready` before returning.
334 //
335 // REQUIRES: `!ready->empty()`.
336 void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
337
338 // A wrapper for runner_ to keep track of the pending queue length. Op
339 // execution should dispatch work using this function instead of using runner_
340 // directly.
341 template <typename Closure>
342 void RunTask(Closure&& c, int sample_rate = 0);
343
344 // Clean up when this executor is done.
345 void Finish();
346 void ScheduleFinish();
347
348 // Contains the device context assigned by the device at the beginning of a
349 // step.
350 DeviceContext* device_context_ = nullptr;
351
352 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
353
354 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
355 const bool log_memory_;
356
357 int64_t step_id_;
358 int64_t start_time_usecs_ = 0;
359 // The deadline for the session to complete by. Empty if unspecified.
360 absl::optional<absl::Time> deadline_;
361
362 // Maximum number of kernels that can be scheduled inline. If lots of kernels
363 // are ready at the same time, scheduling them in one thread can be very slow.
364 // TODO(fishx): Make it configurable if necessary.
365 static constexpr uint64 kInlineScheduleReadyThreshold = 500;
366
367 // Not owned.
368 RendezvousInterface* rendezvous_;
369 CollectiveExecutor* collective_executor_ = nullptr;
370 SessionState* session_state_;
371 string session_handle_;
372 const SessionMetadata* session_metadata_ = nullptr;
373 TensorStore* tensor_store_;
374 // Step-local container.
375 ScopedStepContainer* step_container_;
376 StepStatsCollectorInterface* const stats_collector_;
377 const tracing::EventCollector* const event_collector_;
378 Context context_;
379
380 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
381 // instead of a pointer? (avoids having to delete).
382 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
383 CallFrameInterface* call_frame_;
384 const ImmutableExecutorState& immutable_state_;
385 ExecutorImpl::KernelStats* const kernel_stats_;
386 CancellationManager* cancellation_manager_;
387 CoordinationServiceAgent* coordination_service_agent_;
388 absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt;
389 // If not null, use this device to schedule intra-op operation
390 std::unique_ptr<DeviceBase> user_device_;
391 Executor::Args::Runner runner_;
392 bool sync_on_finish_;
393 const bool run_all_kernels_inline_;
394
395 PropagatorStateType propagator_;
396
397 // Invoked when the execution finishes.
398 Executor::DoneCallback done_cb_;
399
400 std::atomic_int_fast32_t num_outstanding_ops_;
401
402 // Available via OpKernelContext to every OpKernel invocation.
403 mutex num_deferred_ops_mu_;
404 int64_t num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
405 bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
406 false;
407
408 mutex mu_;
409 Status status_ TF_GUARDED_BY(mu_);
410 };
411
412 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)413 ExecutorState<PropagatorStateType>::ExecutorState(
414 const Executor::Args& args, const ImmutableExecutorState& immutable_state,
415 ExecutorImpl::KernelStats* kernel_stats)
416 : vlog_(VLOG_IS_ON(1)),
417 log_memory_(LogMemory::IsEnabled()),
418 step_id_(args.step_id),
419 start_time_usecs_(args.start_time_usecs),
420 deadline_(args.deadline),
421 rendezvous_(args.rendezvous),
422 collective_executor_(args.collective_executor),
423 session_state_(args.session_state),
424 session_handle_(args.session_handle),
425 session_metadata_(immutable_state.params().session_metadata),
426 tensor_store_(args.tensor_store),
427 step_container_(args.step_container),
428 stats_collector_(args.stats_collector),
429 event_collector_(
430 tracing::GetEventCollector(tracing::EventCategory::kCompute)),
431 context_(ContextKind::kThread),
432 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
433 call_frame_(args.call_frame),
434 immutable_state_(immutable_state),
435 kernel_stats_(kernel_stats),
436 cancellation_manager_(args.cancellation_manager),
437 coordination_service_agent_(args.coordination_service_agent),
438 stack_trace_(args.stack_trace),
439 runner_(args.runner),
440 sync_on_finish_(args.sync_on_finish),
441 run_all_kernels_inline_(args.run_all_kernels_inline),
442 propagator_(immutable_state, step_id_, vlog_),
443 num_outstanding_ops_(0) {
444 if (args.user_intra_op_threadpool != nullptr) {
445 Device* device = immutable_state_.params().device;
446 user_device_ = RenamedDevice::NewRenamedDevice(
447 device->name(), device, false, false, args.user_intra_op_threadpool);
448 }
449 }
450
451 template <class PropagatorStateType>
~ExecutorState()452 ExecutorState<PropagatorStateType>::~ExecutorState() {
453 if (device_context_) {
454 device_context_->Unref();
455 }
456 delete slice_reader_cache_;
457 }
458
459 template <class PropagatorStateType>
460 template <typename Closure>
RunTask(Closure && c,int sample_rate)461 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) {
462 // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
463 // cacheline size is 64 bytes or smaller.
464 alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
465 alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
466
467 auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
468 // Sample the queue length on at least every 16 enqueue operations. This
469 // amortizes the cost of metric updates across 16 operations.
470 if (n_enqueues % std::max(16, sample_rate) == 0) {
471 auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
472 metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
473 }
474
475 // mutable is needed because std::forward<Closure> in the lambda body may move
476 // the Closure `c`.
477 runner_([c = std::forward<Closure>(c)]() mutable {
478 num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
479 std::forward<Closure>(c)();
480 });
481 }
482
483 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)484 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
485 TaggedNodeSeq ready;
486
487 // Ask the device to fill in the device context map.
488 Device* device = immutable_state_.params().device;
489 const Status get_context_status =
490 device->TryGetDeviceContext(&device_context_);
491 if (!get_context_status.ok()) {
492 delete this;
493 done(get_context_status);
494 return;
495 }
496
497 // Initialize the ready queue.
498 ready.reserve(immutable_state_.root_nodes().size());
499 propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
500 num_outstanding_ops_ = ready.size();
501 if (ready.empty()) {
502 delete this;
503 done(OkStatus());
504 } else {
505 done_cb_ = std::move(done);
506 // Schedule to run all the ready ops in thread pool.
507 ScheduleReady(&ready, nullptr);
508 }
509 }
510
511 // State kept alive for executing an asynchronous node in another
512 // thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
513 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
514 // the param points to valid input type vector. It's not an issue for
515 // sync kernels because these vectors are kept on the stack.
516 template <class PropagatorStateType>
517 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon79fb0f210111::ExecutorState::AsyncState518 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
519 const NodeItem* _item, Entry* _first_input,
520 NodeExecStatsInterface* _stats)
521 : saved_inputs(p.inputs.begin(), p.inputs.end()),
522 saved_input_alloc_attrs(p.input_alloc_attrs.begin(),
523 p.input_alloc_attrs.end()),
524 params(p),
525 tagged_node(_tagged_node),
526 item(_item),
527 first_input(_first_input),
528 // ParamsButClearingEigenGPUDevice does equivalent of
529 // params.eigen_gpu_device = nullptr;
530 ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs),
531 stats(_stats) {
532 params.inputs = saved_inputs;
533 params.input_alloc_attrs = saved_input_alloc_attrs;
534 }
535
536 TensorValueVec saved_inputs;
537 AllocatorAttributeVec saved_input_alloc_attrs;
538 OpKernelContext::Params params;
539 TaggedNode tagged_node;
540 const NodeItem* item;
541 Entry* first_input;
542 OpKernelContext ctx;
543 NodeExecStatsInterface* stats;
544
545 private:
ParamsButClearingEigenGPUDevicetensorflow::__anon79fb0f210111::ExecutorState::AsyncState546 OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
547 OpKernelContext::Params* p) {
548 // Ensure OpKernelContext constructor will make a new eigen GPU device if
549 // necessary.
550 p->eigen_gpu_device = nullptr; // Force allocation
551 return p;
552 }
553 };
554
555 // Returns true if `item` might be traced by the given trace and event
556 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)557 bool MightTrace(const tracing::EventCollector* event_collector,
558 bool is_expensive) {
559 // Tracing will only be enabled if either `event_collector` is non null,
560 // or `trace_collector` is non-null and enabled for this particular kernel.
561 // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
562 // `tracing::ScopedRegion` check subsets of these properties internally in
563 // their constructors, the cost of passing the necessary arguments to them can
564 // be significant, so we avoid constructing them in the common case (when we
565 // know they will not be used).
566 if (event_collector != nullptr) {
567 return true;
568 }
569
570 if (profiler::ScopedAnnotation::IsEnabled()) return true;
571
572 return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
573 }
574
575 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)576 Status ExecutorState<PropagatorStateType>::ProcessSync(
577 const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
578 NodeExecStatsInterface* stats) {
579 Status s;
580 OpKernelContext ctx(params, item.num_outputs);
581 nodestats::SetOpStart(stats);
582
583 OpKernel* op_kernel = item.kernel;
584 Device* device = immutable_state_.params().device;
585 const bool is_expensive = kernel_stats_->IsExpensive(item);
586
587 if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
588 tracing::ScopedRegion region(tracing::EventCategory::kCompute,
589 op_kernel->name_view());
590 profiler::AnnotatedTraceMe activity(
591 [op_kernel, &ctx] {
592 return op_kernel->TraceString(
593 ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
594 },
595 profiler::GetTFTraceMeLevel(is_expensive));
596 device->Compute(op_kernel, &ctx);
597 } else if (kernel_stats_->HasExpensiveMarker(item)) {
598 KernelTimer timer;
599 device->Compute(op_kernel, &ctx);
600 // For expensive kernels, always update the cost estimate. For inexpensive
601 // kernels, update the cost estimate with ~1/16 probability. This assumes
602 // that the last 4 bits of the CPU cycle count is uniformly distributed.
603 constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
604 if (is_expensive ||
605 timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
606 kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
607 }
608 } else {
609 device->Compute(op_kernel, &ctx);
610 }
611 nodestats::SetOpEnd(stats);
612 if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
613 s = ProcessOutputs(item, &ctx, outputs->data(), stats);
614 nodestats::SetMemory(stats, &ctx);
615 return s;
616 }
617
618 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats,activity_watcher::ActivityId activity_id)619 void ExecutorState<PropagatorStateType>::ProcessAsync(
620 const NodeItem& item, const OpKernelContext::Params& params,
621 const TaggedNode& tagged_node, Entry* first_input,
622 NodeExecStatsInterface* stats, activity_watcher::ActivityId activity_id) {
623 AsyncOpKernel* async_kernel = item.kernel->AsAsync();
624 DCHECK(async_kernel != nullptr);
625 AsyncState* state =
626 new AsyncState(params, tagged_node, &item, first_input, stats);
627
628 auto done = [this, state, activity_id]() {
629 Device* device = immutable_state_.params().device;
630 NodeExecStatsInterface* stats = state->stats; // Shorthand
631 Entry* first_input = state->first_input; // Shorthand
632
633 nodestats::SetOpEnd(stats);
634 EntryVector outputs(state->item->num_outputs);
635 Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
636 nodestats::SetMemory(stats, &state->ctx);
637 if (vlog_) {
638 VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
639 << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
640 << (state->tagged_node.get_is_dead() ? " is dead" : "")
641 << " device: " << device->name();
642 }
643
644 // Clears inputs.
645 const int num_inputs = state->item->num_inputs;
646 for (int i = 0; i < num_inputs; ++i) {
647 (first_input + i)->ClearVal();
648 }
649 propagator_.MaybeMarkCompleted(state->tagged_node);
650 activity_watcher::ActivityEnd(activity_id);
651 TaggedNodeSeq ready;
652 if (s.ok()) {
653 propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
654 }
655 outputs.clear();
656 const bool completed = NodeDone(s, &ready, stats, nullptr);
657 delete state;
658 if (completed) ScheduleFinish();
659 };
660 nodestats::SetOpStart(stats);
661 {
662 profiler::AnnotatedTraceMe activity(
663 [async_kernel, state] {
664 return async_kernel->TraceString(
665 state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
666 },
667 profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
668 immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
669 std::move(done));
670 }
671 }
672
673 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)674 void ExecutorState<PropagatorStateType>::ProcessNoop(
675 NodeExecStatsInterface* stats) {
676 nodestats::SetOpStart(stats);
677 nodestats::SetOpEnd(stats);
678 }
679
680 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)681 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
682 const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
683 nodestats::SetOpStart(stats);
684 nodestats::SetOpEnd(stats);
685 Entry& output = (*outputs)[0];
686 output.state = Entry::State::HAS_CONST_TENSOR;
687 output.const_tensor = item.const_tensor;
688 output.alloc_attr = item.output_attrs()[0];
689 }
690
691 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64_t scheduled_nsec)692 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
693 int64_t scheduled_nsec) {
694 profiler::TraceMeConsumer activity(
695 // From TraceMeProducer in DirectSession::RunInternal,
696 // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
697 [&] {
698 // NOTE: This tracing uses the iteration number from the first tagged
699 // node that executes during this call to `Process()`. In principle,
700 // subsequent nodes could have different values of `iter_num` that
701 // will not be traced.
702 return profiler::TraceMeEncode(
703 "ExecutorState::Process",
704 {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
705 },
706 profiler::ContextType::kTfExecutor, step_id_,
707 profiler::TraceMeLevel::kInfo);
708 TaggedNodeReadyQueue inline_ready;
709 inline_ready.push_back(tagged_node);
710 return ProcessInline(&inline_ready, scheduled_nsec);
711 }
712
713 template <class PropagatorStateType>
ProcessInline(TaggedNodeReadyQueue * inline_ready,int64_t scheduled_nsec)714 void ExecutorState<PropagatorStateType>::ProcessInline(
715 TaggedNodeReadyQueue* inline_ready, int64_t scheduled_nsec) {
716 WithContext wc(context_);
717 TaggedNodeSeq ready;
718
719 // Parameters passed to OpKernel::Compute.
720 TensorValueVec inputs;
721 AllocatorAttributeVec input_alloc_attrs;
722
723 OpKernelContext::Params params;
724 params.step_id = step_id_;
725 // Override device's threadpool if user provides an intra_op_threadpool
726 Device* device = immutable_state_.params().device;
727 if (user_device_) {
728 params.device = user_device_.get();
729 } else {
730 params.device = device;
731 }
732 params.start_time_usecs = start_time_usecs_;
733 params.deadline = deadline_;
734 params.log_memory = log_memory_;
735 params.rendezvous = rendezvous_;
736 params.collective_executor = collective_executor_;
737 params.session_state = session_state_;
738 params.session_handle = session_handle_;
739 params.session_metadata = session_metadata_;
740 params.tensor_store = tensor_store_;
741 params.cancellation_manager = cancellation_manager_;
742 params.coordination_service_agent = coordination_service_agent_;
743 params.stack_trace = stack_trace_;
744 params.call_frame = call_frame_;
745 params.function_library = immutable_state_.params().function_library;
746 params.resource_manager = device->resource_manager();
747 params.step_container = step_container_;
748 params.slice_reader_cache = slice_reader_cache_;
749 params.runner = &runner_;
750 params.run_all_kernels_inline = run_all_kernels_inline_;
751 params.stats_collector = stats_collector_;
752 params.inc_num_deferred_ops_function = [this]() {
753 mutex_lock lock(num_deferred_ops_mu_);
754 num_deferred_ops_++;
755 };
756 params.dec_num_deferred_ops_function = [this]() {
757 bool finish_when_deferred_ops_done = false;
758 {
759 mutex_lock lock(num_deferred_ops_mu_);
760 num_deferred_ops_--;
761 if (num_deferred_ops_ == 0) {
762 finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
763 }
764 }
765 // Invoke Finish if the graph processing has completed. Finish is always
766 // called exactly once per ExecutorState, either here if there are any
767 // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
768 if (finish_when_deferred_ops_done) Finish();
769 };
770
771 // Set the device_context for this device, if it exists.
772 params.op_device_context = device_context_;
773
774 Status s;
775 NodeExecStatsInterface* stats = nullptr;
776
777 EntryVector outputs(1);
778
779 bool completed = false;
780 while (!inline_ready->empty()) {
781 TaggedNode tagged_node = inline_ready->front();
782 inline_ready->pop_front();
783 const NodeItem& item = tagged_node.get_node_item();
784 const int id = item.node_id;
785
786 propagator_.MaybeMarkStarted(tagged_node);
787 const activity_watcher::ActivityId activity_id =
788 activity_watcher::ActivityStart(
789 [&]() {
790 return std::make_unique<activity_watcher::Activity>(
791 "ExecutorState::Process",
792 activity_watcher::ActivityCategory::kMisc,
793 activity_watcher::Activity::Attributes{
794 {"node_name", item.kernel->def().name()},
795 {"op", item.kernel->def().op()},
796 {"iter_num", absl::StrCat(tagged_node.get_iter_num())},
797 {"step_id", absl::StrCat(params.step_id)},
798 {"node_id", absl::StrCat(id)},
799 {"device", device->name()},
800 });
801 },
802 /*level=*/2);
803
804 params.track_allocations = false;
805 stats = nullptr;
806 if (stats_collector_ && !tagged_node.get_is_dead()) {
807 stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
808 // Track allocations if and only if we are collecting statistics, and
809 // `stats` object is expecting allocations to be tracked.
810 params.track_allocations = stats ? stats->TrackAllocations() : false;
811 nodestats::SetScheduled(stats, scheduled_nsec);
812 nodestats::SetAllStart(stats);
813 }
814
815 if (vlog_) {
816 VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
817 << SummarizeNodeDef(item.kernel->def())
818 << (tagged_node.get_is_dead() ? " is dead" : "")
819 << " device: " << device->name();
820 }
821
822 Entry* first_input = propagator_.GetInputTensors(tagged_node);
823
824 // Only execute this node if it is not dead or it is a send/recv
825 // transfer node. For transfer nodes, we need to propagate the "dead"
826 // bit even when the node is dead.
827 bool launched_asynchronously = false;
828 if (tagged_node.get_is_dead() && !item.is_transfer_node) {
829 if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
830 } else if (TF_PREDICT_FALSE(item.is_noop)) {
831 ProcessNoop(stats);
832 } else if (item.const_tensor != nullptr && !params.track_allocations) {
833 ProcessConstTensor(item, &outputs, stats);
834 } else {
835 // Prepares inputs.
836 bool is_input_dead = false;
837 s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
838 &is_input_dead);
839 if (!s.ok()) {
840 // Clear inputs.
841 const int num_inputs = item.num_inputs;
842 for (int i = 0; i < num_inputs; ++i) {
843 (first_input + i)->ClearVal();
844 }
845 propagator_.MaybeMarkCompleted(tagged_node);
846 activity_watcher::ActivityEnd(activity_id);
847 // Continue to process the nodes in 'inline_ready'.
848 completed = NodeDone(s, &ready, stats, inline_ready);
849 continue;
850 }
851
852 // Set up compute params.
853 params.op_kernel = item.kernel;
854 params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
855 params.is_input_dead = is_input_dead;
856 params.output_attr_array = item.output_attrs();
857 params.forward_from_array = item.forward_from();
858 params.outputs_required_array = item.outputs_required.get();
859 params.inputs = inputs;
860 params.input_alloc_attrs = input_alloc_attrs;
861
862 if (item.kernel_is_async) {
863 ProcessAsync(item, params, tagged_node, first_input, stats,
864 activity_id);
865 launched_asynchronously = true;
866 } else {
867 s = ProcessSync(item, ¶ms, &outputs, stats);
868 }
869 }
870
871 if (!launched_asynchronously) {
872 if (vlog_) {
873 VLOG(2) << "Synchronous kernel done: " << id << " step "
874 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
875 << (tagged_node.get_is_dead() ? " is dead: " : "")
876 << " device: " << device->name();
877 }
878
879 // Clears inputs.
880 const int num_inputs = item.num_inputs;
881 for (int i = 0; i < num_inputs; ++i) {
882 (first_input + i)->ClearVal();
883 }
884 propagator_.MaybeMarkCompleted(tagged_node);
885 activity_watcher::ActivityEnd(activity_id);
886 // Propagates outputs.
887 if (s.ok()) {
888 propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
889 }
890
891 // Clear outputs without deallocating the `outputs` vector.
892 const int num_outputs = item.num_outputs;
893 for (int i = 0; i < num_outputs; ++i) {
894 outputs[i].ClearVal();
895 }
896
897 if (stats) {
898 scheduled_nsec = nodestats::NowInNsec();
899 }
900 // Postprocess.
901 completed = NodeDone(s, &ready, stats, inline_ready);
902 }
903 } // while !inline_ready.empty()
904
905 // This thread of computation is done if completed = true.
906 if (completed) ScheduleFinish();
907 }
908
909 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)910 Status ExecutorState<PropagatorStateType>::PrepareInputs(
911 const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
912 AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
913 inputs->resize(item.num_inputs);
914 input_alloc_attrs->resize(item.num_inputs);
915
916 *is_input_dead = false;
917
918 for (int i = 0; i < item.num_inputs; ++i) {
919 const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
920 IsRefType(item.input_type(i));
921 Entry* entry = first_input + i;
922 (*input_alloc_attrs)[i] = entry->alloc_attr;
923
924 // i-th input.
925 TensorValue* inp = &(*inputs)[i];
926
927 switch (entry->state) {
928 case Entry::State::NO_VALUE: {
929 // Only merge and transfer nodes can have no-value inputs.
930 inp->mutex_if_ref = nullptr;
931 if (item.is_merge) {
932 inp->tensor = nullptr;
933 } else {
934 DCHECK(item.is_transfer_node)
935 << item.kernel->name() << " - input " << i;
936 entry->state = Entry::State::HAS_CONST_TENSOR;
937 entry->const_tensor = kEmptyTensor;
938 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
939 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
940 // accessors making dynamic checks that prevent using an immutable
941 // tensor as a mutable tensor.
942 inp->tensor = const_cast<Tensor*>(kEmptyTensor);
943 *is_input_dead = true;
944 }
945 break;
946 }
947
948 case Entry::State::HAS_VALUE: {
949 if (TF_PREDICT_FALSE(expect_ref)) {
950 return AttachDef(
951 errors::InvalidArgument(i, "-th input expects a ref type"),
952 item.kernel->def());
953 }
954 inp->mutex_if_ref = nullptr;
955 inp->tensor = entry->val.get();
956 break;
957 }
958
959 case Entry::State::HAS_CONST_TENSOR: {
960 if (TF_PREDICT_FALSE(expect_ref)) {
961 return AttachDef(
962 errors::InvalidArgument(i, "-th input expects a ref type"),
963 item.kernel->def());
964 }
965 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
966 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
967 // accessors making dynamic checks that prevent using an immutable
968 // tensor as a mutable tensor.
969 inp->mutex_if_ref = nullptr;
970 inp->tensor = const_cast<Tensor*>(entry->const_tensor);
971 break;
972 }
973
974 case Entry::State::HAS_REF_TENSOR: {
975 {
976 tf_shared_lock ml(*entry->ref_tensor.mu);
977 if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
978 !item.is_initialization_op)) {
979 return AttachDef(errors::FailedPrecondition(
980 "Attempting to use uninitialized value ",
981 item.kernel->requested_input(i)),
982 item.kernel->def());
983 }
984 }
985
986 if (expect_ref) {
987 inp->mutex_if_ref = entry->ref_tensor.mu;
988 inp->tensor = entry->ref_tensor.tensor;
989 } else {
990 // Automatically deref the tensor ref when the op expects a
991 // tensor but is given a ref to a tensor. Need to deref it
992 // under the mutex.
993 {
994 mutex* ref_mu = entry->ref_tensor.mu;
995 Tensor* ref_tensor = entry->ref_tensor.tensor;
996 tf_shared_lock l(*ref_mu);
997 entry->val.Init(*ref_tensor);
998 }
999 entry->state = Entry::State::HAS_VALUE;
1000
1001 inp->mutex_if_ref = nullptr;
1002 inp->tensor = entry->val.get();
1003 // The dtype of entry->ref_tensor.tensor could have been changed by
1004 // another operation that ran after the operation that "produced" it
1005 // executed, so re-validate that the type of the dereferenced tensor
1006 // matches the expected input type.
1007 if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
1008 return AttachDef(
1009 errors::InvalidArgument(
1010 i, "-th input expects type ",
1011 DataTypeString(item.input_type(i)),
1012 " but automatically dereferenced input tensor has type ",
1013 DataTypeString(inp->tensor->dtype())),
1014 item.kernel->def());
1015 }
1016 }
1017 break;
1018 }
1019 }
1020 }
1021 return OkStatus();
1022 }
1023
1024 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)1025 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
1026 const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
1027 NodeExecStatsInterface* stats) {
1028 Status s = ctx->status();
1029 if (!s.ok()) {
1030 s = AttachDef(s, item.kernel->def());
1031 // TODO(misard) Replace with a finer-grain enabling flag once we
1032 // add better optional debugging support.
1033 if (vlog_ && VLOG_IS_ON(1)) {
1034 LOG(WARNING) << this << " Compute status: " << s;
1035 }
1036 if (s.code() == error::RESOURCE_EXHAUSTED) {
1037 if (stats_collector_) {
1038 string err = stats_collector_->ReportAllocsOnResourceExhausted(
1039 s.error_message());
1040 s = errors::CreateWithUpdatedMessage(
1041 s, strings::StrCat(s.error_message(), err));
1042 } else {
1043 s = errors::CreateWithUpdatedMessage(
1044 s,
1045 strings::StrCat(
1046 s.error_message(),
1047 "\nHint: If you want to see a list of allocated tensors when "
1048 "OOM happens, add report_tensor_allocations_upon_oom "
1049 "to RunOptions for current allocation info. This isn't "
1050 "available when running in Eager mode.\n"));
1051 }
1052 } else if (s.code() == error::UNAVAILABLE &&
1053 !item.is_distributed_communication) {
1054 s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name());
1055 }
1056 return s;
1057 }
1058
1059 for (int i = 0; i < item.num_outputs; ++i) {
1060 const TensorValue val = ctx->release_output(i);
1061 Entry* out = &outputs[i];
1062 DCHECK(out->state == Entry::State::NO_VALUE);
1063
1064 if (val.tensor == nullptr) {
1065 // Unless it's a Switch or a Recv, or the executor has marked the output
1066 // as not required, the node must produce a tensor value at i-th output.
1067 if (!(item.is_recv_or_switch ||
1068 (item.outputs_required && !item.outputs_required[i]))) {
1069 s.Update(errors::Internal("Missing ", i, "-th output from ",
1070 FormatNodeDefForError(item.kernel->def())));
1071 }
1072 } else {
1073 // Set the allocator attributes of the output entry.
1074 out->alloc_attr = ctx->output_alloc_attr(i);
1075
1076 // Sanity check of output tensor types. We need to inspect this safely as
1077 // we are in the tensor buffer.
1078 DataType dtype = val.dtype_safe();
1079 if (dtype == item.output_type(i)) {
1080 if (stats && val.tensor->IsInitialized()) {
1081 nodestats::SetOutput(stats, i, val.tensor);
1082 }
1083 if (val.is_ref()) {
1084 out->state = Entry::State::HAS_REF_TENSOR;
1085 out->ref_tensor.tensor = val.tensor;
1086 out->ref_tensor.mu = val.mutex_if_ref;
1087 if (log_memory_) {
1088 Tensor to_log;
1089 {
1090 // Dereference the tensor under the lock.
1091 tf_shared_lock l(*out->ref_tensor.mu);
1092 to_log = *out->ref_tensor.tensor;
1093 }
1094 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1095 ctx->step_id(), i, to_log);
1096 }
1097 } else {
1098 // NOTE that std::move is used here, so val.tensor goes to
1099 // uninitialized state (val.tensor->IsInitialized return false).
1100 out->state = Entry::State::HAS_VALUE;
1101 out->val.Init(std::move(*val.tensor));
1102 if (log_memory_) {
1103 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1104 ctx->step_id(), i, *out->val);
1105 }
1106 }
1107 } else {
1108 s.Update(
1109 errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1110 " does not match declared output type ",
1111 DataTypeString(item.output_type(i)), " for node ",
1112 FormatNodeDefForError(item.kernel->def())));
1113 }
1114 }
1115 if (!val.is_ref()) {
1116 // If OpKernelContext returns outputs via pass-by-value, we
1117 // don't need this trouble.
1118 delete val.tensor;
1119 }
1120 }
1121 return s;
1122 }
1123
1124 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1125 bool ExecutorState<PropagatorStateType>::NodeDone(
1126 const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1127 TaggedNodeReadyQueue* inline_ready) {
1128 if (stats) {
1129 nodestats::SetAllEnd(stats);
1130 DCHECK_NE(stats_collector_, nullptr);
1131 stats->Done(immutable_state_.params().device->name());
1132 }
1133
1134 if (TF_PREDICT_TRUE(s.ok())) {
1135 const size_t ready_size = ready->size();
1136 if (ready_size == 0) {
1137 return num_outstanding_ops_.fetch_sub(1) == 1;
1138 } else {
1139 // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1140 if (ready_size > 1) {
1141 num_outstanding_ops_.fetch_add(ready_size - 1,
1142 std::memory_order_relaxed);
1143 }
1144
1145 // Schedule the ready nodes in 'ready'.
1146 ScheduleReady(ready, inline_ready);
1147
1148 return false;
1149 }
1150 } else {
1151 bool abort_run = false;
1152 Status maybe_derived_s(s);
1153
1154 // Some error happened. This thread of computation is done.
1155 {
1156 mutex_lock l(mu_);
1157 if (status_.ok()) {
1158 // If this is the first node to fail in this run, we are responsible for
1159 // aborting all other execution in the step.
1160 abort_run = true;
1161
1162 // If execution has been cancelled, mark cancelled or aborted errors as
1163 // being derived. Note that the original node that fails might also
1164 // trigger cancellation, and here we make sure the original error is
1165 // exposed to users and not buried as a derived error.
1166 if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1167 (errors::IsCancelled(s) || errors::IsAborted(s))) {
1168 status_ = StatusGroup::MakeDerived(s);
1169 maybe_derived_s = status_;
1170 } else {
1171 status_ = s;
1172 }
1173 }
1174 }
1175
1176 if (abort_run) {
1177 TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1178 if (cancellation_manager_) {
1179 // Only log when the abort happens during the actual run time.
1180 // Use VLOG instead of LOG(warning) because error status is expected
1181 // when the executor is run under the grappler optimization phase or
1182 // when iterating through a tf.data input pipeline.
1183 VLOG(1) << "[" << immutable_state_.params().device->name()
1184 << "] Executor start aborting: " << s;
1185 }
1186
1187 if (rendezvous_) {
1188 rendezvous_->StartAbort(s);
1189 }
1190 if (cancellation_manager_) {
1191 cancellation_manager_->StartCancelWithStatus(maybe_derived_s);
1192 } else if (collective_executor_) {
1193 // If there's cancellation_manager_, collective ops aborts
1194 // collective_executor_ upon cancellation; otherwise we need to abort
1195 // here.
1196 collective_executor_->StartAbort(s);
1197 }
1198 }
1199
1200 return num_outstanding_ops_.fetch_sub(1) == 1;
1201 }
1202 }
1203
1204 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1205 void ExecutorState<PropagatorStateType>::ScheduleReady(
1206 TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1207 profiler::TraceMe activity(
1208 [&]() {
1209 return strings::StrCat(
1210 "ExecutorState::ScheduleReady#",
1211 "ready_size=", (ready == nullptr ? -1 : ready->size()),
1212 ",inline_ready_size=",
1213 (inline_ready == nullptr ? -1 : inline_ready->size()), "#");
1214 },
1215 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1216 DCHECK(!ready->empty());
1217
1218 int64_t scheduled_nsec = 0;
1219 if (stats_collector_) {
1220 scheduled_nsec = nodestats::NowInNsec();
1221 }
1222
1223 if (run_all_kernels_inline_) {
1224 if (inline_ready == nullptr) {
1225 // Schedule all ready kernels from a single closure. This ensure that,
1226 // regardless of the `runner_` implementation, all kernels will run
1227 // sequentially on the same thread, and thread wakeup overhead and
1228 // executor mutex contention will be minimized.
1229 RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1230 for (auto& tagged_node : ready) {
1231 Process(tagged_node, scheduled_nsec);
1232 }
1233 });
1234 } else {
1235 for (auto& tagged_node : *ready) {
1236 inline_ready->push_back(tagged_node);
1237 }
1238 }
1239 } else {
1240 const TaggedNode* curr_expensive_node = nullptr;
1241 TaggedNodeSeq expensive_nodes;
1242 if (inline_ready == nullptr) {
1243 // Schedule to run all the ready ops in thread pool.
1244 for (auto& tagged_node : *ready) {
1245 RunTask([=]() { Process(tagged_node, scheduled_nsec); },
1246 /*sample_rate=*/ready->size());
1247 }
1248 } else {
1249 for (auto& tagged_node : *ready) {
1250 const NodeItem& item = *tagged_node.node_item;
1251 if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1252 // Inline this inexpensive node.
1253 inline_ready->push_back(tagged_node);
1254 } else {
1255 if (curr_expensive_node) {
1256 expensive_nodes.push_back(*curr_expensive_node);
1257 }
1258 curr_expensive_node = &tagged_node;
1259 }
1260 }
1261 }
1262 if (curr_expensive_node) {
1263 if (inline_ready->empty()) {
1264 inline_ready->push_back(*curr_expensive_node);
1265 } else {
1266 // There are inline nodes to run already. We dispatch this expensive
1267 // node to other thread.
1268 expensive_nodes.push_back(*curr_expensive_node);
1269 }
1270 }
1271 if (!expensive_nodes.empty()) {
1272 if (expensive_nodes.size() < kInlineScheduleReadyThreshold) {
1273 for (auto& tagged_node : expensive_nodes) {
1274 RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1275 scheduled_nsec),
1276 /*sample_rate=*/expensive_nodes.size());
1277 }
1278 } else {
1279 // There are too many ready expensive nodes. Schedule them in child
1280 // threads.
1281 // TODO(fishx): Apply the same optimization to cheap ops as well since
1282 // executing lots of cheap ops in one thread can potentially be the
1283 // bottleneck as well.
1284 auto it = expensive_nodes.begin();
1285 while (it < expensive_nodes.end()) {
1286 auto end = it;
1287 std::advance(end, kInlineScheduleReadyThreshold);
1288 if (end > expensive_nodes.end()) {
1289 end = expensive_nodes.end();
1290 }
1291 TaggedNodeSeq ready_chunk{it, end};
1292 RunTask(
1293 [this, ready_chunk = std::move(ready_chunk), scheduled_nsec]() {
1294 profiler::TraceMe activity(
1295 [&]() {
1296 return strings::StrCat(
1297 "ExecutorState::ScheduleReady::"
1298 "ChildThreadExpensiveNodes#",
1299 "ready_chunk_size=", ready_chunk.size(), "#");
1300 },
1301 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1302 for (auto& tagged_node : ready_chunk) {
1303 RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1304 scheduled_nsec),
1305 /*sample_rate=*/ready_chunk.size());
1306 }
1307 });
1308 it = end;
1309 }
1310 }
1311 }
1312 }
1313 ready->clear();
1314 }
1315
1316 template <class PropagatorStateType>
ScheduleFinish()1317 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1318 // Checks condition to decide if needs to invoke Finish(). If there are
1319 // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1320 // Finish(). Otherwise, invoke Finish() directly.
1321 // Note that it is critical that the ScheduleFinish / Finish codepath does not
1322 // block, otherwise we might deadlock. See b/124523000 for details.
1323 {
1324 mutex_lock lock(num_deferred_ops_mu_);
1325 if (num_deferred_ops_ > 0) {
1326 finish_when_deferred_ops_done_ = true;
1327 return;
1328 }
1329 }
1330 // Finish is always called exactly once per ExecutorState, either here if
1331 // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1332 // there are deferred ops.
1333 Finish();
1334 }
1335
1336 template <class PropagatorStateType>
Finish()1337 void ExecutorState<PropagatorStateType>::Finish() {
1338 mu_.lock();
1339 auto status = status_;
1340 auto done_cb = std::move(done_cb_);
1341 auto runner = std::move(runner_);
1342 mu_.unlock();
1343 int64_t step_id = step_id_;
1344 CHECK(done_cb != nullptr);
1345 Device* device = immutable_state_.params().device;
1346
1347 if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1348 // Logs verbose information about the current state of active and pending
1349 // nodes in the propagator.
1350 propagator_.DumpState();
1351 }
1352
1353 // There are several potential race conditions below. To name a few:
1354 // 1. Even if the device's status is OK at the precise moment when
1355 // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1356 // is called below, caused by work enqueued onto the same device by other
1357 // concurrent ExecutorState objects.
1358 // 2. Some implementations of Device::RefreshStatus, such as
1359 // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1360 // device mutex after a stream pointer is acquired and before the stream is
1361 // queried for status.
1362 // 3. It's the same for some implementations of Device::Sync, such as
1363 // XlaDevice::Sync.
1364 //
1365 // However, these race conditions are acceptable because a stream (and
1366 // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1367 // which means we will at worst report errors when there isn't any, never the
1368 // opposite.
1369
1370 // An early exit for devices don't allow sync on completion. Ops that run on
1371 // these devices should have used num_deferred_ops correctly to ensure the
1372 // device has finished all relevant work at this point.
1373 if (!device->AllowsSyncOnCompletion()) {
1374 status.Update(device->RefreshStatus());
1375 if (!status.ok()) {
1376 // In device async execution mode, it's possible for device execution to
1377 // lag behind ExecutorState scheduling so much that this is the first
1378 // place a device execution error surfaces.
1379 // If so, all ExecutorState::NodeDone calls have already happened with OK
1380 // status. This is the last defense where StartCancel must be called to
1381 // abort all computation still running on any device.
1382 // TODO(b/124523000): Always call Finish in a separate thread, so even if
1383 // StartCancel blocks the current thread's execution, we won't encounter
1384 // deadlocks caused by inter-op thread exhaustion.
1385 if (rendezvous_) {
1386 rendezvous_->StartAbort(status);
1387 }
1388 if (cancellation_manager_) {
1389 cancellation_manager_->StartCancelWithStatus(status);
1390 } else if (collective_executor_) {
1391 // If there's cancellation_manager_, collective ops aborts
1392 // collective_executor_ upon cancellation; otherwise we need to abort
1393 // here.
1394 collective_executor_->StartAbort(status);
1395 }
1396 }
1397 delete this;
1398 runner([step_id, status, done_cb = std::move(done_cb)]() {
1399 profiler::TraceMeConsumer activity(
1400 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1401 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1402 [&] {
1403 return profiler::TraceMeEncode("ExecutorDoneCallback",
1404 {{"id", step_id}});
1405 },
1406 profiler::ContextType::kTfExecutor, step_id,
1407 profiler::TraceMeLevel::kInfo);
1408 done_cb(status);
1409 });
1410 return;
1411 }
1412
1413 if (sync_on_finish_ && status.ok()) {
1414 // Block until the device has finished all queued operations. For
1415 // devices like GPUs that continue to execute Ops after their Compute
1416 // methods have completed, this ensures that control is not returned to
1417 // the user until the step (and its side-effects) has actually completed.
1418 device->Sync([this, step_id, runner = std::move(runner),
1419 done_cb = std::move(done_cb)](const Status& status) mutable {
1420 delete this;
1421 runner([step_id, status, done_cb = std::move(done_cb)]() {
1422 profiler::TraceMeConsumer activity(
1423 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1424 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1425 [&] {
1426 return profiler::TraceMeEncode("ExecutorDoneCallback",
1427 {{"id", step_id}});
1428 },
1429 profiler::ContextType::kTfExecutor, step_id,
1430 profiler::TraceMeLevel::kInfo);
1431 done_cb(status);
1432 });
1433 });
1434 } else {
1435 delete this;
1436 runner([step_id, status, done_cb = std::move(done_cb)]() {
1437 profiler::TraceMeConsumer activity(
1438 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1439 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1440 [&] {
1441 return profiler::TraceMeEncode("ExecutorDoneCallback",
1442 {{"id", step_id}});
1443 },
1444 profiler::ContextType::kTfExecutor, step_id,
1445 profiler::TraceMeLevel::kInfo);
1446 done_cb(status);
1447 });
1448 }
1449 }
1450
RunAsync(const Args & args,DoneCallback done)1451 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1452 if (OpOrderDeterminismRequired()) {
1453 (new ExecutorState<OrderedPropagatorState>(args, immutable_state_,
1454 &kernel_stats_))
1455 ->RunAsync(std::move(done));
1456 } else if (immutable_state_.requires_control_flow_support()) {
1457 (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1458 ->RunAsync(std::move(done));
1459 } else {
1460 (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1461 &kernel_stats_))
1462 ->RunAsync(std::move(done));
1463 }
1464 }
1465
1466 } // namespace
1467
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1468 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1469 Executor** executor) {
1470 ExecutorImpl* impl = new ExecutorImpl(params);
1471 const Status s = impl->Initialize(graph);
1472 if (s.ok()) {
1473 *executor = impl;
1474 } else {
1475 delete impl;
1476 }
1477 return s;
1478 }
1479
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1480 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1481 const std::shared_ptr<const NodeProperties>& props,
1482 int graph_def_version, OpKernel** kernel) {
1483 const auto device_type = DeviceType(device->attributes().device_type());
1484 auto allocator = device->GetAllocator(AllocatorAttributes());
1485 return CreateOpKernel(device_type, device, allocator, flib,
1486 device->resource_manager(), props, graph_def_version,
1487 kernel);
1488 }
1489
DeleteNonCachedKernel(OpKernel * kernel)1490 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1491
1492 namespace {
1493
1494 class DefaultExecutorRegistrar {
1495 public:
DefaultExecutorRegistrar()1496 DefaultExecutorRegistrar() {
1497 Factory* factory = new Factory;
1498 ExecutorFactory::Register("", factory);
1499 ExecutorFactory::Register("DEFAULT", factory);
1500 }
1501
1502 private:
1503 class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1504 Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1505 std::unique_ptr<Executor>* out_executor) override {
1506 Executor* ret = nullptr;
1507 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1508 out_executor->reset(ret);
1509 return OkStatus();
1510 }
1511 };
1512 };
1513 static DefaultExecutorRegistrar registrar;
1514
1515 } // namespace
1516
1517 } // namespace tensorflow
1518