xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
17 
18 #include <memory>
19 #include <sstream>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 
35 namespace xla {
36 
ToString() const37 std::string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
38   std::string repr =
39       (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL";
40   switch (kind_) {
41     case ComputationKind::kInvalid:
42       repr += ":INVALID";
43       break;
44     case ComputationKind::kWhileCondition:
45       repr += ":WHILE_CONDITION";
46       break;
47     case ComputationKind::kWhileBody:
48       repr += ":WHILE_BODY";
49       break;
50     case ComputationKind::kConditionalBranch:
51       repr += absl::StrCat(":CONDITIONAL_BRANCH_", index_);
52       break;
53     case ComputationKind::kCallFunction:
54       repr += ":CALL";
55       break;
56   }
57   return repr;
58 }
59 
60 /* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
Build(absl::Span<HloModule * const> modules)61 HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) {
62   auto metadata = std::make_unique<HloModuleGroupMetadata>(modules);
63   TF_RETURN_IF_ERROR(metadata->Build());
64   return std::move(metadata);
65 }
66 
Build()67 Status HloModuleGroupMetadata::Build() {
68   TF_RETURN_IF_ERROR(RecordInstructions());
69   TF_RETURN_IF_ERROR(VerifyChannelInstructions());
70 
71   // Record all companion while instructions.
72   const auto visitor = [this](HloInstruction* hlo) -> Status {
73     // We only need to process if the instruction is within the computation
74     // of a companion instruction, like in the condition or body computation
75     // of a While.
76     const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent());
77     if (tracked == nullptr) {
78       return OkStatus();
79     }
80 
81     if (IsChannelInstruction(hlo) || hlo->IsCrossModuleAllReduce()) {
82       std::vector<HloComputation*> peers;
83       if (IsChannelInstruction(hlo)) {
84         peers.push_back(PeerComputation(hlo));
85       } else if (hlo->IsCrossModuleAllReduce()) {
86         for (HloInstruction* instr : GetAllReduceGroup(*hlo->channel_id())) {
87           if (instr == hlo) {
88             continue;
89           }
90           peers.push_back(instr->parent());
91         }
92       }
93 
94       // Add the parent computation of this channel (or all-reduce) instruction
95       // and its peer computation(s) (both must be while computations) as
96       // companions.
97       for (HloComputation* peer_computation : peers) {
98         const TrackedInstruction* peer_tracked =
99             GetTrackedInstruction(peer_computation);
100         if (peer_tracked == nullptr) {
101           continue;
102         }
103         TF_RET_CHECK(*tracked == *peer_tracked)
104             << "Peer instruction does not match the computation kind";
105         TF_RETURN_IF_ERROR(
106             AddCompanion(tracked->instruction(), peer_tracked->instruction()));
107         tracked_instructions_comms_[tracked->instruction()].push_back(hlo);
108       }
109     } else if (IsCompanionInstruction(hlo)) {
110       // Add the parents of companion instructions (they must be all of the same
111       // kind of instructions, opcode wise) as companions.
112       for (HloInstruction* companion : Companions(hlo)) {
113         const TrackedInstruction* companion_tracked =
114             GetTrackedInstruction(companion->parent());
115         TF_RET_CHECK(companion_tracked != nullptr);
116         TF_RET_CHECK(*tracked == *companion_tracked);
117         TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(),
118                                         companion_tracked->instruction()));
119       }
120     }
121 
122     return OkStatus();
123   };
124 
125   // Visit the computations in postorder so that the companion information grows
126   // from inner computations to outer ones.
127   for (HloModule* module : modules_) {
128     FunctionVisitor function_visitor(visitor);
129     for (HloComputation* computation : module->MakeComputationPostOrder()) {
130       TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
131     }
132   }
133 
134   // While building the companion sets, initial sets may be removed by inserting
135   // nullptr in companion_sets_. Prune those removed sets to compact.
136   std::vector<std::unique_ptr<std::vector<HloInstruction*>>> sets;
137   for (int64_t i = 0; i < companion_sets_.size(); ++i) {
138     if (companion_sets_[i] == nullptr) {
139       continue;
140     }
141     sets.push_back(std::move(companion_sets_[i]));
142     for (HloInstruction* hlo : *sets.back()) {
143       companion_set_index_[hlo] = sets.size() - 1;
144     }
145   }
146   companion_sets_ = std::move(sets);
147 
148   TF_RETURN_IF_ERROR(VerifyCompanionSets());
149   if (VLOG_IS_ON(4)) {
150     DumpCollectedStats();
151   }
152 
153   for (HloModule* module : modules_) {
154     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
155                         HloAliasAnalysis::Run(module));
156     alias_analyses_[module] = std::move(alias_analysis);
157   }
158 
159   return OkStatus();
160 }
161 
VerifyCompanionSets() const162 Status HloModuleGroupMetadata::VerifyCompanionSets() const {
163   for (const auto& companions : companion_sets_) {
164     // A companion set must be composed at most of an instruction per
165     // device/module.
166     absl::flat_hash_set<int64_t> devices;
167     for (HloInstruction* instruction : *companions) {
168       // Go through all the communicating instructions (send, recv) of the given
169       // companion, and record their device.
170       auto it = tracked_instructions_comms_.find(instruction);
171       if (it == tracked_instructions_comms_.end()) {
172         // Companions can be added even if they have no communicating
173         // instructions, if they are parent of companions.
174         continue;
175       }
176       absl::flat_hash_set<int64_t> comm_devices;
177       for (HloInstruction* comm_instruction : it->second) {
178         auto device = GetInstructionDevice(*comm_instruction);
179         TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
180                              << " does not have a device";
181         comm_devices.insert(*device);
182       }
183       for (int64_t device : comm_devices) {
184         if (!devices.insert(device).second) {
185           std::stringstream ss;
186           ss << "Companion set:" << std::endl;
187           for (HloInstruction* hlo : *companions) {
188             ss << "  " << hlo->name() << std::endl;
189           }
190           ss << "has multiple instructions on the same device";
191           return FailedPrecondition("%s", ss.str());
192         }
193       }
194     }
195   }
196   return OkStatus();
197 }
198 
IsChannelInstruction(const HloInstruction * instruction) const199 bool HloModuleGroupMetadata::IsChannelInstruction(
200     const HloInstruction* instruction) const {
201   switch (instruction->opcode()) {
202     case HloOpcode::kSend:
203     case HloOpcode::kRecv:
204     case HloOpcode::kSendDone:
205     case HloOpcode::kRecvDone: {
206       const HloSendRecvInstruction* send_recv_instr =
207           DynCast<HloSendRecvInstruction>(instruction);
208       CHECK(send_recv_instr != nullptr);
209       return !send_recv_instr->is_host_transfer();
210     }
211     default:
212       return false;
213   }
214 }
215 
IsCompanionInstruction(HloInstruction * hlo) const216 bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
217   return companion_set_index_.contains(hlo);
218 }
219 
InstructionCommunicates(HloInstruction * hlo) const220 bool HloModuleGroupMetadata::InstructionCommunicates(
221     HloInstruction* hlo) const {
222   return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo) ||
223          hlo->IsCrossModuleAllReduce();
224 }
225 
GetChannel(int64_t channel_id) const226 const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
227     int64_t channel_id) const {
228   CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end());
229   return channels_[channel_id_map_.at(channel_id)];
230 }
231 
HasChannel(int64_t channel_id) const232 bool HloModuleGroupMetadata::HasChannel(int64_t channel_id) const {
233   return channel_id_map_.find(channel_id) != channel_id_map_.end();
234 }
235 
PeerComputation(const HloInstruction * instruction) const236 HloComputation* HloModuleGroupMetadata::PeerComputation(
237     const HloInstruction* instruction) const {
238   CHECK(IsChannelInstruction(instruction));
239   const Channel& channel = GetChannel(*instruction->channel_id());
240   switch (instruction->opcode()) {
241     case HloOpcode::kSend:
242     case HloOpcode::kSendDone:
243       return channel.recv->parent();
244     case HloOpcode::kRecv:
245     case HloOpcode::kRecvDone:
246       return channel.send->parent();
247     default:
248       LOG(FATAL) << "opcode not supported";
249   }
250 }
251 
GetAllReduceGroup(int64_t channel_id) const252 const std::vector<HloInstruction*>& HloModuleGroupMetadata::GetAllReduceGroup(
253     int64_t channel_id) const {
254   auto it = all_reduce_map_.find(channel_id);
255   CHECK(it != all_reduce_map_.end());
256   return it->second;
257 }
258 
259 std::vector<HloModuleGroupMetadata::TrackedInstruction>
GetCompanionsPath(const HloInstruction * hlo) const260 HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const {
261   std::vector<TrackedInstruction> path;
262   const HloComputation* parent = hlo->parent();
263   const TrackedInstruction* companion;
264   while ((companion = GetTrackedInstruction(parent)) != nullptr) {
265     parent = companion->instruction()->parent();
266     path.push_back(*companion);
267   }
268   return path;
269 }
270 
CheckCompanionPathsCompatibility(const std::vector<TrackedInstruction> & path0,const std::vector<TrackedInstruction> & path1) const271 bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility(
272     const std::vector<TrackedInstruction>& path0,
273     const std::vector<TrackedInstruction>& path1) const {
274   if (path0.size() != path1.size()) {
275     VLOG(5) << "Companion path size do not match: " << path0.size()
276             << " != " << path1.size();
277     return false;
278   }
279   for (int64_t i = 0; i < path0.size(); ++i) {
280     if (path0[i] != path1[i]) {
281       VLOG(5) << "Companion instructions at path index " << i
282               << " do not have the same opcode: " << path0[i].ToString()
283               << " vs " << path1[i].ToString();
284       return false;
285     }
286   }
287   return true;
288 }
289 
GetModuleId(const HloModule * module) const290 int64_t HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
291   for (int64_t i = 0; i < modules_.size(); ++i) {
292     if (modules_[i] == module) {
293       return i;
294     }
295   }
296   LOG(FATAL) << "unknown module";
297 }
298 
GetInstructionDevice(const HloInstruction & instruction) const299 std::optional<int64_t> HloModuleGroupMetadata::GetInstructionDevice(
300     const HloInstruction& instruction) const {
301   // The module group metadata can be created in both "single module, multiple
302   // devices" and "multiple modules, no explicit devices" fashions.
303   // The API returns an optional even though the current implementation always
304   // returns a device, to account for cases where we cannot guess a device.
305   // In such cases the VerifyChannelInstructions() will return proper errors.
306   std::optional<int64_t> device = instruction.sharding_unique_device();
307   if (!device) {
308     device = GetModuleId(instruction.parent()->parent());
309   }
310   return device;
311 }
312 
GetDeviceModulesCount() const313 int64_t HloModuleGroupMetadata::GetDeviceModulesCount() const {
314   return modules_.size();
315 }
316 
RecordInstructions()317 Status HloModuleGroupMetadata::RecordInstructions() {
318   const auto visitor = [this](HloInstruction* hlo) -> Status {
319     if (hlo->opcode() == HloOpcode::kWhile) {
320       tracked_instructions_[hlo->while_condition()] =
321           TrackedInstruction(hlo, ComputationKind::kWhileCondition);
322       tracked_instructions_[hlo->while_body()] =
323           TrackedInstruction(hlo, ComputationKind::kWhileBody);
324     } else if (hlo->opcode() == HloOpcode::kConditional) {
325       for (int b = 0; b < hlo->branch_count(); ++b) {
326         tracked_instructions_[hlo->branch_computation(b)] =
327             TrackedInstruction(hlo, ComputationKind::kConditionalBranch, b);
328       }
329     } else if (hlo->opcode() == HloOpcode::kCall) {
330       tracked_instructions_[hlo->to_apply()] =
331           TrackedInstruction(hlo, ComputationKind::kCallFunction);
332     }
333 
334     // Group cross module all-reduce instructions by the channel id.
335     if (hlo->IsCrossModuleAllReduce()) {
336       TF_RET_CHECK(channel_id_map_.find(*hlo->channel_id()) ==
337                    channel_id_map_.end())
338           << "channel_id " << *hlo->channel_id()
339           << " is already used by a send/recv instruction";
340       all_reduce_map_[*hlo->channel_id()].push_back(hlo);
341       max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
342       return OkStatus();
343     }
344 
345     if (!IsChannelInstruction(hlo)) {
346       return OkStatus();
347     }
348 
349     TF_RET_CHECK(all_reduce_map_.find(*hlo->channel_id()) ==
350                  all_reduce_map_.end())
351         << "channel id " << *hlo->channel_id()
352         << " is already used by an all-reduce instruction";
353 
354     // Add a new channel if needed.
355     if (channel_id_map_.find(*hlo->channel_id()) == channel_id_map_.end()) {
356       channels_.emplace_back();
357       channels_.back().id = *hlo->channel_id();
358       channel_id_map_[*hlo->channel_id()] = channels_.size() - 1;
359       max_channel_id_ = std::max(max_channel_id_, *hlo->channel_id());
360     }
361     Channel& channel = channels_[channel_id_map_[*hlo->channel_id()]];
362 
363     if (hlo->opcode() == HloOpcode::kSend) {
364       TF_RET_CHECK(channel.send == nullptr)
365           << "channel id " << *hlo->channel_id()
366           << " is used by multiple send instructions";
367       channel.send = hlo;
368     }
369     if (hlo->opcode() == HloOpcode::kRecv) {
370       TF_RET_CHECK(channel.recv == nullptr)
371           << "channel id " << *hlo->channel_id()
372           << " is used by multiple recv instructions";
373       channel.recv = hlo;
374     }
375     if (hlo->opcode() == HloOpcode::kSendDone) {
376       TF_RET_CHECK(channel.send_done == nullptr)
377           << "channel id " << *hlo->channel_id()
378           << " is used by multiple send-done instructions";
379       channel.send_done = hlo;
380     }
381     if (hlo->opcode() == HloOpcode::kRecvDone) {
382       TF_RET_CHECK(channel.recv_done == nullptr)
383           << "channel id " << *hlo->channel_id()
384           << " is used by multiple recv-done instructions";
385       channel.recv_done = hlo;
386     }
387     return OkStatus();
388   };
389 
390   for (HloModule* module : modules_) {
391     FunctionVisitor function_visitor(visitor);
392     for (auto* computation : module->computations()) {
393       TF_RETURN_IF_ERROR(computation->Accept(&function_visitor));
394     }
395   }
396   VLOG(2) << "Created " << channels_.size() << " channels";
397   VLOG(2) << "Created " << all_reduce_map_.size() << " all-reduce groups";
398   return OkStatus();
399 }
400 
AddCompanion(HloInstruction * instruction1,HloInstruction * instruction2)401 Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
402                                             HloInstruction* instruction2) {
403   TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
404                instruction1->opcode() == HloOpcode::kConditional ||
405                instruction1->opcode() == HloOpcode::kCall);
406   VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
407           << instruction2->ToString();
408   if (instruction1 == instruction2) {
409     return OkStatus();
410   } else if (!ContainsKey(companion_set_index_, instruction1) &&
411              !ContainsKey(companion_set_index_, instruction2)) {
412     companion_sets_.push_back(std::make_unique<std::vector<HloInstruction*>>());
413     auto companion_set = companion_sets_.back().get();
414     companion_set->push_back(instruction1);
415     companion_set->push_back(instruction2);
416     companion_set_index_[instruction1] = companion_sets_.size() - 1;
417     companion_set_index_[instruction2] = companion_sets_.size() - 1;
418   } else if (!ContainsKey(companion_set_index_, instruction1)) {
419     companion_sets_[companion_set_index_[instruction2]]->push_back(
420         instruction1);
421     companion_set_index_[instruction1] = companion_set_index_[instruction2];
422   } else if (!ContainsKey(companion_set_index_, instruction2)) {
423     companion_sets_[companion_set_index_[instruction1]]->push_back(
424         instruction2);
425     companion_set_index_[instruction2] = companion_set_index_[instruction1];
426   } else if (companion_set_index_[instruction1] !=
427              companion_set_index_[instruction2]) {
428     // At any point while building the companion sets, each instruction belongs
429     // to at most 1 companion set, so the union of two companion sets is
430     // concatenating two disjoint sets.
431     absl::c_copy(Companions(instruction2),
432                  std::back_inserter(
433                      *companion_sets_[companion_set_index_[instruction1]]));
434     int64_t index_to_remove = companion_set_index_[instruction2];
435     for (HloInstruction* hlo : Companions(instruction2)) {
436       companion_set_index_[hlo] = companion_set_index_[instruction1];
437     }
438     // We can't remove the set from the vector because companion_set_index_
439     // references sets by their index in this vector, so we reset to nullptr
440     // instead.
441     companion_sets_[index_to_remove].reset(nullptr);
442   }
443   return OkStatus();
444 }
445 
VerifyChannelInstructions()446 Status HloModuleGroupMetadata::VerifyChannelInstructions() {
447   for (const Channel& channel : channels_) {
448     if (channel.send == nullptr) {
449       return FailedPrecondition("missing send for id : %d", channel.id);
450     }
451     if (channel.recv == nullptr) {
452       return FailedPrecondition("missing recv for id : %d", channel.id);
453     }
454     if (channel.send_done == nullptr) {
455       return FailedPrecondition("missing send-done for id : %d", channel.id);
456     }
457     if (channel.recv_done == nullptr) {
458       return FailedPrecondition("missing recv-done for id : %d", channel.id);
459     }
460   }
461 
462   // Check if the shapes match for each channel.
463   for (const Channel& channel : channels_) {
464     const Shape& send_shape = channel.send->operand(0)->shape();
465     const Shape& recv_shape =
466         ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0);
467     if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
468       return FailedPrecondition("send/recv shapes do not match");
469     }
470     auto send_device = GetInstructionDevice(*channel.send);
471     auto send_done_device = GetInstructionDevice(*channel.send_done);
472     if (!send_device) {
473       return FailedPrecondition("send instruction must have a device: %s",
474                                 channel.send->ToString());
475     }
476     if (!send_done_device) {
477       return FailedPrecondition("send_done instruction must have a device: %s",
478                                 channel.send_done->ToString());
479     }
480     if (*send_device != *send_done_device) {
481       return FailedPrecondition(
482           "send and send-done (channel=%d) must be on the same device: %d "
483           "vs. %d",
484           channel.id, *send_device, *send_done_device);
485     }
486     auto recv_device = GetInstructionDevice(*channel.recv);
487     auto recv_done_device = GetInstructionDevice(*channel.recv_done);
488     if (!recv_done_device) {
489       return FailedPrecondition("recv_done instruction must have a device: %s",
490                                 channel.recv_done->ToString());
491     }
492     if (*recv_device != *recv_done_device) {
493       return FailedPrecondition(
494           "recv and recv-done (channel=%d) must be on the same device: %d "
495           "vs. %d",
496           channel.id, *recv_device, *recv_done_device);
497     }
498     if (*send_device == *recv_device) {
499       return FailedPrecondition(
500           "send and recv (channel=%d) must be on different devices: %d",
501           channel.id, *send_device);
502     }
503   }
504 
505   for (const Channel& channel : channels_) {
506     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
507     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
508     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
509     TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
510   }
511   // Check if the nest levels match for each channel.
512   for (const Channel& channel : channels_) {
513     std::vector<TrackedInstruction> path = GetCompanionsPath(channel.send);
514     if (!CheckCompanionPathsCompatibility(
515             path, GetCompanionsPath(channel.send_done)) ||
516         !CheckCompanionPathsCompatibility(path,
517                                           GetCompanionsPath(channel.recv)) ||
518         !CheckCompanionPathsCompatibility(
519             path, GetCompanionsPath(channel.recv_done))) {
520       return FailedPrecondition(
521           "Nest companion paths do not match for channel %d", channel.id);
522     }
523   }
524   return OkStatus();
525 }
526 
CheckCommunicatingInstruction(HloInstruction * instruction) const527 Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
528     HloInstruction* instruction) const {
529   HloComputation* computation = instruction->parent();
530   const HloModule* module = computation->parent();
531   if (module->entry_computation() == computation ||
532       tracked_instructions_.contains(computation)) {
533     return OkStatus();
534   }
535   return FailedPrecondition("channel is used in disallowed computation");
536 }
537 
DumpCollectedStats() const538 void HloModuleGroupMetadata::DumpCollectedStats() const {
539   std::map<std::pair<int64_t, int64_t>, int64_t> communication_histogram;
540   for (auto& channel : channels_) {
541     auto from_device = GetInstructionDevice(*channel.send);
542     auto to_device = GetInstructionDevice(*channel.recv);
543     LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device
544               << " to_device=" << *to_device << " send=" << channel.send->name()
545               << " send_done=" << channel.send_done->name()
546               << " recv=" << channel.recv->name()
547               << " recv_done=" << channel.recv_done->name();
548     communication_histogram[std::pair<int64_t, int64_t>(*from_device,
549                                                         *to_device)] += 1;
550   }
551   for (auto& fromto_count : communication_histogram) {
552     LOG(INFO) << "From " << fromto_count.first.first << " to "
553               << fromto_count.first.second << ": " << fromto_count.second;
554   }
555   for (auto& companion_set : companion_sets_) {
556     LOG(INFO) << "Companion set:";
557     for (HloInstruction* instruction : *companion_set) {
558       LOG(INFO) << "  " << instruction->name();
559     }
560   }
561   for (auto& instruction_comm : tracked_instructions_comms_) {
562     LOG(INFO) << "Communicating instruction " << instruction_comm.first->name();
563     for (HloInstruction* instruction : instruction_comm.second) {
564       auto device = GetInstructionDevice(*instruction);
565       LOG(INFO) << "  " << instruction->name() << " on device " << *device;
566     }
567   }
568 }
569 
570 }  // namespace xla
571