xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/graph_view.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/gtl/array_slice.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 
32 class Device;
33 class Graph;
34 class Node;
35 class OpKernel;
36 class Tensor;
37 
38 // Represents a single data edge in a `NodeItem`.
39 struct EdgeInfo {
40   // The node ID of the destination in the containing `GraphView`.
41   int dst_id;
42   // The index of the output that produces values on this edge.
43   int output_slot : 31;
44   // true if this is the last info for output_slot in the EdgeInfo list.
45   bool is_last : 1;
46   // The index of the input that consumes values on this edge.
47   int input_slot;
48 };
49 
50 // Represents a single control edge in a `NodeItem`.
51 struct ControlEdgeInfo {
52   // The node ID of the destination in the containing `GraphView`.
53   int dst_id;
54 };
55 
56 // Compact structure representing a graph node and its associated kernel.
57 //
58 // Each NodeItem is an element of exactly one GraphView.
59 struct NodeItem {
60   // The index of this node's item in its GraphView.
61   int node_id = -1;
62 
63   // Cached attributes of this node for fast lookup.
64   bool kernel_is_async : 1;     // True iff kernel->AsAsync() != nullptr
65   bool is_merge : 1;            // True iff IsMerge(node)
66   bool is_enter : 1;            // True iff IsEnter(node)
67   bool is_constant_enter : 1;   // True iff IsEnter(node) and
68                                 // node->GetAttr("is_constant") == true.
69   bool is_exit : 1;             // True iff IsExit(node)
70   bool is_control_trigger : 1;  // True iff IsControlTrigger(node)
71   bool is_source : 1;           // True iff IsSource(node)
72   // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
73   bool is_enter_exit_or_next_iter : 1;
74   bool is_transfer_node : 1;      // True iff IsTransferNode(node)
75   bool is_initialization_op : 1;  // True iff IsInitializationOp(node)
76   bool is_recv_or_switch : 1;     // True iff IsRecv(node) || IsSwitch(node)
77   bool is_next_iteration : 1;     // True iff IsNextIteration(node)
78   bool is_noop : 1;  // True iff item->kernel->type_string_view() == "NoOp")
79   bool
80       is_any_consumer_merge_or_control_trigger : 1;  // True iff the destination
81                                                      // of any output edge is a
82                                                      // merge or control trigger
83                                                      // node.
84   bool is_any_input_ref_typed : 1;  // True iff any IsRefType(dt) for dt in this
85                                     // node's input types.
86   bool is_distributed_communication : 1;  // True iff the op is registered to
87                                           // use distributed communication.
88 
89   // The kernel for this node.
90   OpKernel* kernel = nullptr;
91 
92   // If the kernel is a Const op, this containts points to the constant tensor.
93   const Tensor* const_tensor = nullptr;
94 
95   // Cached values of node->num_inputs() and node->num_outputs(), to
96   // avoid levels of indirection.
97   int num_inputs;
98   int num_outputs;
99 
100   // ExecutorImpl::tensors_[input_start] is the 1st positional input
101   // for this node.
102   int input_start = 0;
103 
104   // Number of output edges, excluding control edges.
105   int32 num_output_edges;
106 
107   // Number of output control edges.
108   int32 num_output_control_edges;
109 
110   // If non-null, contains an array of num_outputs bools, where the ith bool
111   // is true if and only if the ith output is consumed by another node.
112   std::unique_ptr<bool[]> outputs_required;
113 
mutable_output_edgesNodeItem114   gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
115     return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
116                                             num_output_edges);
117   }
118 
output_edgesNodeItem119   gtl::ArraySlice<EdgeInfo> output_edges() const {
120     return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
121   }
122 
output_control_edgesNodeItem123   gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
124     return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
125                                                   num_output_control_edges);
126   }
127 
input_typeNodeItem128   DataType input_type(int i) const {
129     DCHECK_LT(i, num_inputs);
130     return static_cast<DataType>(input_type_base()[i]);
131   }
output_typeNodeItem132   DataType output_type(int i) const {
133     DCHECK_LT(i, num_outputs);
134     return static_cast<DataType>(output_type_base()[i]);
135   }
136 
137   // Return array of per-output allocator attributes.
output_attrsNodeItem138   const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
139 
140   // Return array of expected input index from which each output should
141   // be forwarded:
142   // kNeverForward (-2) for DO NOT FORWARD (must allocate).
143   // kNoReservation (-1) for no expected forwarding.
144   // 0... for forward from that input.
forward_fromNodeItem145   const int* forward_from() const { return forward_from_base(); }
146 
147   string DebugString() const;
148 
149  private:
150   friend class GraphView;
151 
NodeItemNodeItem152   NodeItem() {}
153 
154   // Variable length section starts immediately after *this
155   // (uint8 is enough for DataType).
156   //   EdgeInfo            out_edges[num_output_edges];
157   //   ControlEdgeInfo     out_control_edges[num_output_control_edges];
158   //   AllocatorAttributes output_attr[num_outputs];
159   //   int                 forward_from[num_outputs];
160   //   uint8               input_type[num_inputs];
161   //   uint8               output_type[num_outputs];
162 
163   // Return pointer to variable length section.
varNodeItem164   char* var() const {
165     return const_cast<char*>(reinterpret_cast<const char*>(this) +
166                              sizeof(NodeItem));
167   }
168 
output_edge_baseNodeItem169   EdgeInfo* output_edge_base() const {
170     return reinterpret_cast<EdgeInfo*>(var());
171   }
172 
output_control_edge_baseNodeItem173   ControlEdgeInfo* output_control_edge_base() const {
174     return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
175                                                           num_output_edges);
176   }
177 
output_attr_baseNodeItem178   AllocatorAttributes* output_attr_base() const {
179     return reinterpret_cast<AllocatorAttributes*>(
180         var() + sizeof(EdgeInfo) * num_output_edges +
181         sizeof(ControlEdgeInfo) * num_output_control_edges);
182   }
forward_from_baseNodeItem183   int* forward_from_base() const {
184     return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
185                                   sizeof(ControlEdgeInfo) *
186                                       num_output_control_edges +
187                                   sizeof(AllocatorAttributes) * num_outputs);
188   }
input_type_baseNodeItem189   uint8* input_type_base() const {
190     return reinterpret_cast<uint8*>(
191         var() + sizeof(EdgeInfo) * num_output_edges +
192         sizeof(ControlEdgeInfo) * num_output_control_edges +
193         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
194   }
output_type_baseNodeItem195   uint8* output_type_base() const {
196     return reinterpret_cast<uint8*>(
197         var() + sizeof(EdgeInfo) * num_output_edges +
198         sizeof(ControlEdgeInfo) * num_output_control_edges +
199         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
200         sizeof(uint8) * num_inputs);
201   }
202 
203   TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
204 };
205 
206 // Immutable view of a Graph organized for efficient execution.
207 //
208 // TODO(b/152651962): Add independent unit tests for this class.
209 class GraphView {
210  public:
GraphView()211   GraphView() : space_(nullptr) {}
212   ~GraphView();
213 
214   Status Initialize(const Graph* g);
215   Status SetAllocAttrs(const Graph* g, const Device* device);
216   void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
217 
218   // Returns a mutable pointer to the `NodeItem` with the given `id` if it
219   // exists in the graph, or `nullptr` if it does not.
node(int32_t id)220   NodeItem* node(int32_t id) const {
221     DCHECK_GE(id, 0);
222     DCHECK_LT(id, num_nodes_);
223     uint32 offset = node_offsets_[id];
224     return ((offset == kuint32max)
225                 ? nullptr
226                 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
227   }
228 
229   // Returns the `NodeItem` with the given `id`.
230   //
231   // REQUIRES: `id` must be the ID of a valid node in the graph.
node_ref(int32_t id)232   const NodeItem& node_ref(int32_t id) const {
233     DCHECK_GE(id, 0);
234     DCHECK_LT(id, num_nodes_);
235     uint32 offset = node_offsets_[id];
236     DCHECK_NE(offset, kuint32max);
237     return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
238   }
239 
num_nodes()240   int32 num_nodes() const { return num_nodes_; }
241 
242  private:
243   char* InitializeNode(char* ptr, const Node* n);
244   size_t NodeItemBytes(const Node* n);
245 
246   int32 num_nodes_ = 0;
247   uint32* node_offsets_ = nullptr;  // array of size "num_nodes_"
248   // node_offsets_[id] holds the byte offset for node w/ "id" in space_
249 
250   char* space_;  // NodeItem objects are allocated here
251 
252   TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
253 };
254 
255 }  // namespace tensorflow
256 
257 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
258