xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/placer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/placer.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/colocation_graph.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/device_attributes.pb.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/util/dump_graph.h"
32 #include "tensorflow/core/util/port.h"
33 
34 namespace tensorflow {
35 
36 namespace {
37 
38 struct NameCounts {
39   mutex counts_mutex;
40   std::unordered_map<string, int> counts;
41 };
42 
MakeUniqueFilename(string name)43 string MakeUniqueFilename(string name) {
44   static NameCounts& instance = *new NameCounts;
45 
46   // Remove illegal characters from `name`.
47   for (int i = 0; i < name.size(); ++i) {
48     char ch = name[i];
49     if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') {
50       name[i] = '_';
51     }
52   }
53 
54   int count;
55   {
56     mutex_lock lock(instance.counts_mutex);
57     count = instance.counts[name]++;
58   }
59 
60   string filename = name;
61   if (count > 0) {
62     absl::StrAppend(&filename, "_", count);
63   }
64   absl::StrAppend(&filename, ".txt");
65   return filename;
66 }
67 
GetFileName(string base_name,string * fname)68 Status GetFileName(string base_name, string* fname) {
69   const char* dir = nullptr;
70   dir = getenv("TF_DUMP_GRAPH_PREFIX");
71   if (!dir) {
72     return errors::Internal("Failed to get the directory for ", base_name,
73                             " because dump location is not specified through "
74                             "TF_DUMP_GRAPH_PREFIX environment variable");
75   }
76   base_name = MakeUniqueFilename(base_name);
77   *fname = absl::StrCat(dir, "/", base_name);
78   return OkStatus();
79 }
80 
DumpColocationGraph(const string & base_name,const ColocationGraph & colocation_graph)81 void DumpColocationGraph(const string& base_name,
82                          const ColocationGraph& colocation_graph) {
83   string fname;
84   Status status = GetFileName(base_name, &fname);
85   if (status.ok()) {
86     status = WriteStringToFile(Env::Default(), fname,
87                                colocation_graph.DebugString());
88     if (status.ok()) {
89       LOG(INFO) << "Wrote ColocationGraph to " << fname;
90     }
91   }
92   if (!status.ok()) {
93     LOG(ERROR) << "Failed to write final colocation graph to file " << fname
94                << " with " << status.ToString();
95   }
96 }
97 
98 // Returns true if the node has no inputs and produces outputs
99 // that are consumed by a single node.
100 //
101 // TODO(vrv): Currently this handles only nodes with one output, but
102 // this could be extended to handle the case where a node has many
103 // outputs that are connected to nodes in the same colocation group.
IsGeneratorNode(const Node * node)104 bool IsGeneratorNode(const Node* node) {
105   return node->num_inputs() == 0 && node->num_outputs() == 1 &&
106          !IsRefType(node->output_type(0));
107 }
108 
LogDeviceAssignment(const Node * node,bool log_device_placement)109 void LogDeviceAssignment(const Node* node, bool log_device_placement) {
110   // Log placement if log_device_placement is set.
111   if (log_device_placement) {
112     printf("%s: (%s): %s\n", node->name().c_str(), node->type_string().c_str(),
113            node->assigned_device_name().c_str());
114     LOG(INFO) << node->name() << ": "
115               << "(" << node->type_string()
116               << "): " << node->assigned_device_name();
117   }
118   if (VLOG_IS_ON(1)) {
119     if (VLOG_IS_ON(4)) {
120       VLOG(4) << "\nNode:\n"
121               << node->def().DebugString()
122               << "placed on: " << node->assigned_device_name();
123     } else {
124       VLOG(1) << node->name() << "(" << node->type_string()
125               << ") placed on: " << node->assigned_device_name();
126     }
127   }
128 }
129 
AssignAndLog(int assigned_device,Node * node,ColocationGraph * colocation_graph,bool log_device_placement)130 Status AssignAndLog(int assigned_device, Node* node,
131                     ColocationGraph* colocation_graph,
132                     bool log_device_placement) {
133   node->set_assigned_device_name_index(assigned_device);
134 
135   // Constraint the group of node to the assigned device.
136   TF_RETURN_IF_ERROR(colocation_graph->LimitToAssignedDevice(*node));
137 
138   LogDeviceAssignment(node, log_device_placement);
139   return OkStatus();
140 }
141 
142 }  // namespace
143 
Placer(Graph * graph,const string & function_name,const FunctionLibraryDefinition * flib_def,const DeviceSet * devices,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)144 Placer::Placer(Graph* graph, const string& function_name,
145                const FunctionLibraryDefinition* flib_def,
146                const DeviceSet* devices, const Device* default_local_device,
147                bool allow_soft_placement, bool log_device_placement)
148     : graph_(graph),
149       function_name_(function_name),
150       flib_def_(flib_def),
151       devices_(devices),
152       default_local_device_(default_local_device),
153       allow_soft_placement_(allow_soft_placement),
154       log_device_placement_(log_device_placement) {}
155 
Placer(Graph * graph,const string & function_name,const FunctionLibraryDefinition * flib_def,const DeviceSet * devices,const Device * default_local_device)156 Placer::Placer(Graph* graph, const string& function_name,
157                const FunctionLibraryDefinition* flib_def,
158                const DeviceSet* devices, const Device* default_local_device)
159     : Placer(graph, function_name, flib_def, devices, default_local_device,
160              true, false) {}
Placer(Graph * graph,const string & function_name,const FunctionLibraryDefinition * flib_def,const DeviceSet * devices)161 Placer::Placer(Graph* graph, const string& function_name,
162                const FunctionLibraryDefinition* flib_def,
163                const DeviceSet* devices)
164     : Placer(graph, function_name, flib_def, devices, nullptr, true, false) {}
165 
~Placer()166 Placer::~Placer() {}
167 
Run()168 Status Placer::Run() {
169   if (devices_->devices().empty()) {
170     return errors::FailedPrecondition("No devices are registered");
171   }
172 
173   if (VLOG_IS_ON(3)) {
174     DumpGraphToFile("placer_input", *graph_, nullptr);
175   }
176   if (VLOG_IS_ON(5)) {
177     for (const Node* node : graph_->op_nodes()) {
178       VLOG(5) << "    " << node->name() << ": requested: '"
179               << node->requested_device() << "' assigned: '"
180               << node->assigned_device_name() << "'";
181     }
182   }
183 
184   FunctionStack stack(function_name_);
185   ColocationGraph colocation_graph(graph_, stack, flib_def_, devices_,
186                                    default_local_device_, allow_soft_placement_,
187                                    log_device_placement_);
188 
189   TF_RETURN_IF_ERROR(colocation_graph.Initialize());
190 
191   // For each node, assign a device based on the constraints in the disjoint
192   // node set.
193   std::vector<Node*> second_pass;
194   for (Node* node : graph_->op_nodes()) {
195     // The graph may have come pre-populated by the framework with assigned
196     // devices (e.g., for stateful placements), so the placer should not try to
197     // place nodes that are already placed.
198     if (node->has_assigned_device_name()) {
199       TF_RETURN_IF_ERROR(colocation_graph.LimitToAssignedDevice(*node));
200       LogDeviceAssignment(node, log_device_placement_);
201       continue;
202     }
203 
204     // Heuristic A: prefer to place "generators" with their only
205     // consumers.
206     //
207     // If this is a node with no inputs and one output, we save
208     // this for a second pass, so that the consumer's placement
209     // is chosen.
210     if (IsGeneratorNode(node)) {
211       second_pass.push_back(node);
212       continue;
213     }
214 
215     const std::vector<Device*>* devices;
216     Status status = colocation_graph.GetDevicesForNode(node, &devices);
217     if (!status.ok()) {
218       return AttachDef(
219           errors::InvalidArgument("Cannot assign a device for operation ",
220                                   node->name(), ": ", status.error_message()),
221           *node);
222     }
223 
224     // TODO(mdan): This is a constrained optimization solver. Write it like one.
225 
226     // Returns the first device in sorted devices list so we will always
227     // choose the same device.
228     //
229     // TODO(vrv): Factor this assignment out into a pluggable
230     // algorithm, so that Placer is responsible for enforcing
231     // preconditions and we can experiment with other algorithms when
232     // given a choice of devices. Once we have a better idea of the
233     // types of heuristics we want to use and the information needed
234     // to perform good placement we can add an interface for this.
235     int assigned_device = -1;
236 
237     // Heuristic B: If the node only operates on metadata, not data,
238     // then it is desirable to place that metadata node with its
239     // input.
240     if (IsMetadata(node)) {
241       // Make sure that the input device type is in the list of supported
242       // device types for this node.
243       const Node* input = (*node->in_edges().begin())->src();
244       // TODO(vrv): if the input is empty, consider postponing this
245       // node's assignment to the second pass, so that we handle the
246       // case where a metadata node's input comes from a backedge
247       // of a loop.
248       if (CanAssignToDevice(input->assigned_device_name(), *devices)) {
249         assigned_device = input->assigned_device_name_index();
250       }
251     }
252 
253     // Provide the default, if necessary.
254     if (assigned_device == -1) {
255       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
256     }
257 
258     TF_RETURN_IF_ERROR(AssignAndLog(assigned_device, node, &colocation_graph,
259                                     log_device_placement_));
260   }
261 
262   // Perform a second pass assignment for those nodes explicitly
263   // skipped during the first pass.
264   for (Node* node : second_pass) {
265     const std::vector<Device*>* devices;
266     Status status = colocation_graph.GetDevicesForNode(node, &devices);
267     if (!status.ok()) {
268       return AttachDef(
269           errors::InvalidArgument("Cannot assign a device for operation ",
270                                   node->name(), ": ", status.error_message()),
271           *node);
272     }
273 
274     int assigned_device = -1;
275 
276     // Heuristic A application.
277     if (IsGeneratorNode(node) && !node->out_edges().empty()) {
278       const Node* output = (*node->out_edges().begin())->dst();
279       int output_device_name = output->assigned_device_name_index();
280 
281       const bool consumers_on_same_device = std::all_of(
282           node->out_edges().begin(), node->out_edges().end(),
283           [output_device_name](const Edge* e) {
284             return e->dst()->assigned_device_name_index() == output_device_name;
285           });
286 
287       if (consumers_on_same_device &&
288           CanAssignToDevice(output->assigned_device_name(), *devices)) {
289         assigned_device = output_device_name;
290       }
291     }
292 
293     // Provide the default, if necessary.
294     if (assigned_device == -1) {
295       assigned_device = graph_->InternDeviceName((*devices)[0]->name());
296     }
297 
298     TF_RETURN_IF_ERROR(AssignAndLog(assigned_device, node, &colocation_graph,
299                                     log_device_placement_));
300   }
301 
302   if (VLOG_IS_ON(3)) {
303     DumpGraphToFile("placer_output", *graph_, nullptr);
304     DumpColocationGraph("colocation_graph", colocation_graph);
305   }
306   return OkStatus();
307 }
308 
CanAssignToDevice(const string & candidate_device_name,const std::vector<Device * > & devices) const309 bool Placer::CanAssignToDevice(const string& candidate_device_name,
310                                const std::vector<Device*>& devices) const {
311   if (!candidate_device_name.empty()) {
312     // 'devices' lists the set of devices that the placer or the user has
313     // constrained the operation to.  "candidate_device_name" must
314     // refer to a concrete Device that is in the list of 'devices'.
315     const Device* other_device =
316         devices_->FindDeviceByName(candidate_device_name);
317     if (std::find(devices.begin(), devices.end(), other_device) !=
318         devices.end()) {
319       return true;
320     }
321   }
322 
323   return false;
324 }
325 
326 }  // namespace tensorflow
327