xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/colocation_graph.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/colocation_graph.h"
17 
18 #include <memory>
19 #include <set>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/optional.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_set.h"
32 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33 #include "tensorflow/core/common_runtime/inspecting_placer.h"
34 #include "tensorflow/core/common_runtime/partitioning_utils.h"
35 #include "tensorflow/core/framework/attr_value.pb.h"
36 #include "tensorflow/core/framework/attr_value_util.h"
37 #include "tensorflow/core/framework/dataset.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/full_type.pb.h"
40 #include "tensorflow/core/framework/full_type_util.h"
41 #include "tensorflow/core/framework/function.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_kernel.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/graph/algorithm.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/strings/str_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/util/device_name_utils.h"
53 #include "tensorflow/core/util/dump_graph.h"
54 #include "tensorflow/core/util/port.h"
55 
56 namespace tensorflow {
57 
58 namespace {
59 
60 // We hoist the conversion from C-style string literal to StringPiece here,
61 // so that we can avoid the many repeated calls to strlen().
62 const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
63 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
64 
65 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DevicesToString(const std::vector<Device * > devices)66 std::vector<string> DevicesToString(const std::vector<Device*> devices) {
67   std::vector<string> v;
68   v.reserve(devices.size());
69   for (Device* d : devices) {
70     v.push_back(d->name());
71   }
72   return v;
73 }
74 
75 // Using absl::StrJoin with lambda does not work in tf-lite builds.
DeviceTypeAndPriorityToString(const PrioritizedDeviceTypeVector & devices)76 std::vector<string> DeviceTypeAndPriorityToString(
77     const PrioritizedDeviceTypeVector& devices) {
78   std::vector<string> v;
79   v.reserve(devices.size());
80   for (const std::pair<DeviceType, int32>& device_and_type : devices) {
81     v.push_back(DeviceTypeString(device_and_type.first));
82   }
83   return v;
84 }
85 
IsRefOrResource(DataType data_type)86 bool IsRefOrResource(DataType data_type) {
87   return IsRefType(data_type) || data_type == DT_RESOURCE;
88 }
89 
90 // While Placer can override requested device on ops processing
91 // resources, i.e. node that take (and potentially return) a resource,
92 // it must not override requested device on ops generating a resource,
93 // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
94 // output nodes.
IsRefOrResourceGeneratorNode(const Node & node)95 bool IsRefOrResourceGeneratorNode(const Node& node) {
96   return node.num_inputs() == 0 && node.num_outputs() == 1 &&
97          IsRefOrResource(node.output_type(0));
98 }
99 
IsExemptFromResourceInputColocation(const Node * node)100 bool IsExemptFromResourceInputColocation(const Node* node) {
101   // Note: Partitioned function calls, which place and partition their
102   // function bodies, are exempt from this check: they forward resource and
103   // ref inputs to operations that are appropriately placed, instead of
104   // dereferencing them.
105   const string& op_type = node->op_def().name();
106   auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
107   return exempt_ops.find(op_type) != exempt_ops.end();
108 }
109 
HasPriorities(const PrioritizedDeviceTypeVector & device_types)110 bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
111   for (const auto& prioritized_device_type : device_types) {
112     if (prioritized_device_type.second != 0) return true;
113   }
114   return false;
115 }
116 
ArePrioritiesSame(const PrioritizedDeviceTypeVector & a_types,const PrioritizedDeviceTypeVector & b_types)117 bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
118                        const PrioritizedDeviceTypeVector& b_types) {
119   if (a_types.size() != b_types.size()) {
120     return false;
121   }
122   for (int i = 0; i < a_types.size(); ++i) {
123     if (a_types[i].first != b_types[i].first) {
124       return false;
125     }
126   }
127   return true;
128 }
129 
IsXlaDevice(absl::string_view device_type)130 bool IsXlaDevice(absl::string_view device_type) {
131   if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
132       device_type == "XLA_TPU_JIT") {
133     // Symbolic XLA device.
134     return true;
135   }
136 
137   return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
138           device_type == "TPU");
139 }
140 
IsCompositeDevice(absl::string_view device_type)141 bool IsCompositeDevice(absl::string_view device_type) {
142   return device_type == kCompositeDeviceType;
143 }
144 
145 // TODO(mdan): This is still too coarse.
146 // Host-memory constraints are specific to kernel registrations, so in theory
147 // they depend on the assigned device.
148 // So we need a constraint model of the kind: <<node device>>: <<output_device>>
HasHostMemoryOutType(const Node & node)149 bool HasHostMemoryOutType(const Node& node) {
150   if (!node.def().has_experimental_type()) {
151     return false;
152   }
153   const FullTypeDef& ft = node.def().experimental_type();
154   DCHECK(ft.type_id() == TFT_PRODUCT) << ft.DebugString();
155 
156   for (const auto& arg : ft.args()) {
157     if (full_type::IsHostMemoryType(arg)) {
158       return true;
159     }
160   }
161 
162   return false;
163 }
164 }  // namespace
165 
SetParentAndSupportedDevices(const Node & node,const std::vector<DeviceType> & types,const DeviceNameUtils::ParsedName * local_address_spec)166 Status Member::SetParentAndSupportedDevices(
167     const Node& node, const std::vector<DeviceType>& types,
168     const DeviceNameUtils::ParsedName* local_address_spec) {
169   int id = node.id();
170   if (id < 0) {
171     return errors::Internal("Placer should not be creating a Member for node: ",
172                             node.DebugString());
173   }
174   parent_ = id;
175   return SupportedDeviceTypesForNode(
176       types, node.def(), &supported_device_types_, local_address_spec);
177 }
178 
SetAssignedDeviceName(const string & device_name)179 Status Member::SetAssignedDeviceName(const string& device_name) {
180   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
181     return errors::Internal(
182         "Setting assigned device name when there is a requested device set "
183         "is unsupported");
184   }
185   if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
186     return errors::Internal("Malformed assigned device '", device_name, "'");
187   }
188   // Set requested device to assigned_device to maintain the invariant that
189   // requested is a specialization of assigned.
190   requested_device_name_ = assigned_device_name_;
191   return OkStatus();
192 }
193 
SetResourceDeviceName(const Node & node)194 Status Member::SetResourceDeviceName(const Node& node) {
195   if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
196     return errors::Internal(
197         "Setting resource device name when there is a requested device set "
198         "is unsupported");
199   }
200 
201   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
202                                       &resource_device_name_)) {
203     return errors::InvalidArgument("Malformed device specification '",
204                                    node.requested_device(),
205                                    "' in node: ", node.DebugString());
206   }
207 
208   // Set requested device to resource device to maintain the invariant that
209   // requested is a specialization of resource.
210   requested_device_name_ = resource_device_name_;
211   return OkStatus();
212 }
213 
SetRequestedDeviceName(const Node & node)214 Status Member::SetRequestedDeviceName(const Node& node) {
215   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
216     return errors::Internal(
217         "Setting requested device name when there is an assigned device set "
218         "is unsupported");
219   }
220   if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
221     return errors::Internal(
222         "Setting requested device name when there is a resource device set "
223         "is unsupported");
224   }
225   if (!DeviceNameUtils::ParseFullName(node.requested_device(),
226                                       &requested_device_name_)) {
227     return errors::InvalidArgument("Malformed device specification '",
228                                    node.requested_device(),
229                                    "' in node: ", node.DebugString());
230   }
231   return OkStatus();
232 }
233 
FillPossibleDevices(PossibleDevices * possible_device) const234 Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
235   if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
236     return errors::Internal(
237         "Cannot fill PossibleDevices from a member that has non-empty assigned "
238         "device. Did we start assigning devices to functions called by deep "
239         "ops? ",
240         DebugString());
241   }
242   possible_device->requested_device_name = requested_device_name_;
243   possible_device->resource_device_name = resource_device_name_;
244   possible_device->device_types = supported_device_types_;
245   return OkStatus();
246 }
247 
IsEdgeFromCompositeDeviceToPhysicalDevice(const Member & src_root) const248 bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
249     const Member& src_root) const {
250   auto compatible_edge_from_composite_device_to_physical_device =
251       [](const DeviceNameUtils::ParsedName& src_device,
252          const DeviceNameUtils::ParsedName& dst_device) -> bool {
253     return src_device.has_type && dst_device.has_type &&
254            IsCompositeDevice(src_device.type) &&
255            !IsCompositeDevice(dst_device.type);
256   };
257   if (compatible_edge_from_composite_device_to_physical_device(
258           src_root.assigned_device_name_, assigned_device_name_) ||
259       compatible_edge_from_composite_device_to_physical_device(
260           src_root.resource_device_name_, resource_device_name_) ||
261       compatible_edge_from_composite_device_to_physical_device(
262           src_root.requested_device_name_, requested_device_name_)) {
263     return true;
264   }
265   return false;
266 }
267 
EnsureCompatibilityAcrossResourceEdge(const Node & src,const Member & src_root,const Node & dst,bool log_device_placement)268 Status Member::EnsureCompatibilityAcrossResourceEdge(
269     const Node& src, const Member& src_root,
270     const Node& dst, /*dst_root is this*/
271     bool log_device_placement) {
272   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
273                                               assigned_device_name_)) {
274     return errors::InvalidArgument(
275         "Cannot place the graph because a reference or resource edge "
276         "connects colocation groups with incompatible assigned devices: ",
277         DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
278         " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
279         ". The edge src node is name='", src.name(), "' (op='", src.def().op(),
280         "'), and the dst node is name='", dst.name(), "' (op='", dst.def().op(),
281         "').");
282   }
283 
284   if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
285                                               resource_device_name_)) {
286     return errors::InvalidArgument(
287         "Cannot place the graph because a reference or resource edge "
288         "connects colocation groups with incompatible resource devices: ",
289         DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
290         " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
291         ". The edge src node is name='", src.name(), "' (op='", src.def().op(),
292         "'), and the dst node is name='", dst.name(), "' (op='", dst.def().op(),
293         "').");
294   }
295 
296   if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
297                                              requested_device_name_)) {
298     return OkStatus();
299   }
300 
301   // If we are here, assigned and resource devices are compatible but requested
302   // ones are not. We will be overriding the requested device for destination
303   // node, but need to preserve the invariant that it will be a specialization
304   // of the assigned and resource devices.
305   if (log_device_placement) {
306     LOG(INFO) << "Ignoring device specification "
307               << DeviceNameUtils::ParsedNameToString(requested_device_name_)
308               << " for node '" << dst.name()
309               << "' because the input edge from '" << src.name()
310               << "' is a reference connection and already has a device "
311                  "field set to "
312               << DeviceNameUtils::ParsedNameToString(
313                      src_root.requested_device_name_);
314   }
315   requested_device_name_ = src_root.requested_device_name_;
316   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
317                                        assigned_device_name_);
318   DeviceNameUtils::EnsureSpecification(&requested_device_name_,
319                                        resource_device_name_);
320   return OkStatus();
321 }
322 
Merge(std::vector<Member> * tree,int x_root,int y_root,Member ** new_root,Member ** old_root,bool dry_run)323 void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
324                    Member** new_root, Member** old_root, bool dry_run) {
325   Member& x_root_member = (*tree)[x_root];
326   Member& y_root_member = (*tree)[y_root];
327 
328   // Merge the sets by setting the parent pointer of the smaller tree's root
329   // node to point to the root of the larger tree. Together with path
330   // compression in ColocationGraph::FindRoot, this ensures that we do not
331   // experience pathological performance on graphs such as chains.
332   int new_root_id, old_root_id;
333   if (x_root_member.rank_ < y_root_member.rank_) {
334     // The tree rooted at x_root is shallower, so connect it to
335     // y_root. The rank of y_root is unchanged because its new
336     // child has strictly less rank.
337     if (!dry_run) {
338       x_root_member.parent_ = y_root;
339     }
340     new_root_id = y_root;
341     old_root_id = x_root;
342   } else if (x_root_member.rank_ > y_root_member.rank_) {
343     // The tree rooted at y_root is shallower, so connect it to
344     // x_root. The rank of x_root is unchanged because its new
345     // child has strictly less rank.
346     if (!dry_run) {
347       y_root_member.parent_ = x_root;
348     }
349     new_root_id = x_root;
350     old_root_id = y_root;
351   } else {
352     if (!dry_run) {
353       // Both trees have the same rank, so break the tie by choosing
354       // x_root as the new root.
355       y_root_member.parent_ = x_root;
356       // Increment the rank of the tree rooted at x_root, because it
357       // is now strictly deeper than before.
358       ++x_root_member.rank_;
359     }
360     new_root_id = x_root;
361     old_root_id = y_root;
362   }
363 
364   *new_root = &(*tree)[new_root_id];
365   *old_root = &(*tree)[old_root_id];
366 }
367 
368 // tree is non-const because we can change some `parent` pointers in some
369 // members for more efficient future lookups. The vector itself is not
370 // changed.
FindAndUpdateRoot(std::vector<Member> * tree,int node_id)371 int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
372   Member& member = (*tree)[node_id];
373   if (member.parent_ == node_id) {
374     // member.parent is the root of this disjoint tree.  Do nothing.
375   } else {
376     member.parent_ = FindAndUpdateRoot(tree, member.parent_);
377   }
378   // Now it is guaranteed that member.parent is the root of this disjoint
379   // tree.
380   return member.parent_;
381 }
382 
FindRoot(const std::vector<Member> & tree,int node_id)383 int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
384   const Member& member = tree[node_id];
385   if (member.parent_ == node_id) {
386     return member.parent_;
387   }
388   return FindRoot(tree, member.parent_);
389 }
390 
MergeDeviceNames(const Member & other,bool allow_soft_placement)391 Status Member::MergeDeviceNames(const Member& other,
392                                 bool allow_soft_placement) {
393   // Assuming the "requested is a specialization of assigned and resource
394   // devices" invariant holds for this and `other`, it will hold after the
395   // merges below.
396   DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
397   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
398       &assigned_device_name_copy, other.assigned_device_name_));
399 
400   DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
401   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
402       &resource_device_name_copy, other.resource_device_name_));
403 
404   DeviceNameUtils::ParsedName requested_device_name_copy =
405       requested_device_name_;
406   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
407       &requested_device_name_copy, other.requested_device_name_,
408       allow_soft_placement));
409 
410   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
411                                        assigned_device_name_copy);
412   DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
413                                        resource_device_name_copy);
414 
415   // We checked for all errors, now change the devices.
416   assigned_device_name_ = std::move(assigned_device_name_copy);
417   resource_device_name_ = std::move(resource_device_name_copy);
418   requested_device_name_ = std::move(requested_device_name_copy);
419   return OkStatus();
420 }
421 
422 // Updates this to contain the intersection of the device types in
423 // this and "other".
MergeSupportedDevices(const Member & other)424 bool Member::MergeSupportedDevices(const Member& other) {
425   return MergeSupportedDevices(other.supported_device_types_);
426 }
427 
MergeSupportedDevices(const PrioritizedDeviceTypeVector & other_devices)428 bool Member::MergeSupportedDevices(
429     const PrioritizedDeviceTypeVector& other_devices) {
430   // Generate intersection with priorities.
431   // Each vector contains the same device types but with different priorities.
432   // The priorities are taken from the corresponding source vector.
433   PrioritizedDeviceTypeVector target_intersection;
434   PrioritizedDeviceTypeVector other_intersection;
435 
436   for (const auto& prioritized_device_type : supported_device_types_) {
437     bool found = false;
438     for (const auto& other_prioritized_device_type : other_devices) {
439       if (prioritized_device_type.first ==
440           other_prioritized_device_type.first) {
441         found = true;
442         other_intersection.push_back(other_prioritized_device_type);
443         break;
444       }
445     }
446     if (found) {
447       target_intersection.push_back(prioritized_device_type);
448     }
449   }
450 
451   DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
452   DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
453 
454   PrioritizedDeviceTypeVector result;
455 
456   bool is_target_prioritized = HasPriorities(target_intersection);
457   bool is_other_prioritized = HasPriorities(other_intersection);
458   if (!is_target_prioritized && !is_other_prioritized) {
459     // If neither are prioritized then we just return the original i.e. target
460     // prioritization.
461     result = target_intersection;
462   } else if (is_target_prioritized && !is_other_prioritized) {
463     // If only one is prioritized, then we respect priorities of that in the
464     // intersection.
465     result = target_intersection;
466   } else if (!is_target_prioritized && is_other_prioritized) {
467     result = other_intersection;
468   } else {
469     // If both have priorities and agree then we go with that. If the
470     // prioritization order is different, then we just fallback to the default
471     // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
472     // merged priorities to 0, so that downstream merges work correctly as well.
473     if (ArePrioritiesSame(target_intersection, other_intersection)) {
474       result = target_intersection;
475     } else {
476       for (const auto& prioritized_device : target_intersection) {
477         result.push_back(std::make_pair(prioritized_device.first, 0));
478       }
479       DeviceSet::SortPrioritizedDeviceTypeVector(&result);
480     }
481   }
482 
483   if (result.empty()) {
484     return false;
485   }
486   supported_device_types_ = result;
487   return true;
488 }
489 
AssignDevice(const Node & node)490 Status Member::AssignDevice(const Node& node) {
491   if (node.assigned_device_name_index() == assigned_device_name_index_) {
492     return OkStatus();
493   }
494 
495   DeviceNameUtils::ParsedName parsed;
496   DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
497   Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
498   if (!s.ok()) {
499     return errors::Internal(
500         "Constraining by assigned device should not cause an error. Original "
501         "root's assigned device name: ",
502         DeviceNameUtils::ParsedNameToString(assigned_device_name_),
503         " node's assigned device name \"", node.assigned_device_name(),
504         ". Error: ", s.error_message());
505   }
506   s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
507   if (!s.ok()) {
508     return errors::Internal(
509         "Constraining by assigned device should not cause an error. Original "
510         "root's resource device name: ",
511         DeviceNameUtils::ParsedNameToString(resource_device_name_),
512         " node's assigned device name \"", node.assigned_device_name(),
513         ". Error: ", s.error_message());
514   }
515   s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
516   if (!s.ok()) {
517     return errors::Internal(
518         "Constraining by assigned device should not cause an error. Original "
519         "root's requested device name: \"",
520         DeviceNameUtils::ParsedNameToString(requested_device_name_),
521         "\", node's assigned device name \"", node.assigned_device_name(),
522         "\". Error: ", s.error_message());
523   }
524 
525   assigned_device_name_index_ = node.assigned_device_name_index();
526   // Clear cached possible_devices, if any.
527   possible_devices_.clear();
528   return OkStatus();
529 }
530 
MaybeExcludeXlaDevices()531 void Member::MaybeExcludeXlaDevices() {
532   for (const auto& parsed_name :
533        {requested_device_name_, assigned_device_name_, resource_device_name_}) {
534     // Don't exculde XLA devices from supported devices if member is explicitly
535     // assigned to a CompositeDevice.
536     if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
537                                  IsCompositeDevice(parsed_name.type))) {
538       return;
539     }
540   }
541 
542   PrioritizedDeviceTypeVector non_xla_types;
543   absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
544                   [&](const std::pair<DeviceType, int32>& entry) {
545                     return !IsXlaDevice(entry.first.type_string());
546                   });
547 
548   // TODO(b/141216278) Remove all XLA device types from the supported device
549   // types if the node has no requested/assigned/resource XLA device.
550   if (!non_xla_types.empty() &&
551       non_xla_types.size() < supported_device_types_.size()) {
552     supported_device_types_ = std::move(non_xla_types);
553   }
554 }
555 
LimitToPossibleDevices(const PossibleDevices & devices,bool allow_soft_placement)556 Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
557                                       bool allow_soft_placement) {
558   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
559       &requested_device_name_, devices.requested_device_name,
560       allow_soft_placement));
561   TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
562       &resource_device_name_, devices.resource_device_name));
563   MergeSupportedDevices(devices.device_types);
564   return OkStatus();
565 }
566 
DebugString() const567 string Member::DebugString() const {
568   return absl::StrCat(
569       "Member(assigned_device_name_index_=", assigned_device_name_index_,
570       " requested_device_name_='",
571       DeviceNameUtils::ParsedNameToString(requested_device_name_),
572       "' assigned_device_name_='",
573       DeviceNameUtils::ParsedNameToString(assigned_device_name_),
574       "' resource_device_name_='",
575       DeviceNameUtils::ParsedNameToString(resource_device_name_),
576       "' supported_device_types_=[",
577       absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
578                     ", "),
579       "] possible_devices_=[",
580       absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
581 }
582 
GetSoftDeviceName() const583 DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
584   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
585   if (!assigned_device_name_.has_type) {
586     soft_device_name.type.clear();
587     soft_device_name.has_type = false;
588   }
589   if (!assigned_device_name_.has_id) {
590     soft_device_name.has_id = false;
591   }
592   return soft_device_name;
593 }
594 
GetPreferredSoftDeviceName() const595 DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
596   DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
597   if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
598     soft_device_name.type.clear();
599     soft_device_name.has_type = false;
600   }
601   if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
602     soft_device_name.has_id = false;
603   }
604   return soft_device_name;
605 }
606 
607 // Returns ParsedName whose address space (i.e. job, replica, task) identifies
608 // the address space directly accessible by the local process. If the address
609 // space is fully specified and it is exactly the same as the address space
610 // of a device, then all kernels of that device should be registered in the
611 // local process.
LocalAddressSpec(const Device * client_device,const Device * default_local_device)612 static const DeviceNameUtils::ParsedName LocalAddressSpec(
613     const Device* client_device, const Device* default_local_device) {
614   if (client_device != nullptr) {
615     return DeviceNameUtils::AddressSpace(client_device->parsed_name());
616   }
617 
618   if (default_local_device != nullptr) {
619     return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
620   }
621 
622   // TODO(b/139617593) Return the name of the first local device in device_set_
623   // once we can trust the output of Device::IsLocal().
624   return DeviceNameUtils::ParsedName();
625 }
626 
ColocationGraph(const Graph * graph,const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_local_device,bool allow_soft_placement,bool log_device_placement)627 ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
628                                  const FunctionLibraryDefinition* flib_def,
629                                  const DeviceSet* device_set,
630                                  const Device* default_local_device,
631                                  bool allow_soft_placement,
632                                  bool log_device_placement)
633     : graph_(*graph),
634       stack_(stack),
635       inspecting_placer_(stack, flib_def, device_set, default_local_device,
636                          allow_soft_placement, log_device_placement),
637       inspection_required_checker_(graph, flib_def),
638       device_set_(*device_set),
639       device_types_(device_set->PrioritizedDeviceTypeList()),
640       local_address_spec_(
641           LocalAddressSpec(device_set->client_device(), default_local_device)),
642       default_local_device_(default_local_device),
643       allow_soft_placement_(allow_soft_placement),
644       log_device_placement_(log_device_placement) {
645   members_.resize(graph_.num_node_ids());
646 }
647 
648 // Adds each node of the Graph to this ColocationGraph as a singleton.
649 //
650 // NOTE: The implementation assumes that the ids of nodes passed to
651 // this method are dense and zero-based; the memory used will be linear in
652 // the largest node ID.
653 // NOTE: If this method returns an error, *this is left in an undefined
654 // state.
ColocateAllNodes()655 Status ColocationGraph::ColocateAllNodes() {
656   // This maps from a colocation group identifier to the 'root' of that
657   // colocation group.  Note that the keys in this map are StringPiece; the
658   // actual strings are stored under the NodeDef.  The lifetime of this map
659   // is limited to this ColocateAllNodes() method, and no part of the
660   // NodeDef trees are changed during the lifetime of this method, so using
661   // StringPiece as a key is safe.
662   //
663   // Also, as a further optimization, we remove the "loc:@" prefix from
664   // "class" attribute values, when they are used as keys in this table.
665   // This allows us to use StringPiece values that refer to substrings of
666   // 'string' values stored in NodeDef attribute lists, as well as StringPiece
667   // values that refer to 'string' values from NodeDef::name(), without
668   // performing any string allocations.
669   std::unordered_map<StringPiece, const Node*, StringPieceHasher>
670       colocation_group_root;
671 
672   for (const Node* node : graph_.op_nodes()) {
673     // When adding the node, identify whether it is part of a colocation
674     // group.
675 
676     // This code is effectively the equivalent of GetNodeAttr() for a string
677     // array, but it avoids all internal allocations (the allocation of the
678     // backing store of the std::vector<string> as well as the copies of the
679     // strings within it).  Instead, we combine the query of the colocation
680     // attribute with the calls to ColocateNodeToGroup.
681     const AttrValue* attr_value =
682         node->attrs().Find(kColocationAttrNameStringPiece);
683     if (attr_value != nullptr) {
684       if (attr_value->has_list()) {
685         for (const string& class_spec : attr_value->list().s()) {
686           StringPiece spec(class_spec);
687           if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
688             TF_RETURN_IF_ERROR(
689                 ColocateNodeToGroup(&colocation_group_root, node, spec));
690           }
691         }
692       } else if (!attr_value->s().empty()) {
693         LOG(ERROR) << "The value for colocation attribute '_class' must be a "
694                       "list of strings, not a single string: "
695                    << node->DebugString();
696       }
697     }
698 
699     // Each node belongs to a colocation group with the node's name.
700     TF_RETURN_IF_ERROR(
701         ColocateNodeToGroup(&colocation_group_root, node, node->name()));
702   }
703 
704   return OkStatus();
705 }
706 
ColocateResourceOrRefEdge(const Node * src,const Node * dst)707 Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
708                                                   const Node* dst) {
709   // Colocate `src` and `dst` to maintain the invariant that nodes
710   // connected by reference edges are colocated.
711   int src_root_id = FindAndUpdateRoot(src->id());
712   int dst_root_id = FindAndUpdateRoot(dst->id());
713   auto& src_root = members_[src_root_id];
714   auto& dst_root = members_[dst_root_id];
715 
716   if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
717     // If the src root is assigned to a composite device and the dst root is
718     // assigned to a physical device, don't colocate the dst root with the src
719     // root.
720     return OkStatus();
721   }
722   TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
723       *src, src_root, *dst, log_device_placement_));
724   Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
725   if (!status.ok()) {
726     return AttachDef(
727         errors::InvalidArgument(
728             "Nodes were connected by a reference or resource connection "
729             "(requiring them to be on the same device), but the two nodes "
730             "were assigned two different devices: ",
731             status.error_message()),
732         *dst);
733   }
734   return OkStatus();
735 }
736 
ColocateResourceAndRefEdges(std::unordered_set<Node * > * inspection_required)737 Status ColocationGraph::ColocateResourceAndRefEdges(
738     std::unordered_set<Node*>* inspection_required) {
739   // If `node` has an input edge with reference type, add an edge from the
740   // source of that edge to `node`.
741   for (const Edge* edge : graph_.edges()) {
742     if (edge->IsControlEdge()) {
743       continue;
744     }
745     Node* src = edge->src();
746     Node* dst = edge->dst();
747     bool needs_inspection;
748     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
749         *src, &needs_inspection));
750     if (needs_inspection) {
751       inspection_required->insert(src);
752       continue;
753     }
754     TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
755         *dst, &needs_inspection));
756     if (needs_inspection) {
757       inspection_required->insert(dst);
758       continue;
759     }
760 
761     DataType input_type = dst->input_type(edge->dst_input());
762 
763     // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
764     // This is needed to get around the issue in b/135705778.
765     if (input_type == DT_VARIANT &&
766         data::DatasetOpKernel::IsDatasetOp(src->op_def()) &&
767         data::DatasetOpKernel::IsDatasetOp(dst->op_def())) {
768       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
769       continue;
770     }
771 
772     // Even though we can look inside function calling ops, we make an exception
773     // here mostly for performance reasons. Looking inside function calling ops
774     // is extra overhead. It is only necessary when they return resources. When
775     // they don't, we don't look inside them and make this exception here.
776     // Looking inside, could potentially enable us to make better placement
777     // decisions. It might be worth doing at some point.
778     if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
779         !IsExemptFromResourceInputColocation(dst)) {
780       TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
781     }
782   }
783 
784   return OkStatus();
785 }
786 
787 namespace {
788 // Returns tensor list element data type, if the node is one of the ops that
789 // operate with TensorLists. Otherwise returns DT_INVALID.
790 // TODO(b/199443424): Don't use op names, use FullType here.
GetElementDataType(const Node & node)791 DataType GetElementDataType(const Node& node) {
792   static absl::flat_hash_set<std::string>* tensor_list_ops =
793       new absl::flat_hash_set<std::string>(
794           {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
795            "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
796            "TensorListScatterIntoExistingList", "TensorListPushBack",
797            "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
798            "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
799            "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
800 
801   if (tensor_list_ops->contains(node.type_string())) {
802     DataType element_type;
803     if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
804       return element_type;
805     }
806   }
807 
808   return DT_INVALID;
809 }
810 }  // namespace
811 
AddHostOnlyDataTypesConstraints()812 Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
813   auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
814 
815   auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
816     return entry.first == DEVICE_CPU;
817   };
818 
819   for (Node* node : graph_.nodes()) {
820     // Skip nodes that do not have DT_VARIANT inputs.
821     if (absl::c_none_of(node->input_types(), is_variant)) {
822       continue;
823     }
824 
825     // Skip nodes that can't be placed on GPU anyway.
826     Member& root = members_[FindAndUpdateRoot(node->id())];
827     if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) {
828       continue;
829     }
830 
831     absl::optional<bool> constrain_to_host;
832 
833     // This is a list of special nodes that we know to have no HostMemory
834     // inputs, so if they receive a host-only data type, they must necessarily
835     // be constrained to the host.
836     // This is brittle. In general, this should be handled by accounting for
837     // HostMemory as a constraint when the node's device is known, not ahead of
838     // time.
839     // A less ideal, but still better alternative is to look for ops which
840     // have no HostMemory kernels for the corresponding input. Unfortunately,
841     // determining that is challenging because we lack a map from input names
842     // to node input indices.
843     // TODO(mdan): Fix this.
844     if (node->IsRetval() || node->IsIdentity() || node->IsControlFlow() ||
845         node->IsFunctionCall()) {
846       for (const auto& edge : node->in_edges()) {
847         if (HasHostMemoryOutType(*edge->src())) {
848           // Skip nodes in colocation groups that already have a device
849           // assignment
850           if (root.has_assigned_device_name()) {
851             VLOG(4) << "Special node has host-only data type input "
852                     << "but is in a colocation group that already has a device "
853                     << "assignment, so NOT adding constraint:\n"
854                     << node->def().DebugString() << "\nedge:\n"
855                     << edge->DebugString();
856             break;
857           } else {
858             VLOG(4) << "Special node has host-only data type input, "
859                     << "adding constraint:\n"
860                     << node->def().DebugString() << "\nedge:\n"
861                     << edge->DebugString();
862             constrain_to_host = true;
863             break;
864           }
865         }
866       }
867     }
868 
869     if (!constrain_to_host.has_value()) {
870       // Legacy slow path. This covers legacy data types and ops which have not
871       // been upgraded to FullType.
872       auto edge_filter = [&](const Edge& edge) -> bool {
873         // We already found the underlying data type.
874         if (constrain_to_host.has_value()) return false;
875 
876         // Otherwise follow only DT_VARIANT data edges.
877         auto edge_dtype = [&]() -> DataType {
878           return edge.src()->output_type(edge.src_output());
879         };
880         return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
881       };
882 
883       auto enter = [&](Node* n) -> void {
884         DataType element_type = GetElementDataType(*n);
885         // To handle nested lists continue traversal after finding a TensorList
886         // operation that uses DT_VARIANT for element type.
887         if (element_type == DT_INVALID || element_type == DT_VARIANT) {
888           return;
889         }
890         constrain_to_host = DataTypeAlwaysOnHost(element_type);
891       };
892 
893       ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
894                      /*stable_comparator=*/nullptr, edge_filter);
895     }
896 
897     if (constrain_to_host.has_value() && *constrain_to_host) {
898       VLOG(2) << "Constraining node " << node->name()
899               << " to CPU: it has an input with host-only "
900                  "underlying data type.";
901 
902       // Restrict possible device types to CPU only.
903       PossibleDevices possible_devices;
904       absl::c_copy_if(root.supported_device_types(),
905                       std::back_inserter(possible_devices.device_types),
906                       is_cpu_device);
907 
908       TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
909           possible_devices, /*allow_soft_placement=*/false));
910     }
911   }
912 
913   return OkStatus();
914 }
915 
AddInspectionConstraints(const std::unordered_set<Node * > & inspection_required)916 Status ColocationGraph::AddInspectionConstraints(
917     const std::unordered_set<Node*>& inspection_required) {
918   for (Node* node : inspection_required) {
919     IOColocationGroups groups;
920     TF_RETURN_IF_ERROR(
921         inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
922     VLOG(2) << "Computed IOColocationGroups for node " << node->name()
923             << ":\n\t" << groups.DebugString();
924     TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
925   }
926   return OkStatus();
927 }
928 
Initialize()929 Status ColocationGraph::Initialize() {
930   TF_RETURN_IF_ERROR(InitializeMembers());
931 
932   std::unordered_set<Node*> inspection_required;
933   TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
934   TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
935   TF_RETURN_IF_ERROR(ColocateAllNodes());
936   TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
937 
938   for (Node* node : graph_.op_nodes()) {
939     int root_id = FindAndUpdateRoot(node->id());
940     members_[root_id].MaybeExcludeXlaDevices();
941   }
942 
943   return OkStatus();
944 }
945 
946 // pair containing a node and whether this node has a resource input
947 // from the node requiring placer inspection.
948 using NodeAndBool = std::pair<const Node*, bool>;
949 
950 namespace {
951 
952 // Returns a vector of node names from `nodes`.
NodeAndBoolToString(const std::vector<NodeAndBool> & nodes)953 std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
954   std::vector<string> v;
955   v.reserve(nodes.size());
956   for (const NodeAndBool& node_and_bool : nodes) {
957     v.push_back(node_and_bool.first->name());
958   }
959   return v;
960 }
961 
962 // Given a node requiring placer inspection and its IOColocationGroups,
963 // computes `group_nodes`.
964 // group_nodes[i] contains the nodes that are members of colocation
965 // group i. These nodes are inputs or outputs of `node`.
966 // group_nodes[i][j] is a pair containing a node and whether this node
967 // has a resource input from `node`.
968 // Note:
969 // The same node can be added multiple times to the same group.
970 // The same node can be added to multiple groups.
GetGroupNodes(const IOColocationGroups & groups,const Node & node,std::vector<std::vector<NodeAndBool>> * group_nodes)971 Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
972                      std::vector<std::vector<NodeAndBool>>* group_nodes) {
973   group_nodes->reserve(groups.group_devices.size());
974   for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
975     const Node* src;
976     TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
977     int group_id = groups.input_groups[arg_idx];
978     (*group_nodes)[group_id].emplace_back(src, false);
979   }
980 
981   for (const Edge* edge : node.out_edges()) {
982     if (edge->IsControlEdge()) {
983       continue;
984     }
985 
986     int group_id = groups.output_groups[edge->src_output()];
987     (*group_nodes)[group_id].emplace_back(
988         edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
989   }
990 
991   if (VLOG_IS_ON(2)) {
992     VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
993     for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
994       VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
995               << "]";
996     }
997   }
998   return OkStatus();
999 }
1000 
1001 // Returns whether the device_type in `device_attributes` is supported.
IsSupportedDeviceType(const DeviceAttributes & device_attributes,const DeviceType & supported_type)1002 bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
1003                            const DeviceType& supported_type) {
1004   if (DeviceType(device_attributes.device_type()) == supported_type) {
1005     return true;
1006   }
1007   return IsCompositeDevice(device_attributes.device_type());
1008 }
1009 
1010 }  // namespace
1011 
ApplyIOColocationGroups(const IOColocationGroups & groups,const Node & node)1012 Status ColocationGraph::ApplyIOColocationGroups(
1013     const IOColocationGroups& groups, const Node& node) {
1014   if (groups.input_groups.size() != node.num_inputs()) {
1015     return errors::Internal(
1016         "Cannot apply input/output device constraints to node ",
1017         node.DebugString(), " because input_groups.size() (",
1018         groups.input_groups.size(),
1019         ") is different from number of inputs into the op node (",
1020         node.num_inputs(), ")");
1021   }
1022   if (groups.output_groups.size() != node.num_outputs()) {
1023     return errors::Internal(
1024         "Cannot apply input/output device constraints to node ",
1025         node.DebugString(), " because output_groups.size() (",
1026         groups.output_groups.size(),
1027         ") is different from number of outputs into the op node (",
1028         node.num_outputs(), ")");
1029   }
1030 
1031   // group_nodes[i] contains the nodes that are members of colocation
1032   // group i. These nodes are inputs or outputs of `node`.
1033   // group_nodes[i][j] is a pair containing the node and whether this node
1034   // has a resource input from `node`.
1035   // The same node can be added multiple times to the same group.
1036   // The same node can be added to multiple groups.
1037   // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
1038   std::vector<std::vector<NodeAndBool>> group_nodes(
1039       groups.group_devices.size());
1040   TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
1041 
1042   // Colocate nodes in each group
1043   for (const std::vector<NodeAndBool>& nodes : group_nodes) {
1044     for (int i = 1; i < nodes.size(); ++i) {
1045       VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
1046               << nodes[i].first->name() << "\"";
1047       if (nodes[i].second) {
1048         TF_RETURN_IF_ERROR(
1049             ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
1050       } else {
1051         TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
1052       }
1053     }
1054   }
1055 
1056   // Limit devices in each group
1057   for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
1058     // Nothing to do for empty groups. Groups can be empty if some output
1059     // of an op is not used.
1060     if (group_nodes[group_id].empty()) {
1061       continue;
1062     }
1063     const Node* group_node = group_nodes[group_id][0].first;
1064     const PossibleDevices& possible_devices = groups.group_devices[group_id];
1065     TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
1066   }
1067 
1068   return OkStatus();
1069 }
1070 
ColocateNodeToGroup(std::unordered_map<StringPiece,const Node *,StringPieceHasher> * colocation_group_root,const Node * node,StringPiece colocation_group)1071 Status ColocationGraph::ColocateNodeToGroup(
1072     std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1073         colocation_group_root,
1074     const Node* node, StringPiece colocation_group) {
1075   const Node*& root_node = (*colocation_group_root)[colocation_group];
1076   if (root_node == nullptr) {
1077     // This is the first node of the colocation group, so
1078     // designate this node as the 'root' of that colocation group.
1079     root_node = node;
1080   } else {
1081     // Try to colocate the node with the root.  If there is an
1082     // error, return it.
1083     Status s = ColocateNodes(*node, *root_node);
1084     if (!s.ok()) {
1085       if (!allow_soft_placement_) {
1086         return AttachDef(s, *node);
1087       }
1088       if (log_device_placement_) {
1089         LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1090                   << "' with nodes in colocation group '" << colocation_group
1091                   << "' because soft placement is on and an attempt at doing "
1092                      "so resulted in the following error: "
1093                   << AttachDef(s, *node).ToString();
1094       }
1095     }
1096   }
1097   return OkStatus();
1098 }
1099 
1100 // Merge the (possibly disjoint) sets containing nodes "x" and
1101 // "y". Returns OK if the all nodes in the union of these sets can
1102 // be placed on the same device type.
1103 //
1104 // NOTE: If this method returns an error, *this is left in an undefined
1105 // state.
ColocateNodes(const Node & x,const Node & y)1106 Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1107   int x_root = FindAndUpdateRoot(x.id());
1108   int y_root = FindAndUpdateRoot(y.id());
1109   return ColocateNodes(x, x_root, y, y_root);
1110 }
1111 
1112 // This overload of ColocateNodes() allows a caller to provide the root node
1113 // ids for the two nodes. For large graphs, this noticeably reduces the
1114 // graph load time.
ColocateNodes(const Node & x,int x_root,const Node & y,int y_root)1115 Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1116                                       int y_root) {
1117   if (x_root == y_root) {
1118     return OkStatus();
1119   }
1120 
1121   Member* new_root_member;
1122   Member* old_root_member;
1123   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1124                 /*dry_run=*/true);
1125 
1126   // Merge the partial device specifications, and ensure that they are
1127   // compatible. NULL options_ is treated as allowing soft placement.
1128   // If there is an error, nothing is modified.
1129   // TODO(mrry): Consider enriching the error message by pointing
1130   // out which nodes have the explicit partial device
1131   // specifications that caused this conflict.
1132   Status s = new_root_member->MergeDeviceNames(*old_root_member,
1133                                                allow_soft_placement_);
1134   if (!s.ok()) {
1135     return errors::InvalidArgument(
1136         "Cannot colocate nodes ",
1137         errors::FormatColocationNodeForError(x.name()), " and ",
1138         errors::FormatColocationNodeForError(y.name()), ": ",
1139         s.error_message());
1140   }
1141 
1142   // Ensure that the common root has at least one supported device
1143   // type, by computing the intersection of
1144   // new_root_member.supported_device_types and
1145   // old_root_member.supported_device_types.
1146   if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1147     return errors::InvalidArgument(
1148         "Cannot colocate nodes ",
1149         errors::FormatColocationNodeForError(x.name()), " and ",
1150         errors::FormatColocationNodeForError(y.name()),
1151         " because no device type supports both of those nodes and the "
1152         "other nodes colocated with them.",
1153         DebugInfo(x_root), DebugInfo(y_root));
1154   }
1155 
1156   // All error checks are done, merge the colocation graphs.
1157   Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1158                 /*dry_run=*/false);
1159   return OkStatus();
1160 }
1161 
LimitToAssignedDevice(const Node & node)1162 Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1163   if (node.assigned_device_name_index() < 0) {
1164     return errors::Internal(
1165         "Expected an assigned node as argument to LimitToAssignedDevice but "
1166         "got: ",
1167         node.DebugString());
1168   }
1169   int root = FindAndUpdateRoot(node.id());
1170   Member& root_member = members_[root];
1171   return root_member.AssignDevice(node);
1172 }
1173 
GetSoftDeviceCandidates(const Node & node,const Member & root_member,int root_id,std::vector<Device * > * possible_devices)1174 void ColocationGraph::GetSoftDeviceCandidates(
1175     const Node& node, const Member& root_member, int root_id,
1176     std::vector<Device*>* possible_devices) {
1177   // Try to find supported devices that don't violate resource devices.
1178   // The soft_device_name is the same as the requested device name
1179   // without specifying the device type or ID (if assigned and requested
1180   // devices does not specify them).
1181   DeviceNameUtils::ParsedName soft_device_name =
1182       root_member.GetPreferredSoftDeviceName();
1183   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1184   if (!possible_devices->empty()) {
1185     *possible_devices = FilterSupportedDevices(
1186         *possible_devices, root_member.supported_device_types(),
1187         default_local_device_);
1188   }
1189 
1190   if (!possible_devices->empty()) {
1191     return;
1192   }
1193 
1194   // TODO(iga): Disallow changing resource devices when this ColocationGraph
1195   // is for :
1196   // - a function called by an op requiring deep inspection, or
1197   // - a graph containing ops requiring inspection.
1198   // It is fairly tricky to make changing resource devices in presence of
1199   // ops requiring inspection work correctly. One thing it would require is to
1200   // communicate these "resource movement" decisions across Placer instances.
1201 
1202   // Failed to find supported devices that don't violate resource devices.
1203   // Try finding some devices that violated resource devices.
1204   // If we succeed, we will log a warning below.
1205   soft_device_name = root_member.GetSoftDeviceName();
1206   device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1207   if (!possible_devices->empty()) {
1208     *possible_devices = FilterSupportedDevices(
1209         *possible_devices, root_member.supported_device_types(),
1210         default_local_device_);
1211   }
1212 
1213   if (!possible_devices->empty()) {
1214     LOG(WARNING)
1215         << "Failed to place the graph without changing the devices of some "
1216            "resources. Some of the operations (that had to be colocated with "
1217            "resource generating operations) are not supported on the "
1218            "resources' devices. Current candidate devices are [\n  "
1219         << absl::StrJoin(DevicesToString(*possible_devices), "\n  ")
1220         << "].\nSee below for details of this colocation group:"
1221         << DebugInfo(root_id);
1222   }
1223 }
1224 
LimitToPossibleDevices(const Node & node,const PossibleDevices & devices)1225 Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1226                                                const PossibleDevices& devices) {
1227   int root = FindAndUpdateRoot(node.id());
1228   Member& root_member = members_[root];
1229   return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1230 }
1231 
GetDevicesForNode(Node * node,const std::vector<Device * > ** possible_devices)1232 Status ColocationGraph::GetDevicesForNode(
1233     Node* node, const std::vector<Device*>** possible_devices) {
1234   *possible_devices = nullptr;
1235   const int node_root = FindAndUpdateRoot(node->id());
1236   if (!members_[node_root].possible_devices().empty()) {
1237     *possible_devices = &members_[node_root].possible_devices();
1238     return OkStatus();
1239   }
1240 
1241   Member& root_member = members_[node_root];
1242 
1243   // We have not yet computed the possible devices for the
1244   // colocated node set containing 'node', so we do so now using the
1245   // constraints on the root node.
1246 
1247   // "devices" will contain the set of feasible placements for the
1248   // colocated node set containing 'node'.
1249   // NOTE: Basing possible device computation on requested device name
1250   // is guaranteed to respect the assigned and resource device names because
1251   // requested device is always a specialization of both.
1252   std::vector<Device*> devices;
1253   if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1254     // The root node has a (possibly partial) device
1255     // specification, so enumerate the physical devices that
1256     // conform to it.
1257     device_set_.FindMatchingDevices(root_member.requested_device_name(),
1258                                     &devices);
1259 
1260     if (!devices.empty()) {
1261       // Filter devices into those that are compatible with the root
1262       // node (and its children).
1263       devices = FilterSupportedDevices(
1264           devices, root_member.supported_device_types(), default_local_device_);
1265     }
1266 
1267     // Perform soft placement if allow_soft_placement_ is set.
1268     if (devices.empty() && allow_soft_placement_) {
1269       GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1270     }
1271 
1272     if (devices.empty()) {
1273       // Return an error when a physical device that matches an explicit
1274       // device specification is not found. This ensures that we don't
1275       // assign a node to GPU when the user wanted to force it on CPU.
1276       string debug_info = DebugInfo(node_root);
1277 
1278       DeviceNameUtils::ParsedName specified_device_name;
1279       if (DeviceNameUtils::ParseFullName(node->requested_device(),
1280                                          &specified_device_name) &&
1281           specified_device_name == root_member.requested_device_name()) {
1282         // The specified device and merged set device match, and
1283         // will appear in the GraphDef (for debugging), so just
1284         // print the specified device.
1285         std::vector<Device*> devices_matching_nodedef;
1286         device_set_.FindMatchingDevices(specified_device_name,
1287                                         &devices_matching_nodedef);
1288         if (devices_matching_nodedef.empty()) {
1289           // Sometimes it is almost impossible to understand the problem
1290           // without a list of available devices.
1291           std::vector<string> device_names;
1292           for (const Device* device : device_set_.devices()) {
1293             device_names.push_back(device->name());
1294           }
1295           std::sort(device_names.begin(), device_names.end());
1296 
1297           string gpu_msg = "";
1298           if (!IsGoogleCudaEnabled() &&
1299               absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1300             gpu_msg =
1301                 " The requested device appears to be a GPU, but CUDA is not "
1302                 "enabled.";
1303           }
1304 
1305           return errors::InvalidArgument(
1306               errors::FormatNodeNameForError(node->name()),
1307               " was explicitly assigned to ", node->requested_device(),
1308               " but available devices are [ ",
1309               absl::StrJoin(device_names, ", "), " ]. Make sure ",
1310               "the device specification refers to a valid device.", gpu_msg);
1311         } else if (specified_device_name.has_type) {
1312           return errors::InvalidArgument(
1313               "Could not satisfy explicit device specification '",
1314               node->requested_device(), "' because no supported kernel for ",
1315               specified_device_name.type, " devices is available.", debug_info,
1316               "\nOp: ", node->type_string(),
1317               "\nNode attrs: ", node->attrs().DebugString(),
1318               "\nRegistered kernels:\n",
1319               KernelsRegisteredForOp(node->type_string()));
1320         } else {
1321           return errors::InvalidArgument(
1322               "Could not satisfy explicit device specification '",
1323               node->requested_device(), debug_info);
1324         }
1325       } else {
1326         // The specified device may be a valid device but the
1327         // merged set device is different, so print both.
1328         // TODO(b/129057603): There are many possibilities at this point.
1329         // Provide good error messages.
1330         return errors::InvalidArgument(
1331             "Could not satisfy explicit device specification '",
1332             node->requested_device(), "' because the node ",
1333             errors::FormatColocationNodeForError(node->name()),
1334             " was colocated with a group of nodes that ",
1335             "required incompatible device '",
1336             DeviceNameUtils::ParsedNameToString(
1337                 root_member.requested_device_name()),
1338             "'. All available devices [",
1339             absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1340             debug_info);
1341       }
1342     }
1343   } else {
1344     // The device is completely unspecified, so enumerate the devices that
1345     // support all of the nodes in the set.
1346     if (device_set_.devices().empty()) {
1347       return errors::Internal("No devices are registered");
1348     }
1349     devices = FilterSupportedDevices(device_set_.devices(),
1350                                      root_member.supported_device_types(),
1351                                      default_local_device_);
1352 
1353     if (devices.empty()) {
1354       return errors::InvalidArgument(
1355           "Node had no OpKernel registered to support this operation: ",
1356           "Operation was ", node->type_string(), " and inputs were [",
1357           DataTypeVectorString(node->input_types()), "].\n",
1358           DebugInfo(node_root));
1359     }
1360   }
1361 
1362   // Cache the result of the possible devices for this node group.
1363   root_member.set_possible_devices(std::move(devices));
1364   *possible_devices = &root_member.possible_devices();
1365   return OkStatus();
1366 }
1367 
InitializeMembers()1368 Status ColocationGraph::InitializeMembers() {
1369   for (Node* node : graph_.op_nodes()) {
1370     Status status = InitializeMember(*node, &members_[node->id()]);
1371     if (!status.ok()) {
1372       return AttachDef(status, *node);
1373     }
1374   }
1375   return OkStatus();
1376 }
1377 
DebugString() const1378 string ColocationGraph::DebugString() const {
1379   std::unordered_set<int> roots;
1380   std::vector<string> root_strings;
1381   for (const Node* node : graph_.nodes()) {
1382     if (!node->IsOp()) {
1383       continue;
1384     }
1385     int node_root = FindRoot(node->id());
1386     if (roots.count(node_root) == 0) {
1387       root_strings.push_back(DebugInfo(node_root));
1388       roots.insert(node_root);
1389     }
1390   }
1391   return absl::StrJoin(root_strings, "\n");
1392 }
1393 
1394 // Returns debugging info for the node referred to by 'node_root'.
DebugInfo(const int node_root) const1395 string ColocationGraph::DebugInfo(const int node_root) const {
1396   string text(
1397       "\nColocation Debug Info:\n"
1398       "Colocation group had the following types and supported devices: ");
1399 
1400   // If this node is part of a colocation group, then we want to
1401   // collect the mapping of ops to supported devices, so that
1402   // the user can see why an unsatisfiable placement occurred.
1403 
1404   std::unordered_map<string, string> type_to_devices;
1405   std::vector<const Node*> colocation_nodes;
1406   int num_nodes_found = 0;
1407 
1408   for (const Node* node : graph_.nodes()) {
1409     if (!node->IsOp()) {
1410       continue;
1411     }
1412     int id = node->id();
1413     if (FindRoot(id) != node_root) {
1414       continue;
1415     }
1416     ++num_nodes_found;
1417     colocation_nodes.push_back(node);
1418 
1419     PrioritizedDeviceTypeVector supported_types;
1420     SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1421                                 &local_address_spec_)
1422         .IgnoreError();
1423     string devices_registered;
1424     for (const auto& device_type : supported_types) {
1425       strings::StrAppend(&devices_registered,
1426                          DeviceTypeString(device_type.first), " ");
1427     }
1428 
1429     const string& op_type = node->type_string();
1430     type_to_devices[op_type] = std::move(devices_registered);
1431   }
1432   strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1433 
1434   for (const auto& td : type_to_devices) {
1435     strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1436   }
1437   strings::StrAppend(&text,
1438                      "\n\nColocation members, user-requested devices, and "
1439                      "framework assigned devices, if any:");
1440   for (const Node* node : colocation_nodes) {
1441     strings::StrAppend(&text, "\n  ", node->name(), " (", node->type_string(),
1442                        ") ", node->requested_device());
1443     if (node->has_assigned_device_name()) {
1444       strings::StrAppend(
1445           &text, " framework assigned device=", node->assigned_device_name());
1446     }
1447   }
1448   strings::StrAppend(&text, "\n");
1449 
1450   if (num_nodes_found <= 0) {
1451     text.clear();
1452   }
1453   return text;
1454 }
1455 
InitializeMemberWithAssignedDevice(const string & assigned_device_name,const string & node_type,Member * member)1456 Status ColocationGraph::InitializeMemberWithAssignedDevice(
1457     const string& assigned_device_name, const string& node_type,
1458     Member* member) {
1459   // This node has already been assigned to a device, so we
1460   // respect this placement, after sanity-checking it.
1461   // NOTE: Since any assignment must have been performed by
1462   // the TensorFlow runtime, we consider errors in this branch to
1463   // be INTERNAL.
1464   TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1465 
1466   // Since assigned device must be a full specification, do extra checks.
1467   const Device* assigned_device =
1468       device_set_.FindDeviceByName(assigned_device_name);
1469   if (assigned_device == nullptr) {
1470     // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1471     // calls when they are supported.
1472     return errors::Internal(
1473         "Assigned device '", assigned_device_name,
1474         "' does not match any device. This error can happen when one attempts "
1475         "to run a tf.function with resource inputs residing on remote devices. "
1476         "This use case is currently not supported. Here are the devices "
1477         "available on this machine: [",
1478         absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1479         "If you are seeing this error when running using a tf.Session, set "
1480         "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1481   }
1482 
1483   for (const auto& d : member->supported_device_types()) {
1484     if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1485       return OkStatus();
1486     }
1487   }
1488 
1489   return errors::Internal("Assigned device '", assigned_device_name,
1490                           "' does not have registered OpKernel support "
1491                           "for ",
1492                           node_type);
1493 }
1494 
InitializeMember(const Node & node,Member * member)1495 Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1496   TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1497       node, device_types_, &local_address_spec_));
1498 
1499   if (node.has_assigned_device_name()) {
1500     TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1501         node.assigned_device_name(), node.type_string(), member));
1502   } else {
1503     // This node has not yet been assigned to a device, so we
1504     // calculate any constraints due to the set of registered
1505     // kernels and any (partial) user-provided device specification
1506     // in the NodeDef.
1507 
1508     // If no kernels are registered for this op type, fail with an error.
1509     if (member->supported_device_types().empty()) {
1510       std::set<string> registered_device_types;
1511       for (Device* d : device_set_.devices()) {
1512         registered_device_types.insert(d->device_type());
1513       }
1514       return errors::InvalidArgument(
1515           "No OpKernel was registered to support Op '", node.type_string(),
1516           "' used by ", errors::FormatNodeNameForError(node.name()),
1517           " with these attrs: [", node.attrs().DebugString(),
1518           "]\n"
1519           "Registered devices: [",
1520           absl::StrJoin(registered_device_types, ", "), "]\n",
1521           "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1522     }
1523 
1524     // If the NodeDef contains a device, then we interpret it as a
1525     // (partial) device specification.
1526     if (!node.requested_device().empty()) {
1527       if (IsRefOrResourceGeneratorNode(node)) {
1528         // Treat requested device on resource generating nodes as assigned
1529         // device so that we don't override it.
1530         TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1531       } else {
1532         // The user has specified a device in the NodeDef, try to find a
1533         // valid device matching their specification in the set of
1534         // devices.
1535         // NOTE: The full name may specify a device that is not in
1536         // n.supported_device_types(), but we check that in AssignDevice().
1537         TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1538       }
1539     }
1540   }
1541   return OkStatus();
1542 }
1543 
1544 // Returns a list of devices having type in supported_device_types.  The
1545 // returned list is sorted by preferred type (higher numeric type is preferred).
FilterSupportedDevices(const std::vector<Device * > & devices,const PrioritizedDeviceTypeVector & supported_device_types,const Device * default_local_device)1546 /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1547     const std::vector<Device*>& devices,
1548     const PrioritizedDeviceTypeVector& supported_device_types,
1549     const Device* default_local_device) {
1550   Device* filtered_default_device = nullptr;
1551   PrioritizedDeviceVector prioritized_filtered_devices;
1552   for (const auto& supported_device_type : supported_device_types) {
1553     for (Device* device : devices) {
1554       if (IsSupportedDeviceType(device->attributes(),
1555                                 supported_device_type.first)) {
1556         if (default_local_device &&
1557             (device == default_local_device ||
1558              // TODO(nareshmodi, fishx): At times the device pointer in the
1559              // device set is different to the one passed in as the default
1560              // device. Figure out why this might be.
1561              device->name() == default_local_device->name())) {
1562           filtered_default_device = device;
1563         } else {
1564           prioritized_filtered_devices.emplace_back(
1565               device, supported_device_type.second);
1566         }
1567       }
1568     }
1569   }
1570   DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1571 
1572   std::vector<Device*> filtered_devices;
1573   if (filtered_default_device != nullptr) {
1574     filtered_devices.emplace_back(filtered_default_device);
1575   }
1576   for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1577     filtered_devices.push_back(prioritized_filtered_device.first);
1578   }
1579   return filtered_devices;
1580 }
1581 
1582 }  // namespace tensorflow
1583