xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/graph_view.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/graph_view.h"
17 
18 #include <atomic>
19 #include <deque>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/framework/node_def_util.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/edgeset.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/util/determinism.h"
36 #include "tensorflow/core/util/device_name_utils.h"
37 
38 namespace tensorflow {
39 
DebugString() const40 string NodeItem::DebugString() const {
41   string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
42   if (is_source) {
43     strings::StrAppend(&ret, " source}");
44   } else {
45     strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}");
46   }
47   return ret;
48 }
49 
~GraphView()50 GraphView::~GraphView() {
51   static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
52                 "Update code if AllocatorAttributes gains a destructor");
53   static_assert(std::is_trivially_destructible<EdgeInfo>::value,
54                 "Update code if EdgeInfo gains a destructor");
55   for (int i = 0; i < num_nodes_; i++) {
56     NodeItem* n = node(i);
57     if (n != nullptr) {
58       n->NodeItem::~NodeItem();
59       // Memory for "n" itself is held in space_ & gets cleaned up below
60     }
61   }
62   delete[] node_offsets_;
63   delete[] space_;
64 }
65 
66 namespace {
67 typedef std::tuple<int32, int32> OutputAndControlEdges;
68 
CountOutputEdges(const Node * n)69 OutputAndControlEdges CountOutputEdges(const Node* n) {
70   DCHECK_LE(n->out_edges().size(), kint32max);
71   int32_t num_output_edges = 0;
72   int32_t num_output_control_edges = 0;
73   for (auto e : n->out_edges()) {
74     if (IsSink(e->dst())) continue;
75     if (e->IsControlEdge()) {
76       ++num_output_control_edges;
77     } else {
78       ++num_output_edges;
79     }
80   }
81   return OutputAndControlEdges(num_output_edges, num_output_control_edges);
82 }
83 }  // namespace
84 
NodeItemBytes(const Node * n)85 size_t GraphView::NodeItemBytes(const Node* n) {
86   int32_t num_output_edges;
87   int32_t num_output_control_edges;
88   std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
89   const int num_inputs = n->num_inputs();
90   const int num_outputs = n->num_outputs();
91 
92   // Compute number of bytes needed for NodeItem and variable length data.
93   // We do not subtract sizeof(var) since num_inputs/num_outputs might
94   // both be zero.
95   const size_t raw_bytes =
96       sizeof(NodeItem)                             // Fixed
97       + num_output_edges * sizeof(EdgeInfo)        // output_edges[...]
98       + num_output_control_edges *                 //
99             sizeof(ControlEdgeInfo)                // output_control_edges[...]
100       + num_outputs * sizeof(AllocatorAttributes)  // output_attr[...]
101       + num_outputs * sizeof(int)                  // forward_from[num_outputs]
102       + num_inputs * sizeof(uint8)                 // input_type[num_inputs]
103       + num_outputs * sizeof(uint8);               // output_type[num_outputs]
104   static constexpr size_t kItemAlignment = sizeof(NodeItem*);
105   static_assert(kItemAlignment % alignof(NodeItem) == 0,
106                 "NodeItem must be aligned with kItemAlignment");
107   static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
108                 "EdgeInfo must be aligned with kItemAlignment");
109   static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0,
110                 "ControlEdgeInfo must be aligned with kItemAlignment");
111   static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
112                 "AllocatorAttributes must be aligned with kItemAlignment");
113   static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
114                 "NodeItem must be aligned with EdgeInfo");
115   static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
116                 "NodeItem must be aligned with AllocatorAttributes");
117   static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
118                 "EdgeInfo must be aligned with AllocatorAttributes");
119   const size_t bytes =
120       ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
121   return bytes;
122 }
123 
InitializeNode(char * ptr,const Node * n)124 char* GraphView::InitializeNode(char* ptr, const Node* n) {
125   const int id = n->id();
126   CHECK(node_offsets_[id] == kuint32max);  // Initial value in constructor
127 
128   const size_t bytes = NodeItemBytes(n);
129   constexpr size_t kItemAlignment = sizeof(NodeItem*);
130   CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
131   NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
132 
133   // We store a 32-bit offset relative to the beginning of space_, so that we
134   // only need an array of 32-bit values to map from node id to the NodeItem*,
135   // (versus 64 bits on most machines if we just stored an array of NodeItem*
136   // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
137   // values as "int" vs "size_t" in CHECK_LE.
138   CHECK_LE(static_cast<int64_t>(ptr - space_), kuint32max);
139   const uint32 offset = static_cast<uint32>(ptr - space_);
140   node_offsets_[id] = offset;
141   ptr += bytes;
142 
143   int32_t num_output_edges;
144   int32_t num_output_control_edges;
145   std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
146   const int num_inputs = n->num_inputs();
147   const int num_outputs = n->num_outputs();
148 
149   new (item) NodeItem();
150   item->num_inputs = num_inputs;
151   item->num_outputs = num_outputs;
152   item->num_output_edges = num_output_edges;
153   item->num_output_control_edges = num_output_control_edges;
154 
155   // Fill output edges.
156   // Keep track of the last EdgeInfo in the EdgeInfo array that references
157   // a given output slot.  For all but the last, we need to do a copy of the
158   // Tensor when propagating results downstream in the graph, but for the
159   // last one, we can just do a move of the Tensor object to propagate it.
160   gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
161   EdgeInfo* dst_edge = item->output_edge_base();
162   for (auto e : n->out_edges()) {
163     if (e->IsControlEdge()) continue;
164     dst_edge->dst_id = e->dst()->id();
165     CHECK_LE(e->src_output(), 0x3FFFFFFF);  // Must fit in 31 bits
166     dst_edge->output_slot = e->src_output();
167     dst_edge->is_last = false;
168     const int output_slot = dst_edge->output_slot;
169     if (output_slot >= 0) {
170       last_indices[output_slot] = dst_edge;
171     }
172     // NOTE: The `input_slot` will be rewritten to the frame-wide offset later
173     // in `ExecutorImpl::Initialize()`.
174     dst_edge->input_slot = e->dst_input();
175     dst_edge++;
176   }
177   for (EdgeInfo* edge_info : last_indices) {
178     if (edge_info != nullptr) {
179       edge_info->is_last = true;
180     }
181   }
182   ControlEdgeInfo* dst_control_edge = item->output_control_edge_base();
183   for (auto e : n->out_edges()) {
184     if (!e->IsControlEdge() || IsSink(e->dst())) continue;
185     dst_control_edge->dst_id = e->dst()->id();
186     dst_control_edge++;
187   }
188 
189   AllocatorAttributes* output_attrs = item->output_attr_base();
190   for (int i = 0; i < num_outputs; i++) {
191     new (&output_attrs[i]) AllocatorAttributes();
192   }
193 
194   DCHECK_LT(DataType_MAX, 255);  // Must fit in uint8
195   uint8* input_types = item->input_type_base();
196   item->is_any_input_ref_typed = false;
197   for (int i = 0; i < num_inputs; i++) {
198     input_types[i] = static_cast<uint8>(n->input_type(i));
199     DCHECK_EQ(item->input_type(i), n->input_type(i));
200     item->is_any_input_ref_typed |= IsRefType(n->input_type(i));
201   }
202 
203   // Check ScopedAllocatorAttrs and forward_from.  Also assign output_types.
204   {
205     std::vector<int> forward_input;
206     Status fwd_status =
207         GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
208     std::vector<int> scoped_allocator_attrs;
209     Status sa_status =
210         GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
211 
212     int* forward_from = item->forward_from_base();
213     uint8* output_types = item->output_type_base();
214     for (int i = 0; i < num_outputs; ++i) {
215       output_types[i] = static_cast<uint8>(n->output_type(i));
216       DCHECK_EQ(item->output_type(i), n->output_type(i));
217 
218       forward_from[i] = OpKernelContext::Params::kNoReservation;
219       if (sa_status.ok()) {
220         for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
221           if (scoped_allocator_attrs[j] == i) {
222             // This output slot must be explicitly allocated from a
223             // ScopedAllocator.
224             forward_from[i] = OpKernelContext::Params::kNeverForward;
225             DCHECK_EQ(output_attrs[i].scope_id, 0);
226             output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
227           }
228         }
229       }
230       if (fwd_status.ok() &&
231           forward_from[i] == OpKernelContext::Params::kNoReservation) {
232         DCHECK_EQ(forward_input.size() % 2, 0);
233         for (int j = 0; j < forward_input.size(); j += 2) {
234           if (forward_input[j + 1] == i) {
235             DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
236             forward_from[i] = forward_input[j];
237             break;
238           }
239         }
240       }
241     }
242   }
243 
244   return ptr;
245 }
246 
Initialize(const Graph * g)247 Status GraphView::Initialize(const Graph* g) {
248   CHECK(node_offsets_ == nullptr);
249   const int num_nodes = g->num_node_ids();
250   num_nodes_ = num_nodes;
251   size_t total_bytes = 0;
252   for (const Node* n : g->nodes()) {
253     if (n->out_edges().size() > kint32max) {
254       return errors::InvalidArgument(
255           "The executor cannot handle nodes with more than ", kint32max,
256           " output edges. Node ", n->name(), " had ", n->out_edges().size(),
257           " output edges.");
258     }
259     total_bytes += NodeItemBytes(n);
260   }
261 
262   node_offsets_ = new uint32[num_nodes];
263   for (int i = 0; i < num_nodes; i++) {
264     node_offsets_[i] = kuint32max;
265   }
266 
267   space_ = new char[total_bytes];  // NodeItem objects are allocated here
268   char* ptr = space_;
269   auto it = g->nodes();
270   if (OpOrderDeterminismRequired()) {
271     // For OpOrder determinism, we need node_id's to be stable across runs. We
272     // assign node_ids in the order in which `InitializeNode` is called on each
273     // node. However, `g` exposes a NodeIter of nodes, which does not guarantee
274     // a deterministic ordering across runs. Since NodeIter is immutable, we
275     // must sort a local copy. We sort by node_name, which is set in the
276     // GraphDef, so must be stable across runs.
277     std::vector<Node*> nodes(it.begin(), it.end());
278     std::sort(nodes.begin(), nodes.end(), NodeComparatorName());
279     for (const Node* n : nodes) {
280       ptr = InitializeNode(ptr, n);
281     }
282   } else {
283     for (const Node* n : it) {
284       ptr = InitializeNode(ptr, n);
285     }
286   }
287   CHECK_EQ(ptr, space_ + total_bytes);
288   return OkStatus();
289 }
290 
291 namespace {
292 // If a Node has been marked to use a ScopedAllocator x for output i, then
293 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
294 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
295 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)296 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
297                                 int output_index,
298                                 AllocatorAttributes* alloc_attr) {
299   DCHECK_LE(2, sc_attr.size());
300   for (int i = 0; i < sc_attr.size(); i += 2) {
301     if (sc_attr[i] == output_index) {
302       CHECK_EQ(alloc_attr->scope_id, 0);
303       alloc_attr->scope_id = sc_attr[i + 1];
304       return true;
305     }
306   }
307   return false;
308 }
309 }  // namespace
310 
SetScopedAllocatorAttrs(const std::vector<const Node * > & sa_nodes)311 void GraphView::SetScopedAllocatorAttrs(
312     const std::vector<const Node*>& sa_nodes) {
313   for (const Node* sa : sa_nodes) {
314     NodeItem* sa_item = node(sa->id());
315     AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
316     // Control edges out of the ScopedAllocator should be use instances, but may
317     // include a few other nodes.
318     for (const auto& e : sa->out_edges()) {
319       if (IsSink(e->dst()) || !e->IsControlEdge()) {
320         continue;
321       }
322       Node* use_node = e->dst();
323       NodeItem* item = node(use_node->id());
324       AllocatorAttributes* use_attrs = item->output_attr_base();
325       std::vector<int> scoped_allocator_attrs;
326       Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
327                              &scoped_allocator_attrs);
328       if (!s.ok()) {
329         VLOG(2) << "Failed to find expected ScopedAllocator attr on "
330                 << use_node->name();
331         continue;
332       }
333       // There can be more than one output using ScopedAllocation, but this
334       // analysis assumes they use the same ScopedAllocator.
335       for (const auto& e : use_node->out_edges()) {
336         if (IsSink(e->dst()) || !e->IsControlEdge()) {
337           AllocatorAttributes attr;
338           if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
339                                          e->src_output(), &attr)) {
340             // Set the scope_id on this use instance node.
341             (use_attrs + e->src_output())->Merge(attr);
342             // Propagate the other attributes of this node back to the SA node.
343             attr = *(use_attrs + e->src_output());
344             attr.scope_id = 0;
345             sa_attrs->Merge(attr);
346           }
347         }
348       }
349     }
350   }
351 }
352 
353 namespace {
InferAllocAttr(const Node * n,const Node * dst,const DeviceNameUtils::ParsedName & local_dev_name,AllocatorAttributes * attr)354 Status InferAllocAttr(const Node* n, const Node* dst,
355                       const DeviceNameUtils::ParsedName& local_dev_name,
356                       AllocatorAttributes* attr) {
357   Status s;
358   // Note that it's possible for *n to be a Recv and *dst to be a Send,
359   // so these two cases are not mutually exclusive.
360   if (IsRecv(n)) {
361     string src_name;
362     s = GetNodeAttr(n->attrs(), "send_device", &src_name);
363     if (!s.ok()) return s;
364     DeviceNameUtils::ParsedName parsed_src_name;
365     if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
366       s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
367                            n->name());
368       return s;
369     }
370     if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
371       // Value is going to be the sink of an RPC.
372       attr->set_nic_compatible(true);
373       VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
374     } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
375                parsed_src_name.type != "CPU") {
376       // Value is going to be the sink of a local DMA from GPU to CPU (or
377       // other types of accelerators).
378       attr->set_gpu_compatible(true);
379       VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
380     } else {
381       VLOG(2) << "default alloc case local type " << local_dev_name.type
382               << " remote type " << parsed_src_name.type;
383     }
384   }
385   if (IsSend(dst)) {
386     string dst_name;
387     s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
388     if (!s.ok()) return s;
389     DeviceNameUtils::ParsedName parsed_dst_name;
390     if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
391       s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
392                            n->name());
393       return s;
394     }
395     if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
396       // Value is going to be the source of an RPC.
397       attr->set_nic_compatible(true);
398       VLOG(2) << "node " << n->name() << " is the source of an RPC out";
399     } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
400                parsed_dst_name.type != "CPU") {
401       // Value is going to be the source of a local DMA from CPU to GPU (or
402       // other types of accelerators).
403       // Note that this does not cover the case where the allocation of the
404       // output tensor is not generated by the src: n.
405       attr->set_gpu_compatible(true);
406       VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
407     } else {
408       VLOG(2) << "default alloc case local type " << local_dev_name.type
409               << " remote type " << parsed_dst_name.type;
410     }
411   }
412   if (n->IsCollective()) {
413     // We'll make the sweeping assumption that any collective op is going
414     // to be involved in network i/o.
415     attr->set_nic_compatible(true);
416   }
417   return s;
418 }
419 }  // namespace
420 
SetAllocAttrs(const Graph * g,const Device * device)421 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
422   Status s;
423   const DeviceNameUtils::ParsedName& local_dev_name = device->parsed_name();
424 
425   std::vector<const Node*> scoped_allocator_instances;
426   for (const Node* n : g->nodes()) {
427     NodeItem* item = node(n->id());
428     AllocatorAttributes* attrs = item->output_attr_base();
429     if (IsScopedAllocator(n)) {
430       scoped_allocator_instances.push_back(n);
431     }
432 
433     // Examine the out edges of each node looking for special use
434     // cases that may affect memory allocation attributes.
435     for (const auto& e : n->out_edges()) {
436       if (!e->IsControlEdge()) {
437         AllocatorAttributes attr;
438         s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
439         if (!s.ok()) return s;
440         if (attr.value != 0 || attr.scope_id != 0) {
441           attrs[e->src_output()].Merge(attr);
442         }
443       }
444     }
445 
446     for (int out = 0; out < n->num_outputs(); out++) {
447       const OpKernel* op_kernel = item->kernel;
448       DCHECK_LT(out, op_kernel->output_memory_types().size());
449       bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
450       if (on_host) {
451         AllocatorAttributes h;
452         h.set_on_host(on_host);
453         attrs[out].Merge(h);
454       }
455     }
456   }
457   SetScopedAllocatorAttrs(scoped_allocator_instances);
458   return s;
459 }
460 
461 }  // namespace tensorflow
462