xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/loop_schedule_linearizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/loop_schedule_linearizer.h"
17 
18 #include "tensorflow/compiler/xla/service/dump.h"
19 #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h"
20 
21 namespace xla {
22 
23 namespace {
24 
25 // Calculate ordering for HLO, for fast online checking of whether adding
26 // additional dependencies would create cycles.
27 struct ComputationInstructionOrdering {
ComputationInstructionOrderingxla::__anon45eac4f70111::ComputationInstructionOrdering28   explicit ComputationInstructionOrdering(const HloComputation& computation) {
29     for (const HloInstruction* instr : computation.instructions()) {
30       for (const HloInstruction* control_pred : instr->control_predecessors()) {
31         CHECK(this->InsertEdge(*control_pred, *instr))
32             << "Graph already contained a cycle";
33       }
34 
35       for (int op_id = 0; op_id < instr->operand_count(); op_id++) {
36         const HloInstruction* op = instr->operand(op_id);
37         CHECK(this->InsertEdge(*op, *instr))
38             << "Graph already contained a cycle";
39       }
40     }
41   }
42 
NodeIdForInstructionxla::__anon45eac4f70111::ComputationInstructionOrdering43   int32_t NodeIdForInstruction(const HloInstruction& instr) {
44     int32_t instruction_id = instr.unique_id();
45     auto it = node_id_to_graph_id.find(instruction_id);
46 
47     if (it != node_id_to_graph_id.end()) {
48       return it->second;
49     }
50     int32_t node_id = graph_cycles.NewNode();
51     node_id_to_graph_id[instruction_id] = node_id;
52     return node_id;
53   }
54 
55   // Returns `false` if adding an edge would have introduced a cycle. Does not
56   // add an edge in that case. Returns `true` otherwise.
InsertEdgexla::__anon45eac4f70111::ComputationInstructionOrdering57   bool InsertEdge(const HloInstruction& source, const HloInstruction& dest) {
58     int32_t source_id = NodeIdForInstruction(source);
59     int32_t dest_id = NodeIdForInstruction(dest);
60     return graph_cycles.InsertEdge(source_id, dest_id);
61   }
62 
63   absl::flat_hash_map<int32_t, int32_t> node_id_to_graph_id;
64 
65   tensorflow::GraphCycles graph_cycles;
66 };
67 
68 }  // namespace
69 
AddControlEdgesForLoopWrites(HloInstruction * xla_while,HloAliasAnalysis & alias_analysis)70 static StatusOr<bool> AddControlEdgesForLoopWrites(
71     HloInstruction* xla_while, HloAliasAnalysis& alias_analysis) {
72   HloDataflowAnalysis& dataflow = alias_analysis.dataflow_analysis();
73   HloComputation* body = xla_while->while_body();
74   HloInstruction* root = body->root_instruction();
75   HloInstruction* input = body->parameter_instruction(0);
76 
77   bool changed = false;
78 
79   // Compute dependency ordering ourselves. The reason we don't reuse other
80   // computations is because it is hard to extract the underlying graph from
81   // those abstractions.
82   ComputationInstructionOrdering ordering(*body);
83   ShapeTree<bool> indices_to_copy(xla_while->shape());
84 
85   for (auto& p : indices_to_copy) {
86     const ShapeIndex& index = p.first;
87 
88     if (index.empty()) {
89       continue;
90     }
91 
92     if (dataflow.GetValueSet(root, index).values().size() > 1 ||
93         dataflow.GetValueSet(input, index).values().size() > 1) {
94       VLOG(2) << "Index " << index.ToString() << " is associated with multiple "
95               << "values, not attempting to introduce stricter dependencies";
96     } else {
97       HloValue& value_at_root = dataflow.GetUniqueValueAt(root, index);
98       HloValue& value_at_input = dataflow.GetUniqueValueAt(input, index);
99 
100       if (value_at_root.shape().IsTuple()) {
101         // TODO(cheshire): For simplicity we currently do not handle nested
102         // tuples, as we haven't seen them in the examples we care about.
103         continue;
104       }
105 
106       // TODO(cheshire): This is too conservative and does not take aliasing
107       // into account.
108       HloInstruction* write = value_at_root.defining_instruction();
109 
110       for (const HloUse& use : value_at_input.GetUses()) {
111         HloInstruction* read = use.instruction;
112 
113         if (read != write &&
114             value_at_root != value_at_input
115 
116             // TODO(cheshire): Parents sometimes differ in case of e.g. nested
117             // loops, where the value is read/written into in the inner loop.
118             // For now we skip this case for simplicity (as the inner loop
119             // performance is more important in any case)
120             && read->parent() == write->parent()) {
121           VLOG(2) << "Inside " << body->name() << ", index "
122                   << index.ToString();
123           if (!ordering.InsertEdge(*read, *write)) {
124             VLOG(2) << "Not adding a control dependency from "
125                     << read->ToShortString() << " to " << write->ToShortString()
126                     << " as it would introduce a cycle";
127             continue;
128           }
129 
130           changed |= absl::c_linear_search(read->control_successors(), write);
131 
132           // Unless we want a copy, read should happen before write.
133           TF_RETURN_IF_ERROR(read->AddControlDependencyTo(write));
134           VLOG(2) << "Adding dependency: " << read->ToShortString()
135                   << " before " << write->ToShortString();
136         }
137       }
138     }
139   }
140   return changed;
141 }
142 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)143 StatusOr<bool> LoopScheduleLinearizer::Run(
144     HloModule* module,
145     const absl::flat_hash_set<absl::string_view>& execution_threads) {
146   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
147                       HloAliasAnalysis::Run(module, can_share_buffer_));
148 
149   bool changed = false;
150   for (HloComputation* computation :
151        module->MakeNonfusionComputations(execution_threads)) {
152     for (HloInstruction* instruction :
153          computation->MakeInstructionPostOrder()) {
154       if (instruction->opcode() == HloOpcode::kWhile) {
155         StatusOr<bool> updated_loop =
156             AddControlEdgesForLoopWrites(instruction, *alias_analysis);
157         TF_RETURN_IF_ERROR(updated_loop.status());
158         changed |= *updated_loop;
159       }
160     }
161   }
162   DumpHloModuleDuringPassIfEnabled(
163       name(), "after inserting control edges inside while loop bodies",
164       *module);
165 
166   return changed;
167 }
168 
169 }  // end namespace xla
170