xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/executor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/executor.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/time/time.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/activity_watcher/activity.h"
27 #include "tensorflow/core/common_runtime/costmodel_manager.h"
28 #include "tensorflow/core/common_runtime/entry.h"
29 #include "tensorflow/core/common_runtime/executor_factory.h"
30 #include "tensorflow/core/common_runtime/graph_view.h"
31 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
32 #include "tensorflow/core/common_runtime/pending_counts.h"
33 #include "tensorflow/core/common_runtime/propagator_state.h"
34 #include "tensorflow/core/common_runtime/renamed_device.h"
35 #include "tensorflow/core/common_runtime/simple_propagator_state.h"
36 #include "tensorflow/core/common_runtime/step_stats_collector.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/cancellation.h"
39 #include "tensorflow/core/framework/collective.h"
40 #include "tensorflow/core/framework/control_flow.h"
41 #include "tensorflow/core/framework/device_attributes.pb.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/metrics.h"
44 #include "tensorflow/core/framework/node_def_util.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/op_segment.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/framework/tensor_reference.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/framework/types.pb.h"
51 #include "tensorflow/core/graph/edgeset.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_node_util.h"
54 #include "tensorflow/core/lib/core/errors.h"
55 #include "tensorflow/core/lib/core/notification.h"
56 #include "tensorflow/core/lib/core/status.h"
57 #include "tensorflow/core/lib/core/threadpool.h"
58 #include "tensorflow/core/lib/gtl/flatmap.h"
59 #include "tensorflow/core/lib/gtl/inlined_vector.h"
60 #include "tensorflow/core/lib/gtl/manual_constructor.h"
61 #include "tensorflow/core/lib/hash/hash.h"
62 #include "tensorflow/core/platform/context.h"
63 #include "tensorflow/core/platform/env.h"
64 #include "tensorflow/core/platform/errors.h"
65 #include "tensorflow/core/platform/logging.h"
66 #include "tensorflow/core/platform/macros.h"
67 #include "tensorflow/core/platform/mutex.h"
68 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
69 #include "tensorflow/core/platform/status.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/tracing.h"
72 #include "tensorflow/core/platform/types.h"
73 #include "tensorflow/core/profiler/lib/annotated_traceme.h"
74 #include "tensorflow/core/profiler/lib/connected_traceme.h"
75 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
76 #include "tensorflow/core/profiler/lib/traceme_encode.h"
77 #include "tensorflow/core/protobuf/error_codes.pb.h"
78 #include "tensorflow/core/util/determinism.h"
79 #include "tensorflow/core/util/managed_stack_trace.h"
80 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
81 
82 namespace tensorflow {
83 
84 namespace {
85 
86 // 1-D, 0 element tensor.
87 static const Tensor* const kEmptyTensor = new Tensor;
88 
89 // Helper routines for collecting step stats.
90 namespace nodestats {
NowInNsec()91 inline int64_t NowInNsec() { return EnvTime::NowNanos(); }
92 
SetScheduled(NodeExecStatsInterface * stats,int64_t micros)93 void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) {
94   if (!stats) return;
95   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
96 }
97 
SetAllStart(NodeExecStatsInterface * stats)98 void SetAllStart(NodeExecStatsInterface* stats) {
99   if (!stats) return;
100   stats->RecordExecutorStarted();
101 }
102 
SetOpStart(NodeExecStatsInterface * stats)103 void SetOpStart(NodeExecStatsInterface* stats) {
104   if (!stats) return;
105   stats->RecordComputeStarted();
106 }
107 
SetOpEnd(NodeExecStatsInterface * stats)108 void SetOpEnd(NodeExecStatsInterface* stats) {
109   if (!stats) return;
110   stats->RecordComputeEnded();
111 }
112 
SetAllEnd(NodeExecStatsInterface * stats)113 void SetAllEnd(NodeExecStatsInterface* stats) {
114   if (!stats) return;
115   stats->RecordExecutorEnded();
116 }
117 
SetOutput(NodeExecStatsInterface * stats,int slot,const Tensor * v)118 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
119   if (!stats) return;
120   stats->SetOutput(slot, v);
121 }
122 
SetMemory(NodeExecStatsInterface * stats,OpKernelContext * ctx)123 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
124   if (!stats) return;
125   stats->SetMemory(ctx);
126 }
127 
128 }  // namespace nodestats
129 
130 // Time the execution of kernels (in CPU cycles).  Used to dynamically identify
131 // inexpensive kernels which can be dispatched inline.
132 struct KernelTimer {
133   uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
134 
ElapsedCyclestensorflow::__anon79fb0f210111::KernelTimer135   uint64 ElapsedCycles() {
136     return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
137   }
138 };
139 
140 // TODO(b/152925936): Re-evaluate these constants with current usage patterns.
141 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
142 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
143 
144 class ExecutorImpl : public Executor {
145  public:
ExecutorImpl(const LocalExecutorParams & p)146   explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
147 
Initialize(const Graph & graph)148   Status Initialize(const Graph& graph) {
149     TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
150     kernel_stats_.Initialize(immutable_state_.graph_view());
151     return OkStatus();
152   }
153 
154   void RunAsync(const Args& args, DoneCallback done) override;
155 
156  private:
157   template <class PropagatorStateType>
158   friend class ExecutorState;
159 
160   // Stores execution time information about the kernels in an executor's graph.
161   class KernelStats {
162    public:
163     KernelStats() = default;
164 
Initialize(const GraphView & gview)165     void Initialize(const GraphView& gview) {
166       is_expensive_.resize(gview.num_nodes());
167       cost_estimates_ =
168           std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
169       for (int32_t i = 0; i < gview.num_nodes(); ++i) {
170         if (gview.node(i)) {
171           is_expensive_[i] =
172               gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
173           cost_estimates_[i] = kInitialCostEstimateCycles;
174         }
175       }
176     }
177 
178     // Returns true iff the given node is considered "expensive". The
179     // executor uses this flag to optimize graph execution, for example
180     // by "inlining" inexpensive kernels.
IsExpensive(const NodeItem & node) const181     bool IsExpensive(const NodeItem& node) const {
182       return is_expensive_[node.node_id] &&
183              (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
184               kOpIsExpensiveThresholdCycles);
185     }
186 
187     // Returns the value of kernel->IsExpensive().
HasExpensiveMarker(const NodeItem & node) const188     bool HasExpensiveMarker(const NodeItem& node) const {
189       return is_expensive_[node.node_id];
190     }
191 
192     // Updates the dynamic cost estimate, which is used to determine whether the
193     // given node is expensive. The new cost estimate is a weighted average of
194     // the old cost estimate and the latest cost. We only update cost estimates
195     // for kernels for which IsExpensive() return true.
UpdateCostEstimate(const NodeItem & node,uint64 elapsed_cycles)196     void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
197       // N.B. Updates to `cost_estimate` are atomic but unlocked.  Simultaneous
198       // updates may result in one or more updates being ignored.  This does not
199       // affect correctness but may slow down the update frequency.
200       std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
201       auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
202 
203       uint64 new_estimate =
204           ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
205 
206       cost_estimate.store(new_estimate, std::memory_order_relaxed);
207     }
208 
209    private:
210     // Initial time (in CPU cycles) we expect an operation to take.  Used to
211     // determine whether an operation should be place in a threadpool.
212     // Operations start out "expensive".
213     static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
214     static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
215     static constexpr uint64 kCostDecay = 10;
216 
217     std::vector<bool> is_expensive_;
218     // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
219     std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
220   };
221 
222   ImmutableExecutorState immutable_state_;
223   KernelStats kernel_stats_;
224 
225   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
226 };
227 
228 // The state associated with one invocation of ExecutorImpl::Run.
229 //
230 // ExecutorState dispatches nodes when they become ready, and delegates to an
231 // instance of `PropagatorStateType` to keep track of how many predecessors of a
232 // are still pending.
233 //
234 // The template argument `class PropagatorStateType` must define the following
235 // public members:
236 // * A type `TaggedNode`, representing a node to be processed, with public
237 //   members:
238 //   * `const NodeItem& get_node_item() const`
239 //   * `bool get_is_dead() const`
240 // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
241 //   processed, with public members (having the same meanings as in an
242 //   `std::vector<TaggedNode>`):
243 //   * `void push_back(const TaggedNode& node)`
244 //   * `TaggedNode front() const`
245 //   * `void pop_front()`
246 //   * `bool empty() const`
247 // * A type `TaggedNodeSeq`, representing a list of nodes to be scheduled, with
248 //   public members (having the same meanings as in an
249 //   `std::vector<TaggedNode>`):
250 //   * `size_t size() const`
251 //   * `bool empty() const`
252 //   * `void clear()`
253 //   * `const_iterator begin() const`
254 //   * `const_iterator end() const`
255 // * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
256 //   immutable_state, int64 step_id)`.
257 // * The following public methods:
258 //   * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
259 //     TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
260 //     nodes in `roots` and adds them to `*ready`
261 //   * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
262 //     outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
263 //     given `tagged_node` to the destinations of its output edges, and adds
264 //     any newly runnable nodes to `*ready`
265 //   * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
266 //     returns a pointer to the input tensors for the given `tagged_node`
267 //   * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
268 //     which creates a `FrameAndIter` for the given `tagged_node`
269 //   * `void DumpState()`, which dumps the dynamic state of the executing graph
270 //   * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
271 //     that a node has started
272 //   * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
273 //     that a node has completed
274 //
275 // See `PropagatorState` in "./propagator_state.h" for an example of a type that
276 // can be used to instantiate `PropagatorStateType`.
277 template <class PropagatorStateType>
278 class ExecutorState {
279  public:
280   ExecutorState(const Executor::Args& args,
281                 const ImmutableExecutorState& immutable_state_,
282                 ExecutorImpl::KernelStats* kernel_stats_);
283   ~ExecutorState();
284 
285   void RunAsync(Executor::DoneCallback done);
286 
287  private:
288   // Use `TaggedNode` types defined by `PropagatorStateType`.
289   typedef typename PropagatorStateType::TaggedNode TaggedNode;
290   typedef
291       typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
292   typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
293 
294   struct AsyncState;
295 
296   // Process a ready node in current thread.
297   void Process(TaggedNode node, int64_t scheduled_nsec);
298 
299   void ProcessInline(TaggedNodeReadyQueue* inline_ready,
300                      int64_t scheduled_nsec);
301 
302   Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
303                      EntryVector* outputs, NodeExecStatsInterface* stats);
304   void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
305                     const TaggedNode& tagged_node, Entry* first_input,
306                     NodeExecStatsInterface* stats,
307                     activity_watcher::ActivityId activity_id);
308   void ProcessNoop(NodeExecStatsInterface* stats);
309   void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
310                           NodeExecStatsInterface* stats);
311 
312   // Before invoking item->kernel, fills in its "inputs".
313   Status PrepareInputs(const NodeItem& item, Entry* first_input,
314                        TensorValueVec* inputs,
315                        AllocatorAttributeVec* input_alloc_attrs,
316                        bool* is_input_dead);
317 
318   // After item->kernel computation is done, processes its outputs.
319   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
320                         Entry* outputs, NodeExecStatsInterface* stats);
321 
322   // Called after each node finishes. Takes ownership of "stats". Returns true
323   // if execution has completed.
324   //
325   // This method will clear `*ready` before returning.
326   bool NodeDone(const Status& s, TaggedNodeSeq* ready,
327                 NodeExecStatsInterface* stats,
328                 TaggedNodeReadyQueue* inline_ready);
329 
330   // Schedule all the expensive nodes in '*ready', and put all the inexpensive
331   // nodes in 'ready' into 'inline_ready'.
332   //
333   // This method will clear `*ready` before returning.
334   //
335   // REQUIRES: `!ready->empty()`.
336   void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
337 
338   // A wrapper for runner_ to keep track of the pending queue length. Op
339   // execution should dispatch work using this function instead of using runner_
340   // directly.
341   template <typename Closure>
342   void RunTask(Closure&& c, int sample_rate = 0);
343 
344   // Clean up when this executor is done.
345   void Finish();
346   void ScheduleFinish();
347 
348   // Contains the device context assigned by the device at the beginning of a
349   // step.
350   DeviceContext* device_context_ = nullptr;
351 
352   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
353 
354   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
355   const bool log_memory_;
356 
357   int64_t step_id_;
358   int64_t start_time_usecs_ = 0;
359   // The deadline for the session to complete by. Empty if unspecified.
360   absl::optional<absl::Time> deadline_;
361 
362   // Maximum number of kernels that can be scheduled inline. If lots of kernels
363   // are ready at the same time, scheduling them in one thread can be very slow.
364   // TODO(fishx): Make it configurable if necessary.
365   static constexpr uint64 kInlineScheduleReadyThreshold = 500;
366 
367   // Not owned.
368   RendezvousInterface* rendezvous_;
369   CollectiveExecutor* collective_executor_ = nullptr;
370   SessionState* session_state_;
371   string session_handle_;
372   const SessionMetadata* session_metadata_ = nullptr;
373   TensorStore* tensor_store_;
374   // Step-local container.
375   ScopedStepContainer* step_container_;
376   StepStatsCollectorInterface* const stats_collector_;
377   const tracing::EventCollector* const event_collector_;
378   Context context_;
379 
380   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
381   // instead of a pointer?  (avoids having to delete).
382   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
383   CallFrameInterface* call_frame_;
384   const ImmutableExecutorState& immutable_state_;
385   ExecutorImpl::KernelStats* const kernel_stats_;
386   CancellationManager* cancellation_manager_;
387   CoordinationServiceAgent* coordination_service_agent_;
388   absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt;
389   // If not null, use this device to schedule intra-op operation
390   std::unique_ptr<DeviceBase> user_device_;
391   Executor::Args::Runner runner_;
392   bool sync_on_finish_;
393   const bool run_all_kernels_inline_;
394 
395   PropagatorStateType propagator_;
396 
397   // Invoked when the execution finishes.
398   Executor::DoneCallback done_cb_;
399 
400   std::atomic_int_fast32_t num_outstanding_ops_;
401 
402   // Available via OpKernelContext to every OpKernel invocation.
403   mutex num_deferred_ops_mu_;
404   int64_t num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
405   bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
406       false;
407 
408   mutex mu_;
409   Status status_ TF_GUARDED_BY(mu_);
410 };
411 
412 template <class PropagatorStateType>
ExecutorState(const Executor::Args & args,const ImmutableExecutorState & immutable_state,ExecutorImpl::KernelStats * kernel_stats)413 ExecutorState<PropagatorStateType>::ExecutorState(
414     const Executor::Args& args, const ImmutableExecutorState& immutable_state,
415     ExecutorImpl::KernelStats* kernel_stats)
416     : vlog_(VLOG_IS_ON(1)),
417       log_memory_(LogMemory::IsEnabled()),
418       step_id_(args.step_id),
419       start_time_usecs_(args.start_time_usecs),
420       deadline_(args.deadline),
421       rendezvous_(args.rendezvous),
422       collective_executor_(args.collective_executor),
423       session_state_(args.session_state),
424       session_handle_(args.session_handle),
425       session_metadata_(immutable_state.params().session_metadata),
426       tensor_store_(args.tensor_store),
427       step_container_(args.step_container),
428       stats_collector_(args.stats_collector),
429       event_collector_(
430           tracing::GetEventCollector(tracing::EventCategory::kCompute)),
431       context_(ContextKind::kThread),
432       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
433       call_frame_(args.call_frame),
434       immutable_state_(immutable_state),
435       kernel_stats_(kernel_stats),
436       cancellation_manager_(args.cancellation_manager),
437       coordination_service_agent_(args.coordination_service_agent),
438       stack_trace_(args.stack_trace),
439       runner_(args.runner),
440       sync_on_finish_(args.sync_on_finish),
441       run_all_kernels_inline_(args.run_all_kernels_inline),
442       propagator_(immutable_state, step_id_, vlog_),
443       num_outstanding_ops_(0) {
444   if (args.user_intra_op_threadpool != nullptr) {
445     Device* device = immutable_state_.params().device;
446     user_device_ = RenamedDevice::NewRenamedDevice(
447         device->name(), device, false, false, args.user_intra_op_threadpool);
448   }
449 }
450 
451 template <class PropagatorStateType>
~ExecutorState()452 ExecutorState<PropagatorStateType>::~ExecutorState() {
453   if (device_context_) {
454     device_context_->Unref();
455   }
456   delete slice_reader_cache_;
457 }
458 
459 template <class PropagatorStateType>
460 template <typename Closure>
RunTask(Closure && c,int sample_rate)461 void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) {
462   // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
463   // cacheline size is 64 bytes or smaller.
464   alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
465   alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
466 
467   auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
468   // Sample the queue length on at least every 16 enqueue operations. This
469   // amortizes the cost of metric updates across 16 operations.
470   if (n_enqueues % std::max(16, sample_rate) == 0) {
471     auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
472     metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
473   }
474 
475   // mutable is needed because std::forward<Closure> in the lambda body may move
476   // the Closure `c`.
477   runner_([c = std::forward<Closure>(c)]() mutable {
478     num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
479     std::forward<Closure>(c)();
480   });
481 }
482 
483 template <class PropagatorStateType>
RunAsync(Executor::DoneCallback done)484 void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
485   TaggedNodeSeq ready;
486 
487   // Ask the device to fill in the device context map.
488   Device* device = immutable_state_.params().device;
489   const Status get_context_status =
490       device->TryGetDeviceContext(&device_context_);
491   if (!get_context_status.ok()) {
492     delete this;
493     done(get_context_status);
494     return;
495   }
496 
497   // Initialize the ready queue.
498   ready.reserve(immutable_state_.root_nodes().size());
499   propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
500   num_outstanding_ops_ = ready.size();
501   if (ready.empty()) {
502     delete this;
503     done(OkStatus());
504   } else {
505     done_cb_ = std::move(done);
506     // Schedule to run all the ready ops in thread pool.
507     ScheduleReady(&ready, nullptr);
508   }
509 }
510 
511 // State kept alive for executing an asynchronous node in another
512 // thread.  NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
513 // asynchronous kernels because OpKernelContext methods like input_type(i) needs
514 // the param points to valid input type vector. It's not an issue for
515 // sync kernels because these vectors are kept on the stack.
516 template <class PropagatorStateType>
517 struct ExecutorState<PropagatorStateType>::AsyncState {
AsyncStatetensorflow::__anon79fb0f210111::ExecutorState::AsyncState518   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
519              const NodeItem* _item, Entry* _first_input,
520              NodeExecStatsInterface* _stats)
521       : saved_inputs(p.inputs.begin(), p.inputs.end()),
522         saved_input_alloc_attrs(p.input_alloc_attrs.begin(),
523                                 p.input_alloc_attrs.end()),
524         params(p),
525         tagged_node(_tagged_node),
526         item(_item),
527         first_input(_first_input),
528         // ParamsButClearingEigenGPUDevice does equivalent of
529         //   params.eigen_gpu_device = nullptr;
530         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
531         stats(_stats) {
532     params.inputs = saved_inputs;
533     params.input_alloc_attrs = saved_input_alloc_attrs;
534   }
535 
536   TensorValueVec saved_inputs;
537   AllocatorAttributeVec saved_input_alloc_attrs;
538   OpKernelContext::Params params;
539   TaggedNode tagged_node;
540   const NodeItem* item;
541   Entry* first_input;
542   OpKernelContext ctx;
543   NodeExecStatsInterface* stats;
544 
545  private:
ParamsButClearingEigenGPUDevicetensorflow::__anon79fb0f210111::ExecutorState::AsyncState546   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
547       OpKernelContext::Params* p) {
548     // Ensure OpKernelContext constructor will make a new eigen GPU device if
549     // necessary.
550     p->eigen_gpu_device = nullptr;  // Force allocation
551     return p;
552   }
553 };
554 
555 // Returns true if `item` might be traced by the given trace and event
556 // collectors. Returns false only if `item` definitely will not be traced.
MightTrace(const tracing::EventCollector * event_collector,bool is_expensive)557 bool MightTrace(const tracing::EventCollector* event_collector,
558                 bool is_expensive) {
559   // Tracing will only be enabled if either `event_collector` is non null,
560   // or `trace_collector` is non-null and enabled for this particular kernel.
561   // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
562   // `tracing::ScopedRegion` check subsets of these properties internally in
563   // their constructors, the cost of passing the necessary arguments to them can
564   // be significant, so we avoid constructing them in the common case (when we
565   // know they will not be used).
566   if (event_collector != nullptr) {
567     return true;
568   }
569 
570   if (profiler::ScopedAnnotation::IsEnabled()) return true;
571 
572   return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
573 }
574 
575 template <class PropagatorStateType>
ProcessSync(const NodeItem & item,OpKernelContext::Params * params,EntryVector * outputs,NodeExecStatsInterface * stats)576 Status ExecutorState<PropagatorStateType>::ProcessSync(
577     const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
578     NodeExecStatsInterface* stats) {
579   Status s;
580   OpKernelContext ctx(params, item.num_outputs);
581   nodestats::SetOpStart(stats);
582 
583   OpKernel* op_kernel = item.kernel;
584   Device* device = immutable_state_.params().device;
585   const bool is_expensive = kernel_stats_->IsExpensive(item);
586 
587   if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
588     tracing::ScopedRegion region(tracing::EventCategory::kCompute,
589                                  op_kernel->name_view());
590     profiler::AnnotatedTraceMe activity(
591         [op_kernel, &ctx] {
592           return op_kernel->TraceString(
593               ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
594         },
595         profiler::GetTFTraceMeLevel(is_expensive));
596     device->Compute(op_kernel, &ctx);
597   } else if (kernel_stats_->HasExpensiveMarker(item)) {
598     KernelTimer timer;
599     device->Compute(op_kernel, &ctx);
600     // For expensive kernels, always update the cost estimate. For inexpensive
601     // kernels, update the cost estimate with ~1/16 probability. This assumes
602     // that the last 4 bits of the CPU cycle count is uniformly distributed.
603     constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
604     if (is_expensive ||
605         timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
606       kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
607     }
608   } else {
609     device->Compute(op_kernel, &ctx);
610   }
611   nodestats::SetOpEnd(stats);
612   if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
613   s = ProcessOutputs(item, &ctx, outputs->data(), stats);
614   nodestats::SetMemory(stats, &ctx);
615   return s;
616 }
617 
618 template <class PropagatorStateType>
ProcessAsync(const NodeItem & item,const OpKernelContext::Params & params,const TaggedNode & tagged_node,Entry * first_input,NodeExecStatsInterface * stats,activity_watcher::ActivityId activity_id)619 void ExecutorState<PropagatorStateType>::ProcessAsync(
620     const NodeItem& item, const OpKernelContext::Params& params,
621     const TaggedNode& tagged_node, Entry* first_input,
622     NodeExecStatsInterface* stats, activity_watcher::ActivityId activity_id) {
623   AsyncOpKernel* async_kernel = item.kernel->AsAsync();
624   DCHECK(async_kernel != nullptr);
625   AsyncState* state =
626       new AsyncState(params, tagged_node, &item, first_input, stats);
627 
628   auto done = [this, state, activity_id]() {
629     Device* device = immutable_state_.params().device;
630     NodeExecStatsInterface* stats = state->stats;  // Shorthand
631     Entry* first_input = state->first_input;       // Shorthand
632 
633     nodestats::SetOpEnd(stats);
634     EntryVector outputs(state->item->num_outputs);
635     Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
636     nodestats::SetMemory(stats, &state->ctx);
637     if (vlog_) {
638       VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
639               << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
640               << (state->tagged_node.get_is_dead() ? " is dead" : "")
641               << " device: " << device->name();
642     }
643 
644     // Clears inputs.
645     const int num_inputs = state->item->num_inputs;
646     for (int i = 0; i < num_inputs; ++i) {
647       (first_input + i)->ClearVal();
648     }
649     propagator_.MaybeMarkCompleted(state->tagged_node);
650     activity_watcher::ActivityEnd(activity_id);
651     TaggedNodeSeq ready;
652     if (s.ok()) {
653       propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
654     }
655     outputs.clear();
656     const bool completed = NodeDone(s, &ready, stats, nullptr);
657     delete state;
658     if (completed) ScheduleFinish();
659   };
660   nodestats::SetOpStart(stats);
661   {
662     profiler::AnnotatedTraceMe activity(
663         [async_kernel, state] {
664           return async_kernel->TraceString(
665               state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
666         },
667         profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
668     immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
669                                                    std::move(done));
670   }
671 }
672 
673 template <class PropagatorStateType>
ProcessNoop(NodeExecStatsInterface * stats)674 void ExecutorState<PropagatorStateType>::ProcessNoop(
675     NodeExecStatsInterface* stats) {
676   nodestats::SetOpStart(stats);
677   nodestats::SetOpEnd(stats);
678 }
679 
680 template <class PropagatorStateType>
ProcessConstTensor(const NodeItem & item,EntryVector * outputs,NodeExecStatsInterface * stats)681 void ExecutorState<PropagatorStateType>::ProcessConstTensor(
682     const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
683   nodestats::SetOpStart(stats);
684   nodestats::SetOpEnd(stats);
685   Entry& output = (*outputs)[0];
686   output.state = Entry::State::HAS_CONST_TENSOR;
687   output.const_tensor = item.const_tensor;
688   output.alloc_attr = item.output_attrs()[0];
689 }
690 
691 template <class PropagatorStateType>
Process(TaggedNode tagged_node,int64_t scheduled_nsec)692 void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
693                                                  int64_t scheduled_nsec) {
694   profiler::TraceMeConsumer activity(
695       // From TraceMeProducer in DirectSession::RunInternal,
696       // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
697       [&] {
698         // NOTE: This tracing uses the iteration number from the first tagged
699         // node that executes during this call to `Process()`. In principle,
700         // subsequent nodes could have different values of `iter_num` that
701         // will not be traced.
702         return profiler::TraceMeEncode(
703             "ExecutorState::Process",
704             {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
705       },
706       profiler::ContextType::kTfExecutor, step_id_,
707       profiler::TraceMeLevel::kInfo);
708   TaggedNodeReadyQueue inline_ready;
709   inline_ready.push_back(tagged_node);
710   return ProcessInline(&inline_ready, scheduled_nsec);
711 }
712 
713 template <class PropagatorStateType>
ProcessInline(TaggedNodeReadyQueue * inline_ready,int64_t scheduled_nsec)714 void ExecutorState<PropagatorStateType>::ProcessInline(
715     TaggedNodeReadyQueue* inline_ready, int64_t scheduled_nsec) {
716   WithContext wc(context_);
717   TaggedNodeSeq ready;
718 
719   // Parameters passed to OpKernel::Compute.
720   TensorValueVec inputs;
721   AllocatorAttributeVec input_alloc_attrs;
722 
723   OpKernelContext::Params params;
724   params.step_id = step_id_;
725   // Override device's threadpool if user provides an intra_op_threadpool
726   Device* device = immutable_state_.params().device;
727   if (user_device_) {
728     params.device = user_device_.get();
729   } else {
730     params.device = device;
731   }
732   params.start_time_usecs = start_time_usecs_;
733   params.deadline = deadline_;
734   params.log_memory = log_memory_;
735   params.rendezvous = rendezvous_;
736   params.collective_executor = collective_executor_;
737   params.session_state = session_state_;
738   params.session_handle = session_handle_;
739   params.session_metadata = session_metadata_;
740   params.tensor_store = tensor_store_;
741   params.cancellation_manager = cancellation_manager_;
742   params.coordination_service_agent = coordination_service_agent_;
743   params.stack_trace = stack_trace_;
744   params.call_frame = call_frame_;
745   params.function_library = immutable_state_.params().function_library;
746   params.resource_manager = device->resource_manager();
747   params.step_container = step_container_;
748   params.slice_reader_cache = slice_reader_cache_;
749   params.runner = &runner_;
750   params.run_all_kernels_inline = run_all_kernels_inline_;
751   params.stats_collector = stats_collector_;
752   params.inc_num_deferred_ops_function = [this]() {
753     mutex_lock lock(num_deferred_ops_mu_);
754     num_deferred_ops_++;
755   };
756   params.dec_num_deferred_ops_function = [this]() {
757     bool finish_when_deferred_ops_done = false;
758     {
759       mutex_lock lock(num_deferred_ops_mu_);
760       num_deferred_ops_--;
761       if (num_deferred_ops_ == 0) {
762         finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
763       }
764     }
765     // Invoke Finish if the graph processing has completed. Finish is always
766     // called exactly once per ExecutorState, either here if there are any
767     // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
768     if (finish_when_deferred_ops_done) Finish();
769   };
770 
771   // Set the device_context for this device, if it exists.
772   params.op_device_context = device_context_;
773 
774   Status s;
775   NodeExecStatsInterface* stats = nullptr;
776 
777   EntryVector outputs(1);
778 
779   bool completed = false;
780   while (!inline_ready->empty()) {
781     TaggedNode tagged_node = inline_ready->front();
782     inline_ready->pop_front();
783     const NodeItem& item = tagged_node.get_node_item();
784     const int id = item.node_id;
785 
786     propagator_.MaybeMarkStarted(tagged_node);
787     const activity_watcher::ActivityId activity_id =
788         activity_watcher::ActivityStart(
789             [&]() {
790               return std::make_unique<activity_watcher::Activity>(
791                   "ExecutorState::Process",
792                   activity_watcher::ActivityCategory::kMisc,
793                   activity_watcher::Activity::Attributes{
794                       {"node_name", item.kernel->def().name()},
795                       {"op", item.kernel->def().op()},
796                       {"iter_num", absl::StrCat(tagged_node.get_iter_num())},
797                       {"step_id", absl::StrCat(params.step_id)},
798                       {"node_id", absl::StrCat(id)},
799                       {"device", device->name()},
800                   });
801             },
802             /*level=*/2);
803 
804     params.track_allocations = false;
805     stats = nullptr;
806     if (stats_collector_ && !tagged_node.get_is_dead()) {
807       stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
808       // Track allocations if and only if we are collecting statistics, and
809       // `stats` object is expecting allocations to be tracked.
810       params.track_allocations = stats ? stats->TrackAllocations() : false;
811       nodestats::SetScheduled(stats, scheduled_nsec);
812       nodestats::SetAllStart(stats);
813     }
814 
815     if (vlog_) {
816       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
817               << SummarizeNodeDef(item.kernel->def())
818               << (tagged_node.get_is_dead() ? " is dead" : "")
819               << " device: " << device->name();
820     }
821 
822     Entry* first_input = propagator_.GetInputTensors(tagged_node);
823 
824     // Only execute this node if it is not dead or it is a send/recv
825     // transfer node. For transfer nodes, we need to propagate the "dead"
826     // bit even when the node is dead.
827     bool launched_asynchronously = false;
828     if (tagged_node.get_is_dead() && !item.is_transfer_node) {
829       if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
830     } else if (TF_PREDICT_FALSE(item.is_noop)) {
831       ProcessNoop(stats);
832     } else if (item.const_tensor != nullptr && !params.track_allocations) {
833       ProcessConstTensor(item, &outputs, stats);
834     } else {
835       // Prepares inputs.
836       bool is_input_dead = false;
837       s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
838                         &is_input_dead);
839       if (!s.ok()) {
840         // Clear inputs.
841         const int num_inputs = item.num_inputs;
842         for (int i = 0; i < num_inputs; ++i) {
843           (first_input + i)->ClearVal();
844         }
845         propagator_.MaybeMarkCompleted(tagged_node);
846         activity_watcher::ActivityEnd(activity_id);
847         // Continue to process the nodes in 'inline_ready'.
848         completed = NodeDone(s, &ready, stats, inline_ready);
849         continue;
850       }
851 
852       // Set up compute params.
853       params.op_kernel = item.kernel;
854       params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
855       params.is_input_dead = is_input_dead;
856       params.output_attr_array = item.output_attrs();
857       params.forward_from_array = item.forward_from();
858       params.outputs_required_array = item.outputs_required.get();
859       params.inputs = inputs;
860       params.input_alloc_attrs = input_alloc_attrs;
861 
862       if (item.kernel_is_async) {
863         ProcessAsync(item, params, tagged_node, first_input, stats,
864                      activity_id);
865         launched_asynchronously = true;
866       } else {
867         s = ProcessSync(item, &params, &outputs, stats);
868       }
869     }
870 
871     if (!launched_asynchronously) {
872       if (vlog_) {
873         VLOG(2) << "Synchronous kernel done: " << id << " step "
874                 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
875                 << (tagged_node.get_is_dead() ? " is dead: " : "")
876                 << " device: " << device->name();
877       }
878 
879       // Clears inputs.
880       const int num_inputs = item.num_inputs;
881       for (int i = 0; i < num_inputs; ++i) {
882         (first_input + i)->ClearVal();
883       }
884       propagator_.MaybeMarkCompleted(tagged_node);
885       activity_watcher::ActivityEnd(activity_id);
886       // Propagates outputs.
887       if (s.ok()) {
888         propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
889       }
890 
891       // Clear outputs without deallocating the `outputs` vector.
892       const int num_outputs = item.num_outputs;
893       for (int i = 0; i < num_outputs; ++i) {
894         outputs[i].ClearVal();
895       }
896 
897       if (stats) {
898         scheduled_nsec = nodestats::NowInNsec();
899       }
900       // Postprocess.
901       completed = NodeDone(s, &ready, stats, inline_ready);
902     }
903   }  // while !inline_ready.empty()
904 
905   // This thread of computation is done if completed = true.
906   if (completed) ScheduleFinish();
907 }
908 
909 template <class PropagatorStateType>
PrepareInputs(const NodeItem & item,Entry * first_input,TensorValueVec * inputs,AllocatorAttributeVec * input_alloc_attrs,bool * is_input_dead)910 Status ExecutorState<PropagatorStateType>::PrepareInputs(
911     const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
912     AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
913   inputs->resize(item.num_inputs);
914   input_alloc_attrs->resize(item.num_inputs);
915 
916   *is_input_dead = false;
917 
918   for (int i = 0; i < item.num_inputs; ++i) {
919     const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
920                             IsRefType(item.input_type(i));
921     Entry* entry = first_input + i;
922     (*input_alloc_attrs)[i] = entry->alloc_attr;
923 
924     // i-th input.
925     TensorValue* inp = &(*inputs)[i];
926 
927     switch (entry->state) {
928       case Entry::State::NO_VALUE: {
929         // Only merge and transfer nodes can have no-value inputs.
930         inp->mutex_if_ref = nullptr;
931         if (item.is_merge) {
932           inp->tensor = nullptr;
933         } else {
934           DCHECK(item.is_transfer_node)
935               << item.kernel->name() << " - input " << i;
936           entry->state = Entry::State::HAS_CONST_TENSOR;
937           entry->const_tensor = kEmptyTensor;
938           // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
939           // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
940           // accessors making dynamic checks that prevent using an immutable
941           // tensor as a mutable tensor.
942           inp->tensor = const_cast<Tensor*>(kEmptyTensor);
943           *is_input_dead = true;
944         }
945         break;
946       }
947 
948       case Entry::State::HAS_VALUE: {
949         if (TF_PREDICT_FALSE(expect_ref)) {
950           return AttachDef(
951               errors::InvalidArgument(i, "-th input expects a ref type"),
952               item.kernel->def());
953         }
954         inp->mutex_if_ref = nullptr;
955         inp->tensor = entry->val.get();
956         break;
957       }
958 
959       case Entry::State::HAS_CONST_TENSOR: {
960         if (TF_PREDICT_FALSE(expect_ref)) {
961           return AttachDef(
962               errors::InvalidArgument(i, "-th input expects a ref type"),
963               item.kernel->def());
964         }
965         // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
966         // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
967         // accessors making dynamic checks that prevent using an immutable
968         // tensor as a mutable tensor.
969         inp->mutex_if_ref = nullptr;
970         inp->tensor = const_cast<Tensor*>(entry->const_tensor);
971         break;
972       }
973 
974       case Entry::State::HAS_REF_TENSOR: {
975         {
976           tf_shared_lock ml(*entry->ref_tensor.mu);
977           if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
978                                !item.is_initialization_op)) {
979             return AttachDef(errors::FailedPrecondition(
980                                  "Attempting to use uninitialized value ",
981                                  item.kernel->requested_input(i)),
982                              item.kernel->def());
983           }
984         }
985 
986         if (expect_ref) {
987           inp->mutex_if_ref = entry->ref_tensor.mu;
988           inp->tensor = entry->ref_tensor.tensor;
989         } else {
990           // Automatically deref the tensor ref when the op expects a
991           // tensor but is given a ref to a tensor.  Need to deref it
992           // under the mutex.
993           {
994             mutex* ref_mu = entry->ref_tensor.mu;
995             Tensor* ref_tensor = entry->ref_tensor.tensor;
996             tf_shared_lock l(*ref_mu);
997             entry->val.Init(*ref_tensor);
998           }
999           entry->state = Entry::State::HAS_VALUE;
1000 
1001           inp->mutex_if_ref = nullptr;
1002           inp->tensor = entry->val.get();
1003           // The dtype of entry->ref_tensor.tensor could have been changed by
1004           // another operation that ran after the operation that "produced" it
1005           // executed, so re-validate that the type of the dereferenced tensor
1006           // matches the expected input type.
1007           if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
1008             return AttachDef(
1009                 errors::InvalidArgument(
1010                     i, "-th input expects type ",
1011                     DataTypeString(item.input_type(i)),
1012                     " but automatically dereferenced input tensor has type ",
1013                     DataTypeString(inp->tensor->dtype())),
1014                 item.kernel->def());
1015           }
1016         }
1017         break;
1018       }
1019     }
1020   }
1021   return OkStatus();
1022 }
1023 
1024 template <class PropagatorStateType>
ProcessOutputs(const NodeItem & item,OpKernelContext * ctx,Entry * outputs,NodeExecStatsInterface * stats)1025 Status ExecutorState<PropagatorStateType>::ProcessOutputs(
1026     const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
1027     NodeExecStatsInterface* stats) {
1028   Status s = ctx->status();
1029   if (!s.ok()) {
1030     s = AttachDef(s, item.kernel->def());
1031     // TODO(misard) Replace with a finer-grain enabling flag once we
1032     // add better optional debugging support.
1033     if (vlog_ && VLOG_IS_ON(1)) {
1034       LOG(WARNING) << this << " Compute status: " << s;
1035     }
1036     if (s.code() == error::RESOURCE_EXHAUSTED) {
1037       if (stats_collector_) {
1038         string err = stats_collector_->ReportAllocsOnResourceExhausted(
1039             s.error_message());
1040         s = errors::CreateWithUpdatedMessage(
1041             s, strings::StrCat(s.error_message(), err));
1042       } else {
1043         s = errors::CreateWithUpdatedMessage(
1044             s,
1045             strings::StrCat(
1046                 s.error_message(),
1047                 "\nHint: If you want to see a list of allocated tensors when "
1048                 "OOM happens, add report_tensor_allocations_upon_oom "
1049                 "to RunOptions for current allocation info. This isn't "
1050                 "available when running in Eager mode.\n"));
1051       }
1052     } else if (s.code() == error::UNAVAILABLE &&
1053                !item.is_distributed_communication) {
1054       s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name());
1055     }
1056     return s;
1057   }
1058 
1059   for (int i = 0; i < item.num_outputs; ++i) {
1060     const TensorValue val = ctx->release_output(i);
1061     Entry* out = &outputs[i];
1062     DCHECK(out->state == Entry::State::NO_VALUE);
1063 
1064     if (val.tensor == nullptr) {
1065       // Unless it's a Switch or a Recv, or the executor has marked the output
1066       // as not required, the node must produce a tensor value at i-th output.
1067       if (!(item.is_recv_or_switch ||
1068             (item.outputs_required && !item.outputs_required[i]))) {
1069         s.Update(errors::Internal("Missing ", i, "-th output from ",
1070                                   FormatNodeDefForError(item.kernel->def())));
1071       }
1072     } else {
1073       // Set the allocator attributes of the output entry.
1074       out->alloc_attr = ctx->output_alloc_attr(i);
1075 
1076       // Sanity check of output tensor types. We need to inspect this safely as
1077       // we are in the tensor buffer.
1078       DataType dtype = val.dtype_safe();
1079       if (dtype == item.output_type(i)) {
1080         if (stats && val.tensor->IsInitialized()) {
1081           nodestats::SetOutput(stats, i, val.tensor);
1082         }
1083         if (val.is_ref()) {
1084           out->state = Entry::State::HAS_REF_TENSOR;
1085           out->ref_tensor.tensor = val.tensor;
1086           out->ref_tensor.mu = val.mutex_if_ref;
1087           if (log_memory_) {
1088             Tensor to_log;
1089             {
1090               // Dereference the tensor under the lock.
1091               tf_shared_lock l(*out->ref_tensor.mu);
1092               to_log = *out->ref_tensor.tensor;
1093             }
1094             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1095                                           ctx->step_id(), i, to_log);
1096           }
1097         } else {
1098           // NOTE that std::move is used here, so val.tensor goes to
1099           // uninitialized state (val.tensor->IsInitialized return false).
1100           out->state = Entry::State::HAS_VALUE;
1101           out->val.Init(std::move(*val.tensor));
1102           if (log_memory_) {
1103             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1104                                           ctx->step_id(), i, *out->val);
1105           }
1106         }
1107       } else {
1108         s.Update(
1109             errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1110                              " does not match declared output type ",
1111                              DataTypeString(item.output_type(i)), " for node ",
1112                              FormatNodeDefForError(item.kernel->def())));
1113       }
1114     }
1115     if (!val.is_ref()) {
1116       // If OpKernelContext returns outputs via pass-by-value, we
1117       // don't need this trouble.
1118       delete val.tensor;
1119     }
1120   }
1121   return s;
1122 }
1123 
1124 template <class PropagatorStateType>
NodeDone(const Status & s,TaggedNodeSeq * ready,NodeExecStatsInterface * stats,TaggedNodeReadyQueue * inline_ready)1125 bool ExecutorState<PropagatorStateType>::NodeDone(
1126     const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1127     TaggedNodeReadyQueue* inline_ready) {
1128   if (stats) {
1129     nodestats::SetAllEnd(stats);
1130     DCHECK_NE(stats_collector_, nullptr);
1131     stats->Done(immutable_state_.params().device->name());
1132   }
1133 
1134   if (TF_PREDICT_TRUE(s.ok())) {
1135     const size_t ready_size = ready->size();
1136     if (ready_size == 0) {
1137       return num_outstanding_ops_.fetch_sub(1) == 1;
1138     } else {
1139       // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1140       if (ready_size > 1) {
1141         num_outstanding_ops_.fetch_add(ready_size - 1,
1142                                        std::memory_order_relaxed);
1143       }
1144 
1145       // Schedule the ready nodes in 'ready'.
1146       ScheduleReady(ready, inline_ready);
1147 
1148       return false;
1149     }
1150   } else {
1151     bool abort_run = false;
1152     Status maybe_derived_s(s);
1153 
1154     // Some error happened. This thread of computation is done.
1155     {
1156       mutex_lock l(mu_);
1157       if (status_.ok()) {
1158         // If this is the first node to fail in this run, we are responsible for
1159         // aborting all other execution in the step.
1160         abort_run = true;
1161 
1162         // If execution has been cancelled, mark cancelled or aborted errors as
1163         // being derived. Note that the original node that fails might also
1164         // trigger cancellation, and here we make sure the original error is
1165         // exposed to users and not buried as a derived error.
1166         if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1167             (errors::IsCancelled(s) || errors::IsAborted(s))) {
1168           status_ = StatusGroup::MakeDerived(s);
1169           maybe_derived_s = status_;
1170         } else {
1171           status_ = s;
1172         }
1173       }
1174     }
1175 
1176     if (abort_run) {
1177       TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1178       if (cancellation_manager_) {
1179         // Only log when the abort happens during the actual run time.
1180         // Use VLOG instead of LOG(warning) because error status is expected
1181         // when the executor is run under the grappler optimization phase or
1182         // when iterating through a tf.data input pipeline.
1183         VLOG(1) << "[" << immutable_state_.params().device->name()
1184                 << "] Executor start aborting: " << s;
1185       }
1186 
1187       if (rendezvous_) {
1188         rendezvous_->StartAbort(s);
1189       }
1190       if (cancellation_manager_) {
1191         cancellation_manager_->StartCancelWithStatus(maybe_derived_s);
1192       } else if (collective_executor_) {
1193         // If there's cancellation_manager_, collective ops aborts
1194         // collective_executor_ upon cancellation; otherwise we need to abort
1195         // here.
1196         collective_executor_->StartAbort(s);
1197       }
1198     }
1199 
1200     return num_outstanding_ops_.fetch_sub(1) == 1;
1201   }
1202 }
1203 
1204 template <class PropagatorStateType>
ScheduleReady(TaggedNodeSeq * ready,TaggedNodeReadyQueue * inline_ready)1205 void ExecutorState<PropagatorStateType>::ScheduleReady(
1206     TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1207   profiler::TraceMe activity(
1208       [&]() {
1209         return strings::StrCat(
1210             "ExecutorState::ScheduleReady#",
1211             "ready_size=", (ready == nullptr ? -1 : ready->size()),
1212             ",inline_ready_size=",
1213             (inline_ready == nullptr ? -1 : inline_ready->size()), "#");
1214       },
1215       profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1216   DCHECK(!ready->empty());
1217 
1218   int64_t scheduled_nsec = 0;
1219   if (stats_collector_) {
1220     scheduled_nsec = nodestats::NowInNsec();
1221   }
1222 
1223   if (run_all_kernels_inline_) {
1224     if (inline_ready == nullptr) {
1225       // Schedule all ready kernels from a single closure. This ensure that,
1226       // regardless of the `runner_` implementation, all kernels will run
1227       // sequentially on the same thread, and thread wakeup overhead and
1228       // executor mutex contention will be minimized.
1229       RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1230         for (auto& tagged_node : ready) {
1231           Process(tagged_node, scheduled_nsec);
1232         }
1233       });
1234     } else {
1235       for (auto& tagged_node : *ready) {
1236         inline_ready->push_back(tagged_node);
1237       }
1238     }
1239   } else {
1240     const TaggedNode* curr_expensive_node = nullptr;
1241     TaggedNodeSeq expensive_nodes;
1242     if (inline_ready == nullptr) {
1243       // Schedule to run all the ready ops in thread pool.
1244       for (auto& tagged_node : *ready) {
1245         RunTask([=]() { Process(tagged_node, scheduled_nsec); },
1246                 /*sample_rate=*/ready->size());
1247       }
1248     } else {
1249       for (auto& tagged_node : *ready) {
1250         const NodeItem& item = *tagged_node.node_item;
1251         if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1252           // Inline this inexpensive node.
1253           inline_ready->push_back(tagged_node);
1254         } else {
1255           if (curr_expensive_node) {
1256             expensive_nodes.push_back(*curr_expensive_node);
1257           }
1258           curr_expensive_node = &tagged_node;
1259         }
1260       }
1261     }
1262     if (curr_expensive_node) {
1263       if (inline_ready->empty()) {
1264         inline_ready->push_back(*curr_expensive_node);
1265       } else {
1266         // There are inline nodes to run already. We dispatch this expensive
1267         // node to other thread.
1268         expensive_nodes.push_back(*curr_expensive_node);
1269       }
1270     }
1271     if (!expensive_nodes.empty()) {
1272       if (expensive_nodes.size() < kInlineScheduleReadyThreshold) {
1273         for (auto& tagged_node : expensive_nodes) {
1274           RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1275                             scheduled_nsec),
1276                   /*sample_rate=*/expensive_nodes.size());
1277         }
1278       } else {
1279         // There are too many ready expensive nodes. Schedule them in child
1280         // threads.
1281         // TODO(fishx): Apply the same optimization to cheap ops as well since
1282         // executing lots of cheap ops in one thread can potentially be the
1283         // bottleneck as well.
1284         auto it = expensive_nodes.begin();
1285         while (it < expensive_nodes.end()) {
1286           auto end = it;
1287           std::advance(end, kInlineScheduleReadyThreshold);
1288           if (end > expensive_nodes.end()) {
1289             end = expensive_nodes.end();
1290           }
1291           TaggedNodeSeq ready_chunk{it, end};
1292           RunTask(
1293               [this, ready_chunk = std::move(ready_chunk), scheduled_nsec]() {
1294                 profiler::TraceMe activity(
1295                     [&]() {
1296                       return strings::StrCat(
1297                           "ExecutorState::ScheduleReady::"
1298                           "ChildThreadExpensiveNodes#",
1299                           "ready_chunk_size=", ready_chunk.size(), "#");
1300                     },
1301                     profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1302                 for (auto& tagged_node : ready_chunk) {
1303                   RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1304                                     scheduled_nsec),
1305                           /*sample_rate=*/ready_chunk.size());
1306                 }
1307               });
1308           it = end;
1309         }
1310       }
1311     }
1312   }
1313   ready->clear();
1314 }
1315 
1316 template <class PropagatorStateType>
ScheduleFinish()1317 void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1318   // Checks condition to decide if needs to invoke Finish(). If there are
1319   // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1320   // Finish(). Otherwise, invoke Finish() directly.
1321   // Note that it is critical that the ScheduleFinish / Finish codepath does not
1322   // block, otherwise we might deadlock.  See b/124523000 for details.
1323   {
1324     mutex_lock lock(num_deferred_ops_mu_);
1325     if (num_deferred_ops_ > 0) {
1326       finish_when_deferred_ops_done_ = true;
1327       return;
1328     }
1329   }
1330   // Finish is always called exactly once per ExecutorState, either here if
1331   // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1332   // there are deferred ops.
1333   Finish();
1334 }
1335 
1336 template <class PropagatorStateType>
Finish()1337 void ExecutorState<PropagatorStateType>::Finish() {
1338   mu_.lock();
1339   auto status = status_;
1340   auto done_cb = std::move(done_cb_);
1341   auto runner = std::move(runner_);
1342   mu_.unlock();
1343   int64_t step_id = step_id_;
1344   CHECK(done_cb != nullptr);
1345   Device* device = immutable_state_.params().device;
1346 
1347   if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1348     // Logs verbose information about the current state of active and pending
1349     // nodes in the propagator.
1350     propagator_.DumpState();
1351   }
1352 
1353   // There are several potential race conditions below. To name a few:
1354   // 1. Even if the device's status is OK at the precise moment when
1355   // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1356   // is called below, caused by work enqueued onto the same device by other
1357   // concurrent ExecutorState objects.
1358   // 2. Some implementations of Device::RefreshStatus, such as
1359   // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1360   // device mutex after a stream pointer is acquired and before the stream is
1361   // queried for status.
1362   // 3. It's the same for some implementations of Device::Sync, such as
1363   // XlaDevice::Sync.
1364   //
1365   // However, these race conditions are acceptable because a stream (and
1366   // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1367   // which means we will at worst report errors when there isn't any, never the
1368   // opposite.
1369 
1370   // An early exit for devices don't allow sync on completion. Ops that run on
1371   // these devices should have used num_deferred_ops correctly to ensure the
1372   // device has finished all relevant work at this point.
1373   if (!device->AllowsSyncOnCompletion()) {
1374     status.Update(device->RefreshStatus());
1375     if (!status.ok()) {
1376       // In device async execution mode, it's possible for device execution to
1377       // lag behind ExecutorState scheduling so much that this is the first
1378       // place a device execution error surfaces.
1379       // If so, all ExecutorState::NodeDone calls have already happened with OK
1380       // status. This is the last defense where StartCancel must be called to
1381       // abort all computation still running on any device.
1382       // TODO(b/124523000): Always call Finish in a separate thread, so even if
1383       // StartCancel blocks the current thread's execution, we won't encounter
1384       // deadlocks caused by inter-op thread exhaustion.
1385       if (rendezvous_) {
1386         rendezvous_->StartAbort(status);
1387       }
1388       if (cancellation_manager_) {
1389         cancellation_manager_->StartCancelWithStatus(status);
1390       } else if (collective_executor_) {
1391         // If there's cancellation_manager_, collective ops aborts
1392         // collective_executor_ upon cancellation; otherwise we need to abort
1393         // here.
1394         collective_executor_->StartAbort(status);
1395       }
1396     }
1397     delete this;
1398     runner([step_id, status, done_cb = std::move(done_cb)]() {
1399       profiler::TraceMeConsumer activity(
1400           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1401           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1402           [&] {
1403             return profiler::TraceMeEncode("ExecutorDoneCallback",
1404                                            {{"id", step_id}});
1405           },
1406           profiler::ContextType::kTfExecutor, step_id,
1407           profiler::TraceMeLevel::kInfo);
1408       done_cb(status);
1409     });
1410     return;
1411   }
1412 
1413   if (sync_on_finish_ && status.ok()) {
1414     // Block until the device has finished all queued operations. For
1415     // devices like GPUs that continue to execute Ops after their Compute
1416     // methods have completed, this ensures that control is not returned to
1417     // the user until the step (and its side-effects) has actually completed.
1418     device->Sync([this, step_id, runner = std::move(runner),
1419                   done_cb = std::move(done_cb)](const Status& status) mutable {
1420       delete this;
1421       runner([step_id, status, done_cb = std::move(done_cb)]() {
1422         profiler::TraceMeConsumer activity(
1423             // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1424             // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1425             [&] {
1426               return profiler::TraceMeEncode("ExecutorDoneCallback",
1427                                              {{"id", step_id}});
1428             },
1429             profiler::ContextType::kTfExecutor, step_id,
1430             profiler::TraceMeLevel::kInfo);
1431         done_cb(status);
1432       });
1433     });
1434   } else {
1435     delete this;
1436     runner([step_id, status, done_cb = std::move(done_cb)]() {
1437       profiler::TraceMeConsumer activity(
1438           // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1439           // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1440           [&] {
1441             return profiler::TraceMeEncode("ExecutorDoneCallback",
1442                                            {{"id", step_id}});
1443           },
1444           profiler::ContextType::kTfExecutor, step_id,
1445           profiler::TraceMeLevel::kInfo);
1446       done_cb(status);
1447     });
1448   }
1449 }
1450 
RunAsync(const Args & args,DoneCallback done)1451 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1452   if (OpOrderDeterminismRequired()) {
1453     (new ExecutorState<OrderedPropagatorState>(args, immutable_state_,
1454                                                &kernel_stats_))
1455         ->RunAsync(std::move(done));
1456   } else if (immutable_state_.requires_control_flow_support()) {
1457     (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1458         ->RunAsync(std::move(done));
1459   } else {
1460     (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1461                                               &kernel_stats_))
1462         ->RunAsync(std::move(done));
1463   }
1464 }
1465 
1466 }  // namespace
1467 
NewLocalExecutor(const LocalExecutorParams & params,const Graph & graph,Executor ** executor)1468 Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1469                         Executor** executor) {
1470   ExecutorImpl* impl = new ExecutorImpl(params);
1471   const Status s = impl->Initialize(graph);
1472   if (s.ok()) {
1473     *executor = impl;
1474   } else {
1475     delete impl;
1476   }
1477   return s;
1478 }
1479 
CreateNonCachedKernel(Device * device,FunctionLibraryRuntime * flib,const std::shared_ptr<const NodeProperties> & props,int graph_def_version,OpKernel ** kernel)1480 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1481                              const std::shared_ptr<const NodeProperties>& props,
1482                              int graph_def_version, OpKernel** kernel) {
1483   const auto device_type = DeviceType(device->attributes().device_type());
1484   auto allocator = device->GetAllocator(AllocatorAttributes());
1485   return CreateOpKernel(device_type, device, allocator, flib,
1486                         device->resource_manager(), props, graph_def_version,
1487                         kernel);
1488 }
1489 
DeleteNonCachedKernel(OpKernel * kernel)1490 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1491 
1492 namespace {
1493 
1494 class DefaultExecutorRegistrar {
1495  public:
DefaultExecutorRegistrar()1496   DefaultExecutorRegistrar() {
1497     Factory* factory = new Factory;
1498     ExecutorFactory::Register("", factory);
1499     ExecutorFactory::Register("DEFAULT", factory);
1500   }
1501 
1502  private:
1503   class Factory : public ExecutorFactory {
NewExecutor(const LocalExecutorParams & params,const Graph & graph,std::unique_ptr<Executor> * out_executor)1504     Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1505                        std::unique_ptr<Executor>* out_executor) override {
1506       Executor* ret = nullptr;
1507       TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1508       out_executor->reset(ret);
1509       return OkStatus();
1510     }
1511   };
1512 };
1513 static DefaultExecutorRegistrar registrar;
1514 
1515 }  // namespace
1516 
1517 }  // namespace tensorflow
1518