xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/copy_insertion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <optional>
21 #include <sstream>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/any.h"
29 #include "tensorflow/compiler/xla/service/compile_time_cap.h"
30 #include "tensorflow/compiler/xla/service/dump.h"
31 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
32 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/platform/logging.h"
47 
48 namespace xla {
49 namespace {
50 
51 using absl::StrAppend;
52 
IsReadonlyEntryParameterValue(const HloValue & value)53 bool IsReadonlyEntryParameterValue(const HloValue& value) {
54   const HloComputation* computation = value.defining_instruction()->parent();
55   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
56          computation == computation->parent()->entry_computation() &&
57          !computation->parent()->input_output_alias_config().ParameterHasAlias(
58              value.defining_instruction()->parameter_number(), value.index());
59 }
60 
IsConstantValue(const HloValue & value)61 bool IsConstantValue(const HloValue& value) {
62   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
63 }
64 
ValueIsReadOnly(const HloValue & value)65 bool ValueIsReadOnly(const HloValue& value) {
66   return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
67 }
68 
69 // Data structure describing the action which should be taken on parts of a
70 // computation buffers, with respect to the adding of special case copies.
71 struct SpecialCaseCopyPolicy {
72   // Insert a copy if the same buffer is found at multiple indices within the
73   // output tuple.
74   bool copy_root_replicated_buffers = false;
75   // If true, insert a copy if a buffer coming from a constant or a parameter
76   // is found within the output tuple.
77   bool copy_parameters_and_constants = false;
78 };
79 
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)80 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
81                                                HloModule* module,
82                                                HloComputation* computation) {
83   SpecialCaseCopyPolicy policy;
84   if (computation == module->entry_computation()) {
85     policy.copy_parameters_and_constants = true;
86     policy.copy_root_replicated_buffers = true;
87   }
88   return policy;
89 }
90 
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)91 bool ShouldCopyRootValue(const HloValue& value,
92                          const SpecialCaseCopyPolicy& policy) {
93   if (policy.copy_parameters_and_constants) {
94     return ValueIsReadOnly(value);
95   }
96   return false;
97 }
98 
99 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
100 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
101 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
102 // of 'to'.
103 //
104 // Requirements: 'from' and 'to' must have compatible shapes.
105 //
106 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
107 // the only index to copy. Prior to deep-copying we have:
108 //
109 //
110 //      'from'
111 //         |
112 //        ...
113 //         |
114 //       'to'
115 //
116 // DeepCopyAndAddControlEdges produces:
117 //
118 //       'from'
119 //        /   \
120 //      GTE   GTE
121 //       |     |
122 //     Copy    |
123 //    /   \   /
124 //   |    Tuple
125 //   |      |
126 //  ctrl   ...
127 //  edge    |
128 //   |      |
129 //   |    'to'
130 //   |    /   \
131 //   |  GTE   GTE
132 //    \  |     |
133 //     Copy    |
134 //        \   /
135 //        Tuple
136 //
137 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)138 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
139                            const ShapeTree<bool>& indices_to_copy) {
140   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
141   // to/from_copy_tree hold the kCopy instruction produces by the deep
142   // copies. Elements which are not copied (indices_to_copy.element(index) ==
143   // false) have nullptr at that index.
144   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
145                                             /*init_value=*/nullptr);
146   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
147                       from->parent()->DeepCopyInstruction(
148                           from, &indices_to_copy, &from_copy_tree));
149 
150   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
151   TF_ASSIGN_OR_RETURN(
152       HloInstruction * to_deep_copy,
153       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
154 
155   // Add control edges between the respective kCopy instructions.
156   for (const auto& pair : from_copy_tree) {
157     const ShapeIndex& index = pair.first;
158     HloInstruction* from_copy = pair.second;
159     HloInstruction* to_copy = to_copy_tree.element(index);
160     if (from_copy == nullptr) {
161       TF_RET_CHECK(to_copy == nullptr);
162       continue;
163     }
164     TF_RET_CHECK(to_copy != nullptr);
165     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
166   }
167 
168   return std::make_pair(from_deep_copy, to_deep_copy);
169 }
170 
171 // Compute the indices of the loop state which need copies in order to avoid
172 // live range interference. Generally, an element in the loop state does not
173 // need to be copied if the element is passed through transparently through the
174 // body.
175 //
176 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)177 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
178                            const HloInstruction* xla_while,
179                            ShapeTree<bool>* indices_to_copy) {
180   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
181 
182   bool any_copies = false;
183   const HloInstruction* init = xla_while->operand(0);
184   for (auto& pair : *indices_to_copy) {
185     const ShapeIndex& index = pair.first;
186     bool& should_copy = pair.second;
187     // If there is any ambiguity, then loop state must be copied.
188     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
189         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
190       should_copy = true;
191     } else {
192       // If the output of the while instruction is not the same as the init
193       // value of the while, then this element is not passed through the body
194       // transparently and must be copied.
195       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
196                     dataflow.GetUniqueValueAt(init, index);
197     }
198     any_copies |= should_copy;
199   }
200   return any_copies;
201 }
202 
203 // Compute the indices of the conditional outputs which need copies. Umambiguous
204 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)205 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
206                                  const HloInstruction* xla_conditional,
207                                  ShapeTree<bool>* indices_to_copy) {
208   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
209                                xla_conditional->shape()));
210 
211   bool any_copies = false;
212   for (auto& pair : *indices_to_copy) {
213     const ShapeIndex& index = pair.first;
214     bool& should_copy = pair.second;
215 
216     CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
217 
218     auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
219     // The conditional must be copied if the value is a phi.
220     should_copy =
221         value->is_phi() && value->defining_instruction() == xla_conditional;
222     any_copies |= should_copy;
223   }
224   return any_copies;
225 }
226 
227 // Add kCopy instructions around the given kWhile instruction to eliminate any
228 // possible live range interference of HLO values assuming a dependency-based
229 // ordering. Copies are added conservatively. There  likely are copies which are
230 // not strictly necessary, but they are removed later in the pass via
231 // RemoveUnnecessaryCopies.
232 //
233 // Elements (each ShapeIndex) in the loop state are considered independently.  A
234 // copy is added to each element of the loop state which is modified in the
235 // while body. For each such element, a total of three kCopy instructions are
236 // added at following locations:
237 //
238 //   (1) The init value is copied before the kWhile instruction. Before:
239 //
240 //           (Init)
241 //             |
242 //           kWhile
243 //             |
244 //            ...
245 //
246 //       After:
247 //
248 //           (Init)
249 //             |
250 //           kCopy
251 //             |
252 //           kWhile
253 //             |
254 //            ...
255 //
256 //       This copy is necessary in case the init value is simultaneously live
257 //       with the kWhile.
258 //
259 //   (2) Copies are added to the parameter and root of the while body
260 //       computation. Before:
261 //
262 //           kParameter
263 //               |
264 //              ...
265 //               |
266 //           (body root)
267 //
268 //       After:
269 //
270 //           kParameter
271 //               |
272 //             kCopy ----------+
273 //               |             |
274 //              ...           ctrl
275 //               |            edge
276 //           (body root)       |
277 //               |             |
278 //             kCopy <---------+
279 //
280 //       The root kCopy becomes the new root of the computation. Both copies are
281 //       necessary to any potential interference between the parameter value and
282 //       the root value. The control edge prevents potential interference
283 //       between the copies themselves.
284 //
285 // If the loop state is a tuple then the above kCopy instructions are a deep
286 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
287 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)288 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
289                          HloInstruction* xla_while) {
290   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
291   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
292 
293   ShapeTree<bool> indices_to_copy(xla_while->shape());
294   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
295                              &indices_to_copy)) {
296     VLOG(2) << "No copies necessary for kWhile instruction "
297             << xla_while->name();
298     return OkStatus();
299   }
300 
301   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
302   for (auto& pair : indices_to_copy) {
303     if (pair.second) {
304       VLOG(2) << "  " << pair.first;
305     }
306   }
307 
308   // Deep copy init.
309   HloInstruction* while_init = xla_while->mutable_operand(0);
310   TF_ASSIGN_OR_RETURN(
311       HloInstruction * while_init_copy,
312       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
313   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
314 
315   // Deep copy the parameter and the root. Extend a control edge from the copy
316   // of the parameter value to the corresponding copy value of the root.
317   HloComputation* body = xla_while->while_body();
318   HloInstruction* param = body->parameter_instruction(0);
319   HloInstruction* root = body->root_instruction();
320 
321   // If param is the root then all indices should have been passed through the
322   // while body and we should have returned early above.
323   TF_RET_CHECK(param != root);
324 
325   // Copy users before making a deep copy of the parameter as the deep copy
326   // will create new users of the parameter (eg, the GTE instructions of the
327   // deep copy).
328   std::vector<HloInstruction*> param_users = param->users();
329 
330   TF_ASSIGN_OR_RETURN(auto pair,
331                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
332 
333   HloInstruction* param_copy = pair.first;
334   HloInstruction* root_copy = pair.second;
335 
336   for (HloInstruction* user : param_users) {
337     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
338   }
339 
340   body->set_root_instruction(root_copy);
341   return OkStatus();
342 }
343 
344 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
345 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64_t operand_number)346 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
347                                     HloInstruction* in_place_op,
348                                     int64_t operand_number) {
349   VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
350   HloInstruction* operand = in_place_op->mutable_operand(operand_number);
351   TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
352                       in_place_op->parent()->DeepCopyInstruction(operand));
353   TF_RETURN_IF_ERROR(
354       operand->ReplaceUseWith(in_place_op, operand_number, deep_copy));
355   return OkStatus();
356 }
357 
358 // Conservatively adds copies before root instruction of entry computation and
359 // each aliased parameter to resolve interference of aliased input and output
360 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
361 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)362 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
363   HloComputation* entry = module->entry_computation();
364   HloInstruction* root = entry->root_instruction();
365 
366   ShapeTree<bool> output_indices_to_copy(root->shape());
367   std::vector<std::optional<ShapeTree<HloInstruction*>>> copied_parameters(
368       entry->num_parameters());
369   bool has_alias = false;
370   for (auto* param : entry->parameter_instructions()) {
371     bool param_has_alias = false;
372     ShapeTree<bool> param_indices_to_copy(param->shape());
373 
374     module->input_output_alias_config().ForEachAlias(
375         [&](const ShapeIndex& output_index,
376             const HloInputOutputAliasConfig::Alias& alias) {
377           if (alias.parameter_number == param->parameter_number()) {
378             param_has_alias = true;
379             *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
380                 true;
381             *(output_indices_to_copy.mutable_element(output_index)) = true;
382           }
383         });
384 
385     if (!param_has_alias) {
386       continue;
387     }
388 
389     TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
390     TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
391 
392     has_alias = true;
393     // Store a snapshot of users before DeepCopyInstruction, as
394     // DeepCopyInstruction introduces new users of the instruction.
395     std::vector<HloInstruction*> users = param->users();
396     ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
397                                                /*init_value=*/nullptr);
398     TF_ASSIGN_OR_RETURN(HloInstruction * copied,
399                         entry->DeepCopyInstruction(
400                             param, &param_indices_to_copy, &param_copy_tree));
401     if (param == root) {
402       entry->set_root_instruction(copied);
403       root = copied;
404     }
405     for (HloInstruction* user : users) {
406       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
407     }
408 
409     copied_parameters[param->parameter_number()] = param_copy_tree;
410   }
411 
412   if (!has_alias) {
413     return OkStatus();
414   }
415 
416   // Add copies before root instruction.
417   ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
418                                               /*init_value=*/nullptr);
419 
420   TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
421                       root->parent()->DeepCopyInstruction(
422                           root, &output_indices_to_copy, &output_copy_tree));
423 
424   // Add control dependencies between the input/output copies.
425   TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
426       [&](const ShapeIndex& output_index,
427           const HloInputOutputAliasConfig::Alias& alias) -> Status {
428         if (!copied_parameters[alias.parameter_number]) {
429           return OkStatus();
430         }
431         HloInstruction* from =
432             copied_parameters[alias.parameter_number]->element(
433                 alias.parameter_index);
434         HloInstruction* to = output_copy_tree.element(output_index);
435 
436         TF_RET_CHECK(from != nullptr);
437         TF_RET_CHECK(to != nullptr);
438         TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
439         return OkStatus();
440       }));
441 
442   entry->set_root_instruction(root_copied);
443 
444   return OkStatus();
445 }
446 
447 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)448 Status StripControlDependenciesFrom(HloInstruction* instruction) {
449   while (!instruction->control_successors().empty()) {
450     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
451         instruction->control_successors().front()));
452   }
453 
454   while (!instruction->control_predecessors().empty()) {
455     TF_RETURN_IF_ERROR(
456         instruction->control_predecessors().front()->RemoveControlDependencyTo(
457             instruction));
458   }
459 
460   return OkStatus();
461 }
462 
463 class LiveRangeRegions {
464  public:
465   struct InstructionInfo {
InstructionInfoxla::__anonedb7fbf20111::LiveRangeRegions::InstructionInfo466     InstructionInfo() : value_definition(nullptr), is_definition(false) {}
467 
468     // The instruction that defines the value being used. It basically saves
469     // the defining instruction of each HloValue.
470     HloInstruction* value_definition;
471     // Whether the instruction defines a new value (or merely uses one). This
472     // basically remembers whether the instruction actually creates an HloValue
473     // or merely uses one, from a collection of given HloValues. Note that if
474     // is_definition = true, it merely says the instruction creates a new
475     // HloValue with or without defining a new one. For example, kAdd create a
476     // new HloValue (can be value_definition), but tuples or get-tuple-element,
477     // create a new HloValue aliasing without defining a new value (cannot be
478     // value_definition).
479     bool is_definition;
480   };
481   // Map instructions that use a value to the defining instruction of the value.
482   // Because all values must belong to the same live range, an instruction can
483   // have at most a single value-defining instruction; otherwise the multiple
484   // incoming active values would share a single buffer, which is not allowed.
485   // The value-defining and value-use instructions do not have to belong to the
486   // same computation, but the value use needs to be nested within the defining
487   // computation.
488   typedef absl::flat_hash_map<HloInstruction*, InstructionInfo> InstructionMap;
489   typedef std::pair<HloInstruction*, InstructionInfo> InstructionEntry;
490   // Map each computation to its immediately contained instructions.
491   typedef absl::flat_hash_map<const HloComputation*, InstructionMap>
492       ComputationMap;
493 
operator [](const HloComputation * computation)494   InstructionMap& operator[](const HloComputation* computation) {
495     if (computation_map_.find(computation) == computation_map_.end()) {
496       computation_vector_.push_back(computation);
497     }
498     return computation_map_[computation];
499   }
500 
operator [](const HloComputation * computation) const501   const InstructionMap& operator[](const HloComputation* computation) const {
502     ComputationMap::const_iterator p = computation_map_.find(computation);
503     CHECK(p != computation_map_.end());
504     return p->second;
505   }
begin() const506   ComputationMap::const_iterator begin() const {
507     return computation_map_.begin();
508   }
end() const509   ComputationMap::const_iterator end() const { return computation_map_.end(); }
size() const510   int64_t size() const {
511     CHECK_EQ(computation_vector_.size(), computation_map_.size());
512     return computation_vector_.size();
513   }
empty() const514   bool empty() const { return size() == 0; }
Computation(int64_t index) const515   const HloComputation* Computation(int64_t index) const {
516     return computation_vector_[index];
517   }
contains(const HloInstruction * instr) const518   bool contains(const HloInstruction* instr) const {
519     CHECK_NE(instr, nullptr);
520     auto* computation = instr->parent();
521     auto p = computation_map_.find(computation);
522     if (p == computation_map_.end()) {
523       return false;
524     }
525     auto instr_map = (*p).second;
526     return instr_map.find(instr) != instr_map.end();
527   }
528 
529  private:
530   ComputationMap computation_map_;
531   absl::InlinedVector<const HloComputation*, 5> computation_vector_;
532 };
533 
534 namespace {
535 // Represent relations between the locations of two regions of instructions,
536 // each region can include 0-n instructions.
537 class Relation {
538  public:
539   enum RuntimeOrder {
540     // Indicate that there is no overlap whatsoever between the two regions.
541     kNoOverlap = 0,
542     // Indicate that the first region includes the same set of instructions as
543     // the second region.
544     kSameInstr = 1,
545     // Indicate that the first region is entirely before the second region
546     // starts.
547     kBeforeStart = 2,
548     // Indicate that the first region is before the second region ends.
549     kBeforeStartOrSameInstr = kBeforeStart | kSameInstr,
550     // Indicate that the first region is entirely after the second region ends.
551     kAfterEnd = 4,
552     // Indicate that the first region is after the second region
553     // starts, with some instructions before the second region ends.
554     kAfterEndOrSameInstr = kAfterEnd | kSameInstr,
555     // Indicate that the first region overlaps with the second one, but share no
556     // common instructions.
557     kBeforeStartOrAfterEnd = kBeforeStart | kAfterEnd,
558     // Indicate that the first region overlaps with the second one, and have
559     // some common instructions.
560     kBeforeOrAfterOrOverlap = kBeforeStart | kAfterEnd | kSameInstr,
561   };
Relation()562   Relation() : intercept_def_use_(false) {}
Relation(RuntimeOrder order,bool intercept_def_use=false)563   explicit Relation(RuntimeOrder order, bool intercept_def_use = false)
564       : intercept_def_use_(intercept_def_use) {
565     orders_.push_back(order);
566   }
Relation(const Relation & that)567   Relation(const Relation& that)
568       : intercept_def_use_(that.intercept_def_use_), orders_(that.orders_) {}
operator ==(const Relation & that) const569   bool operator==(const Relation& that) const {
570     return intercept_def_use_ == that.intercept_def_use_ &&
571            absl::c_equal(orders_, that.orders_);
572   }
573 
574   // Return whether the runtime ordering may imply interception, assuming it
575   // models the relation between a modifying and a use instruction.
UseImpliesInterception() const576   bool UseImpliesInterception() const {
577     CHECK_EQ(orders_.size(), 1);
578     return UseImpliesInterception(orders_[0]);
579   }
580   // Return whether the runtime ordering may imply interception, assuming it
581   // models the relation between a modifying and a definition instruction.
DefinitionImpliesInterception() const582   bool DefinitionImpliesInterception() const {
583     CHECK_EQ(orders_.size(), 1);
584     return DefinitionImpliesInterception(orders_[0]);
585   }
586   // Return whether the current relation models a modifying instruction that
587   // intercepts the dataflow of another live range region.
InterceptDefUse() const588   bool InterceptDefUse() const { return intercept_def_use_; }
589   // Update interception state to the given value.
UpdateInterception(bool value)590   void UpdateInterception(bool value) {
591     CHECK_EQ(orders_.size(), 1);
592     intercept_def_use_ = value;
593   }
GetRuntimeOrder() const594   Relation::RuntimeOrder GetRuntimeOrder() const {
595     if (orders_.empty()) {
596       return Relation::kNoOverlap;
597     }
598     CHECK_EQ(orders_.size(), 1);
599     return orders_[0];
600   }
601   // Return whether the current relation implies two overlapping regions.
RuntimeOrderOverlap() const602   bool RuntimeOrderOverlap() const {
603     return absl::c_any_of(orders_, ImpliesOverlap);
604   }
RuntimeOrderIsUnordered() const605   bool RuntimeOrderIsUnordered() const {
606     return orders_.size() == 1 && orders_[0] == kBeforeStartOrAfterEnd;
607   }
RuntimeOrderIsNoOverlap() const608   bool RuntimeOrderIsNoOverlap() const {
609     return orders_.empty() || (orders_.size() == 1 && orders_[0] == kNoOverlap);
610   }
RuntimeOrderIsRunBefore() const611   bool RuntimeOrderIsRunBefore() const {
612     return orders_.size() == 1 && orders_[0] == kBeforeStart;
613   }
RuntimeOrderIsRunAfter() const614   bool RuntimeOrderIsRunAfter() const {
615     return orders_.size() == 1 && orders_[0] == kAfterEnd;
616   }
ToString() const617   std::string ToString() const {
618     return absl::StrCat("Interception = ", intercept_def_use_, ";",
619                         absl::StrJoin(orders_, ","));
620   }
621 
DefinitionImpliesInterception(RuntimeOrder definition)622   static bool DefinitionImpliesInterception(RuntimeOrder definition) {
623     return (definition == kAfterEnd || definition == kBeforeStartOrAfterEnd);
624   }
UseImpliesInterception(RuntimeOrder use)625   static bool UseImpliesInterception(RuntimeOrder use) {
626     return (use == kBeforeStart || use == kBeforeStartOrAfterEnd);
627   }
628 
629   // Summarize additional relations into a single runtime ordering, assuming
630   // both relations are modeling constraints of the same source instruction.
UnionRelationFromSameSource(const Relation & rel)631   void UnionRelationFromSameSource(const Relation& rel) {
632     CHECK_LE(orders_.size(), 1);
633     CHECK_EQ(rel.orders_.size(), 1);
634     if (orders_.empty()) {
635       orders_.push_back(rel.orders_[0]);
636     } else {
637       orders_[0] = Union(orders_[0], rel.orders_[0]);
638     }
639     intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
640   }
641 
642   // Summarize additional relations into disjoint runtime orderings, assuming
643   // the relations are modeling constraints of different source instructions.
UnionRelationFromDifferentSource(const Relation & rel)644   void UnionRelationFromDifferentSource(const Relation& rel) {
645     if (rel.orders_.empty()) {
646       return;
647     }
648     CHECK_EQ(rel.orders_.size(), 1);
649     intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
650     for (auto& local_order : orders_) {
651       if (OverwriteIfSubsume(rel.orders_[0], &local_order)) {
652         return;
653       }
654     }
655     orders_.push_back(rel.orders_[0]);
656   }
657 
ReverseRuntimeOrder(RuntimeOrder order)658   static Relation::RuntimeOrder ReverseRuntimeOrder(RuntimeOrder order) {
659     switch (order) {
660       case kNoOverlap:
661       case kSameInstr:
662       case kBeforeStartOrAfterEnd:
663       case kBeforeOrAfterOrOverlap:
664         return order;
665       case kBeforeStart:
666         return kAfterEnd;
667       case kBeforeStartOrSameInstr:
668         return kAfterEndOrSameInstr;
669       case kAfterEnd:
670         return kBeforeStart;
671       case kAfterEndOrSameInstr:
672         return kBeforeStartOrSameInstr;
673     }
674   }
675 
676  private:
677   // Indicate that the second region may intercept the def-use dataflow of the
678   // first region, if their buffers are combined.
679   bool intercept_def_use_;
680   // Remember the different runtime orderings of different instructions.
681   absl::InlinedVector<RuntimeOrder, 4> orders_;
682 
Union(RuntimeOrder o1,RuntimeOrder o2)683   static RuntimeOrder Union(RuntimeOrder o1, RuntimeOrder o2) {
684     return static_cast<Relation::RuntimeOrder>(o1 | o2);
685   }
ImpliesOverlap(RuntimeOrder o)686   static bool ImpliesOverlap(RuntimeOrder o) {
687     return o >= RuntimeOrder::kBeforeStartOrAfterEnd;
688   }
689   // Returns whether ordering constraint o1 includes o2 as a subset, when they
690   // represent runtime orderings (interleavings) of two different regions.
Subsume(RuntimeOrder o1,RuntimeOrder o2)691   static bool Subsume(RuntimeOrder o1, RuntimeOrder o2) {
692     return Union(o1, o2) == o1;
693   }
694   // Overwrites o1 with o2 if o2 subsumes o1 (as defined above by the Subsume
695   // function). Return whether o2 is subsumed by the new value in o1.
OverwriteIfSubsume(RuntimeOrder o2,RuntimeOrder * o1)696   static bool OverwriteIfSubsume(RuntimeOrder o2, RuntimeOrder* o1) {
697     if (*o1 == o2) {
698       return true;
699     }
700     CHECK_NE(o1, nullptr);
701     // Overwrite o1 with o2 if it is subsumed by o2.
702     if (Subsume(o2, *o1)) {
703       *o1 = o2;
704       return true;
705     } else if (Subsume(*o1, o2)) {
706       // If o2 is already subsumed by o1, do nothing.
707       return true;
708     }
709     // If neither o1 nor o2 is subsumed by the other, return false, so that o2
710     // will be inserted as a separate entry representing all possible orderings.
711     return false;
712   }
713 };
714 
715 class ComputeRelativeLocation {
716  public:
717   typedef LiveRangeRegions::InstructionEntry InstructionEntry;
ComputeRelativeLocation(HloOrdering * ordering)718   explicit ComputeRelativeLocation(HloOrdering* ordering)
719       : ordering_(ordering) {
720     VLOG(3) << "New analysis\n";
721   }
722 
723   // Compute locationing constraints between two instructions. Here entry2 is
724   // the source instruction, in that the returned value describes the relation
725   // of entry2 in terms of whether it is before or after entry1, and whether it
726   // can intercept the def-use data flow of entry1.
Compute(const InstructionEntry & entry1,const InstructionEntry & entry2,bool instr2_can_modify)727   Relation Compute(const InstructionEntry& entry1,
728                    const InstructionEntry& entry2, bool instr2_can_modify) {
729     auto def = entry1.second.value_definition;
730     auto use = entry1.first;
731     Relation::RuntimeOrder order =
732         ComputeRuntimeOrdering(entry2.first, entry1.first);
733     if (order == Relation::kSameInstr &&
734         entry1.second.is_definition != entry2.second.is_definition) {
735       if (entry1.second.is_definition) {
736         order = Relation::kBeforeStart;
737       } else {
738         order = Relation::kAfterEnd;
739       }
740     }
741     bool intercept = AlwaysForceInterception(entry2.first);
742     if (def == nullptr || !instr2_can_modify) {
743       return Relation(order, intercept);
744     }
745     // If the definition and use are parameter and return (root) of the parent
746     // computation, then any modification is considered intercepting.
747     if (def->opcode() == HloOpcode::kParameter &&
748         use == use->parent()->root_instruction()) {
749       VLOG(3) << "Setting interception due to parameter/root relation\n";
750       return Relation(order, true);
751     }
752     if (Relation::UseImpliesInterception(order)) {
753       auto order2 = ComputeRuntimeOrdering(entry2.first, def);
754       if (Relation::DefinitionImpliesInterception(order2)) {
755         VLOG(3) << "Setting interception for " << def->ToString()
756                 << " with use:" << entry1.first->ToString() << "\n";
757         intercept = true;
758       }
759     }
760     return Relation(order, intercept);
761   }
762 
763   // Return the relative locations (defined above) of range2 in relation to
764   // instructions in range1. Return kNoOverlap if range2 is outside of range1.
Compute(const LiveRangeRegions & range1,const LiveRangeRegions & range2)765   Relation Compute(const LiveRangeRegions& range1,
766                    const LiveRangeRegions& range2) {
767     Relation dir_src_dest;
768     for (int64_t index = 0; index < range1.size(); index++) {
769       auto* computation1 = range1.Computation(index);
770       for (const auto& computation_entry2 : range2) {
771         auto* computation2 = computation_entry2.first;
772         for (auto instr_entry2 : computation_entry2.second) {
773           if (!ordering_->call_graph().Dominates(computation1, computation2)) {
774             continue;
775           }
776           VLOG(3) << "Locationing " << instr_entry2.first->ToString();
777           // Saves relations between instr2 and other instructions in range1.
778           bool instr2_can_modify =
779               InstructionCanIntercept(instr_entry2, range1);
780           Relation instr2_relation;
781           std::vector<InstructionEntry> unordered_ops;
782           bool unordered_intercept = false;
783           for (auto instr_entry1 : range1[computation1]) {
784             auto rel = Compute(instr_entry1, instr_entry2, instr2_can_modify);
785             VLOG(3) << "new relation with:" << instr_entry1.first->ToString()
786                     << " = " << rel.ToString() << "\n";
787             if (!rel.RuntimeOrderIsUnordered()) {
788               instr2_relation.UnionRelationFromSameSource(rel);
789             } else {
790               unordered_ops.push_back(instr_entry1);
791               unordered_intercept |= rel.InterceptDefUse();
792             }
793             VLOG(3) << "instr2 relation:" << instr2_relation.ToString() << "\n";
794           }
795           // Here instru2_relation is guaranteed to have at most a single entry,
796           // because it was initialized to be empty, and has been updated only
797           // via instr2_relation.UnionRelationFromSameSource(rel), which
798           // maintains that the updated result has only a single entry.
799           if (!ForceRuntimeOrder(unordered_ops, instr_entry2,
800                                  instr2_relation.GetRuntimeOrder())) {
801             VLOG(3) << "Unable to force ordering of unordered ops\n";
802             instr2_relation.UnionRelationFromSameSource(Relation(
803                 Relation::kBeforeStartOrAfterEnd, unordered_intercept));
804           }
805           dir_src_dest.UnionRelationFromDifferentSource(instr2_relation);
806           VLOG(3) << "Resulting relation : " << dir_src_dest.ToString() << "\n";
807         }
808       }
809     }
810     return dir_src_dest;
811   }
812 
813   // Return whether control dependences, if exist, are added successfully.
AddControlDependenceForUnorderedOps()814   bool AddControlDependenceForUnorderedOps() {
815     if (ctrl_deps_.empty()) {
816       return true;
817     }
818     PredecessorHloOrdering* ordering =
819         dynamic_cast<PredecessorHloOrdering*>(ordering_);
820     if (ordering == nullptr) {
821       // Support force ordering of unordered-ops only when using predecssor
822       // ordering.
823       return false;
824     }
825     for (const auto& comp_it : ctrl_deps_) {
826       HloComputation* parent = comp_it.first;
827       HloReachabilityMap& reachability_map = ordering->reachability_map(parent);
828       for (const auto& instr_it : comp_it.second) {
829         HloInstruction* entry1 = instr_it.first;
830         for (HloInstruction* entry2 : instr_it.second) {
831           VLOG(3) << "Add control dependence between " << entry2->ToString();
832           VLOG(3) << "\n vs " << entry1->ToString() << "\n";
833           TF_CHECK_OK(entry2->AddControlDependencyTo(entry1));
834         }
835         reachability_map.UpdateReachabilityThroughInstruction(entry1);
836         for (HloInstruction* entry2 : instr_it.second) {
837           DCHECK(ordering_->GetExecutionConstraint(entry1, entry2) ==
838                  HloOrdering::ExecutionConstraint::kRunAfter);
839         }
840       }
841     }
842     return true;
843   }
844 
845  private:
846   enum ComputeStatus {
847     kFullyComputed,
848     kPartiallyComputed,
849     kNotComputed,
850   };
851   typedef std::pair<ComputeStatus, Relation::RuntimeOrder> SavedRelation;
852 
853   // Returns whether it is safe to force the desired_relation ordering between
854   // all operations in unordered_ops and entry2. If safe, save the new enforced
855   // ordering relations.
ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,const InstructionEntry entry2,Relation::RuntimeOrder desired_relation)856   bool ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,
857                          const InstructionEntry entry2,
858                          Relation::RuntimeOrder desired_relation) {
859     if (unordered_ops.empty()) {
860       return true;
861     }
862     if (desired_relation != Relation::kBeforeStart &&
863         desired_relation != Relation::kAfterEnd) {
864       return false;
865     }
866     auto ModifiesNonCopy = [](HloInstruction* instr, const HloInstruction* op) {
867       auto in_place = HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr);
868       if (in_place.empty()) {
869         return false;
870       }
871       return absl::c_any_of(
872           in_place, [&](const std::pair<HloOperandIndex, ShapeIndex>&
873                             operand_and_output_index) {
874             auto* op2 =
875                 instr->operand(operand_and_output_index.first.operand_number);
876             return (op == nullptr) ? (op2->opcode() == HloOpcode::kCopy)
877                                    : (op2 == op);
878           });
879     };
880     for (const InstructionEntry& entry1 : unordered_ops) {
881       // Only consider instructions in the same computation.
882       if (entry1.first->parent() != entry2.first->parent()) {
883         return false;
884       }
885       HloInstruction* pred = (desired_relation == Relation::kBeforeStart)
886                                  ? entry2.first
887                                  : entry1.first;
888       HloInstruction* succ = (desired_relation == Relation::kBeforeStart)
889                                  ? entry1.first
890                                  : entry2.first;
891       if (pred == pred->parent()->root_instruction()) {
892         return false;
893       }
894       if (succ->opcode() == HloOpcode::kCopy &&
895           ModifiesNonCopy(pred, succ->operand(0))) {
896         VLOG(3) << "Failed to force unordered op ordering due to copy ordering "
897                 << " between " << pred->ToString() << "\n";
898         VLOG(3) << " vs. " << succ->ToString() << "\n";
899         return false;
900       }
901     }
902     for (const InstructionEntry& entry1 : unordered_ops) {
903       Save(entry2.first, entry1.first, desired_relation, true);
904     }
905     return true;
906   }
907 
AlwaysForceInterception(HloInstruction * instr)908   static bool AlwaysForceInterception(HloInstruction* instr) {
909     // The following communication operations can have some unexpected side
910     // effects, when synchronizing across processes. Therefore, we
911     // conservatively try provide dedicated buffers to these operations instead
912     // of allowing them to share buffers with other operations, as the reuse may
913     // cause unexpected interferences.
914     if (HloDataflowAnalysis::IsAsynchronousOperationStart(instr->opcode()) ||
915         HloDataflowAnalysis::IsAsynchronousOperationDone(instr->opcode())) {
916       return true;
917     }
918     switch (instr->opcode()) {
919       // TODO(b/190903339): It appears that collectivePermute needs to be
920       // followed by a copy when escaping through a computation root.
921       case HloOpcode::kCollectivePermute:
922         return true;
923       default:
924         return false;
925     }
926   }
927 
928   // Returns whether the given instr may intercept the def-use flow of another
929   // ongoing live range if its buffer is combined with the other live range.
930   // The function should return true if instr creates a new HloValue that could
931   // overwrite an existing HloValue in the combined buffer.
932   // More specifically, here we are looking for operations that create new
933   // values, e.g., add, subtract, in contrast to HLOs that merely create
934   // aliasings among existing values, e.g., tuple, get-tuple-element. Any of the
935   // new values created by operations such as add or subtract, when included as
936   // definition operations in a live range, are aliases of the buffer to be
937   // allocated to the live range and so are treated as they may be modifying the
938   // targeting buffer.
InstructionCanIntercept(const InstructionEntry & entry,const LiveRangeRegions & region)939   bool InstructionCanIntercept(const InstructionEntry& entry,
940                                const LiveRangeRegions& region) {
941     auto instr = entry.first;
942     if (!entry.second.is_definition) {
943       // If the instruction only uses the value, it can intercept only if it
944       // modifies the buffer in place.
945       return !HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr).empty();
946     }
947     switch (instr->opcode()) {
948       // If the copy instruction is used to connect two live range regions,
949       // it does not overwrite the combined buffer with new values.
950       case HloOpcode::kCopy:
951         // Checking the copy simply copies from the other live range with no
952         // layout conflicts.
953         if (region.contains(instr->operand(0)) &&
954             ShapeUtil::Equal(instr->shape(), instr->operand(0)->shape())) {
955           return false;  // Cannot intercept.
956         }
957         return true;
958       // The following operations merely create aliases among the HloValues.
959       case HloOpcode::kParameter:
960       case HloOpcode::kTuple:
961       case HloOpcode::kGetTupleElement:
962       // Here we consider all the compound operations (e.g., conditionals and
963       // while loops) as if they do not modify any HloValue, with the argument
964       // being that any value modifying operation contained inside will be
965       // considered separately to make sure the kIntercept relation being
966       // recorded as appropriate. Since the compound operations may or may not
967       // modify, not treating them as value modifying would make the algorithm
968       // less conservative.
969       case HloOpcode::kWhile:
970       case HloOpcode::kCall:
971       case HloOpcode::kConditional:
972         return false;
973       default:
974         return true;
975     }
976     return true;
977   }
978 
AlreadyComputed(HloInstruction * op1,HloInstruction * op2)979   SavedRelation AlreadyComputed(HloInstruction* op1, HloInstruction* op2) {
980     auto p2 = saved_relations_.find(op2);
981     if (p2 != saved_relations_.end()) {
982       auto p1 = (*p2).second.find(op1);
983       if (p1 != (*p2).second.end()) {
984         return SavedRelation(kFullyComputed, (*p1).second);
985       }
986     }
987     p2 = saved_relations_.find(op1);
988     if (p2 != saved_relations_.end()) {
989       auto p1 = (*p2).second.find(op2);
990       if (p1 != (*p2).second.end()) {
991         return SavedRelation(kPartiallyComputed,
992                              Relation::ReverseRuntimeOrder((*p1).second));
993       }
994     }
995     return SavedRelation(kNotComputed, Relation::kNoOverlap);
996   }
997 
Save(HloInstruction * entry1,HloInstruction * entry2,const Relation::RuntimeOrder relation,bool is_unordered_originally=false)998   Relation::RuntimeOrder Save(HloInstruction* entry1, HloInstruction* entry2,
999                               const Relation::RuntimeOrder relation,
1000                               bool is_unordered_originally = false) {
1001     CHECK_EQ(AlreadyComputed(entry1, entry2).first, kNotComputed);
1002     // Do not save unordered relations.
1003     CHECK_NE(relation, Relation::kBeforeStartOrAfterEnd);
1004     saved_relations_[entry2][entry1] = relation;
1005     if (is_unordered_originally) {
1006       CHECK(relation == Relation::kBeforeStart ||
1007             relation == Relation::kAfterEnd)
1008           << relation;
1009       HloInstruction* pred =
1010           (relation == Relation::kBeforeStart) ? entry1 : entry2;
1011       HloInstruction* succ =
1012           (relation == Relation::kBeforeStart) ? entry2 : entry1;
1013       VLOG(3) << "Save unordered relation: " << pred->ToString() << "\n";
1014       VLOG(3) << " vs " << succ->ToString() << "\n";
1015       CHECK_EQ(succ->parent(), pred->parent());
1016       auto& dep_vec = ctrl_deps_[succ->parent()][succ];
1017       for (HloInstruction*& op : dep_vec) {
1018         auto rel = AlreadyComputed(pred, op);
1019         if (rel.first != kNotComputed) {
1020           if (rel.second == Relation::kAfterEnd) {
1021             op = pred;
1022           } else {
1023             CHECK(rel.second == Relation::kBeforeStart);
1024           }
1025           return relation;
1026         }
1027       }
1028       VLOG(2) << "Forcing unordered:" << pred->ToString() << "\n";
1029       VLOG(2) << " vs " << succ->ToString() << "\n";
1030       dep_vec.push_back(pred);
1031     }
1032     return relation;
1033   }
1034 
1035   // Compute the runtime ordering constraints between two instructions.
ComputeRuntimeOrdering(HloInstruction * instr1,HloInstruction * instr2)1036   Relation::RuntimeOrder ComputeRuntimeOrdering(HloInstruction* instr1,
1037                                                 HloInstruction* instr2) {
1038     auto saved_relation = AlreadyComputed(instr1, instr2);
1039     if (saved_relation.first != kNotComputed) {
1040       VLOG(3) << "Already computed between " << instr1->ToString() << "\n vs "
1041               << instr2->ToString() << "\n";
1042       return saved_relation.second;
1043     }
1044     auto constraint = ordering_->GetExecutionConstraint(instr1, instr2);
1045     switch (constraint) {
1046       case HloOrdering::ExecutionConstraint::kIsSame:
1047         return Save(instr1, instr2, Relation::kSameInstr);
1048       case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
1049         return Save(instr1, instr2, Relation::kBeforeStartOrSameInstr);
1050       case HloOrdering::ExecutionConstraint::kRunBeforeStart:
1051         return Save(instr1, instr2, Relation::kBeforeStart);
1052       case HloOrdering::ExecutionConstraint::kRunAfter:
1053         return Save(instr1, instr2, Relation::kAfterEnd);
1054       case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
1055       case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
1056         return Save(instr1, instr2, Relation::kNoOverlap);
1057       case HloOrdering::ExecutionConstraint::kUnordered: {
1058         if (instr1->parent() != instr2->parent()) {
1059           return Relation::kBeforeStartOrAfterEnd;
1060         }
1061         auto ControlDependenceBefore = [&](HloInstruction* op1,
1062                                            HloInstruction* op2) {
1063           auto constraint = ComputeRuntimeOrdering(op1, op2);
1064           if (constraint == Relation::kBeforeStart ||
1065               constraint == Relation::kSameInstr ||
1066               constraint == Relation::kBeforeStartOrSameInstr) {
1067             return true;
1068           } else {
1069             return false;
1070           }
1071         };
1072         if (!ctrl_deps_.empty()) {
1073           auto ctrl_deps = ctrl_deps_[instr1->parent()];
1074           if (absl::c_any_of(ctrl_deps[instr2], [&](HloInstruction* pred2) {
1075                 return ControlDependenceBefore(instr1, pred2);
1076               })) {
1077             VLOG(2) << "control-dependent: " << instr1->ToString() << "\n";
1078             VLOG(2) << "vs " << instr2->ToString() << "\n";
1079             return Save(instr1, instr2, Relation::kBeforeStart);
1080           } else if (absl::c_any_of(
1081                          ctrl_deps[instr1], [&](HloInstruction* pred1) {
1082                            return ControlDependenceBefore(instr2, pred1);
1083                          })) {
1084             VLOG(2) << "control-dependent: " << instr2->ToString() << "\n";
1085             VLOG(2) << "vs " << instr1->ToString() << "\n";
1086             return Save(instr1, instr2, Relation::kAfterEnd);
1087           }
1088         }
1089         // Don't save the result for unordered operations, so they can be
1090         // refined later.
1091         return Relation::kBeforeStartOrAfterEnd;
1092       }
1093     }
1094   }
1095 
1096   HloOrdering* ordering_;
1097   absl::flat_hash_map<
1098       HloInstruction*,
1099       absl::flat_hash_map<HloInstruction*, Relation::RuntimeOrder>>
1100       saved_relations_;
1101   absl::flat_hash_map<
1102       HloComputation*,
1103       absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>>
1104       ctrl_deps_;
1105 };
1106 }  // namespace
1107 
1108 // Class which tracks the HLO values within each HLO buffer in the module
1109 // during copy removal.
1110 //
1111 // The values are held in a linked list where there is one list for each
1112 // buffer. Removing a copy instruction merges together the values in the
1113 // source buffer of the copy to the destination buffer of the copy. This class
1114 // tracks these value lists as copies are removed from the graph (and value
1115 // lists are merged).
1116 //
1117 // The CopyRemover object is initialized to match the state of
1118 // HloAliasAnalysis. However, as copies are removed this state diverges. The
1119 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
1120 // a fully updatable alias analysis is very slow.
1121 class CopyRemover {
1122  public:
1123   // The values held in a single HLO buffer are represented using a linked
1124   // list. An element type in this list is ValueNode.
1125   //
1126   // This linked list is hand-rolled to enable efficient splicing of lists
1127   // using only references to list elements without knowing which lists are
1128   // being spliced. std::list requires a reference to the list object to
1129   // splice.
1130   struct ValueNode {
ValueNodexla::__anonedb7fbf20111::CopyRemover::ValueNode1131     explicit ValueNode(const HloValue* v) : value(v) {}
1132 
1133     const HloValue* value;
1134 
1135     // The uses are maintained outside of HloValue::uses() because
1136     // HloValue::uses() is not updatable (a fully updatable dataflow analysis
1137     // is slow).
1138     std::vector<const HloUse*> uses;
1139 
1140     // next/prev elements in the linked list. The list is circularly linked so
1141     // these values are never null for elements in the list.
1142     ValueNode* prev = nullptr;
1143     ValueNode* next = nullptr;
1144   };
1145 
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,HloOrdering * ordering,bool check_live_range_ordering)1146   CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
1147               HloOrdering* ordering, bool check_live_range_ordering)
1148       : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
1149     // Construct a list for each HLO buffer in the alias analysis. Maintain a
1150     // map from HloValue to the respective list element representing that
1151     // value. The map is used to construct the copy info map below.
1152     absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
1153     // Perform check only if the default dependence-based ordering is used.
1154     for (const HloBuffer& buffer : alias_analysis.buffers()) {
1155       // No copies should have been inserted within fused computations, so no
1156       // need to remove them. HloOrdering isn't compatible with HloValues inside
1157       // fusions, so skip copy removal for them.
1158       if (buffer.values().at(0)->defining_instruction()->IsFused()) {
1159         continue;
1160       }
1161       if (check_live_range_ordering) {
1162         // Verify values contained in the buffer are strictly ordered. This
1163         // should always be the case after adding copies to eliminate
1164         // interference. Specifically, the addition of the control flow edges
1165         // between copies added around aliased operations (kWhile) guarantees
1166         // this strict order.
1167         for (const HloValue* value_a : buffer.values()) {
1168           if (value_a->shape().IsToken()) {
1169             // Token values have no representation and cannot interfere.
1170             continue;
1171           }
1172           for (const HloValue* value_b : buffer.values()) {
1173             if (value_a != value_b) {
1174               DCHECK(ordering_->LiveRangeStrictlyBefore(
1175                          *value_a, *value_b, dataflow_,
1176                          /*use_is_always_before_def_in_same_instr=*/true) ||
1177                      ordering_->LiveRangeStrictlyBefore(
1178                          *value_b, *value_a, dataflow_,
1179                          /*use_is_always_before_def_in_same_instr=*/true))
1180                   << value_a->ToString() << " and " << value_b->ToString()
1181                   << " are not ordered";
1182             }
1183           }
1184         }
1185       }
1186 
1187       std::vector<const HloValue*> values = buffer.values();
1188       absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
1189         return ordering_->IsDefinedBefore(*a, *b);
1190       });
1191 
1192       // Create a list containing all of the values in the buffer.
1193       AddValueList(values, &value_to_node);
1194     }
1195 
1196     // Create copy_map_ which contains the source and destination values
1197     // of all copies.
1198     CreateCopyMap(module, value_to_node);
1199 
1200     XLA_VLOG_LINES(3, ToString());
1201     TF_DCHECK_OK(Verify());
1202   }
1203 
1204   // Add a list containing the given values to CopyRemover. This
1205   // represents the values contained in a single buffer. For each value in
1206   // 'values' an entry is created in value_to_node which indicates the
1207   // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)1208   void AddValueList(
1209       absl::Span<const HloValue* const> values,
1210       absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
1211     ValueNode* tail = nullptr;
1212     ValueNode* head = nullptr;
1213     for (const HloValue* value : values) {
1214       auto new_node = new ValueNode(value);
1215       (*value_to_node)[value] = new_node;
1216 
1217       // Copy the HLO values's uses into the ValueNode for the value. These
1218       // uses in ValueNode are updated as copies are removed.
1219       new_node->uses.reserve(value->GetUses().size());
1220       for (const HloUse& use : value->GetUses()) {
1221         new_node->uses.push_back(&use);
1222       }
1223 
1224       // Connect the new node into the linked list.
1225       if (tail == nullptr) {
1226         head = new_node;
1227       } else {
1228         tail->next = new_node;
1229         new_node->prev = tail;
1230       }
1231       tail = new_node;
1232     }
1233 
1234     // The linked list is circular so connect the head and tail.
1235     tail->next = head;
1236     head->prev = tail;
1237     value_lists_.insert(head);
1238   }
1239 
1240   // This method also fills in copy_map_ which indicates which nodes
1241   // in the value lists corresponding to the source and destination values of
1242   // kCopy instructions. value_to_node should map each HloValue to its
1243   // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)1244   void CreateCopyMap(
1245       const HloModule& module,
1246       const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
1247     for (HloComputation* computation : module.MakeNonfusionComputations()) {
1248       for (HloInstruction* instruction : computation->instructions()) {
1249         // Add copies with unambiguous source values to the map. Copies with
1250         // ambiguous sources are not removable.
1251         if (instruction->opcode() == HloOpcode::kCopy) {
1252           const HloValueSet& src_value_set =
1253               dataflow_.GetValueSet(instruction->operand(0));
1254           if (src_value_set.values().size() == 1) {
1255             CopyNodes& copy_node = copy_map_[instruction];
1256             copy_node.dest =
1257                 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
1258             copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
1259           }
1260         }
1261       }
1262     }
1263   }
1264 
~CopyRemover()1265   ~CopyRemover() {
1266     for (const ValueNode* head : value_lists_) {
1267       const ValueNode* p = head;
1268       do {
1269         const ValueNode* tmp = p->next;
1270         delete p;
1271         p = tmp;
1272       } while (p != head);
1273     }
1274   }
1275 
1276   // Verify invariants within the linked lists.
Verify() const1277   Status Verify() const {
1278     for (const ValueNode* head : value_lists_) {
1279       const ValueNode* p = head;
1280       do {
1281         // Verify links between elements are consistent.
1282         TF_RET_CHECK(p->prev->next == p);
1283         TF_RET_CHECK(p->next->prev == p);
1284 
1285         const HloInstruction* def = p->value->defining_instruction();
1286         if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
1287           TF_RET_CHECK(copy_map_.at(def).dest == p);
1288         }
1289         for (const HloUse* use : p->uses) {
1290           if (use->instruction->opcode() == HloOpcode::kCopy &&
1291               ContainsKey(copy_map_, use->instruction)) {
1292             TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
1293           }
1294         }
1295 
1296         p = p->next;
1297       } while (p != head);
1298     }
1299     return OkStatus();
1300   }
1301 
1302   // Compute the set of instructions where values are alive and organize these
1303   // instructions by separating them into their respective computations.
ComputeLiveRangeRegions(const ValueNode * head)1304   LiveRangeRegions ComputeLiveRangeRegions(const ValueNode* head) {
1305     LiveRangeRegions live_range;
1306 
1307     auto VisitValueNode = [&](const ValueNode* node) {
1308       HloInstruction* def_op = node->value->instruction();
1309       HloComputation* def_parent = def_op->parent();
1310       live_range[def_parent][def_op].is_definition = true;
1311       for (const auto& use : node->uses) {
1312         auto* use_op = use->instruction;
1313         HloComputation* use_parent = use_op->parent();
1314         live_range[use_parent][use_op].value_definition = def_op;
1315       }
1316     };
1317     ForEachValueInRange(head, VisitValueNode);
1318     return live_range;
1319   }
1320 
1321   // Try to elide the given copy. Elision of a copy is possible only if no
1322   // live range interference is introduced by the copy's elimination. If
1323   // elision is possible, then the internal state (value lists) are updated,
1324   // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy,int64_t * region_analysis_limit)1325   bool TryElideCopy(const HloInstruction* copy,
1326                     int64_t* region_analysis_limit) {
1327     VLOG(2) << "Trying to remove " << copy->name();
1328     CHECK_NE(region_analysis_limit, nullptr);
1329 
1330     if (!ContainsKey(copy_map_, copy)) {
1331       VLOG(2) << copy->name() << " is not removable";
1332       return false;
1333     }
1334     if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
1335       VLOG(2) << copy->name() << " is not removable (shape mismatch)";
1336       return false;
1337     }
1338     const CopyNodes& copy_node = copy_map_.at(copy);
1339     DCHECK(copy_node.src != nullptr);
1340     DCHECK(copy_node.dest != nullptr);
1341 
1342     int64_t live_range_size1 = 0, live_range_size2 = 0;
1343     ForEachValueInRange(copy_node.src, [&](const ValueNode* node) {
1344       live_range_size1 += 1 + node->uses.size();
1345     });
1346     ForEachValueInRange(copy_node.dest, [&](const ValueNode* node) {
1347       live_range_size2 += 1 + node->uses.size();
1348     });
1349     // Use the more accurate region-based live range interference analysis if
1350     // the live range size is within a given limit (or if no limit is given).
1351     // Also don't use the new analysis for copies of broadcasts as these copies
1352     // are cheap and are later removed by replicating the broadcasts.
1353     bool use_region_analysis =
1354         copy->operand(0)->opcode() != HloOpcode::kBroadcast &&
1355         (*region_analysis_limit < 0 ||
1356          live_range_size1 * live_range_size2 <= *region_analysis_limit);
1357     *region_analysis_limit = 0;
1358     VLOG(3) << copy->name() << " copies value "
1359             << copy_node.src->value->ToShortString();
1360     VLOG(3) << "Source buffer values: " << ValueListToString(copy_node.src);
1361     VLOG(3) << "Dest buffer values: " << ValueListToString(copy_node.dest);
1362     // Checks whether the live range at src is before that defined by dest.
1363     auto CheckLiveRangeBefore = [&](ValueNode* src, ValueNode* dest) {
1364       for (ValueNode* next_dest = dest; next_dest != nullptr;
1365            next_dest = Next(*next_dest)) {
1366         for (ValueNode* prev_src = src; prev_src != nullptr;
1367              prev_src = Prev(*prev_src)) {
1368           if (!LiveRangeBefore(*prev_src, *next_dest)) {
1369             VLOG(2) << "Live range of " << prev_src->value->ToShortString()
1370                     << " is not before " << next_dest->value->ToShortString();
1371             return false;
1372           }
1373         }
1374       }
1375       return true;
1376     };
1377     auto CheckLiveRangeInterference = [&](ValueNode* src, ValueNode* dest,
1378                                           const CombineLiveRangeOption option) {
1379       CHECK_NE(src, nullptr);
1380       CHECK_NE(dest, nullptr);
1381       if (!use_region_analysis) {
1382         VLOG(2) << "Configured to not use region-based analysis.\n";
1383         return true;
1384       }
1385       *region_analysis_limit += live_range_size1 * live_range_size2;
1386       if (ValuesInterfere(src, dest, option)) {
1387         VLOG(2) << "Region-based interference is true. \n";
1388         return true;
1389       }
1390       VLOG(2) << "Region-based interference is false. \n";
1391       return false;
1392     };
1393 
1394     // A kCopy instruction copies an HLO value from a source buffer and
1395     // defines an HLO value in a destination buffer. Most generally, the
1396     // source and destination buffers may each hold more than one value at
1397     // different points in the computation so we define the following:
1398     //
1399     //   Values in source buffer:      {s_0, ..., s_n}
1400     //   Values in destination buffer: {d_0, ..., d_m}
1401     //
1402     // A kCopy instruction between these buffers copies a value s_x in the
1403     // source buffer and defines a value d_y in the destination buffer. The
1404     // elision of a copy merges the source and destination buffers together,
1405     // so the list of values for the source and destination buffers are
1406     // merged.
1407     //
1408     // We handle two different cases for copy elision:
1409     //
1410     //  (1) the kCopy defines the first value in the destination buffer (d_0).
1411     //
1412     //  (2) the kCopy copies the last value in the source buffer (s_n).
1413     //
1414     // For the remaining case where the kCopy copies a not-last value from the
1415     // source buffer to a not-first value of the destination buffer, the kCopy
1416     // instruction cannot be removed. This case is generated, for example, if
1417     // the kCopy copies a while body parameter of the loop state at one tuple
1418     // index to a different tuple index in the while body root. Removal of the
1419     // copy necessarily results in live range interference of values in the
1420     // loop state at the two different tuple indices.
1421     //
1422     //  We can only perform copy elision if the resulting merged values have
1423     //  totally ordered live ranges; otherwise the merged buffer would have
1424     //  live range interference.
1425     if (copy_node.src->next == copy_node.dest) {
1426       // In the process of eliding copies, its possible for a copy to have the
1427       // same source and destination buffer. In this case, the copy can be
1428       // safely removed.
1429       VLOG(2) << copy->name() << " source and destination buffers are same.";
1430     } else if (IsHead(*copy_node.dest)) {
1431       // The copy copies an arbitrary value in the source buffer (call it s_x)
1432       // and defines d_0, the first value in the destination buffer. After
1433       // merging, the values in the combined buffer must be strictly ordered
1434       // as follows** to elide the copy:
1435       //
1436       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
1437       //
1438       // Removing the copy eliminates d_0, and uses of d_0 become uses of
1439       // s_x. In the above ordering, the live range of d_m will be ordered
1440       // before the live range of s_{x+1} and the definition and all uses of
1441       // s_x will be ordered before the definition of d_1. To make sure the
1442       // copy elision is safe, the following code checks that this ordering is
1443       // valid --- in particular we check it is safe to order d_m ahead of all
1444       // the liverages at and after x_{x+1}, and it is safe to order all uses
1445       // of s_x before the definition of d_1, by checking the live range
1446       // constraints for each pair --- we cannot skip the later checks because
1447       // the live range ordering is not guranteed to be transitive --- while it
1448       // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
1449       // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
1450       // because the exclusiveness relation of non-overlapping computations is
1451       // not transitive.
1452       //
1453       // ** Technically it might be possible to have a non-interfering
1454       //    non-trivial interleaving of the values of the source and
1455       //    destination buffers in the resulting order. This can be potentially
1456       //    supported in the ValuesInterfere function, which performs
1457       //    interference analysis at a more global scope than the alternative
1458       //    LiveRangeBefore analysis which requires strict ordering of all live
1459       //    ranges. Currently, however, this is not yet supported, as
1460       //    we simply check for the case where *all* values of the destination
1461       //    buffer (d_1 through d_m) are spliced into the point where the copy
1462       //    used to be.
1463       VLOG(2) << copy->name() << " defines the first value in its buffer";
1464       bool live_range_before =
1465           // Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
1466           CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)) &&
1467           // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
1468           CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src));
1469       VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1470       if (!live_range_before &&
1471           CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1472                                      kMergeFirstDestInSource)) {
1473         return false;
1474       }
1475       VLOG(2) << "Splice dest after source.";
1476       // Splice in destination buffer values list right after 'src'.
1477       SpliceAfter(copy_node.dest, copy_node.src);
1478     } else if (IsTail(*copy_node.src)) {
1479       // The copy copies the last value in the source buffer, s_n, and defines
1480       // an arbitrary value in the destination buffer, d_y.  After
1481       // merging, the values in the combined buffer must be strictly ordered
1482       // as follows** to elide the copy:
1483       //
1484       // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
1485       //
1486       // Removing the copy eliminates d_y, and uses of d_y become uses of
1487       // s_n. To enforce the above order, the live range of d_{y-1} must be
1488       // before the live range of s_0, and the live range of s_n must be
1489       // before the live range of d_{y+1}.
1490       //
1491       // ** See comment above in the code handling Case (1).
1492       VLOG(2) << copy->name() << " copies the last value ("
1493               << copy_node.src->value->ToShortString() << ") in its buffer";
1494       bool live_range_before =
1495           // Live range of d_0, ..., d_{y-1} must be before s_0;
1496           CheckLiveRangeBefore(Prev(*copy_node.dest), copy_node.src->next) &&
1497           // Live range of 'last_src' must be before next_dest d_{y+1}.
1498           CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest));
1499       VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1500       if (!live_range_before &&
1501           CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1502                                      kMergeLastSourceInDest)) {
1503         VLOG(2) << "Region-based analysis concludes interference.\n";
1504         return false;
1505       }
1506       VLOG(2) << "Splice src after prev of dest.";
1507       // Splice source buffer values list right after 'prev_dest'.
1508       SpliceAfter(copy_node.src->next, Prev(*copy_node.dest));
1509     } else {
1510       VLOG(2) << copy->name()
1511               << " copies value in middle of source buffer to value in middle "
1512                  "of destination buffer";
1513       return false;
1514     }
1515 
1516     RemoveCopyValue(copy_node.dest);
1517 
1518     XLA_VLOG_LINES(4, ToString());
1519     TF_DCHECK_OK(Verify());
1520 
1521     return true;
1522   }
1523 
1524   // Delete the given ValueNode associated with a elided kCopy
1525   // instruction. This should be called after splicing the value lists of the
1526   // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)1527   void RemoveCopyValue(ValueNode* copy_value_node) {
1528     CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
1529              HloOpcode::kCopy);
1530     ValueNode* operand_node = copy_value_node->prev;
1531     CHECK(operand_node != copy_value_node);
1532 
1533     VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
1534             << " => " << copy_value_node->value->ToShortString();
1535 
1536     // Splice out the copy value node.
1537     operand_node->next = copy_value_node->next;
1538     copy_value_node->next->prev = operand_node;
1539 
1540     // Patch up uses. Remove use of copy from operand_node uses.
1541     auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
1542                                                       const HloUse* use) {
1543       return use->instruction == copy_value_node->value->defining_instruction();
1544     });
1545     CHECK(it != operand_node->uses.end());
1546     operand_node->uses.erase(it);
1547 
1548     // If the elided copy has any uses which are themselves kCopy instructions
1549     // then patch up the copy info to reflect the that this kCopy instruction
1550     // has a different operand (the operand of the elided copy).
1551     for (const HloUse* copy_use : copy_value_node->uses) {
1552       operand_node->uses.push_back(copy_use);
1553       if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
1554           ContainsKey(copy_map_, copy_use->instruction)) {
1555         copy_map_.at(copy_use->instruction).src = operand_node;
1556       }
1557     }
1558 
1559     // Delete the copy info and the value node.
1560     copy_map_.erase(copy_value_node->value->defining_instruction());
1561     delete copy_value_node;
1562   }
1563 
1564   // Returns true if the live range of given value 'a' is before the live
1565   // range of 'b'.
1566   //
1567   // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
1568   // updated as copies are removed. Also here because the result is used
1569   // to directly drive copy elision, use_is_always_before_def_in_same_instr is
1570   // set to false.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)1571   bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
1572     if (a.uses.empty()) {
1573       VLOG(2) << "Empty uses for " << *a.value;
1574       return ordering_->IsDefinedBefore(*a.value, *b.value);
1575     }
1576     VLOG(3) << "Checking live ranges before :" << ValueListToString(&a)
1577             << " vs " << ValueListToString(&b) << "\n";
1578     // If any of the positions of the "a" value is a root of the same
1579     // computation as "b", "a"'s live range cannot be before "b"'s. This catches
1580     // the cases where the root may not be the last instruction in the
1581     // computation.
1582     if (a.value->IsRootOf(b.value->defining_instruction()->parent())) {
1583       VLOG(3) << "Value is root of the same computation";
1584       return false;
1585     }
1586     return ordering_->UsesBeforeValueDefinition(
1587         a.uses, *b.value, dataflow_,
1588         /* use_is_always_before_def_in_same_instr=*/false);
1589   }
1590 
1591   // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const1592   bool IsTail(const ValueNode& node) const {
1593     return ContainsKey(value_lists_, node.next);
1594   }
1595 
1596   // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const1597   bool IsHead(const ValueNode& node) const {
1598     return ContainsKey(value_lists_, &node);
1599   }
1600 
1601   // Returns the next node in the list after 'node'. If 'node' is the
1602   // tail, then nullptr is returned.
Next(const ValueNode & node) const1603   ValueNode* Next(const ValueNode& node) const {
1604     if (IsTail(node)) {
1605       return nullptr;
1606     } else {
1607       return node.next;
1608     }
1609   }
1610 
1611   // Returns the previous node in the list before 'node'. If 'node'
1612   // is the head, then nullptr is returned.
Prev(const ValueNode & node) const1613   ValueNode* Prev(const ValueNode& node) const {
1614     if (IsHead(node)) {
1615       return nullptr;
1616     } else {
1617       return node.prev;
1618     }
1619   }
1620 
1621   // Splices the entire linked list with 'head' as its head right after the
1622   // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)1623   void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
1624     DCHECK(IsHead(*head));
1625     value_lists_.erase(head);
1626 
1627     ValueNode* tail = head->prev;
1628     tail->next = insert_after->next;
1629     insert_after->next->prev = tail;
1630 
1631     insert_after->next = head;
1632     head->prev = insert_after;
1633   }
1634 
1635   enum CombineLiveRangeOption {
1636     kMergeFirstDestInSource = 1,
1637     kMergeLastSourceInDest = 2
1638   };
1639   // This function analyzes all the HloValues that have been grouped together
1640   // with src to share a single buffer, and all the HloValues that have been
1641   // similarly grouped together with dest, to determine whether these two groups
1642   // can be combined, by removing the operation in dest, which makes a copy of
1643   // the buffer in src.
ValuesInterfere(const ValueNode * src,const ValueNode * dest,CombineLiveRangeOption merge_location)1644   bool ValuesInterfere(const ValueNode* src, const ValueNode* dest,
1645                        CombineLiveRangeOption merge_location) {
1646     // Get the entire range of values sharing the buffers in src and dest.
1647     auto src_live_range = ComputeLiveRangeRegions(src);
1648     auto dest_live_range = ComputeLiveRangeRegions(dest);
1649     ComputeRelativeLocation relative_location_analysis(ordering_);
1650     auto rel1 =
1651         relative_location_analysis.Compute(src_live_range, dest_live_range);
1652     VLOG(3) << "Location of dest in relation to src:" << rel1.ToString()
1653             << " with interception set to " << rel1.InterceptDefUse() << "\n";
1654     auto rel2 =
1655         relative_location_analysis.Compute(dest_live_range, src_live_range);
1656     VLOG(3) << "Location of src in relation to dest:" << rel2.ToString()
1657             << " with interception set to " << rel1.InterceptDefUse() << "\n";
1658     // If src and dest are interleaved with each other, they interfere.
1659     if (rel1.RuntimeOrderOverlap() && rel2.RuntimeOrderOverlap()) {
1660       VLOG(3) << "Both relations are overlap.\n";
1661       return true;
1662     }
1663     // If src and dest belong to the same group of computations and do not
1664     // overlap, they do not interfere.
1665     if (rel1.RuntimeOrderOverlap() || rel2.RuntimeOrderOverlap()) {
1666       VLOG(3) << "At least one relation is overlap.\n";
1667       if (rel1.RuntimeOrderOverlap()) {
1668         VLOG(3) << "rel1 is overlap, with interception = "
1669                 << rel1.InterceptDefUse() << "\n";
1670         if (rel1.InterceptDefUse() ||
1671             (merge_location != kMergeFirstDestInSource &&
1672              rel2.InterceptDefUse())) {
1673           return true;
1674         }
1675       } else {
1676         VLOG(3) << "rel2 is overlap, with interception = "
1677                 << rel2.InterceptDefUse() << "\n";
1678         // Here src is at the end of a nested computation inside dest.
1679         if (rel2.InterceptDefUse() ||
1680             (merge_location != kMergeLastSourceInDest &&
1681              rel1.InterceptDefUse())) {
1682           return true;
1683         }
1684       }
1685     }
1686     if (relative_location_analysis.AddControlDependenceForUnorderedOps()) {
1687       return false;
1688     } else {
1689       // Disallow removing of copy if control deps cannot be added.
1690       return true;
1691     }
1692   }
1693 
1694   // return the sequence of HloValues starting from element.
1695   // If element is not head, traverse from element to tail, then wrap around.
1696   // The ordering is important for live range region analysis.
ForEachValueInRange(const ValueNode * element,std::function<void (const ValueNode *)> visitor)1697   void ForEachValueInRange(const ValueNode* element,
1698                            std::function<void(const ValueNode*)> visitor) {
1699     const ValueNode* head = element;
1700     std::vector<const ValueNode*> values;
1701     for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
1702       visitor(p);
1703     }
1704     while (!IsHead(*head)) {
1705       head = Prev(*head);
1706     }
1707     for (const ValueNode* p = head; p != element; p = Next(*p)) {
1708       visitor(p);
1709     }
1710   }
1711 
ValueListToString(const ValueNode * element)1712   std::string ValueListToString(const ValueNode* element) {
1713     std::string result = "{";
1714     auto VisitValueNode = [&](const ValueNode* node) {
1715       if (result == "{") {
1716         result = node->value->ToShortString();
1717       } else {
1718         StrAppend(&result, ", ");
1719         StrAppend(&result, node->value->ToShortString());
1720       }
1721     };
1722     VisitValueNode(element);
1723     StrAppend(&result, "}");
1724     return result;
1725   }
1726 
ToString() const1727   std::string ToString() const {
1728     std::string out = absl::StrCat("CopyRemover:\n");
1729     StrAppend(&out, "  Def-use chains in each buffer:\n");
1730     for (const ValueNode* head : value_lists_) {
1731       StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
1732                 ":\n");
1733       const ValueNode* p = head;
1734       do {
1735         StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
1736                   absl::StrJoin(p->uses, "; ",
1737                                 [](std::string* s, const HloUse* use) {
1738                                   StrAppend(s, use->ToString());
1739                                 }),
1740                   "\n");
1741 
1742         p = p->next;
1743       } while (p != head);
1744     }
1745     StrAppend(&out, "  Potentially removable copies:\n");
1746     for (const auto& pair : copy_map_) {
1747       const HloInstruction* copy = pair.first;
1748       const CopyNodes& copy_info = pair.second;
1749 
1750       StrAppend(&out, "    ", copy->name(), " : ",
1751                 copy_info.src->value->ToShortString(), " => ",
1752                 copy_info.dest->value->ToShortString(), "\n");
1753     }
1754     return out;
1755   }
1756 
1757  private:
1758   const HloDataflowAnalysis& dataflow_;
1759   HloOrdering* ordering_;
1760 
1761   // The heads of all the value lists. Each value list represents the HLO
1762   // values contained in a particular HLO buffer. The values in the list are
1763   // in dependency order.
1764   absl::flat_hash_set<const ValueNode*> value_lists_;
1765 
1766   // Copy removal requires fast access to the value list elements
1767   // corresponding to the source and destination values of the kCopy
1768   // instruction. This data structure holds pointers to these elements for
1769   // each kCopy instruction in the graph.
1770   struct CopyNodes {
1771     // The source and destinations values of the kCopy instruction.
1772     ValueNode* src = nullptr;
1773     ValueNode* dest = nullptr;
1774   };
1775   absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
1776 };
1777 
1778 }  // namespace
1779 
1780 // We add copies for all non-phi indices of the true and false computation
1781 // roots, in order to resolve interference. We later rely on
1782 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)1783 Status CopyInsertion::AddCopiesForConditional(
1784     const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
1785   VLOG(2) << "Adding copies for kConditional instruction "
1786           << conditional->name();
1787   ShapeTree<bool> indices_to_copy(conditional->shape());
1788   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
1789   if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
1790                                    conditional, &indices_to_copy)) {
1791     VLOG(2) << "No copies necessary for kWhile instruction "
1792             << conditional->name();
1793     return OkStatus();
1794   }
1795 
1796   for (HloComputation* computation : conditional->branch_computations()) {
1797     HloInstruction* root = computation->root_instruction();
1798     std::vector<HloInstruction*> users = root->users();
1799     TF_ASSIGN_OR_RETURN(
1800         HloInstruction * deep_copy,
1801         computation->DeepCopyInstruction(root, &indices_to_copy));
1802     for (HloInstruction* user : users) {
1803       TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1804     }
1805     computation->set_root_instruction(deep_copy);
1806   }
1807   return OkStatus();
1808 }
1809 
1810 // Add kCopy instructions to the given module to guarantee there is no
1811 // live-range interference. Generally interference can only occur around kWhile
1812 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1813 Status CopyInsertion::AddCopiesToResolveInterference(
1814     HloModule* module,
1815     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1816   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1817                       HloAliasAnalysis::Run(module, can_share_buffer_));
1818   for (HloComputation* computation :
1819        module->MakeNonfusionComputations(execution_threads)) {
1820     for (HloInstruction* instruction :
1821          computation->MakeInstructionPostOrder()) {
1822       if (instruction->opcode() == HloOpcode::kWhile) {
1823         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1824       } else if (instruction->opcode() == HloOpcode::kConditional) {
1825         TF_RETURN_IF_ERROR(
1826             AddCopiesForConditional(*alias_analysis, instruction));
1827       } else {
1828         // When an operand is a tuple, we avoid copying the operand multiple
1829         // times by recording and checking the operand number of operands that
1830         // have been copied.
1831         absl::flat_hash_set<int64_t> copied_operands;
1832         for (const auto& operand_and_output_index :
1833              HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1834           const HloOperandIndex& operand_index = operand_and_output_index.first;
1835           if (copied_operands.contains(operand_index.operand_number)) {
1836             continue;
1837           }
1838           copied_operands.insert(operand_index.operand_number);
1839           TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1840               *alias_analysis, instruction, operand_index.operand_number));
1841         }
1842       }
1843     }
1844   }
1845 
1846   TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1847   return OkStatus();
1848 }
1849 
AddSpecialCaseCopies(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1850 Status CopyInsertion::AddSpecialCaseCopies(
1851     HloModule* module,
1852     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1853   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1854   return AddSpecialCaseCopies(*call_graph, execution_threads, module);
1855 }
1856 
AddSpecialCaseCopies(const CallGraph & call_graph,const absl::flat_hash_set<absl::string_view> & execution_threads,HloModule * module)1857 Status CopyInsertion::AddSpecialCaseCopies(
1858     const CallGraph& call_graph,
1859     const absl::flat_hash_set<absl::string_view>& execution_threads,
1860     HloModule* module) {
1861   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1862                       HloAliasAnalysis::Run(module, can_share_buffer_));
1863 
1864   // Identify which shape indices of which instructions need to be copied. Store
1865   // these results in 'instructions_to_copy'.
1866   HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1867   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1868                                                    const ShapeIndex& index) {
1869     auto it = instructions_to_copy.find(instruction);
1870     if (it == instructions_to_copy.end()) {
1871       auto it_added = instructions_to_copy.emplace(
1872           std::piecewise_construct, std::forward_as_tuple(instruction),
1873           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1874       it = it_added.first;
1875     }
1876     *it->second.mutable_element(index) = true;
1877   };
1878 
1879   // Iterate through values of all constants and entry parameters. These values
1880   // are special because they are held in read-only buffers. If any of these
1881   // values share a buffer with other values (for example, the init value of a
1882   // while is a constant) then copy the value at its definition and replace all
1883   // its uses with the copy.
1884   // Also, locate all input-output aliasing violations for operations that
1885   // cannot be done in place. Such aliasing can be created when some copies are
1886   // removed too aggressively by CopyRemoval.
1887   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1888     HloBuffer& buffer = alias_analysis->GetBufferContainingValue(*value);
1889     if (buffer.values().size() > 1 && ValueIsReadOnly(*value)) {
1890       VLOG(2) << "Value " << value->ToShortString()
1891               << " is read only, but its buffer contains more than one value. "
1892                  "Copying.";
1893       add_index_to_copy(value->defining_instruction(), value->defining_index());
1894     }
1895     for (const HloValue* value2 : buffer.values()) {
1896       // Find HloValues that share a position and use, which would cause the use
1897       // and operand to share buffers. Check if this is allowed and insert a
1898       // copy if it isn't.
1899       if (value2 == value) {
1900         continue;
1901       }
1902       HloPosition position = value2->defining_position();
1903       for (const HloUse& use : value->GetUses()) {
1904         if (use.instruction == position.instruction) {
1905           VLOG(3) << "Same instruction: " << position.instruction->ToString();
1906           if (!alias_analysis->dataflow_analysis()
1907                    .CanShareOperandBufferWithUser(
1908                        /*operand=*/use.instruction->mutable_operand(
1909                            use.operand_number),
1910                        /*operand_index=*/use.operand_index,
1911                        /*user=*/position.instruction,
1912                        /*user_index=*/position.index)) {
1913             VLOG(2) << "Adding back copy: "
1914                     << use.instruction->operand(use.operand_number)->ToString()
1915                     << "@" << use.operand_index.ToString()
1916                     << " instr: " << position.instruction->ToString() << "@"
1917                     << position.index;
1918             add_index_to_copy(
1919                 use.instruction->mutable_operand(use.operand_number),
1920                 use.operand_index);
1921           }
1922         }
1923       }
1924     }
1925   }
1926 
1927   // Identify copies which must be added at root instructions
1928   for (HloComputation* computation : module->computations(execution_threads)) {
1929     const CallGraphNode& node = call_graph.GetNode(computation);
1930     if (node.context() == CallContext::kEmbedded) {
1931       continue;
1932     }
1933     TF_RET_CHECK(node.context() == CallContext::kControlFlow);
1934 
1935     SpecialCaseCopyPolicy policy =
1936         GetSpecialCaseCopyPolicy(node, module, computation);
1937     HloInstruction* root = computation->root_instruction();
1938 
1939     // Mark nondistinct/ambiguous indices.
1940     absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1941     ShapeUtil::ForEachSubshape(
1942         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1943           std::vector<const HloBuffer*> buffers_at_index =
1944               alias_analysis->ComputeBuffersAt(root, index);
1945           bool buffer_seen_before = false;
1946           for (const HloBuffer* buffer : buffers_at_index) {
1947             buffer_seen_before |= !seen.emplace(buffer, index).second;
1948           }
1949 
1950           if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1951               computation == module->entry_computation() &&
1952               module->input_output_alias_config().OutputHasAlias(index) &&
1953               buffers_at_index.size() == 1) {
1954             std::optional<HloInputOutputAliasConfig::Alias> alias =
1955                 module->input_output_alias_config().GetAliasedParameter(index);
1956             CHECK(alias) << "Alias does not exist";
1957             const ShapeIndex& other_index = seen[buffers_at_index[0]];
1958             VLOG(2) << "Output indices " << index.ToString() << " and "
1959                     << other_index.ToString() << " are both aliased to "
1960                     << alias->parameter_number << " copying " << other_index;
1961             add_index_to_copy(root, other_index);
1962             return;
1963           }
1964 
1965           if (buffers_at_index.size() > 1 ||
1966               (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1967             VLOG(2) << "Index " << index << " of computation "
1968                     << computation->name() << " (" << root->name()
1969                     << ") has ambiguous or non-distinct buffer. Copying.";
1970             add_index_to_copy(root, index);
1971           }
1972         });
1973 
1974     for (const auto& pair :
1975          alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1976       const ShapeIndex& index = pair.first;
1977       const HloValueSet& value_set = pair.second;
1978       for (const HloValue* value : value_set.values()) {
1979         if (ShouldCopyRootValue(*value, policy)) {
1980           VLOG(2) << "Root of (" << root->name() << ") of computation("
1981                   << computation->name()
1982                   << ") has constant or parameter value at index " << index
1983                   << ". Copying.";
1984           add_index_to_copy(root, index);
1985         }
1986       }
1987     }
1988   }
1989 
1990   // Add copy instructions indicated in 'instructions_to_copy' to the module.
1991   for (const auto& pair : instructions_to_copy) {
1992     HloInstruction* instruction = pair.first;
1993     const ShapeTree<bool>& indices_to_copy = pair.second;
1994 
1995     ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1996     std::vector<HloInstruction*> users = instruction->users();
1997     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1998                         instruction->parent()->DeepCopyInstruction(
1999                             instruction, &indices_to_copy, &copies_added));
2000     for (HloInstruction* user : users) {
2001       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
2002     }
2003     if (instruction == instruction->parent()->root_instruction()) {
2004       instruction->parent()->set_root_instruction(deep_copy);
2005     }
2006   }
2007   return OkStatus();
2008 }
2009 
GetNumExistingCopies(const HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2010 static int64_t GetNumExistingCopies(
2011     const HloModule* module,
2012     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2013   int64_t num_existing_copies = 0;
2014   for (HloComputation* computation : module->computations(execution_threads)) {
2015     for (HloInstruction* instruction : computation->instructions()) {
2016       if (instruction->opcode() == HloOpcode::kCopy) {
2017         ++num_existing_copies;
2018       }
2019     }
2020   }
2021   return num_existing_copies;
2022 }
2023 
RemoveUnnecessaryCopies(HloOrdering * ordering,HloModule * module,bool check_live_range_ordering,const absl::flat_hash_set<absl::string_view> & execution_threads)2024 Status CopyInsertion::RemoveUnnecessaryCopies(
2025     HloOrdering* ordering, HloModule* module, bool check_live_range_ordering,
2026     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2027   XLA_VLOG_LINES(4, module->ToString());
2028   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
2029                       HloAliasAnalysis::Run(module, can_share_buffer_));
2030   CopyRemover copy_remover(*module, *alias_analysis, ordering,
2031                            check_live_range_ordering);
2032   if (VLOG_IS_ON(3)) {
2033     LOG(INFO) << "Removing unnecessary copies in " << module->name();
2034     LOG(INFO) << "Buffer values, in dependency order: ";
2035     for (const HloBuffer& buffer : alias_analysis->buffers()) {
2036       LOG(INFO) << "    HloBuffer " << buffer.id();
2037     }
2038   }
2039 
2040   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2041 
2042   int64_t num_existing_copies = GetNumExistingCopies(module, execution_threads);
2043   bool changed = true;
2044   int64_t num_iterations = -1;
2045   VLOG(6) << "Copy Insertion analyzing module with instructino count = "
2046           << module->instruction_count() << "\n";
2047   BoundNonLinearCompilerAnalysis allowance(module, name(), 10);
2048   while (changed) {
2049     CHECK_LE(++num_iterations, num_existing_copies);
2050     changed = false;
2051     VLOG(2) << "Running fixpoint iteration " << num_iterations
2052             << " of copy elision";
2053     for (HloComputation* computation :
2054          module->computations(execution_threads)) {
2055       VLOG(2) << "computation:" << computation->name() << "\n";
2056       for (HloInstruction* instruction : computation->instructions()) {
2057         VLOG(2) << instruction->ToString() << "\n";
2058         // The region_analysis_cost_now is always set to
2059         // use_region_based_live_range_analysis_ if it is < 0, in which case the
2060         // analysis is always performed.
2061         int64_t region_analysis_cost_now =
2062             (use_region_based_live_range_analysis_ == 0)
2063                 ? 0
2064                 : std::min(allowance.analysis_allowance(),
2065                            use_region_based_live_range_analysis_);
2066         if (instruction->opcode() == HloOpcode::kCopy) {
2067           if (copy_remover.TryElideCopy(instruction,
2068                                         &region_analysis_cost_now)) {
2069             changed = true;
2070             TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
2071             TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(
2072                 instruction->mutable_operand(0)));
2073             VLOG(6) << "succeeded in eliminating copy.\n";
2074           }
2075           if (allowance.ContinueAnalysis() && region_analysis_cost_now > 0) {
2076             VLOG(6) << "Copy Insertion analyzing module cost: "
2077                     << region_analysis_cost_now << "\n";
2078             VLOG(6) << "instruction:" << instruction->ToString() << "\n";
2079             allowance.DeductCost(region_analysis_cost_now);
2080             VLOG(6) << "allowance:" << allowance.analysis_allowance() << "\n";
2081           }
2082         }
2083       }
2084     }
2085   }
2086   return OkStatus();
2087 }
2088 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2089 StatusOr<bool> CopyInsertion::Run(
2090     HloModule* module,
2091     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2092   // Copy insertion is performed in three steps:
2093   //
2094   // (1) Add copies conservatively to guarantee that there is no live-range
2095   //     interference. This is done simplistically and usually results in more
2096   //     copies than is strictly necessary.
2097   //
2098   // (2) Using a more fine-grained analysis, remove as many copies that were
2099   //     added in (1) as possible while ensuring no live-range interference.
2100   //
2101   // (3) Add copies to resolve issues not related to live range interference
2102   //     such as parameters and constants live out of the entry computation.
2103   //
2104   // We add copies then remove them (step (1) then (2)) rather than simply
2105   // adding only the copies that are necessary because, in general, it is
2106   // difficult to figure out the minimal set of copies to add once there is
2107   // interference. On the other hand, it is easy to determine if removing a copy
2108   // will introduce interference.
2109   //
2110   // The final copy insertion in (3) is done separately to simplify the
2111   // implementation of copy removal in (2) which is the most complicated part of
2112   // the pass. As is, copy removal only has to reason about live range
2113   // interference. If all copies were added in step (1) then copy removal would
2114   // also have to reason about things like constants and parameters live out of
2115   // the computation.
2116   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2117   if (!call_graph->IsFlattened()) {
2118     return FailedPrecondition(
2119         "Call graph must be flattened before copy insertion.");
2120   }
2121 
2122   int64_t num_copies_before = GetNumExistingCopies(module, execution_threads);
2123 
2124   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module, execution_threads));
2125 
2126   // Simplify the tuple structures introduced by the deep copies. This should be
2127   // done before removing copies (RemoveUnnecessaryCopies) because tuple
2128   // simplification changes dependencies in the graph which changes live range
2129   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
2130   // instructions introduced by tuple simplification.
2131   TupleSimplifier tuple_simplifier;
2132   HloDCE dce;
2133   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status());
2134   TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
2135   DumpHloModuleDuringPassIfEnabled(
2136       name(), "after adding copies to resolve interference", *module);
2137 
2138   DependencyHloOrdering ordering(module);
2139   TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(&ordering, module,
2140                                              /*check_live_range_ordering=*/true,
2141                                              execution_threads));
2142   DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
2143                                    *module);
2144   TF_RETURN_IF_ERROR(
2145       AddSpecialCaseCopies(*call_graph, execution_threads, module));
2146   DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
2147                                    *module);
2148 
2149   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status());
2150   TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
2151 
2152   VLOG(1) << "Num copies before copy-insertion: " << num_copies_before;
2153   VLOG(1) << "Num copies after copy-insertion: "
2154           << GetNumExistingCopies(module, execution_threads);
2155 
2156   return true;
2157 }
2158 }  // namespace xla
2159