xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/simple_propagator_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/common_runtime/entry.h"
21 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
22 #include "tensorflow/core/common_runtime/pending_counts.h"
23 #include "tensorflow/core/framework/control_flow.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 // Represents the ephemeral "edge state" associated with one invocation of
34 // `Executor::Run()`.
35 //
36 // NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
37 // including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
38 // graph. Use `PropagatorState` for graphs with those features.
39 // `SimplePropagatorState` *does* support "v2-style" or "functional" control
40 // flow.
41 //
42 // `SimplePropagatorState` is responsible for propagating values along dataflow
43 // edges in a TensorFlow graph and determining which nodes are runnable. The
44 // executor primarily updates `SimplePropagatorState` by calling
45 // `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
46 // dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
47 class SimplePropagatorState {
48  public:
49   SimplePropagatorState(const ImmutableExecutorState& immutable_state,
50                         int64_t step_id, bool vlog);
51   ~SimplePropagatorState();
52 
53   // A `TaggedNode` corresponds to a single invocation of a node's kernel,
54   // and it is created when the kernel becomes runnable.
55   struct TaggedNode {
56     const NodeItem* node_item;
57 
TaggedNodeTaggedNode58     explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
59 
get_node_itemTaggedNode60     const NodeItem& get_node_item() const { return *node_item; }
61 
get_is_deadTaggedNode62     bool get_is_dead() const { return false; }
get_iter_numTaggedNode63     int64_t get_iter_num() const { return 0; }
64   };
65 
66   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
67   // have that many nodes in the ready queue, so we just use a vector and
68   // don't free up memory from the queue as we consume nodes.
69   // TODO(mrry): Extract this and share it with the version in
70   // `PropagatorState`. The correct constants might be different, since
71   // sizeof(TaggedNode) is smaller in this version.
72   class TaggedNodeReadyQueue {
73    public:
TaggedNodeReadyQueue()74     TaggedNodeReadyQueue() : front_index_(0) {}
75 
push_back(const TaggedNode & node)76     void push_back(const TaggedNode& node) { ready_.push_back(node); }
front()77     TaggedNode front() const {
78       DCHECK_LT(front_index_, ready_.size());
79       return ready_[front_index_];
80     }
pop_front()81     void pop_front() {
82       DCHECK_LT(front_index_, ready_.size());
83       front_index_++;
84       if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
85         if (front_index_ == ready_.size()) {
86           ready_.clear();
87         } else {
88           // Lots of unused entries at beginning of vector: move everything
89           // down to start of vector.
90           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
91         }
92         front_index_ = 0;
93       }
94     }
empty()95     bool empty() const { return ready_.empty(); }
size()96     int size() const { return ready_.size() - front_index_; }
97 
98    private:
99     // TODO(b/152925936): Re-evaluate these constants with current usage
100     // patterns.
101     static constexpr int kSpillThreshold = 16384;
102     gtl::InlinedVector<TaggedNode, 16> ready_;
103     int front_index_;
104   };
105 
106   // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
107   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
108 
109   // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
110   void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
111                      TaggedNodeSeq* ready);
112 
113   // After processing the outputs, propagates the outputs to their dsts.
114   // Contents of *outputs are left in an indeterminate state after
115   // returning from this method.
116   void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
117                         TaggedNodeSeq* ready);
118 
119   // Returns an array of `Entry` objects corresponding to the inputs of
120   // `tagged_node`.
GetInputTensors(const TaggedNode & tagged_node)121   Entry* GetInputTensors(const TaggedNode& tagged_node) {
122 #if defined(THREAD_SANITIZER) || defined(DEBUG)
123     // NOTE: This read of `pending_[...]` works around a limitation in TSAN.
124     // To avoid false positive data race reports, we need to perform an atomic
125     // object access that will establish the happens-before relation between
126     // the write to input_tensors_ in `PropagateOutputs()` and the read in
127     // `PrepareInputs()`.
128     CHECK_EQ(pending_[tagged_node.node_item->node_id], 0);
129 #endif  // defined(THREAD_SANITIZER) || defined(DEBUG)
130     return input_tensors_.data() + tagged_node.node_item->input_start;
131   }
132 
GetFrameAndIter(const TaggedNode & tagged_node)133   FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
134     return {0, 0};
135   }
136 
137   // Provide debugging output of the state of the executor.
138   void DumpState();
139 
140   // For debugging/logging only.
MaybeMarkStarted(const TaggedNode & tagged_node)141   void MaybeMarkStarted(const TaggedNode& tagged_node) {
142     // TODO(misard) Replace with a finer-grain enabling flag once we add better
143     // optional debugging support.
144     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
145       mutex_lock l(mu_);
146       (*active_)[tagged_node.node_item->node_id] = true;
147     }
148   }
MaybeMarkCompleted(const TaggedNode & tagged_node)149   void MaybeMarkCompleted(const TaggedNode& tagged_node) {
150     // TODO(misard) Replace with a finer-grain enabling flag once we add better
151     // optional debugging support.
152     if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
153       mutex_lock l(mu_);
154       (*active_)[tagged_node.node_item->node_id] = false;
155     }
156   }
157 
158  private:
159   SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
160                         int64_t step_id,
161                         const ImmutableExecutorState::FrameInfo& finfo,
162                         bool vlog);
163 
164   const ImmutableExecutorState& immutable_state_;
165   const int64_t step_id_;
166   const bool vlog_;
167 
168   // The i-th node's j-th input is stored at
169   // `input_tensors[impl_->nodes[i].input_start + j]`.
170   //
171   // NOTE: No need to protect input_tensors[i] by any locks because it
172   // is resized once. Each element of input_tensors is written once by the
173   // source node of an edge and is cleared by the destination of the same
174   // edge. The destination node always runs after the source node, so there
175   // is never concurrent access to the same entry.
176   std::vector<Entry> input_tensors_;
177 
178   std::unique_ptr<std::atomic<int32>[]> pending_;
179 
180   // If `vlog_` is true, this stores a bit vector of active nodes, indexed by
181   // node ID.
182   mutex mu_;
183   std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_);
184 
185   const std::vector<const NodeItem*>* const nodes_;
186 };
187 
188 }  // namespace tensorflow
189 
190 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
191