1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/core/framework/allocator.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/platform/logging.h" 27 #include "tensorflow/core/platform/macros.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 class Device; 33 class Graph; 34 class Node; 35 class OpKernel; 36 class Tensor; 37 38 // Represents a single data edge in a `NodeItem`. 39 struct EdgeInfo { 40 // The node ID of the destination in the containing `GraphView`. 41 int dst_id; 42 // The index of the output that produces values on this edge. 43 int output_slot : 31; 44 // true if this is the last info for output_slot in the EdgeInfo list. 45 bool is_last : 1; 46 // The index of the input that consumes values on this edge. 47 int input_slot; 48 }; 49 50 // Represents a single control edge in a `NodeItem`. 51 struct ControlEdgeInfo { 52 // The node ID of the destination in the containing `GraphView`. 53 int dst_id; 54 }; 55 56 // Compact structure representing a graph node and its associated kernel. 57 // 58 // Each NodeItem is an element of exactly one GraphView. 59 struct NodeItem { 60 // The index of this node's item in its GraphView. 61 int node_id = -1; 62 63 // Cached attributes of this node for fast lookup. 64 bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr 65 bool is_merge : 1; // True iff IsMerge(node) 66 bool is_enter : 1; // True iff IsEnter(node) 67 bool is_constant_enter : 1; // True iff IsEnter(node) and 68 // node->GetAttr("is_constant") == true. 69 bool is_exit : 1; // True iff IsExit(node) 70 bool is_control_trigger : 1; // True iff IsControlTrigger(node) 71 bool is_source : 1; // True iff IsSource(node) 72 // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) 73 bool is_enter_exit_or_next_iter : 1; 74 bool is_transfer_node : 1; // True iff IsTransferNode(node) 75 bool is_initialization_op : 1; // True iff IsInitializationOp(node) 76 bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) 77 bool is_next_iteration : 1; // True iff IsNextIteration(node) 78 bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") 79 bool 80 is_any_consumer_merge_or_control_trigger : 1; // True iff the destination 81 // of any output edge is a 82 // merge or control trigger 83 // node. 84 bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this 85 // node's input types. 86 bool is_distributed_communication : 1; // True iff the op is registered to 87 // use distributed communication. 88 89 // The kernel for this node. 90 OpKernel* kernel = nullptr; 91 92 // If the kernel is a Const op, this containts points to the constant tensor. 93 const Tensor* const_tensor = nullptr; 94 95 // Cached values of node->num_inputs() and node->num_outputs(), to 96 // avoid levels of indirection. 97 int num_inputs; 98 int num_outputs; 99 100 // ExecutorImpl::tensors_[input_start] is the 1st positional input 101 // for this node. 102 int input_start = 0; 103 104 // Number of output edges, excluding control edges. 105 int32 num_output_edges; 106 107 // Number of output control edges. 108 int32 num_output_control_edges; 109 110 // If non-null, contains an array of num_outputs bools, where the ith bool 111 // is true if and only if the ith output is consumed by another node. 112 std::unique_ptr<bool[]> outputs_required; 113 mutable_output_edgesNodeItem114 gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() { 115 return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(), 116 num_output_edges); 117 } 118 output_edgesNodeItem119 gtl::ArraySlice<EdgeInfo> output_edges() const { 120 return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges); 121 } 122 output_control_edgesNodeItem123 gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const { 124 return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(), 125 num_output_control_edges); 126 } 127 input_typeNodeItem128 DataType input_type(int i) const { 129 DCHECK_LT(i, num_inputs); 130 return static_cast<DataType>(input_type_base()[i]); 131 } output_typeNodeItem132 DataType output_type(int i) const { 133 DCHECK_LT(i, num_outputs); 134 return static_cast<DataType>(output_type_base()[i]); 135 } 136 137 // Return array of per-output allocator attributes. output_attrsNodeItem138 const AllocatorAttributes* output_attrs() const { return output_attr_base(); } 139 140 // Return array of expected input index from which each output should 141 // be forwarded: 142 // kNeverForward (-2) for DO NOT FORWARD (must allocate). 143 // kNoReservation (-1) for no expected forwarding. 144 // 0... for forward from that input. forward_fromNodeItem145 const int* forward_from() const { return forward_from_base(); } 146 147 string DebugString() const; 148 149 private: 150 friend class GraphView; 151 NodeItemNodeItem152 NodeItem() {} 153 154 // Variable length section starts immediately after *this 155 // (uint8 is enough for DataType). 156 // EdgeInfo out_edges[num_output_edges]; 157 // ControlEdgeInfo out_control_edges[num_output_control_edges]; 158 // AllocatorAttributes output_attr[num_outputs]; 159 // int forward_from[num_outputs]; 160 // uint8 input_type[num_inputs]; 161 // uint8 output_type[num_outputs]; 162 163 // Return pointer to variable length section. varNodeItem164 char* var() const { 165 return const_cast<char*>(reinterpret_cast<const char*>(this) + 166 sizeof(NodeItem)); 167 } 168 output_edge_baseNodeItem169 EdgeInfo* output_edge_base() const { 170 return reinterpret_cast<EdgeInfo*>(var()); 171 } 172 output_control_edge_baseNodeItem173 ControlEdgeInfo* output_control_edge_base() const { 174 return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) * 175 num_output_edges); 176 } 177 output_attr_baseNodeItem178 AllocatorAttributes* output_attr_base() const { 179 return reinterpret_cast<AllocatorAttributes*>( 180 var() + sizeof(EdgeInfo) * num_output_edges + 181 sizeof(ControlEdgeInfo) * num_output_control_edges); 182 } forward_from_baseNodeItem183 int* forward_from_base() const { 184 return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + 185 sizeof(ControlEdgeInfo) * 186 num_output_control_edges + 187 sizeof(AllocatorAttributes) * num_outputs); 188 } input_type_baseNodeItem189 uint8* input_type_base() const { 190 return reinterpret_cast<uint8*>( 191 var() + sizeof(EdgeInfo) * num_output_edges + 192 sizeof(ControlEdgeInfo) * num_output_control_edges + 193 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); 194 } output_type_baseNodeItem195 uint8* output_type_base() const { 196 return reinterpret_cast<uint8*>( 197 var() + sizeof(EdgeInfo) * num_output_edges + 198 sizeof(ControlEdgeInfo) * num_output_control_edges + 199 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + 200 sizeof(uint8) * num_inputs); 201 } 202 203 TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); 204 }; 205 206 // Immutable view of a Graph organized for efficient execution. 207 // 208 // TODO(b/152651962): Add independent unit tests for this class. 209 class GraphView { 210 public: GraphView()211 GraphView() : space_(nullptr) {} 212 ~GraphView(); 213 214 Status Initialize(const Graph* g); 215 Status SetAllocAttrs(const Graph* g, const Device* device); 216 void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes); 217 218 // Returns a mutable pointer to the `NodeItem` with the given `id` if it 219 // exists in the graph, or `nullptr` if it does not. node(int32_t id)220 NodeItem* node(int32_t id) const { 221 DCHECK_GE(id, 0); 222 DCHECK_LT(id, num_nodes_); 223 uint32 offset = node_offsets_[id]; 224 return ((offset == kuint32max) 225 ? nullptr 226 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); 227 } 228 229 // Returns the `NodeItem` with the given `id`. 230 // 231 // REQUIRES: `id` must be the ID of a valid node in the graph. node_ref(int32_t id)232 const NodeItem& node_ref(int32_t id) const { 233 DCHECK_GE(id, 0); 234 DCHECK_LT(id, num_nodes_); 235 uint32 offset = node_offsets_[id]; 236 DCHECK_NE(offset, kuint32max); 237 return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]); 238 } 239 num_nodes()240 int32 num_nodes() const { return num_nodes_; } 241 242 private: 243 char* InitializeNode(char* ptr, const Node* n); 244 size_t NodeItemBytes(const Node* n); 245 246 int32 num_nodes_ = 0; 247 uint32* node_offsets_ = nullptr; // array of size "num_nodes_" 248 // node_offsets_[id] holds the byte offset for node w/ "id" in space_ 249 250 char* space_; // NodeItem objects are allocated here 251 252 TF_DISALLOW_COPY_AND_ASSIGN(GraphView); 253 }; 254 255 } // namespace tensorflow 256 257 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ 258