1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/entry.h" 21 #include "tensorflow/core/common_runtime/immutable_executor_state.h" 22 #include "tensorflow/core/common_runtime/pending_counts.h" 23 #include "tensorflow/core/framework/control_flow.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 // Represents the ephemeral "edge state" associated with one invocation of 34 // `Executor::Run()`. 35 // 36 // NOTE: `SimplePropagatorState` does not support "v1-style" control flow, 37 // including "dead tensors", "Switch" and "Merge" nodes, and cycles in the 38 // graph. Use `PropagatorState` for graphs with those features. 39 // `SimplePropagatorState` *does* support "v2-style" or "functional" control 40 // flow. 41 // 42 // `SimplePropagatorState` is responsible for propagating values along dataflow 43 // edges in a TensorFlow graph and determining which nodes are runnable. The 44 // executor primarily updates `SimplePropagatorState` by calling 45 // `PropagateOutputs()` after processing a node, and `SimplePropagatorState` 46 // dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. 47 class SimplePropagatorState { 48 public: 49 SimplePropagatorState(const ImmutableExecutorState& immutable_state, 50 int64_t step_id, bool vlog); 51 ~SimplePropagatorState(); 52 53 // A `TaggedNode` corresponds to a single invocation of a node's kernel, 54 // and it is created when the kernel becomes runnable. 55 struct TaggedNode { 56 const NodeItem* node_item; 57 TaggedNodeTaggedNode58 explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} 59 get_node_itemTaggedNode60 const NodeItem& get_node_item() const { return *node_item; } 61 get_is_deadTaggedNode62 bool get_is_dead() const { return false; } get_iter_numTaggedNode63 int64_t get_iter_num() const { return 0; } 64 }; 65 66 // A drop-in replacement for std::deque<TaggedNode>. We typically don't 67 // have that many nodes in the ready queue, so we just use a vector and 68 // don't free up memory from the queue as we consume nodes. 69 // TODO(mrry): Extract this and share it with the version in 70 // `PropagatorState`. The correct constants might be different, since 71 // sizeof(TaggedNode) is smaller in this version. 72 class TaggedNodeReadyQueue { 73 public: TaggedNodeReadyQueue()74 TaggedNodeReadyQueue() : front_index_(0) {} 75 push_back(const TaggedNode & node)76 void push_back(const TaggedNode& node) { ready_.push_back(node); } front()77 TaggedNode front() const { 78 DCHECK_LT(front_index_, ready_.size()); 79 return ready_[front_index_]; 80 } pop_front()81 void pop_front() { 82 DCHECK_LT(front_index_, ready_.size()); 83 front_index_++; 84 if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { 85 if (front_index_ == ready_.size()) { 86 ready_.clear(); 87 } else { 88 // Lots of unused entries at beginning of vector: move everything 89 // down to start of vector. 90 ready_.erase(ready_.begin(), ready_.begin() + front_index_); 91 } 92 front_index_ = 0; 93 } 94 } empty()95 bool empty() const { return ready_.empty(); } size()96 int size() const { return ready_.size() - front_index_; } 97 98 private: 99 // TODO(b/152925936): Re-evaluate these constants with current usage 100 // patterns. 101 static constexpr int kSpillThreshold = 16384; 102 gtl::InlinedVector<TaggedNode, 16> ready_; 103 int front_index_; 104 }; 105 106 // TODO(b/152925936): Re-evaluate this constant with current usage patterns. 107 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; 108 109 // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. 110 void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, 111 TaggedNodeSeq* ready); 112 113 // After processing the outputs, propagates the outputs to their dsts. 114 // Contents of *outputs are left in an indeterminate state after 115 // returning from this method. 116 void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, 117 TaggedNodeSeq* ready); 118 119 // Returns an array of `Entry` objects corresponding to the inputs of 120 // `tagged_node`. GetInputTensors(const TaggedNode & tagged_node)121 Entry* GetInputTensors(const TaggedNode& tagged_node) { 122 #if defined(THREAD_SANITIZER) || defined(DEBUG) 123 // NOTE: This read of `pending_[...]` works around a limitation in TSAN. 124 // To avoid false positive data race reports, we need to perform an atomic 125 // object access that will establish the happens-before relation between 126 // the write to input_tensors_ in `PropagateOutputs()` and the read in 127 // `PrepareInputs()`. 128 CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); 129 #endif // defined(THREAD_SANITIZER) || defined(DEBUG) 130 return input_tensors_.data() + tagged_node.node_item->input_start; 131 } 132 GetFrameAndIter(const TaggedNode & tagged_node)133 FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { 134 return {0, 0}; 135 } 136 137 // Provide debugging output of the state of the executor. 138 void DumpState(); 139 140 // For debugging/logging only. MaybeMarkStarted(const TaggedNode & tagged_node)141 void MaybeMarkStarted(const TaggedNode& tagged_node) { 142 // TODO(misard) Replace with a finer-grain enabling flag once we add better 143 // optional debugging support. 144 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { 145 mutex_lock l(mu_); 146 (*active_)[tagged_node.node_item->node_id] = true; 147 } 148 } MaybeMarkCompleted(const TaggedNode & tagged_node)149 void MaybeMarkCompleted(const TaggedNode& tagged_node) { 150 // TODO(misard) Replace with a finer-grain enabling flag once we add better 151 // optional debugging support. 152 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { 153 mutex_lock l(mu_); 154 (*active_)[tagged_node.node_item->node_id] = false; 155 } 156 } 157 158 private: 159 SimplePropagatorState(const ImmutableExecutorState& immutable_state_, 160 int64_t step_id, 161 const ImmutableExecutorState::FrameInfo& finfo, 162 bool vlog); 163 164 const ImmutableExecutorState& immutable_state_; 165 const int64_t step_id_; 166 const bool vlog_; 167 168 // The i-th node's j-th input is stored at 169 // `input_tensors[impl_->nodes[i].input_start + j]`. 170 // 171 // NOTE: No need to protect input_tensors[i] by any locks because it 172 // is resized once. Each element of input_tensors is written once by the 173 // source node of an edge and is cleared by the destination of the same 174 // edge. The destination node always runs after the source node, so there 175 // is never concurrent access to the same entry. 176 std::vector<Entry> input_tensors_; 177 178 std::unique_ptr<std::atomic<int32>[]> pending_; 179 180 // If `vlog_` is true, this stores a bit vector of active nodes, indexed by 181 // node ID. 182 mutex mu_; 183 std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_); 184 185 const std::vector<const NodeItem*>* const nodes_; 186 }; 187 188 } // namespace tensorflow 189 190 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ 191