1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <optional>
21 #include <sstream>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "absl/types/any.h"
29 #include "tensorflow/compiler/xla/service/compile_time_cap.h"
30 #include "tensorflow/compiler/xla/service/dump.h"
31 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
32 #include "tensorflow/compiler/xla/service/hlo_buffer.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
42 #include "tensorflow/compiler/xla/status_macros.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/core/platform/logging.h"
47
48 namespace xla {
49 namespace {
50
51 using absl::StrAppend;
52
IsReadonlyEntryParameterValue(const HloValue & value)53 bool IsReadonlyEntryParameterValue(const HloValue& value) {
54 const HloComputation* computation = value.defining_instruction()->parent();
55 return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
56 computation == computation->parent()->entry_computation() &&
57 !computation->parent()->input_output_alias_config().ParameterHasAlias(
58 value.defining_instruction()->parameter_number(), value.index());
59 }
60
IsConstantValue(const HloValue & value)61 bool IsConstantValue(const HloValue& value) {
62 return value.defining_instruction()->opcode() == HloOpcode::kConstant;
63 }
64
ValueIsReadOnly(const HloValue & value)65 bool ValueIsReadOnly(const HloValue& value) {
66 return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
67 }
68
69 // Data structure describing the action which should be taken on parts of a
70 // computation buffers, with respect to the adding of special case copies.
71 struct SpecialCaseCopyPolicy {
72 // Insert a copy if the same buffer is found at multiple indices within the
73 // output tuple.
74 bool copy_root_replicated_buffers = false;
75 // If true, insert a copy if a buffer coming from a constant or a parameter
76 // is found within the output tuple.
77 bool copy_parameters_and_constants = false;
78 };
79
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)80 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
81 HloModule* module,
82 HloComputation* computation) {
83 SpecialCaseCopyPolicy policy;
84 if (computation == module->entry_computation()) {
85 policy.copy_parameters_and_constants = true;
86 policy.copy_root_replicated_buffers = true;
87 }
88 return policy;
89 }
90
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)91 bool ShouldCopyRootValue(const HloValue& value,
92 const SpecialCaseCopyPolicy& policy) {
93 if (policy.copy_parameters_and_constants) {
94 return ValueIsReadOnly(value);
95 }
96 return false;
97 }
98
99 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
100 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
101 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
102 // of 'to'.
103 //
104 // Requirements: 'from' and 'to' must have compatible shapes.
105 //
106 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
107 // the only index to copy. Prior to deep-copying we have:
108 //
109 //
110 // 'from'
111 // |
112 // ...
113 // |
114 // 'to'
115 //
116 // DeepCopyAndAddControlEdges produces:
117 //
118 // 'from'
119 // / \
120 // GTE GTE
121 // | |
122 // Copy |
123 // / \ /
124 // | Tuple
125 // | |
126 // ctrl ...
127 // edge |
128 // | |
129 // | 'to'
130 // | / \
131 // | GTE GTE
132 // \ | |
133 // Copy |
134 // \ /
135 // Tuple
136 //
137 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)138 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
139 const ShapeTree<bool>& indices_to_copy) {
140 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
141 // to/from_copy_tree hold the kCopy instruction produces by the deep
142 // copies. Elements which are not copied (indices_to_copy.element(index) ==
143 // false) have nullptr at that index.
144 ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
145 /*init_value=*/nullptr);
146 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
147 from->parent()->DeepCopyInstruction(
148 from, &indices_to_copy, &from_copy_tree));
149
150 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
151 TF_ASSIGN_OR_RETURN(
152 HloInstruction * to_deep_copy,
153 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
154
155 // Add control edges between the respective kCopy instructions.
156 for (const auto& pair : from_copy_tree) {
157 const ShapeIndex& index = pair.first;
158 HloInstruction* from_copy = pair.second;
159 HloInstruction* to_copy = to_copy_tree.element(index);
160 if (from_copy == nullptr) {
161 TF_RET_CHECK(to_copy == nullptr);
162 continue;
163 }
164 TF_RET_CHECK(to_copy != nullptr);
165 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
166 }
167
168 return std::make_pair(from_deep_copy, to_deep_copy);
169 }
170
171 // Compute the indices of the loop state which need copies in order to avoid
172 // live range interference. Generally, an element in the loop state does not
173 // need to be copied if the element is passed through transparently through the
174 // body.
175 //
176 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)177 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
178 const HloInstruction* xla_while,
179 ShapeTree<bool>* indices_to_copy) {
180 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
181
182 bool any_copies = false;
183 const HloInstruction* init = xla_while->operand(0);
184 for (auto& pair : *indices_to_copy) {
185 const ShapeIndex& index = pair.first;
186 bool& should_copy = pair.second;
187 // If there is any ambiguity, then loop state must be copied.
188 if (dataflow.GetValueSet(init, index).values().size() > 1 ||
189 dataflow.GetValueSet(xla_while, index).values().size() > 1) {
190 should_copy = true;
191 } else {
192 // If the output of the while instruction is not the same as the init
193 // value of the while, then this element is not passed through the body
194 // transparently and must be copied.
195 should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
196 dataflow.GetUniqueValueAt(init, index);
197 }
198 any_copies |= should_copy;
199 }
200 return any_copies;
201 }
202
203 // Compute the indices of the conditional outputs which need copies. Umambiguous
204 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)205 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
206 const HloInstruction* xla_conditional,
207 ShapeTree<bool>* indices_to_copy) {
208 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
209 xla_conditional->shape()));
210
211 bool any_copies = false;
212 for (auto& pair : *indices_to_copy) {
213 const ShapeIndex& index = pair.first;
214 bool& should_copy = pair.second;
215
216 CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
217
218 auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
219 // The conditional must be copied if the value is a phi.
220 should_copy =
221 value->is_phi() && value->defining_instruction() == xla_conditional;
222 any_copies |= should_copy;
223 }
224 return any_copies;
225 }
226
227 // Add kCopy instructions around the given kWhile instruction to eliminate any
228 // possible live range interference of HLO values assuming a dependency-based
229 // ordering. Copies are added conservatively. There likely are copies which are
230 // not strictly necessary, but they are removed later in the pass via
231 // RemoveUnnecessaryCopies.
232 //
233 // Elements (each ShapeIndex) in the loop state are considered independently. A
234 // copy is added to each element of the loop state which is modified in the
235 // while body. For each such element, a total of three kCopy instructions are
236 // added at following locations:
237 //
238 // (1) The init value is copied before the kWhile instruction. Before:
239 //
240 // (Init)
241 // |
242 // kWhile
243 // |
244 // ...
245 //
246 // After:
247 //
248 // (Init)
249 // |
250 // kCopy
251 // |
252 // kWhile
253 // |
254 // ...
255 //
256 // This copy is necessary in case the init value is simultaneously live
257 // with the kWhile.
258 //
259 // (2) Copies are added to the parameter and root of the while body
260 // computation. Before:
261 //
262 // kParameter
263 // |
264 // ...
265 // |
266 // (body root)
267 //
268 // After:
269 //
270 // kParameter
271 // |
272 // kCopy ----------+
273 // | |
274 // ... ctrl
275 // | edge
276 // (body root) |
277 // | |
278 // kCopy <---------+
279 //
280 // The root kCopy becomes the new root of the computation. Both copies are
281 // necessary to any potential interference between the parameter value and
282 // the root value. The control edge prevents potential interference
283 // between the copies themselves.
284 //
285 // If the loop state is a tuple then the above kCopy instructions are a deep
286 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
287 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)288 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
289 HloInstruction* xla_while) {
290 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
291 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
292
293 ShapeTree<bool> indices_to_copy(xla_while->shape());
294 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
295 &indices_to_copy)) {
296 VLOG(2) << "No copies necessary for kWhile instruction "
297 << xla_while->name();
298 return OkStatus();
299 }
300
301 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
302 for (auto& pair : indices_to_copy) {
303 if (pair.second) {
304 VLOG(2) << " " << pair.first;
305 }
306 }
307
308 // Deep copy init.
309 HloInstruction* while_init = xla_while->mutable_operand(0);
310 TF_ASSIGN_OR_RETURN(
311 HloInstruction * while_init_copy,
312 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
313 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
314
315 // Deep copy the parameter and the root. Extend a control edge from the copy
316 // of the parameter value to the corresponding copy value of the root.
317 HloComputation* body = xla_while->while_body();
318 HloInstruction* param = body->parameter_instruction(0);
319 HloInstruction* root = body->root_instruction();
320
321 // If param is the root then all indices should have been passed through the
322 // while body and we should have returned early above.
323 TF_RET_CHECK(param != root);
324
325 // Copy users before making a deep copy of the parameter as the deep copy
326 // will create new users of the parameter (eg, the GTE instructions of the
327 // deep copy).
328 std::vector<HloInstruction*> param_users = param->users();
329
330 TF_ASSIGN_OR_RETURN(auto pair,
331 DeepCopyAndAddControlEdges(param, root, indices_to_copy));
332
333 HloInstruction* param_copy = pair.first;
334 HloInstruction* root_copy = pair.second;
335
336 for (HloInstruction* user : param_users) {
337 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
338 }
339
340 body->set_root_instruction(root_copy);
341 return OkStatus();
342 }
343
344 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
345 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64_t operand_number)346 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
347 HloInstruction* in_place_op,
348 int64_t operand_number) {
349 VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
350 HloInstruction* operand = in_place_op->mutable_operand(operand_number);
351 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
352 in_place_op->parent()->DeepCopyInstruction(operand));
353 TF_RETURN_IF_ERROR(
354 operand->ReplaceUseWith(in_place_op, operand_number, deep_copy));
355 return OkStatus();
356 }
357
358 // Conservatively adds copies before root instruction of entry computation and
359 // each aliased parameter to resolve interference of aliased input and output
360 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
361 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)362 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
363 HloComputation* entry = module->entry_computation();
364 HloInstruction* root = entry->root_instruction();
365
366 ShapeTree<bool> output_indices_to_copy(root->shape());
367 std::vector<std::optional<ShapeTree<HloInstruction*>>> copied_parameters(
368 entry->num_parameters());
369 bool has_alias = false;
370 for (auto* param : entry->parameter_instructions()) {
371 bool param_has_alias = false;
372 ShapeTree<bool> param_indices_to_copy(param->shape());
373
374 module->input_output_alias_config().ForEachAlias(
375 [&](const ShapeIndex& output_index,
376 const HloInputOutputAliasConfig::Alias& alias) {
377 if (alias.parameter_number == param->parameter_number()) {
378 param_has_alias = true;
379 *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
380 true;
381 *(output_indices_to_copy.mutable_element(output_index)) = true;
382 }
383 });
384
385 if (!param_has_alias) {
386 continue;
387 }
388
389 TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
390 TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
391
392 has_alias = true;
393 // Store a snapshot of users before DeepCopyInstruction, as
394 // DeepCopyInstruction introduces new users of the instruction.
395 std::vector<HloInstruction*> users = param->users();
396 ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
397 /*init_value=*/nullptr);
398 TF_ASSIGN_OR_RETURN(HloInstruction * copied,
399 entry->DeepCopyInstruction(
400 param, ¶m_indices_to_copy, ¶m_copy_tree));
401 if (param == root) {
402 entry->set_root_instruction(copied);
403 root = copied;
404 }
405 for (HloInstruction* user : users) {
406 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
407 }
408
409 copied_parameters[param->parameter_number()] = param_copy_tree;
410 }
411
412 if (!has_alias) {
413 return OkStatus();
414 }
415
416 // Add copies before root instruction.
417 ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
418 /*init_value=*/nullptr);
419
420 TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
421 root->parent()->DeepCopyInstruction(
422 root, &output_indices_to_copy, &output_copy_tree));
423
424 // Add control dependencies between the input/output copies.
425 TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
426 [&](const ShapeIndex& output_index,
427 const HloInputOutputAliasConfig::Alias& alias) -> Status {
428 if (!copied_parameters[alias.parameter_number]) {
429 return OkStatus();
430 }
431 HloInstruction* from =
432 copied_parameters[alias.parameter_number]->element(
433 alias.parameter_index);
434 HloInstruction* to = output_copy_tree.element(output_index);
435
436 TF_RET_CHECK(from != nullptr);
437 TF_RET_CHECK(to != nullptr);
438 TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
439 return OkStatus();
440 }));
441
442 entry->set_root_instruction(root_copied);
443
444 return OkStatus();
445 }
446
447 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)448 Status StripControlDependenciesFrom(HloInstruction* instruction) {
449 while (!instruction->control_successors().empty()) {
450 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
451 instruction->control_successors().front()));
452 }
453
454 while (!instruction->control_predecessors().empty()) {
455 TF_RETURN_IF_ERROR(
456 instruction->control_predecessors().front()->RemoveControlDependencyTo(
457 instruction));
458 }
459
460 return OkStatus();
461 }
462
463 class LiveRangeRegions {
464 public:
465 struct InstructionInfo {
InstructionInfoxla::__anonedb7fbf20111::LiveRangeRegions::InstructionInfo466 InstructionInfo() : value_definition(nullptr), is_definition(false) {}
467
468 // The instruction that defines the value being used. It basically saves
469 // the defining instruction of each HloValue.
470 HloInstruction* value_definition;
471 // Whether the instruction defines a new value (or merely uses one). This
472 // basically remembers whether the instruction actually creates an HloValue
473 // or merely uses one, from a collection of given HloValues. Note that if
474 // is_definition = true, it merely says the instruction creates a new
475 // HloValue with or without defining a new one. For example, kAdd create a
476 // new HloValue (can be value_definition), but tuples or get-tuple-element,
477 // create a new HloValue aliasing without defining a new value (cannot be
478 // value_definition).
479 bool is_definition;
480 };
481 // Map instructions that use a value to the defining instruction of the value.
482 // Because all values must belong to the same live range, an instruction can
483 // have at most a single value-defining instruction; otherwise the multiple
484 // incoming active values would share a single buffer, which is not allowed.
485 // The value-defining and value-use instructions do not have to belong to the
486 // same computation, but the value use needs to be nested within the defining
487 // computation.
488 typedef absl::flat_hash_map<HloInstruction*, InstructionInfo> InstructionMap;
489 typedef std::pair<HloInstruction*, InstructionInfo> InstructionEntry;
490 // Map each computation to its immediately contained instructions.
491 typedef absl::flat_hash_map<const HloComputation*, InstructionMap>
492 ComputationMap;
493
operator [](const HloComputation * computation)494 InstructionMap& operator[](const HloComputation* computation) {
495 if (computation_map_.find(computation) == computation_map_.end()) {
496 computation_vector_.push_back(computation);
497 }
498 return computation_map_[computation];
499 }
500
operator [](const HloComputation * computation) const501 const InstructionMap& operator[](const HloComputation* computation) const {
502 ComputationMap::const_iterator p = computation_map_.find(computation);
503 CHECK(p != computation_map_.end());
504 return p->second;
505 }
begin() const506 ComputationMap::const_iterator begin() const {
507 return computation_map_.begin();
508 }
end() const509 ComputationMap::const_iterator end() const { return computation_map_.end(); }
size() const510 int64_t size() const {
511 CHECK_EQ(computation_vector_.size(), computation_map_.size());
512 return computation_vector_.size();
513 }
empty() const514 bool empty() const { return size() == 0; }
Computation(int64_t index) const515 const HloComputation* Computation(int64_t index) const {
516 return computation_vector_[index];
517 }
contains(const HloInstruction * instr) const518 bool contains(const HloInstruction* instr) const {
519 CHECK_NE(instr, nullptr);
520 auto* computation = instr->parent();
521 auto p = computation_map_.find(computation);
522 if (p == computation_map_.end()) {
523 return false;
524 }
525 auto instr_map = (*p).second;
526 return instr_map.find(instr) != instr_map.end();
527 }
528
529 private:
530 ComputationMap computation_map_;
531 absl::InlinedVector<const HloComputation*, 5> computation_vector_;
532 };
533
534 namespace {
535 // Represent relations between the locations of two regions of instructions,
536 // each region can include 0-n instructions.
537 class Relation {
538 public:
539 enum RuntimeOrder {
540 // Indicate that there is no overlap whatsoever between the two regions.
541 kNoOverlap = 0,
542 // Indicate that the first region includes the same set of instructions as
543 // the second region.
544 kSameInstr = 1,
545 // Indicate that the first region is entirely before the second region
546 // starts.
547 kBeforeStart = 2,
548 // Indicate that the first region is before the second region ends.
549 kBeforeStartOrSameInstr = kBeforeStart | kSameInstr,
550 // Indicate that the first region is entirely after the second region ends.
551 kAfterEnd = 4,
552 // Indicate that the first region is after the second region
553 // starts, with some instructions before the second region ends.
554 kAfterEndOrSameInstr = kAfterEnd | kSameInstr,
555 // Indicate that the first region overlaps with the second one, but share no
556 // common instructions.
557 kBeforeStartOrAfterEnd = kBeforeStart | kAfterEnd,
558 // Indicate that the first region overlaps with the second one, and have
559 // some common instructions.
560 kBeforeOrAfterOrOverlap = kBeforeStart | kAfterEnd | kSameInstr,
561 };
Relation()562 Relation() : intercept_def_use_(false) {}
Relation(RuntimeOrder order,bool intercept_def_use=false)563 explicit Relation(RuntimeOrder order, bool intercept_def_use = false)
564 : intercept_def_use_(intercept_def_use) {
565 orders_.push_back(order);
566 }
Relation(const Relation & that)567 Relation(const Relation& that)
568 : intercept_def_use_(that.intercept_def_use_), orders_(that.orders_) {}
operator ==(const Relation & that) const569 bool operator==(const Relation& that) const {
570 return intercept_def_use_ == that.intercept_def_use_ &&
571 absl::c_equal(orders_, that.orders_);
572 }
573
574 // Return whether the runtime ordering may imply interception, assuming it
575 // models the relation between a modifying and a use instruction.
UseImpliesInterception() const576 bool UseImpliesInterception() const {
577 CHECK_EQ(orders_.size(), 1);
578 return UseImpliesInterception(orders_[0]);
579 }
580 // Return whether the runtime ordering may imply interception, assuming it
581 // models the relation between a modifying and a definition instruction.
DefinitionImpliesInterception() const582 bool DefinitionImpliesInterception() const {
583 CHECK_EQ(orders_.size(), 1);
584 return DefinitionImpliesInterception(orders_[0]);
585 }
586 // Return whether the current relation models a modifying instruction that
587 // intercepts the dataflow of another live range region.
InterceptDefUse() const588 bool InterceptDefUse() const { return intercept_def_use_; }
589 // Update interception state to the given value.
UpdateInterception(bool value)590 void UpdateInterception(bool value) {
591 CHECK_EQ(orders_.size(), 1);
592 intercept_def_use_ = value;
593 }
GetRuntimeOrder() const594 Relation::RuntimeOrder GetRuntimeOrder() const {
595 if (orders_.empty()) {
596 return Relation::kNoOverlap;
597 }
598 CHECK_EQ(orders_.size(), 1);
599 return orders_[0];
600 }
601 // Return whether the current relation implies two overlapping regions.
RuntimeOrderOverlap() const602 bool RuntimeOrderOverlap() const {
603 return absl::c_any_of(orders_, ImpliesOverlap);
604 }
RuntimeOrderIsUnordered() const605 bool RuntimeOrderIsUnordered() const {
606 return orders_.size() == 1 && orders_[0] == kBeforeStartOrAfterEnd;
607 }
RuntimeOrderIsNoOverlap() const608 bool RuntimeOrderIsNoOverlap() const {
609 return orders_.empty() || (orders_.size() == 1 && orders_[0] == kNoOverlap);
610 }
RuntimeOrderIsRunBefore() const611 bool RuntimeOrderIsRunBefore() const {
612 return orders_.size() == 1 && orders_[0] == kBeforeStart;
613 }
RuntimeOrderIsRunAfter() const614 bool RuntimeOrderIsRunAfter() const {
615 return orders_.size() == 1 && orders_[0] == kAfterEnd;
616 }
ToString() const617 std::string ToString() const {
618 return absl::StrCat("Interception = ", intercept_def_use_, ";",
619 absl::StrJoin(orders_, ","));
620 }
621
DefinitionImpliesInterception(RuntimeOrder definition)622 static bool DefinitionImpliesInterception(RuntimeOrder definition) {
623 return (definition == kAfterEnd || definition == kBeforeStartOrAfterEnd);
624 }
UseImpliesInterception(RuntimeOrder use)625 static bool UseImpliesInterception(RuntimeOrder use) {
626 return (use == kBeforeStart || use == kBeforeStartOrAfterEnd);
627 }
628
629 // Summarize additional relations into a single runtime ordering, assuming
630 // both relations are modeling constraints of the same source instruction.
UnionRelationFromSameSource(const Relation & rel)631 void UnionRelationFromSameSource(const Relation& rel) {
632 CHECK_LE(orders_.size(), 1);
633 CHECK_EQ(rel.orders_.size(), 1);
634 if (orders_.empty()) {
635 orders_.push_back(rel.orders_[0]);
636 } else {
637 orders_[0] = Union(orders_[0], rel.orders_[0]);
638 }
639 intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
640 }
641
642 // Summarize additional relations into disjoint runtime orderings, assuming
643 // the relations are modeling constraints of different source instructions.
UnionRelationFromDifferentSource(const Relation & rel)644 void UnionRelationFromDifferentSource(const Relation& rel) {
645 if (rel.orders_.empty()) {
646 return;
647 }
648 CHECK_EQ(rel.orders_.size(), 1);
649 intercept_def_use_ = intercept_def_use_ || rel.intercept_def_use_;
650 for (auto& local_order : orders_) {
651 if (OverwriteIfSubsume(rel.orders_[0], &local_order)) {
652 return;
653 }
654 }
655 orders_.push_back(rel.orders_[0]);
656 }
657
ReverseRuntimeOrder(RuntimeOrder order)658 static Relation::RuntimeOrder ReverseRuntimeOrder(RuntimeOrder order) {
659 switch (order) {
660 case kNoOverlap:
661 case kSameInstr:
662 case kBeforeStartOrAfterEnd:
663 case kBeforeOrAfterOrOverlap:
664 return order;
665 case kBeforeStart:
666 return kAfterEnd;
667 case kBeforeStartOrSameInstr:
668 return kAfterEndOrSameInstr;
669 case kAfterEnd:
670 return kBeforeStart;
671 case kAfterEndOrSameInstr:
672 return kBeforeStartOrSameInstr;
673 }
674 }
675
676 private:
677 // Indicate that the second region may intercept the def-use dataflow of the
678 // first region, if their buffers are combined.
679 bool intercept_def_use_;
680 // Remember the different runtime orderings of different instructions.
681 absl::InlinedVector<RuntimeOrder, 4> orders_;
682
Union(RuntimeOrder o1,RuntimeOrder o2)683 static RuntimeOrder Union(RuntimeOrder o1, RuntimeOrder o2) {
684 return static_cast<Relation::RuntimeOrder>(o1 | o2);
685 }
ImpliesOverlap(RuntimeOrder o)686 static bool ImpliesOverlap(RuntimeOrder o) {
687 return o >= RuntimeOrder::kBeforeStartOrAfterEnd;
688 }
689 // Returns whether ordering constraint o1 includes o2 as a subset, when they
690 // represent runtime orderings (interleavings) of two different regions.
Subsume(RuntimeOrder o1,RuntimeOrder o2)691 static bool Subsume(RuntimeOrder o1, RuntimeOrder o2) {
692 return Union(o1, o2) == o1;
693 }
694 // Overwrites o1 with o2 if o2 subsumes o1 (as defined above by the Subsume
695 // function). Return whether o2 is subsumed by the new value in o1.
OverwriteIfSubsume(RuntimeOrder o2,RuntimeOrder * o1)696 static bool OverwriteIfSubsume(RuntimeOrder o2, RuntimeOrder* o1) {
697 if (*o1 == o2) {
698 return true;
699 }
700 CHECK_NE(o1, nullptr);
701 // Overwrite o1 with o2 if it is subsumed by o2.
702 if (Subsume(o2, *o1)) {
703 *o1 = o2;
704 return true;
705 } else if (Subsume(*o1, o2)) {
706 // If o2 is already subsumed by o1, do nothing.
707 return true;
708 }
709 // If neither o1 nor o2 is subsumed by the other, return false, so that o2
710 // will be inserted as a separate entry representing all possible orderings.
711 return false;
712 }
713 };
714
715 class ComputeRelativeLocation {
716 public:
717 typedef LiveRangeRegions::InstructionEntry InstructionEntry;
ComputeRelativeLocation(HloOrdering * ordering)718 explicit ComputeRelativeLocation(HloOrdering* ordering)
719 : ordering_(ordering) {
720 VLOG(3) << "New analysis\n";
721 }
722
723 // Compute locationing constraints between two instructions. Here entry2 is
724 // the source instruction, in that the returned value describes the relation
725 // of entry2 in terms of whether it is before or after entry1, and whether it
726 // can intercept the def-use data flow of entry1.
Compute(const InstructionEntry & entry1,const InstructionEntry & entry2,bool instr2_can_modify)727 Relation Compute(const InstructionEntry& entry1,
728 const InstructionEntry& entry2, bool instr2_can_modify) {
729 auto def = entry1.second.value_definition;
730 auto use = entry1.first;
731 Relation::RuntimeOrder order =
732 ComputeRuntimeOrdering(entry2.first, entry1.first);
733 if (order == Relation::kSameInstr &&
734 entry1.second.is_definition != entry2.second.is_definition) {
735 if (entry1.second.is_definition) {
736 order = Relation::kBeforeStart;
737 } else {
738 order = Relation::kAfterEnd;
739 }
740 }
741 bool intercept = AlwaysForceInterception(entry2.first);
742 if (def == nullptr || !instr2_can_modify) {
743 return Relation(order, intercept);
744 }
745 // If the definition and use are parameter and return (root) of the parent
746 // computation, then any modification is considered intercepting.
747 if (def->opcode() == HloOpcode::kParameter &&
748 use == use->parent()->root_instruction()) {
749 VLOG(3) << "Setting interception due to parameter/root relation\n";
750 return Relation(order, true);
751 }
752 if (Relation::UseImpliesInterception(order)) {
753 auto order2 = ComputeRuntimeOrdering(entry2.first, def);
754 if (Relation::DefinitionImpliesInterception(order2)) {
755 VLOG(3) << "Setting interception for " << def->ToString()
756 << " with use:" << entry1.first->ToString() << "\n";
757 intercept = true;
758 }
759 }
760 return Relation(order, intercept);
761 }
762
763 // Return the relative locations (defined above) of range2 in relation to
764 // instructions in range1. Return kNoOverlap if range2 is outside of range1.
Compute(const LiveRangeRegions & range1,const LiveRangeRegions & range2)765 Relation Compute(const LiveRangeRegions& range1,
766 const LiveRangeRegions& range2) {
767 Relation dir_src_dest;
768 for (int64_t index = 0; index < range1.size(); index++) {
769 auto* computation1 = range1.Computation(index);
770 for (const auto& computation_entry2 : range2) {
771 auto* computation2 = computation_entry2.first;
772 for (auto instr_entry2 : computation_entry2.second) {
773 if (!ordering_->call_graph().Dominates(computation1, computation2)) {
774 continue;
775 }
776 VLOG(3) << "Locationing " << instr_entry2.first->ToString();
777 // Saves relations between instr2 and other instructions in range1.
778 bool instr2_can_modify =
779 InstructionCanIntercept(instr_entry2, range1);
780 Relation instr2_relation;
781 std::vector<InstructionEntry> unordered_ops;
782 bool unordered_intercept = false;
783 for (auto instr_entry1 : range1[computation1]) {
784 auto rel = Compute(instr_entry1, instr_entry2, instr2_can_modify);
785 VLOG(3) << "new relation with:" << instr_entry1.first->ToString()
786 << " = " << rel.ToString() << "\n";
787 if (!rel.RuntimeOrderIsUnordered()) {
788 instr2_relation.UnionRelationFromSameSource(rel);
789 } else {
790 unordered_ops.push_back(instr_entry1);
791 unordered_intercept |= rel.InterceptDefUse();
792 }
793 VLOG(3) << "instr2 relation:" << instr2_relation.ToString() << "\n";
794 }
795 // Here instru2_relation is guaranteed to have at most a single entry,
796 // because it was initialized to be empty, and has been updated only
797 // via instr2_relation.UnionRelationFromSameSource(rel), which
798 // maintains that the updated result has only a single entry.
799 if (!ForceRuntimeOrder(unordered_ops, instr_entry2,
800 instr2_relation.GetRuntimeOrder())) {
801 VLOG(3) << "Unable to force ordering of unordered ops\n";
802 instr2_relation.UnionRelationFromSameSource(Relation(
803 Relation::kBeforeStartOrAfterEnd, unordered_intercept));
804 }
805 dir_src_dest.UnionRelationFromDifferentSource(instr2_relation);
806 VLOG(3) << "Resulting relation : " << dir_src_dest.ToString() << "\n";
807 }
808 }
809 }
810 return dir_src_dest;
811 }
812
813 // Return whether control dependences, if exist, are added successfully.
AddControlDependenceForUnorderedOps()814 bool AddControlDependenceForUnorderedOps() {
815 if (ctrl_deps_.empty()) {
816 return true;
817 }
818 PredecessorHloOrdering* ordering =
819 dynamic_cast<PredecessorHloOrdering*>(ordering_);
820 if (ordering == nullptr) {
821 // Support force ordering of unordered-ops only when using predecssor
822 // ordering.
823 return false;
824 }
825 for (const auto& comp_it : ctrl_deps_) {
826 HloComputation* parent = comp_it.first;
827 HloReachabilityMap& reachability_map = ordering->reachability_map(parent);
828 for (const auto& instr_it : comp_it.second) {
829 HloInstruction* entry1 = instr_it.first;
830 for (HloInstruction* entry2 : instr_it.second) {
831 VLOG(3) << "Add control dependence between " << entry2->ToString();
832 VLOG(3) << "\n vs " << entry1->ToString() << "\n";
833 TF_CHECK_OK(entry2->AddControlDependencyTo(entry1));
834 }
835 reachability_map.UpdateReachabilityThroughInstruction(entry1);
836 for (HloInstruction* entry2 : instr_it.second) {
837 DCHECK(ordering_->GetExecutionConstraint(entry1, entry2) ==
838 HloOrdering::ExecutionConstraint::kRunAfter);
839 }
840 }
841 }
842 return true;
843 }
844
845 private:
846 enum ComputeStatus {
847 kFullyComputed,
848 kPartiallyComputed,
849 kNotComputed,
850 };
851 typedef std::pair<ComputeStatus, Relation::RuntimeOrder> SavedRelation;
852
853 // Returns whether it is safe to force the desired_relation ordering between
854 // all operations in unordered_ops and entry2. If safe, save the new enforced
855 // ordering relations.
ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,const InstructionEntry entry2,Relation::RuntimeOrder desired_relation)856 bool ForceRuntimeOrder(absl::Span<const InstructionEntry> unordered_ops,
857 const InstructionEntry entry2,
858 Relation::RuntimeOrder desired_relation) {
859 if (unordered_ops.empty()) {
860 return true;
861 }
862 if (desired_relation != Relation::kBeforeStart &&
863 desired_relation != Relation::kAfterEnd) {
864 return false;
865 }
866 auto ModifiesNonCopy = [](HloInstruction* instr, const HloInstruction* op) {
867 auto in_place = HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr);
868 if (in_place.empty()) {
869 return false;
870 }
871 return absl::c_any_of(
872 in_place, [&](const std::pair<HloOperandIndex, ShapeIndex>&
873 operand_and_output_index) {
874 auto* op2 =
875 instr->operand(operand_and_output_index.first.operand_number);
876 return (op == nullptr) ? (op2->opcode() == HloOpcode::kCopy)
877 : (op2 == op);
878 });
879 };
880 for (const InstructionEntry& entry1 : unordered_ops) {
881 // Only consider instructions in the same computation.
882 if (entry1.first->parent() != entry2.first->parent()) {
883 return false;
884 }
885 HloInstruction* pred = (desired_relation == Relation::kBeforeStart)
886 ? entry2.first
887 : entry1.first;
888 HloInstruction* succ = (desired_relation == Relation::kBeforeStart)
889 ? entry1.first
890 : entry2.first;
891 if (pred == pred->parent()->root_instruction()) {
892 return false;
893 }
894 if (succ->opcode() == HloOpcode::kCopy &&
895 ModifiesNonCopy(pred, succ->operand(0))) {
896 VLOG(3) << "Failed to force unordered op ordering due to copy ordering "
897 << " between " << pred->ToString() << "\n";
898 VLOG(3) << " vs. " << succ->ToString() << "\n";
899 return false;
900 }
901 }
902 for (const InstructionEntry& entry1 : unordered_ops) {
903 Save(entry2.first, entry1.first, desired_relation, true);
904 }
905 return true;
906 }
907
AlwaysForceInterception(HloInstruction * instr)908 static bool AlwaysForceInterception(HloInstruction* instr) {
909 // The following communication operations can have some unexpected side
910 // effects, when synchronizing across processes. Therefore, we
911 // conservatively try provide dedicated buffers to these operations instead
912 // of allowing them to share buffers with other operations, as the reuse may
913 // cause unexpected interferences.
914 if (HloDataflowAnalysis::IsAsynchronousOperationStart(instr->opcode()) ||
915 HloDataflowAnalysis::IsAsynchronousOperationDone(instr->opcode())) {
916 return true;
917 }
918 switch (instr->opcode()) {
919 // TODO(b/190903339): It appears that collectivePermute needs to be
920 // followed by a copy when escaping through a computation root.
921 case HloOpcode::kCollectivePermute:
922 return true;
923 default:
924 return false;
925 }
926 }
927
928 // Returns whether the given instr may intercept the def-use flow of another
929 // ongoing live range if its buffer is combined with the other live range.
930 // The function should return true if instr creates a new HloValue that could
931 // overwrite an existing HloValue in the combined buffer.
932 // More specifically, here we are looking for operations that create new
933 // values, e.g., add, subtract, in contrast to HLOs that merely create
934 // aliasings among existing values, e.g., tuple, get-tuple-element. Any of the
935 // new values created by operations such as add or subtract, when included as
936 // definition operations in a live range, are aliases of the buffer to be
937 // allocated to the live range and so are treated as they may be modifying the
938 // targeting buffer.
InstructionCanIntercept(const InstructionEntry & entry,const LiveRangeRegions & region)939 bool InstructionCanIntercept(const InstructionEntry& entry,
940 const LiveRangeRegions& region) {
941 auto instr = entry.first;
942 if (!entry.second.is_definition) {
943 // If the instruction only uses the value, it can intercept only if it
944 // modifies the buffer in place.
945 return !HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr).empty();
946 }
947 switch (instr->opcode()) {
948 // If the copy instruction is used to connect two live range regions,
949 // it does not overwrite the combined buffer with new values.
950 case HloOpcode::kCopy:
951 // Checking the copy simply copies from the other live range with no
952 // layout conflicts.
953 if (region.contains(instr->operand(0)) &&
954 ShapeUtil::Equal(instr->shape(), instr->operand(0)->shape())) {
955 return false; // Cannot intercept.
956 }
957 return true;
958 // The following operations merely create aliases among the HloValues.
959 case HloOpcode::kParameter:
960 case HloOpcode::kTuple:
961 case HloOpcode::kGetTupleElement:
962 // Here we consider all the compound operations (e.g., conditionals and
963 // while loops) as if they do not modify any HloValue, with the argument
964 // being that any value modifying operation contained inside will be
965 // considered separately to make sure the kIntercept relation being
966 // recorded as appropriate. Since the compound operations may or may not
967 // modify, not treating them as value modifying would make the algorithm
968 // less conservative.
969 case HloOpcode::kWhile:
970 case HloOpcode::kCall:
971 case HloOpcode::kConditional:
972 return false;
973 default:
974 return true;
975 }
976 return true;
977 }
978
AlreadyComputed(HloInstruction * op1,HloInstruction * op2)979 SavedRelation AlreadyComputed(HloInstruction* op1, HloInstruction* op2) {
980 auto p2 = saved_relations_.find(op2);
981 if (p2 != saved_relations_.end()) {
982 auto p1 = (*p2).second.find(op1);
983 if (p1 != (*p2).second.end()) {
984 return SavedRelation(kFullyComputed, (*p1).second);
985 }
986 }
987 p2 = saved_relations_.find(op1);
988 if (p2 != saved_relations_.end()) {
989 auto p1 = (*p2).second.find(op2);
990 if (p1 != (*p2).second.end()) {
991 return SavedRelation(kPartiallyComputed,
992 Relation::ReverseRuntimeOrder((*p1).second));
993 }
994 }
995 return SavedRelation(kNotComputed, Relation::kNoOverlap);
996 }
997
Save(HloInstruction * entry1,HloInstruction * entry2,const Relation::RuntimeOrder relation,bool is_unordered_originally=false)998 Relation::RuntimeOrder Save(HloInstruction* entry1, HloInstruction* entry2,
999 const Relation::RuntimeOrder relation,
1000 bool is_unordered_originally = false) {
1001 CHECK_EQ(AlreadyComputed(entry1, entry2).first, kNotComputed);
1002 // Do not save unordered relations.
1003 CHECK_NE(relation, Relation::kBeforeStartOrAfterEnd);
1004 saved_relations_[entry2][entry1] = relation;
1005 if (is_unordered_originally) {
1006 CHECK(relation == Relation::kBeforeStart ||
1007 relation == Relation::kAfterEnd)
1008 << relation;
1009 HloInstruction* pred =
1010 (relation == Relation::kBeforeStart) ? entry1 : entry2;
1011 HloInstruction* succ =
1012 (relation == Relation::kBeforeStart) ? entry2 : entry1;
1013 VLOG(3) << "Save unordered relation: " << pred->ToString() << "\n";
1014 VLOG(3) << " vs " << succ->ToString() << "\n";
1015 CHECK_EQ(succ->parent(), pred->parent());
1016 auto& dep_vec = ctrl_deps_[succ->parent()][succ];
1017 for (HloInstruction*& op : dep_vec) {
1018 auto rel = AlreadyComputed(pred, op);
1019 if (rel.first != kNotComputed) {
1020 if (rel.second == Relation::kAfterEnd) {
1021 op = pred;
1022 } else {
1023 CHECK(rel.second == Relation::kBeforeStart);
1024 }
1025 return relation;
1026 }
1027 }
1028 VLOG(2) << "Forcing unordered:" << pred->ToString() << "\n";
1029 VLOG(2) << " vs " << succ->ToString() << "\n";
1030 dep_vec.push_back(pred);
1031 }
1032 return relation;
1033 }
1034
1035 // Compute the runtime ordering constraints between two instructions.
ComputeRuntimeOrdering(HloInstruction * instr1,HloInstruction * instr2)1036 Relation::RuntimeOrder ComputeRuntimeOrdering(HloInstruction* instr1,
1037 HloInstruction* instr2) {
1038 auto saved_relation = AlreadyComputed(instr1, instr2);
1039 if (saved_relation.first != kNotComputed) {
1040 VLOG(3) << "Already computed between " << instr1->ToString() << "\n vs "
1041 << instr2->ToString() << "\n";
1042 return saved_relation.second;
1043 }
1044 auto constraint = ordering_->GetExecutionConstraint(instr1, instr2);
1045 switch (constraint) {
1046 case HloOrdering::ExecutionConstraint::kIsSame:
1047 return Save(instr1, instr2, Relation::kSameInstr);
1048 case HloOrdering::ExecutionConstraint::kRunBeforeEnd:
1049 return Save(instr1, instr2, Relation::kBeforeStartOrSameInstr);
1050 case HloOrdering::ExecutionConstraint::kRunBeforeStart:
1051 return Save(instr1, instr2, Relation::kBeforeStart);
1052 case HloOrdering::ExecutionConstraint::kRunAfter:
1053 return Save(instr1, instr2, Relation::kAfterEnd);
1054 case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
1055 case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
1056 return Save(instr1, instr2, Relation::kNoOverlap);
1057 case HloOrdering::ExecutionConstraint::kUnordered: {
1058 if (instr1->parent() != instr2->parent()) {
1059 return Relation::kBeforeStartOrAfterEnd;
1060 }
1061 auto ControlDependenceBefore = [&](HloInstruction* op1,
1062 HloInstruction* op2) {
1063 auto constraint = ComputeRuntimeOrdering(op1, op2);
1064 if (constraint == Relation::kBeforeStart ||
1065 constraint == Relation::kSameInstr ||
1066 constraint == Relation::kBeforeStartOrSameInstr) {
1067 return true;
1068 } else {
1069 return false;
1070 }
1071 };
1072 if (!ctrl_deps_.empty()) {
1073 auto ctrl_deps = ctrl_deps_[instr1->parent()];
1074 if (absl::c_any_of(ctrl_deps[instr2], [&](HloInstruction* pred2) {
1075 return ControlDependenceBefore(instr1, pred2);
1076 })) {
1077 VLOG(2) << "control-dependent: " << instr1->ToString() << "\n";
1078 VLOG(2) << "vs " << instr2->ToString() << "\n";
1079 return Save(instr1, instr2, Relation::kBeforeStart);
1080 } else if (absl::c_any_of(
1081 ctrl_deps[instr1], [&](HloInstruction* pred1) {
1082 return ControlDependenceBefore(instr2, pred1);
1083 })) {
1084 VLOG(2) << "control-dependent: " << instr2->ToString() << "\n";
1085 VLOG(2) << "vs " << instr1->ToString() << "\n";
1086 return Save(instr1, instr2, Relation::kAfterEnd);
1087 }
1088 }
1089 // Don't save the result for unordered operations, so they can be
1090 // refined later.
1091 return Relation::kBeforeStartOrAfterEnd;
1092 }
1093 }
1094 }
1095
1096 HloOrdering* ordering_;
1097 absl::flat_hash_map<
1098 HloInstruction*,
1099 absl::flat_hash_map<HloInstruction*, Relation::RuntimeOrder>>
1100 saved_relations_;
1101 absl::flat_hash_map<
1102 HloComputation*,
1103 absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>>
1104 ctrl_deps_;
1105 };
1106 } // namespace
1107
1108 // Class which tracks the HLO values within each HLO buffer in the module
1109 // during copy removal.
1110 //
1111 // The values are held in a linked list where there is one list for each
1112 // buffer. Removing a copy instruction merges together the values in the
1113 // source buffer of the copy to the destination buffer of the copy. This class
1114 // tracks these value lists as copies are removed from the graph (and value
1115 // lists are merged).
1116 //
1117 // The CopyRemover object is initialized to match the state of
1118 // HloAliasAnalysis. However, as copies are removed this state diverges. The
1119 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
1120 // a fully updatable alias analysis is very slow.
1121 class CopyRemover {
1122 public:
1123 // The values held in a single HLO buffer are represented using a linked
1124 // list. An element type in this list is ValueNode.
1125 //
1126 // This linked list is hand-rolled to enable efficient splicing of lists
1127 // using only references to list elements without knowing which lists are
1128 // being spliced. std::list requires a reference to the list object to
1129 // splice.
1130 struct ValueNode {
ValueNodexla::__anonedb7fbf20111::CopyRemover::ValueNode1131 explicit ValueNode(const HloValue* v) : value(v) {}
1132
1133 const HloValue* value;
1134
1135 // The uses are maintained outside of HloValue::uses() because
1136 // HloValue::uses() is not updatable (a fully updatable dataflow analysis
1137 // is slow).
1138 std::vector<const HloUse*> uses;
1139
1140 // next/prev elements in the linked list. The list is circularly linked so
1141 // these values are never null for elements in the list.
1142 ValueNode* prev = nullptr;
1143 ValueNode* next = nullptr;
1144 };
1145
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,HloOrdering * ordering,bool check_live_range_ordering)1146 CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
1147 HloOrdering* ordering, bool check_live_range_ordering)
1148 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
1149 // Construct a list for each HLO buffer in the alias analysis. Maintain a
1150 // map from HloValue to the respective list element representing that
1151 // value. The map is used to construct the copy info map below.
1152 absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
1153 // Perform check only if the default dependence-based ordering is used.
1154 for (const HloBuffer& buffer : alias_analysis.buffers()) {
1155 // No copies should have been inserted within fused computations, so no
1156 // need to remove them. HloOrdering isn't compatible with HloValues inside
1157 // fusions, so skip copy removal for them.
1158 if (buffer.values().at(0)->defining_instruction()->IsFused()) {
1159 continue;
1160 }
1161 if (check_live_range_ordering) {
1162 // Verify values contained in the buffer are strictly ordered. This
1163 // should always be the case after adding copies to eliminate
1164 // interference. Specifically, the addition of the control flow edges
1165 // between copies added around aliased operations (kWhile) guarantees
1166 // this strict order.
1167 for (const HloValue* value_a : buffer.values()) {
1168 if (value_a->shape().IsToken()) {
1169 // Token values have no representation and cannot interfere.
1170 continue;
1171 }
1172 for (const HloValue* value_b : buffer.values()) {
1173 if (value_a != value_b) {
1174 DCHECK(ordering_->LiveRangeStrictlyBefore(
1175 *value_a, *value_b, dataflow_,
1176 /*use_is_always_before_def_in_same_instr=*/true) ||
1177 ordering_->LiveRangeStrictlyBefore(
1178 *value_b, *value_a, dataflow_,
1179 /*use_is_always_before_def_in_same_instr=*/true))
1180 << value_a->ToString() << " and " << value_b->ToString()
1181 << " are not ordered";
1182 }
1183 }
1184 }
1185 }
1186
1187 std::vector<const HloValue*> values = buffer.values();
1188 absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
1189 return ordering_->IsDefinedBefore(*a, *b);
1190 });
1191
1192 // Create a list containing all of the values in the buffer.
1193 AddValueList(values, &value_to_node);
1194 }
1195
1196 // Create copy_map_ which contains the source and destination values
1197 // of all copies.
1198 CreateCopyMap(module, value_to_node);
1199
1200 XLA_VLOG_LINES(3, ToString());
1201 TF_DCHECK_OK(Verify());
1202 }
1203
1204 // Add a list containing the given values to CopyRemover. This
1205 // represents the values contained in a single buffer. For each value in
1206 // 'values' an entry is created in value_to_node which indicates the
1207 // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)1208 void AddValueList(
1209 absl::Span<const HloValue* const> values,
1210 absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
1211 ValueNode* tail = nullptr;
1212 ValueNode* head = nullptr;
1213 for (const HloValue* value : values) {
1214 auto new_node = new ValueNode(value);
1215 (*value_to_node)[value] = new_node;
1216
1217 // Copy the HLO values's uses into the ValueNode for the value. These
1218 // uses in ValueNode are updated as copies are removed.
1219 new_node->uses.reserve(value->GetUses().size());
1220 for (const HloUse& use : value->GetUses()) {
1221 new_node->uses.push_back(&use);
1222 }
1223
1224 // Connect the new node into the linked list.
1225 if (tail == nullptr) {
1226 head = new_node;
1227 } else {
1228 tail->next = new_node;
1229 new_node->prev = tail;
1230 }
1231 tail = new_node;
1232 }
1233
1234 // The linked list is circular so connect the head and tail.
1235 tail->next = head;
1236 head->prev = tail;
1237 value_lists_.insert(head);
1238 }
1239
1240 // This method also fills in copy_map_ which indicates which nodes
1241 // in the value lists corresponding to the source and destination values of
1242 // kCopy instructions. value_to_node should map each HloValue to its
1243 // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)1244 void CreateCopyMap(
1245 const HloModule& module,
1246 const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
1247 for (HloComputation* computation : module.MakeNonfusionComputations()) {
1248 for (HloInstruction* instruction : computation->instructions()) {
1249 // Add copies with unambiguous source values to the map. Copies with
1250 // ambiguous sources are not removable.
1251 if (instruction->opcode() == HloOpcode::kCopy) {
1252 const HloValueSet& src_value_set =
1253 dataflow_.GetValueSet(instruction->operand(0));
1254 if (src_value_set.values().size() == 1) {
1255 CopyNodes& copy_node = copy_map_[instruction];
1256 copy_node.dest =
1257 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
1258 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
1259 }
1260 }
1261 }
1262 }
1263 }
1264
~CopyRemover()1265 ~CopyRemover() {
1266 for (const ValueNode* head : value_lists_) {
1267 const ValueNode* p = head;
1268 do {
1269 const ValueNode* tmp = p->next;
1270 delete p;
1271 p = tmp;
1272 } while (p != head);
1273 }
1274 }
1275
1276 // Verify invariants within the linked lists.
Verify() const1277 Status Verify() const {
1278 for (const ValueNode* head : value_lists_) {
1279 const ValueNode* p = head;
1280 do {
1281 // Verify links between elements are consistent.
1282 TF_RET_CHECK(p->prev->next == p);
1283 TF_RET_CHECK(p->next->prev == p);
1284
1285 const HloInstruction* def = p->value->defining_instruction();
1286 if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
1287 TF_RET_CHECK(copy_map_.at(def).dest == p);
1288 }
1289 for (const HloUse* use : p->uses) {
1290 if (use->instruction->opcode() == HloOpcode::kCopy &&
1291 ContainsKey(copy_map_, use->instruction)) {
1292 TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
1293 }
1294 }
1295
1296 p = p->next;
1297 } while (p != head);
1298 }
1299 return OkStatus();
1300 }
1301
1302 // Compute the set of instructions where values are alive and organize these
1303 // instructions by separating them into their respective computations.
ComputeLiveRangeRegions(const ValueNode * head)1304 LiveRangeRegions ComputeLiveRangeRegions(const ValueNode* head) {
1305 LiveRangeRegions live_range;
1306
1307 auto VisitValueNode = [&](const ValueNode* node) {
1308 HloInstruction* def_op = node->value->instruction();
1309 HloComputation* def_parent = def_op->parent();
1310 live_range[def_parent][def_op].is_definition = true;
1311 for (const auto& use : node->uses) {
1312 auto* use_op = use->instruction;
1313 HloComputation* use_parent = use_op->parent();
1314 live_range[use_parent][use_op].value_definition = def_op;
1315 }
1316 };
1317 ForEachValueInRange(head, VisitValueNode);
1318 return live_range;
1319 }
1320
1321 // Try to elide the given copy. Elision of a copy is possible only if no
1322 // live range interference is introduced by the copy's elimination. If
1323 // elision is possible, then the internal state (value lists) are updated,
1324 // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy,int64_t * region_analysis_limit)1325 bool TryElideCopy(const HloInstruction* copy,
1326 int64_t* region_analysis_limit) {
1327 VLOG(2) << "Trying to remove " << copy->name();
1328 CHECK_NE(region_analysis_limit, nullptr);
1329
1330 if (!ContainsKey(copy_map_, copy)) {
1331 VLOG(2) << copy->name() << " is not removable";
1332 return false;
1333 }
1334 if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
1335 VLOG(2) << copy->name() << " is not removable (shape mismatch)";
1336 return false;
1337 }
1338 const CopyNodes& copy_node = copy_map_.at(copy);
1339 DCHECK(copy_node.src != nullptr);
1340 DCHECK(copy_node.dest != nullptr);
1341
1342 int64_t live_range_size1 = 0, live_range_size2 = 0;
1343 ForEachValueInRange(copy_node.src, [&](const ValueNode* node) {
1344 live_range_size1 += 1 + node->uses.size();
1345 });
1346 ForEachValueInRange(copy_node.dest, [&](const ValueNode* node) {
1347 live_range_size2 += 1 + node->uses.size();
1348 });
1349 // Use the more accurate region-based live range interference analysis if
1350 // the live range size is within a given limit (or if no limit is given).
1351 // Also don't use the new analysis for copies of broadcasts as these copies
1352 // are cheap and are later removed by replicating the broadcasts.
1353 bool use_region_analysis =
1354 copy->operand(0)->opcode() != HloOpcode::kBroadcast &&
1355 (*region_analysis_limit < 0 ||
1356 live_range_size1 * live_range_size2 <= *region_analysis_limit);
1357 *region_analysis_limit = 0;
1358 VLOG(3) << copy->name() << " copies value "
1359 << copy_node.src->value->ToShortString();
1360 VLOG(3) << "Source buffer values: " << ValueListToString(copy_node.src);
1361 VLOG(3) << "Dest buffer values: " << ValueListToString(copy_node.dest);
1362 // Checks whether the live range at src is before that defined by dest.
1363 auto CheckLiveRangeBefore = [&](ValueNode* src, ValueNode* dest) {
1364 for (ValueNode* next_dest = dest; next_dest != nullptr;
1365 next_dest = Next(*next_dest)) {
1366 for (ValueNode* prev_src = src; prev_src != nullptr;
1367 prev_src = Prev(*prev_src)) {
1368 if (!LiveRangeBefore(*prev_src, *next_dest)) {
1369 VLOG(2) << "Live range of " << prev_src->value->ToShortString()
1370 << " is not before " << next_dest->value->ToShortString();
1371 return false;
1372 }
1373 }
1374 }
1375 return true;
1376 };
1377 auto CheckLiveRangeInterference = [&](ValueNode* src, ValueNode* dest,
1378 const CombineLiveRangeOption option) {
1379 CHECK_NE(src, nullptr);
1380 CHECK_NE(dest, nullptr);
1381 if (!use_region_analysis) {
1382 VLOG(2) << "Configured to not use region-based analysis.\n";
1383 return true;
1384 }
1385 *region_analysis_limit += live_range_size1 * live_range_size2;
1386 if (ValuesInterfere(src, dest, option)) {
1387 VLOG(2) << "Region-based interference is true. \n";
1388 return true;
1389 }
1390 VLOG(2) << "Region-based interference is false. \n";
1391 return false;
1392 };
1393
1394 // A kCopy instruction copies an HLO value from a source buffer and
1395 // defines an HLO value in a destination buffer. Most generally, the
1396 // source and destination buffers may each hold more than one value at
1397 // different points in the computation so we define the following:
1398 //
1399 // Values in source buffer: {s_0, ..., s_n}
1400 // Values in destination buffer: {d_0, ..., d_m}
1401 //
1402 // A kCopy instruction between these buffers copies a value s_x in the
1403 // source buffer and defines a value d_y in the destination buffer. The
1404 // elision of a copy merges the source and destination buffers together,
1405 // so the list of values for the source and destination buffers are
1406 // merged.
1407 //
1408 // We handle two different cases for copy elision:
1409 //
1410 // (1) the kCopy defines the first value in the destination buffer (d_0).
1411 //
1412 // (2) the kCopy copies the last value in the source buffer (s_n).
1413 //
1414 // For the remaining case where the kCopy copies a not-last value from the
1415 // source buffer to a not-first value of the destination buffer, the kCopy
1416 // instruction cannot be removed. This case is generated, for example, if
1417 // the kCopy copies a while body parameter of the loop state at one tuple
1418 // index to a different tuple index in the while body root. Removal of the
1419 // copy necessarily results in live range interference of values in the
1420 // loop state at the two different tuple indices.
1421 //
1422 // We can only perform copy elision if the resulting merged values have
1423 // totally ordered live ranges; otherwise the merged buffer would have
1424 // live range interference.
1425 if (copy_node.src->next == copy_node.dest) {
1426 // In the process of eliding copies, its possible for a copy to have the
1427 // same source and destination buffer. In this case, the copy can be
1428 // safely removed.
1429 VLOG(2) << copy->name() << " source and destination buffers are same.";
1430 } else if (IsHead(*copy_node.dest)) {
1431 // The copy copies an arbitrary value in the source buffer (call it s_x)
1432 // and defines d_0, the first value in the destination buffer. After
1433 // merging, the values in the combined buffer must be strictly ordered
1434 // as follows** to elide the copy:
1435 //
1436 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
1437 //
1438 // Removing the copy eliminates d_0, and uses of d_0 become uses of
1439 // s_x. In the above ordering, the live range of d_m will be ordered
1440 // before the live range of s_{x+1} and the definition and all uses of
1441 // s_x will be ordered before the definition of d_1. To make sure the
1442 // copy elision is safe, the following code checks that this ordering is
1443 // valid --- in particular we check it is safe to order d_m ahead of all
1444 // the liverages at and after x_{x+1}, and it is safe to order all uses
1445 // of s_x before the definition of d_1, by checking the live range
1446 // constraints for each pair --- we cannot skip the later checks because
1447 // the live range ordering is not guranteed to be transitive --- while it
1448 // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
1449 // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
1450 // because the exclusiveness relation of non-overlapping computations is
1451 // not transitive.
1452 //
1453 // ** Technically it might be possible to have a non-interfering
1454 // non-trivial interleaving of the values of the source and
1455 // destination buffers in the resulting order. This can be potentially
1456 // supported in the ValuesInterfere function, which performs
1457 // interference analysis at a more global scope than the alternative
1458 // LiveRangeBefore analysis which requires strict ordering of all live
1459 // ranges. Currently, however, this is not yet supported, as
1460 // we simply check for the case where *all* values of the destination
1461 // buffer (d_1 through d_m) are spliced into the point where the copy
1462 // used to be.
1463 VLOG(2) << copy->name() << " defines the first value in its buffer";
1464 bool live_range_before =
1465 // Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
1466 CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest)) &&
1467 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
1468 CheckLiveRangeBefore(copy_node.dest->prev, Next(*copy_node.src));
1469 VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1470 if (!live_range_before &&
1471 CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1472 kMergeFirstDestInSource)) {
1473 return false;
1474 }
1475 VLOG(2) << "Splice dest after source.";
1476 // Splice in destination buffer values list right after 'src'.
1477 SpliceAfter(copy_node.dest, copy_node.src);
1478 } else if (IsTail(*copy_node.src)) {
1479 // The copy copies the last value in the source buffer, s_n, and defines
1480 // an arbitrary value in the destination buffer, d_y. After
1481 // merging, the values in the combined buffer must be strictly ordered
1482 // as follows** to elide the copy:
1483 //
1484 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
1485 //
1486 // Removing the copy eliminates d_y, and uses of d_y become uses of
1487 // s_n. To enforce the above order, the live range of d_{y-1} must be
1488 // before the live range of s_0, and the live range of s_n must be
1489 // before the live range of d_{y+1}.
1490 //
1491 // ** See comment above in the code handling Case (1).
1492 VLOG(2) << copy->name() << " copies the last value ("
1493 << copy_node.src->value->ToShortString() << ") in its buffer";
1494 bool live_range_before =
1495 // Live range of d_0, ..., d_{y-1} must be before s_0;
1496 CheckLiveRangeBefore(Prev(*copy_node.dest), copy_node.src->next) &&
1497 // Live range of 'last_src' must be before next_dest d_{y+1}.
1498 CheckLiveRangeBefore(copy_node.src, Next(*copy_node.dest));
1499 VLOG(2) << "LiveRangeBefore result: " << live_range_before << "\n";
1500 if (!live_range_before &&
1501 CheckLiveRangeInterference(copy_node.src, copy_node.dest,
1502 kMergeLastSourceInDest)) {
1503 VLOG(2) << "Region-based analysis concludes interference.\n";
1504 return false;
1505 }
1506 VLOG(2) << "Splice src after prev of dest.";
1507 // Splice source buffer values list right after 'prev_dest'.
1508 SpliceAfter(copy_node.src->next, Prev(*copy_node.dest));
1509 } else {
1510 VLOG(2) << copy->name()
1511 << " copies value in middle of source buffer to value in middle "
1512 "of destination buffer";
1513 return false;
1514 }
1515
1516 RemoveCopyValue(copy_node.dest);
1517
1518 XLA_VLOG_LINES(4, ToString());
1519 TF_DCHECK_OK(Verify());
1520
1521 return true;
1522 }
1523
1524 // Delete the given ValueNode associated with a elided kCopy
1525 // instruction. This should be called after splicing the value lists of the
1526 // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)1527 void RemoveCopyValue(ValueNode* copy_value_node) {
1528 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
1529 HloOpcode::kCopy);
1530 ValueNode* operand_node = copy_value_node->prev;
1531 CHECK(operand_node != copy_value_node);
1532
1533 VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
1534 << " => " << copy_value_node->value->ToShortString();
1535
1536 // Splice out the copy value node.
1537 operand_node->next = copy_value_node->next;
1538 copy_value_node->next->prev = operand_node;
1539
1540 // Patch up uses. Remove use of copy from operand_node uses.
1541 auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
1542 const HloUse* use) {
1543 return use->instruction == copy_value_node->value->defining_instruction();
1544 });
1545 CHECK(it != operand_node->uses.end());
1546 operand_node->uses.erase(it);
1547
1548 // If the elided copy has any uses which are themselves kCopy instructions
1549 // then patch up the copy info to reflect the that this kCopy instruction
1550 // has a different operand (the operand of the elided copy).
1551 for (const HloUse* copy_use : copy_value_node->uses) {
1552 operand_node->uses.push_back(copy_use);
1553 if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
1554 ContainsKey(copy_map_, copy_use->instruction)) {
1555 copy_map_.at(copy_use->instruction).src = operand_node;
1556 }
1557 }
1558
1559 // Delete the copy info and the value node.
1560 copy_map_.erase(copy_value_node->value->defining_instruction());
1561 delete copy_value_node;
1562 }
1563
1564 // Returns true if the live range of given value 'a' is before the live
1565 // range of 'b'.
1566 //
1567 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
1568 // updated as copies are removed. Also here because the result is used
1569 // to directly drive copy elision, use_is_always_before_def_in_same_instr is
1570 // set to false.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)1571 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
1572 if (a.uses.empty()) {
1573 VLOG(2) << "Empty uses for " << *a.value;
1574 return ordering_->IsDefinedBefore(*a.value, *b.value);
1575 }
1576 VLOG(3) << "Checking live ranges before :" << ValueListToString(&a)
1577 << " vs " << ValueListToString(&b) << "\n";
1578 // If any of the positions of the "a" value is a root of the same
1579 // computation as "b", "a"'s live range cannot be before "b"'s. This catches
1580 // the cases where the root may not be the last instruction in the
1581 // computation.
1582 if (a.value->IsRootOf(b.value->defining_instruction()->parent())) {
1583 VLOG(3) << "Value is root of the same computation";
1584 return false;
1585 }
1586 return ordering_->UsesBeforeValueDefinition(
1587 a.uses, *b.value, dataflow_,
1588 /* use_is_always_before_def_in_same_instr=*/false);
1589 }
1590
1591 // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const1592 bool IsTail(const ValueNode& node) const {
1593 return ContainsKey(value_lists_, node.next);
1594 }
1595
1596 // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const1597 bool IsHead(const ValueNode& node) const {
1598 return ContainsKey(value_lists_, &node);
1599 }
1600
1601 // Returns the next node in the list after 'node'. If 'node' is the
1602 // tail, then nullptr is returned.
Next(const ValueNode & node) const1603 ValueNode* Next(const ValueNode& node) const {
1604 if (IsTail(node)) {
1605 return nullptr;
1606 } else {
1607 return node.next;
1608 }
1609 }
1610
1611 // Returns the previous node in the list before 'node'. If 'node'
1612 // is the head, then nullptr is returned.
Prev(const ValueNode & node) const1613 ValueNode* Prev(const ValueNode& node) const {
1614 if (IsHead(node)) {
1615 return nullptr;
1616 } else {
1617 return node.prev;
1618 }
1619 }
1620
1621 // Splices the entire linked list with 'head' as its head right after the
1622 // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)1623 void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
1624 DCHECK(IsHead(*head));
1625 value_lists_.erase(head);
1626
1627 ValueNode* tail = head->prev;
1628 tail->next = insert_after->next;
1629 insert_after->next->prev = tail;
1630
1631 insert_after->next = head;
1632 head->prev = insert_after;
1633 }
1634
1635 enum CombineLiveRangeOption {
1636 kMergeFirstDestInSource = 1,
1637 kMergeLastSourceInDest = 2
1638 };
1639 // This function analyzes all the HloValues that have been grouped together
1640 // with src to share a single buffer, and all the HloValues that have been
1641 // similarly grouped together with dest, to determine whether these two groups
1642 // can be combined, by removing the operation in dest, which makes a copy of
1643 // the buffer in src.
ValuesInterfere(const ValueNode * src,const ValueNode * dest,CombineLiveRangeOption merge_location)1644 bool ValuesInterfere(const ValueNode* src, const ValueNode* dest,
1645 CombineLiveRangeOption merge_location) {
1646 // Get the entire range of values sharing the buffers in src and dest.
1647 auto src_live_range = ComputeLiveRangeRegions(src);
1648 auto dest_live_range = ComputeLiveRangeRegions(dest);
1649 ComputeRelativeLocation relative_location_analysis(ordering_);
1650 auto rel1 =
1651 relative_location_analysis.Compute(src_live_range, dest_live_range);
1652 VLOG(3) << "Location of dest in relation to src:" << rel1.ToString()
1653 << " with interception set to " << rel1.InterceptDefUse() << "\n";
1654 auto rel2 =
1655 relative_location_analysis.Compute(dest_live_range, src_live_range);
1656 VLOG(3) << "Location of src in relation to dest:" << rel2.ToString()
1657 << " with interception set to " << rel1.InterceptDefUse() << "\n";
1658 // If src and dest are interleaved with each other, they interfere.
1659 if (rel1.RuntimeOrderOverlap() && rel2.RuntimeOrderOverlap()) {
1660 VLOG(3) << "Both relations are overlap.\n";
1661 return true;
1662 }
1663 // If src and dest belong to the same group of computations and do not
1664 // overlap, they do not interfere.
1665 if (rel1.RuntimeOrderOverlap() || rel2.RuntimeOrderOverlap()) {
1666 VLOG(3) << "At least one relation is overlap.\n";
1667 if (rel1.RuntimeOrderOverlap()) {
1668 VLOG(3) << "rel1 is overlap, with interception = "
1669 << rel1.InterceptDefUse() << "\n";
1670 if (rel1.InterceptDefUse() ||
1671 (merge_location != kMergeFirstDestInSource &&
1672 rel2.InterceptDefUse())) {
1673 return true;
1674 }
1675 } else {
1676 VLOG(3) << "rel2 is overlap, with interception = "
1677 << rel2.InterceptDefUse() << "\n";
1678 // Here src is at the end of a nested computation inside dest.
1679 if (rel2.InterceptDefUse() ||
1680 (merge_location != kMergeLastSourceInDest &&
1681 rel1.InterceptDefUse())) {
1682 return true;
1683 }
1684 }
1685 }
1686 if (relative_location_analysis.AddControlDependenceForUnorderedOps()) {
1687 return false;
1688 } else {
1689 // Disallow removing of copy if control deps cannot be added.
1690 return true;
1691 }
1692 }
1693
1694 // return the sequence of HloValues starting from element.
1695 // If element is not head, traverse from element to tail, then wrap around.
1696 // The ordering is important for live range region analysis.
ForEachValueInRange(const ValueNode * element,std::function<void (const ValueNode *)> visitor)1697 void ForEachValueInRange(const ValueNode* element,
1698 std::function<void(const ValueNode*)> visitor) {
1699 const ValueNode* head = element;
1700 std::vector<const ValueNode*> values;
1701 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
1702 visitor(p);
1703 }
1704 while (!IsHead(*head)) {
1705 head = Prev(*head);
1706 }
1707 for (const ValueNode* p = head; p != element; p = Next(*p)) {
1708 visitor(p);
1709 }
1710 }
1711
ValueListToString(const ValueNode * element)1712 std::string ValueListToString(const ValueNode* element) {
1713 std::string result = "{";
1714 auto VisitValueNode = [&](const ValueNode* node) {
1715 if (result == "{") {
1716 result = node->value->ToShortString();
1717 } else {
1718 StrAppend(&result, ", ");
1719 StrAppend(&result, node->value->ToShortString());
1720 }
1721 };
1722 VisitValueNode(element);
1723 StrAppend(&result, "}");
1724 return result;
1725 }
1726
ToString() const1727 std::string ToString() const {
1728 std::string out = absl::StrCat("CopyRemover:\n");
1729 StrAppend(&out, " Def-use chains in each buffer:\n");
1730 for (const ValueNode* head : value_lists_) {
1731 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
1732 ":\n");
1733 const ValueNode* p = head;
1734 do {
1735 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
1736 absl::StrJoin(p->uses, "; ",
1737 [](std::string* s, const HloUse* use) {
1738 StrAppend(s, use->ToString());
1739 }),
1740 "\n");
1741
1742 p = p->next;
1743 } while (p != head);
1744 }
1745 StrAppend(&out, " Potentially removable copies:\n");
1746 for (const auto& pair : copy_map_) {
1747 const HloInstruction* copy = pair.first;
1748 const CopyNodes& copy_info = pair.second;
1749
1750 StrAppend(&out, " ", copy->name(), " : ",
1751 copy_info.src->value->ToShortString(), " => ",
1752 copy_info.dest->value->ToShortString(), "\n");
1753 }
1754 return out;
1755 }
1756
1757 private:
1758 const HloDataflowAnalysis& dataflow_;
1759 HloOrdering* ordering_;
1760
1761 // The heads of all the value lists. Each value list represents the HLO
1762 // values contained in a particular HLO buffer. The values in the list are
1763 // in dependency order.
1764 absl::flat_hash_set<const ValueNode*> value_lists_;
1765
1766 // Copy removal requires fast access to the value list elements
1767 // corresponding to the source and destination values of the kCopy
1768 // instruction. This data structure holds pointers to these elements for
1769 // each kCopy instruction in the graph.
1770 struct CopyNodes {
1771 // The source and destinations values of the kCopy instruction.
1772 ValueNode* src = nullptr;
1773 ValueNode* dest = nullptr;
1774 };
1775 absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
1776 };
1777
1778 } // namespace
1779
1780 // We add copies for all non-phi indices of the true and false computation
1781 // roots, in order to resolve interference. We later rely on
1782 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)1783 Status CopyInsertion::AddCopiesForConditional(
1784 const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
1785 VLOG(2) << "Adding copies for kConditional instruction "
1786 << conditional->name();
1787 ShapeTree<bool> indices_to_copy(conditional->shape());
1788 TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
1789 if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
1790 conditional, &indices_to_copy)) {
1791 VLOG(2) << "No copies necessary for kWhile instruction "
1792 << conditional->name();
1793 return OkStatus();
1794 }
1795
1796 for (HloComputation* computation : conditional->branch_computations()) {
1797 HloInstruction* root = computation->root_instruction();
1798 std::vector<HloInstruction*> users = root->users();
1799 TF_ASSIGN_OR_RETURN(
1800 HloInstruction * deep_copy,
1801 computation->DeepCopyInstruction(root, &indices_to_copy));
1802 for (HloInstruction* user : users) {
1803 TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1804 }
1805 computation->set_root_instruction(deep_copy);
1806 }
1807 return OkStatus();
1808 }
1809
1810 // Add kCopy instructions to the given module to guarantee there is no
1811 // live-range interference. Generally interference can only occur around kWhile
1812 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1813 Status CopyInsertion::AddCopiesToResolveInterference(
1814 HloModule* module,
1815 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1816 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1817 HloAliasAnalysis::Run(module, can_share_buffer_));
1818 for (HloComputation* computation :
1819 module->MakeNonfusionComputations(execution_threads)) {
1820 for (HloInstruction* instruction :
1821 computation->MakeInstructionPostOrder()) {
1822 if (instruction->opcode() == HloOpcode::kWhile) {
1823 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1824 } else if (instruction->opcode() == HloOpcode::kConditional) {
1825 TF_RETURN_IF_ERROR(
1826 AddCopiesForConditional(*alias_analysis, instruction));
1827 } else {
1828 // When an operand is a tuple, we avoid copying the operand multiple
1829 // times by recording and checking the operand number of operands that
1830 // have been copied.
1831 absl::flat_hash_set<int64_t> copied_operands;
1832 for (const auto& operand_and_output_index :
1833 HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1834 const HloOperandIndex& operand_index = operand_and_output_index.first;
1835 if (copied_operands.contains(operand_index.operand_number)) {
1836 continue;
1837 }
1838 copied_operands.insert(operand_index.operand_number);
1839 TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1840 *alias_analysis, instruction, operand_index.operand_number));
1841 }
1842 }
1843 }
1844 }
1845
1846 TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1847 return OkStatus();
1848 }
1849
AddSpecialCaseCopies(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1850 Status CopyInsertion::AddSpecialCaseCopies(
1851 HloModule* module,
1852 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1853 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1854 return AddSpecialCaseCopies(*call_graph, execution_threads, module);
1855 }
1856
AddSpecialCaseCopies(const CallGraph & call_graph,const absl::flat_hash_set<absl::string_view> & execution_threads,HloModule * module)1857 Status CopyInsertion::AddSpecialCaseCopies(
1858 const CallGraph& call_graph,
1859 const absl::flat_hash_set<absl::string_view>& execution_threads,
1860 HloModule* module) {
1861 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1862 HloAliasAnalysis::Run(module, can_share_buffer_));
1863
1864 // Identify which shape indices of which instructions need to be copied. Store
1865 // these results in 'instructions_to_copy'.
1866 HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1867 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1868 const ShapeIndex& index) {
1869 auto it = instructions_to_copy.find(instruction);
1870 if (it == instructions_to_copy.end()) {
1871 auto it_added = instructions_to_copy.emplace(
1872 std::piecewise_construct, std::forward_as_tuple(instruction),
1873 std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1874 it = it_added.first;
1875 }
1876 *it->second.mutable_element(index) = true;
1877 };
1878
1879 // Iterate through values of all constants and entry parameters. These values
1880 // are special because they are held in read-only buffers. If any of these
1881 // values share a buffer with other values (for example, the init value of a
1882 // while is a constant) then copy the value at its definition and replace all
1883 // its uses with the copy.
1884 // Also, locate all input-output aliasing violations for operations that
1885 // cannot be done in place. Such aliasing can be created when some copies are
1886 // removed too aggressively by CopyRemoval.
1887 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1888 HloBuffer& buffer = alias_analysis->GetBufferContainingValue(*value);
1889 if (buffer.values().size() > 1 && ValueIsReadOnly(*value)) {
1890 VLOG(2) << "Value " << value->ToShortString()
1891 << " is read only, but its buffer contains more than one value. "
1892 "Copying.";
1893 add_index_to_copy(value->defining_instruction(), value->defining_index());
1894 }
1895 for (const HloValue* value2 : buffer.values()) {
1896 // Find HloValues that share a position and use, which would cause the use
1897 // and operand to share buffers. Check if this is allowed and insert a
1898 // copy if it isn't.
1899 if (value2 == value) {
1900 continue;
1901 }
1902 HloPosition position = value2->defining_position();
1903 for (const HloUse& use : value->GetUses()) {
1904 if (use.instruction == position.instruction) {
1905 VLOG(3) << "Same instruction: " << position.instruction->ToString();
1906 if (!alias_analysis->dataflow_analysis()
1907 .CanShareOperandBufferWithUser(
1908 /*operand=*/use.instruction->mutable_operand(
1909 use.operand_number),
1910 /*operand_index=*/use.operand_index,
1911 /*user=*/position.instruction,
1912 /*user_index=*/position.index)) {
1913 VLOG(2) << "Adding back copy: "
1914 << use.instruction->operand(use.operand_number)->ToString()
1915 << "@" << use.operand_index.ToString()
1916 << " instr: " << position.instruction->ToString() << "@"
1917 << position.index;
1918 add_index_to_copy(
1919 use.instruction->mutable_operand(use.operand_number),
1920 use.operand_index);
1921 }
1922 }
1923 }
1924 }
1925 }
1926
1927 // Identify copies which must be added at root instructions
1928 for (HloComputation* computation : module->computations(execution_threads)) {
1929 const CallGraphNode& node = call_graph.GetNode(computation);
1930 if (node.context() == CallContext::kEmbedded) {
1931 continue;
1932 }
1933 TF_RET_CHECK(node.context() == CallContext::kControlFlow);
1934
1935 SpecialCaseCopyPolicy policy =
1936 GetSpecialCaseCopyPolicy(node, module, computation);
1937 HloInstruction* root = computation->root_instruction();
1938
1939 // Mark nondistinct/ambiguous indices.
1940 absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1941 ShapeUtil::ForEachSubshape(
1942 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1943 std::vector<const HloBuffer*> buffers_at_index =
1944 alias_analysis->ComputeBuffersAt(root, index);
1945 bool buffer_seen_before = false;
1946 for (const HloBuffer* buffer : buffers_at_index) {
1947 buffer_seen_before |= !seen.emplace(buffer, index).second;
1948 }
1949
1950 if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1951 computation == module->entry_computation() &&
1952 module->input_output_alias_config().OutputHasAlias(index) &&
1953 buffers_at_index.size() == 1) {
1954 std::optional<HloInputOutputAliasConfig::Alias> alias =
1955 module->input_output_alias_config().GetAliasedParameter(index);
1956 CHECK(alias) << "Alias does not exist";
1957 const ShapeIndex& other_index = seen[buffers_at_index[0]];
1958 VLOG(2) << "Output indices " << index.ToString() << " and "
1959 << other_index.ToString() << " are both aliased to "
1960 << alias->parameter_number << " copying " << other_index;
1961 add_index_to_copy(root, other_index);
1962 return;
1963 }
1964
1965 if (buffers_at_index.size() > 1 ||
1966 (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1967 VLOG(2) << "Index " << index << " of computation "
1968 << computation->name() << " (" << root->name()
1969 << ") has ambiguous or non-distinct buffer. Copying.";
1970 add_index_to_copy(root, index);
1971 }
1972 });
1973
1974 for (const auto& pair :
1975 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1976 const ShapeIndex& index = pair.first;
1977 const HloValueSet& value_set = pair.second;
1978 for (const HloValue* value : value_set.values()) {
1979 if (ShouldCopyRootValue(*value, policy)) {
1980 VLOG(2) << "Root of (" << root->name() << ") of computation("
1981 << computation->name()
1982 << ") has constant or parameter value at index " << index
1983 << ". Copying.";
1984 add_index_to_copy(root, index);
1985 }
1986 }
1987 }
1988 }
1989
1990 // Add copy instructions indicated in 'instructions_to_copy' to the module.
1991 for (const auto& pair : instructions_to_copy) {
1992 HloInstruction* instruction = pair.first;
1993 const ShapeTree<bool>& indices_to_copy = pair.second;
1994
1995 ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1996 std::vector<HloInstruction*> users = instruction->users();
1997 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1998 instruction->parent()->DeepCopyInstruction(
1999 instruction, &indices_to_copy, &copies_added));
2000 for (HloInstruction* user : users) {
2001 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
2002 }
2003 if (instruction == instruction->parent()->root_instruction()) {
2004 instruction->parent()->set_root_instruction(deep_copy);
2005 }
2006 }
2007 return OkStatus();
2008 }
2009
GetNumExistingCopies(const HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2010 static int64_t GetNumExistingCopies(
2011 const HloModule* module,
2012 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2013 int64_t num_existing_copies = 0;
2014 for (HloComputation* computation : module->computations(execution_threads)) {
2015 for (HloInstruction* instruction : computation->instructions()) {
2016 if (instruction->opcode() == HloOpcode::kCopy) {
2017 ++num_existing_copies;
2018 }
2019 }
2020 }
2021 return num_existing_copies;
2022 }
2023
RemoveUnnecessaryCopies(HloOrdering * ordering,HloModule * module,bool check_live_range_ordering,const absl::flat_hash_set<absl::string_view> & execution_threads)2024 Status CopyInsertion::RemoveUnnecessaryCopies(
2025 HloOrdering* ordering, HloModule* module, bool check_live_range_ordering,
2026 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2027 XLA_VLOG_LINES(4, module->ToString());
2028 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
2029 HloAliasAnalysis::Run(module, can_share_buffer_));
2030 CopyRemover copy_remover(*module, *alias_analysis, ordering,
2031 check_live_range_ordering);
2032 if (VLOG_IS_ON(3)) {
2033 LOG(INFO) << "Removing unnecessary copies in " << module->name();
2034 LOG(INFO) << "Buffer values, in dependency order: ";
2035 for (const HloBuffer& buffer : alias_analysis->buffers()) {
2036 LOG(INFO) << " HloBuffer " << buffer.id();
2037 }
2038 }
2039
2040 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2041
2042 int64_t num_existing_copies = GetNumExistingCopies(module, execution_threads);
2043 bool changed = true;
2044 int64_t num_iterations = -1;
2045 VLOG(6) << "Copy Insertion analyzing module with instructino count = "
2046 << module->instruction_count() << "\n";
2047 BoundNonLinearCompilerAnalysis allowance(module, name(), 10);
2048 while (changed) {
2049 CHECK_LE(++num_iterations, num_existing_copies);
2050 changed = false;
2051 VLOG(2) << "Running fixpoint iteration " << num_iterations
2052 << " of copy elision";
2053 for (HloComputation* computation :
2054 module->computations(execution_threads)) {
2055 VLOG(2) << "computation:" << computation->name() << "\n";
2056 for (HloInstruction* instruction : computation->instructions()) {
2057 VLOG(2) << instruction->ToString() << "\n";
2058 // The region_analysis_cost_now is always set to
2059 // use_region_based_live_range_analysis_ if it is < 0, in which case the
2060 // analysis is always performed.
2061 int64_t region_analysis_cost_now =
2062 (use_region_based_live_range_analysis_ == 0)
2063 ? 0
2064 : std::min(allowance.analysis_allowance(),
2065 use_region_based_live_range_analysis_);
2066 if (instruction->opcode() == HloOpcode::kCopy) {
2067 if (copy_remover.TryElideCopy(instruction,
2068 ®ion_analysis_cost_now)) {
2069 changed = true;
2070 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
2071 TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(
2072 instruction->mutable_operand(0)));
2073 VLOG(6) << "succeeded in eliminating copy.\n";
2074 }
2075 if (allowance.ContinueAnalysis() && region_analysis_cost_now > 0) {
2076 VLOG(6) << "Copy Insertion analyzing module cost: "
2077 << region_analysis_cost_now << "\n";
2078 VLOG(6) << "instruction:" << instruction->ToString() << "\n";
2079 allowance.DeductCost(region_analysis_cost_now);
2080 VLOG(6) << "allowance:" << allowance.analysis_allowance() << "\n";
2081 }
2082 }
2083 }
2084 }
2085 }
2086 return OkStatus();
2087 }
2088
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2089 StatusOr<bool> CopyInsertion::Run(
2090 HloModule* module,
2091 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2092 // Copy insertion is performed in three steps:
2093 //
2094 // (1) Add copies conservatively to guarantee that there is no live-range
2095 // interference. This is done simplistically and usually results in more
2096 // copies than is strictly necessary.
2097 //
2098 // (2) Using a more fine-grained analysis, remove as many copies that were
2099 // added in (1) as possible while ensuring no live-range interference.
2100 //
2101 // (3) Add copies to resolve issues not related to live range interference
2102 // such as parameters and constants live out of the entry computation.
2103 //
2104 // We add copies then remove them (step (1) then (2)) rather than simply
2105 // adding only the copies that are necessary because, in general, it is
2106 // difficult to figure out the minimal set of copies to add once there is
2107 // interference. On the other hand, it is easy to determine if removing a copy
2108 // will introduce interference.
2109 //
2110 // The final copy insertion in (3) is done separately to simplify the
2111 // implementation of copy removal in (2) which is the most complicated part of
2112 // the pass. As is, copy removal only has to reason about live range
2113 // interference. If all copies were added in step (1) then copy removal would
2114 // also have to reason about things like constants and parameters live out of
2115 // the computation.
2116 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
2117 if (!call_graph->IsFlattened()) {
2118 return FailedPrecondition(
2119 "Call graph must be flattened before copy insertion.");
2120 }
2121
2122 int64_t num_copies_before = GetNumExistingCopies(module, execution_threads);
2123
2124 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module, execution_threads));
2125
2126 // Simplify the tuple structures introduced by the deep copies. This should be
2127 // done before removing copies (RemoveUnnecessaryCopies) because tuple
2128 // simplification changes dependencies in the graph which changes live range
2129 // interference in the graph. Also run DCE to remove the dead Tuple/GTE
2130 // instructions introduced by tuple simplification.
2131 TupleSimplifier tuple_simplifier;
2132 HloDCE dce;
2133 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status());
2134 TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
2135 DumpHloModuleDuringPassIfEnabled(
2136 name(), "after adding copies to resolve interference", *module);
2137
2138 DependencyHloOrdering ordering(module);
2139 TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(&ordering, module,
2140 /*check_live_range_ordering=*/true,
2141 execution_threads));
2142 DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
2143 *module);
2144 TF_RETURN_IF_ERROR(
2145 AddSpecialCaseCopies(*call_graph, execution_threads, module));
2146 DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
2147 *module);
2148
2149 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module, execution_threads).status());
2150 TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
2151
2152 VLOG(1) << "Num copies before copy-insertion: " << num_copies_before;
2153 VLOG(1) << "Num copies after copy-insertion: "
2154 << GetNumExistingCopies(module, execution_threads);
2155
2156 return true;
2157 }
2158 } // namespace xla
2159