xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/immutable_executor_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/immutable_executor_state.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/core/framework/function.h"
20 #include "tensorflow/core/framework/metrics.h"
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/graph/edgeset.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/graph/graph_node_util.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 
30 namespace {
IsInitializationOp(const Node * node)31 bool IsInitializationOp(const Node* node) {
32   return node->op_def().allows_uninitialized_input();
33 }
34 }  // namespace
35 
~ImmutableExecutorState()36 ImmutableExecutorState::~ImmutableExecutorState() {
37   for (int32_t i = 0; i < gview_.num_nodes(); i++) {
38     NodeItem* item = gview_.node(i);
39     if (item != nullptr) {
40       params_.delete_kernel(item->kernel);
41     }
42   }
43 }
44 
45 namespace {
GetMaxPendingCounts(const Node * n,size_t * max_pending,size_t * max_dead_count)46 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47                          size_t* max_dead_count) {
48   const size_t num_in_edges = n->in_edges().size();
49   size_t initial_count;
50   if (IsMerge(n)) {
51     // merge waits all control inputs so we initialize the pending
52     // count to be the number of control edges.
53     int32_t num_control_edges = 0;
54     for (const Edge* edge : n->in_edges()) {
55       if (edge->IsControlEdge()) {
56         num_control_edges++;
57       }
58     }
59     // Use bit 0 to indicate if we are waiting for a ready live data input.
60     initial_count = 1 + (num_control_edges << 1);
61   } else {
62     initial_count = num_in_edges;
63   }
64 
65   *max_pending = initial_count;
66   *max_dead_count = num_in_edges;
67 }
68 }  // namespace
69 
EnsureFrameInfo(const string & fname)70 ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71     const string& fname) {
72   auto iter = frame_info_.find(fname);
73   if (iter != frame_info_.end()) {
74     return iter->second.get();
75   } else {
76     auto frame_info = std::make_unique<FrameInfo>(fname);
77     absl::string_view fname_view = frame_info->name;
78     auto emplace_result =
79         frame_info_.emplace(fname_view, std::move(frame_info));
80     return emplace_result.first->second.get();
81   }
82 }
83 
Initialize(const Graph & graph)84 Status ImmutableExecutorState::Initialize(const Graph& graph) {
85   TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86 
87   // Build the information about frames in this subgraph.
88   ControlFlowInfo cf_info;
89   TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90 
91   for (auto& it : cf_info.unique_frame_names) {
92     EnsureFrameInfo(it)->nodes =
93         std::make_unique<std::vector<const NodeItem*>>();
94   }
95   root_frame_info_ = frame_info_[""].get();
96 
97   pending_ids_.resize(gview_.num_nodes());
98 
99   // Preprocess every node in the graph to create an instance of op
100   // kernel for each node.
101   requires_control_flow_ = false;
102   for (const Node* n : graph.nodes()) {
103     if (IsSink(n)) continue;
104     if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105       requires_control_flow_ = true;
106     } else if (IsRecv(n)) {
107       // A Recv node from a different device may produce dead tensors from
108       // non-local control-flow nodes.
109       //
110       // TODO(mrry): Track whether control flow was present in the
111       // pre-partitioned graph, and enable the caller (e.g.
112       // `DirectSession`) to relax this constraint.
113       string send_device;
114       string recv_device;
115       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117       if (send_device != recv_device) {
118         requires_control_flow_ = true;
119       }
120     }
121 
122     const int id = n->id();
123     const string& frame_name = cf_info.frame_names[id];
124     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125 
126     NodeItem* item = gview_.node(id);
127     item->node_id = id;
128 
129     item->input_start = frame_info->total_inputs;
130     frame_info->total_inputs += n->num_inputs();
131 
132     Status s = params_.create_kernel(n->properties(), &item->kernel);
133     if (!s.ok()) {
134       params_.delete_kernel(item->kernel);
135       item->kernel = nullptr;
136       s = AttachDef(s, *n);
137       return s;
138     }
139     CHECK(item->kernel);
140     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
141     item->is_merge = IsMerge(n);
142     item->is_any_consumer_merge_or_control_trigger = false;
143     for (const Node* consumer : n->out_nodes()) {
144       if (IsMerge(consumer) || IsControlTrigger(consumer)) {
145         item->is_any_consumer_merge_or_control_trigger = true;
146         break;
147       }
148     }
149     const Tensor* const_tensor = item->kernel->const_tensor();
150     if (const_tensor) {
151       // Hold onto a shallow copy of the constant tensor in `*this` so that the
152       // reference count does not drop to 1. This prevents the constant tensor
153       // from being forwarded, and its buffer reused.
154       const_tensors_.emplace_back(*const_tensor);
155     }
156     item->const_tensor = const_tensor;
157     item->is_noop = (item->kernel->type_string_view() == "NoOp");
158     item->is_enter = IsEnter(n);
159     if (item->is_enter) {
160       bool is_constant_enter;
161       TF_RETURN_IF_ERROR(
162           GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
163       item->is_constant_enter = is_constant_enter;
164 
165       string frame_name;
166       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
167       FrameInfo* frame_info = frame_info_[frame_name].get();
168 
169       int parallel_iterations;
170       TF_RETURN_IF_ERROR(
171           GetNodeAttr(n->attrs(), "parallel_iterations", &parallel_iterations));
172 
173       if (frame_info->parallel_iterations == -1) {
174         frame_info->parallel_iterations = parallel_iterations;
175       } else if (frame_info->parallel_iterations != parallel_iterations) {
176         LOG(WARNING) << "Loop frame \"" << frame_name
177                      << "\" had two different values for parallel_iterations: "
178                      << frame_info->parallel_iterations << " vs. "
179                      << parallel_iterations << ".";
180       }
181 
182       if (enter_frame_info_.size() <= id) {
183         enter_frame_info_.resize(id + 1);
184       }
185       enter_frame_info_[id] = frame_info;
186     } else {
187       item->is_constant_enter = false;
188     }
189     item->is_exit = IsExit(n);
190     item->is_control_trigger = IsControlTrigger(n);
191     item->is_source = IsSource(n);
192     item->is_enter_exit_or_next_iter =
193         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
194     item->is_transfer_node = IsTransferNode(n);
195     item->is_initialization_op = IsInitializationOp(n);
196     item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
197     item->is_next_iteration = IsNextIteration(n);
198     item->is_distributed_communication = IsDistributedCommunication(n);
199 
200     // Compute the maximum values we'll store for this node in the
201     // pending counts data structure, and allocate a handle in
202     // that frame's pending counts data structure that has enough
203     // space to store these maximal count values.
204     size_t max_pending, max_dead;
205     GetMaxPendingCounts(n, &max_pending, &max_dead);
206     pending_ids_[id] =
207         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
208 
209     // See if this node is a root node, and if so, add item to root_nodes_.
210     if (n->in_edges().empty()) {
211       root_nodes_.push_back(item);
212     }
213 
214     // Initialize static information about the frames in the graph.
215     frame_info->nodes->push_back(item);
216     if (item->is_enter) {
217       string enter_name;
218       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
219       EnsureFrameInfo(enter_name)->input_count++;
220     }
221 
222     // Record information about whether each output of the op is used.
223     std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
224     std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
225     int32_t unused_outputs = n->num_outputs();
226     for (const Edge* e : n->out_edges()) {
227       if (IsSink(e->dst())) continue;
228       if (e->src_output() >= 0) {
229         if (!outputs_required[e->src_output()]) {
230           --unused_outputs;
231           outputs_required[e->src_output()] = true;
232         }
233       }
234     }
235     if (unused_outputs > 0) {
236       for (int i = 0; i < n->num_outputs(); ++i) {
237         if (!outputs_required[i]) {
238           metrics::RecordUnusedOutput(n->type_string());
239         }
240       }
241       item->outputs_required = std::move(outputs_required);
242     }
243   }
244 
245   // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
246   // location.
247   for (const Node* n : graph.nodes()) {
248     if (IsSink(n)) continue;
249     const int id = n->id();
250     NodeItem* item = gview_.node(id);
251 
252     for (EdgeInfo& e : item->mutable_output_edges()) {
253       const int dst_id = e.dst_id;
254       NodeItem* dst_item = gview_.node(dst_id);
255       e.input_slot += dst_item->input_start;
256     }
257   }
258 
259   // Initialize PendingCounts only after pending_ids_[node.id] is initialized
260   // for all nodes.
261   InitializePending(&graph, cf_info);
262   return gview_.SetAllocAttrs(&graph, params_.device);
263 }
264 
265 namespace {
266 // If a Node has been marked to use a ScopedAllocator x for output i, then
267 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
268 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
269 // only allow one ScopedAllocator use per Node.
ExtractScopedAllocatorAttr(const std::vector<int> & sc_attr,int output_index,AllocatorAttributes * alloc_attr)270 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
271                                 int output_index,
272                                 AllocatorAttributes* alloc_attr) {
273   DCHECK_LE(2, sc_attr.size());
274   for (int i = 0; i < sc_attr.size(); i += 2) {
275     if (sc_attr[i] == output_index) {
276       CHECK_EQ(alloc_attr->scope_id, 0);
277       alloc_attr->scope_id = sc_attr[i + 1];
278       return true;
279     }
280   }
281   return false;
282 }
283 }  // namespace
284 
BuildControlFlowInfo(const Graph * g,ControlFlowInfo * cf_info)285 Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
286                                                     ControlFlowInfo* cf_info) {
287   const int num_nodes = g->num_node_ids();
288   cf_info->frame_names.resize(num_nodes);
289   std::vector<Node*> parent_nodes;
290   parent_nodes.resize(num_nodes);
291   std::vector<bool> visited;
292   visited.resize(num_nodes);
293 
294   string frame_name;
295   std::deque<Node*> ready;
296 
297   // Initialize with the root nodes.
298   for (Node* n : g->nodes()) {
299     if (n->in_edges().empty()) {
300       visited[n->id()] = true;
301       cf_info->unique_frame_names.insert(frame_name);
302       ready.push_back(n);
303     }
304   }
305 
306   while (!ready.empty()) {
307     Node* curr_node = ready.front();
308     int curr_id = curr_node->id();
309     ready.pop_front();
310 
311     Node* parent = nullptr;
312     if (IsEnter(curr_node)) {
313       // Enter a child frame.
314       TF_RETURN_IF_ERROR(
315           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
316       parent = curr_node;
317     } else if (IsExit(curr_node)) {
318       // Exit to the parent frame.
319       parent = parent_nodes[curr_id];
320       if (!parent) {
321         return errors::InvalidArgument(
322             "Invalid Exit op: Cannot find a corresponding Enter op.");
323       }
324       frame_name = cf_info->frame_names[parent->id()];
325       parent = parent_nodes[parent->id()];
326     } else {
327       parent = parent_nodes[curr_id];
328       frame_name = cf_info->frame_names[curr_id];
329     }
330 
331     for (const Edge* out_edge : curr_node->out_edges()) {
332       Node* out = out_edge->dst();
333       if (IsSink(out)) continue;
334       const int out_id = out->id();
335 
336       // Add to ready queue if not visited.
337       bool is_visited = visited[out_id];
338       if (!is_visited) {
339         ready.push_back(out);
340         visited[out_id] = true;
341 
342         // Process the node 'out'.
343         cf_info->frame_names[out_id] = frame_name;
344         parent_nodes[out_id] = parent;
345         cf_info->unique_frame_names.insert(frame_name);
346       }
347     }
348   }
349 
350   return OkStatus();
351 }
352 
InitializePending(const Graph * graph,const ControlFlowInfo & cf_info)353 void ImmutableExecutorState::InitializePending(const Graph* graph,
354                                                const ControlFlowInfo& cf_info) {
355   for (auto& it : cf_info.unique_frame_names) {
356     FrameInfo* finfo = EnsureFrameInfo(it);
357     DCHECK_EQ(finfo->pending_counts.get(), nullptr);
358     finfo->pending_counts =
359         std::make_unique<PendingCounts>(finfo->pending_counts_layout);
360   }
361 
362   if (!requires_control_flow_) {
363     atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
364     std::fill(atomic_pending_counts_.get(),
365               atomic_pending_counts_.get() + gview_.num_nodes(), 0);
366   }
367 
368   for (const Node* n : graph->nodes()) {
369     if (IsSink(n)) continue;
370     const int id = n->id();
371     const string& name = cf_info.frame_names[id];
372     size_t max_pending, max_dead;
373     GetMaxPendingCounts(n, &max_pending, &max_dead);
374     auto& counts = EnsureFrameInfo(name)->pending_counts;
375     counts->set_initial_count(pending_ids_[id], max_pending);
376     if (!requires_control_flow_) {
377       atomic_pending_counts_[id] = max_pending;
378     }
379   }
380 }
381 }  // namespace tensorflow
382