1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
17
18 #include "tensorflow/compiler/xla/service/dump.h"
19 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
20
21 namespace xla {
22
23 namespace {
24
25 // Calculate ordering for HLO, for fast online checking of whether adding
26 // additional dependencies would create cycles.
27 struct ComputationInstructionOrdering {
ComputationInstructionOrderingxla::__anon45eac4f70111::ComputationInstructionOrdering28 explicit ComputationInstructionOrdering(const HloComputation& computation) {
29 for (const HloInstruction* instr : computation.instructions()) {
30 for (const HloInstruction* control_pred : instr->control_predecessors()) {
31 CHECK(this->InsertEdge(*control_pred, *instr))
32 << "Graph already contained a cycle";
33 }
34
35 for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
36 const HloInstruction* op = instr->operand(op_id);
37 CHECK(this->InsertEdge(*op, *instr))
38 << "Graph already contained a cycle";
39 }
40 }
41 }
42
NodeIdForInstructionxla::__anon45eac4f70111::ComputationInstructionOrdering43 int32_t NodeIdForInstruction(const HloInstruction& instr) {
44 int32_t instruction_id = instr.unique_id();
45 auto it = node_id_to_graph_id.find(instruction_id);
46
47 if (it != node_id_to_graph_id.end()) {
48 return it->second;
49 }
50 int32_t node_id = graph_cycles.NewNode();
51 node_id_to_graph_id[instruction_id] = node_id;
52 return node_id;
53 }
54
55 // Returns `false` if adding an edge would have introduced a cycle. Does not
56 // add an edge in that case. Returns `true` otherwise.
InsertEdgexla::__anon45eac4f70111::ComputationInstructionOrdering57 bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
58 int32_t source_id = NodeIdForInstruction(source);
59 int32_t dest_id = NodeIdForInstruction(dest);
60 return graph_cycles.InsertEdge(source_id, dest_id);
61 }
62
63 absl::flat_hash_map<int32_t, int32_t> node_id_to_graph_id;
64
65 tensorflow::GraphCycles graph_cycles;
66 };
67
68 } // namespace
69
AddControlEdgesForLoopWrites(HloInstruction * xla_while,HloAliasAnalysis & alias_analysis)70 static StatusOr<bool> AddControlEdgesForLoopWrites(
71 HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
72 HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
73 HloComputation* body = xla_while->while_body();
74 HloInstruction* root = body->root_instruction();
75 HloInstruction* input = body->parameter_instruction(0);
76
77 bool changed = false;
78
79 // Compute dependency ordering ourselves. The reason we don't reuse other
80 // computations is because it is hard to extract the underlying graph from
81 // those abstractions.
82 ComputationInstructionOrdering ordering(*body);
83 ShapeTree<bool> indices_to_copy(xla_while->shape());
84
85 for (auto& p : indices_to_copy) {
86 const ShapeIndex& index = p.first;
87
88 if (index.empty()) {
89 continue;
90 }
91
92 if (dataflow.GetValueSet(root, index).values().size() > 1 ||
93 dataflow.GetValueSet(input, index).values().size() > 1) {
94 VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
95 << "values, not attempting to introduce stricter dependencies";
96 } else {
97 HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
98 HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
99
100 if (value_at_root.shape().IsTuple()) {
101 // TODO(cheshire): For simplicity we currently do not handle nested
102 // tuples, as we haven't seen them in the examples we care about.
103 continue;
104 }
105
106 // TODO(cheshire): This is too conservative and does not take aliasing
107 // into account.
108 HloInstruction* write = value_at_root.defining_instruction();
109
110 for (const HloUse& use : value_at_input.GetUses()) {
111 HloInstruction* read = use.instruction;
112
113 if (read != write &&
114 value_at_root != value_at_input
115
116 // TODO(cheshire): Parents sometimes differ in case of e.g. nested
117 // loops, where the value is read/written into in the inner loop.
118 // For now we skip this case for simplicity (as the inner loop
119 // performance is more important in any case)
120 && read->parent() == write->parent()) {
121 VLOG(2) << "Inside " << body->name() << ", index "
122 << index.ToString();
123 if (!ordering.InsertEdge(*read, *write)) {
124 VLOG(2) << "Not adding a control dependency from "
125 << read->ToShortString() << " to " << write->ToShortString()
126 << " as it would introduce a cycle";
127 continue;
128 }
129
130 changed |= absl::c_linear_search(read->control_successors(), write);
131
132 // Unless we want a copy, read should happen before write.
133 TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
134 VLOG(2) << "Adding dependency: " << read->ToShortString()
135 << " before " << write->ToShortString();
136 }
137 }
138 }
139 }
140 return changed;
141 }
142
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)143 StatusOr<bool> LoopScheduleLinearizer::Run(
144 HloModule* module,
145 const absl::flat_hash_set<absl::string_view>& execution_threads) {
146 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
147 HloAliasAnalysis::Run(module, can_share_buffer_));
148
149 bool changed = false;
150 for (HloComputation* computation :
151 module->MakeNonfusionComputations(execution_threads)) {
152 for (HloInstruction* instruction :
153 computation->MakeInstructionPostOrder()) {
154 if (instruction->opcode() == HloOpcode::kWhile) {
155 StatusOr<bool> updated_loop =
156 AddControlEdgesForLoopWrites(instruction, *alias_analysis);
157 TF_RETURN_IF_ERROR(updated_loop.status());
158 changed |= *updated_loop;
159 }
160 }
161 }
162 DumpHloModuleDuringPassIfEnabled(
163 name(), "after inserting control edges inside while loop bodies",
164 *module);
165
166 return changed;
167 }
168
169 } // end namespace xla
170