xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/device_propagation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/device_propagation.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/graph/algorithm.h"
24 #include "tensorflow/core/graph/graph.h"
25 
26 namespace tensorflow {
27 
28 namespace {
29 
AssignedOrRequestedDevice(const Node & node)30 const std::string& AssignedOrRequestedDevice(const Node& node) {
31   if (!node.assigned_device_name().empty()) {
32     return node.assigned_device_name();
33   }
34   return node.requested_device();
35 }
36 
UpdateDeviceFromInputs(const device_propagation::NodeFilter & node_filter,const device_propagation::DeviceFilter & device_filter,Node * node)37 void UpdateDeviceFromInputs(
38     const device_propagation::NodeFilter& node_filter,
39     const device_propagation::DeviceFilter& device_filter, Node* node) {
40   if (!AssignedOrRequestedDevice(*node).empty() || !node_filter(*node)) {
41     return;
42   }
43   string proposed_device = "";
44   Node* proposed_src = nullptr;
45   // Scan the input edges, propagate device assignment from its inputs to this
46   // node iff all input nodes has the same device assignment and the device is
47   // propagatable (checked by `device_filter`). Some kinds of edges are
48   // ignored.
49   for (const Edge* e : node->in_edges()) {
50     // Ignore control edge.
51     if (e->IsControlEdge()) {
52       continue;
53     }
54     Node* src = e->src();
55     const string& src_device = AssignedOrRequestedDevice(*src);
56 
57     // Ignore LoopCond -> Switch and Enter -> Merge. In other words, the device
58     // placement of a Switch op is determined by all its non-LoopCond inputs and
59     // that of a Merge op is determined by all its non-Enter inputs.
60     if ((node->IsSwitch() && src->IsLoopCond()) ||
61         (node->IsMerge() && src->IsEnter())) {
62       continue;
63     }
64 
65     // If a source device is not propagatable, stop.
66     if (!device_filter(src_device)) return;
67 
68     if (proposed_src == nullptr) {
69       proposed_device = src_device;
70       proposed_src = src;
71     } else if (proposed_device != src_device) {
72       // The device assignments of some input nodes are not the same. Stop.
73       return;
74     }
75   }
76   if (proposed_src) {
77     node->set_assigned_device_name(proposed_src->assigned_device_name());
78     node->set_requested_device(proposed_src->requested_device());
79   }
80 }
81 
82 }  // namespace
83 
PropagateDevices(const device_propagation::NodeFilter & node_filter,const device_propagation::DeviceFilter & device_filter,Graph * graph)84 void PropagateDevices(const device_propagation::NodeFilter& node_filter,
85                       const device_propagation::DeviceFilter& device_filter,
86                       Graph* graph) {
87   ReverseDFS(*graph, {}, [&node_filter, &device_filter](Node* node) {
88     UpdateDeviceFromInputs(node_filter, device_filter, node);
89   });
90 }
91 
92 }  // namespace tensorflow
93