1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/core/platform/errors.h"
23
24 namespace xla {
25 namespace gpu {
26
27 // Returns true if all instructions are supported operations.
AreInstructionSupported(HloComputation * comp)28 static bool AreInstructionSupported(HloComputation* comp) {
29 for (HloInstruction* instr : comp->instructions()) {
30 bool supported =
31 HloInstruction::IsOpElementwise(instr->opcode()) ||
32 instr->opcode() == HloOpcode::kConstant ||
33 // We only support reduction when they are at the root or when
34 // in a MOF, at the end. This should always be true for now,
35 // but if we implement reduction epilog fusion in the future,
36 // this optimization need to be updated. So disable it just for
37 // future safety.
38 (instr->opcode() == HloOpcode::kReduce &&
39 (comp->root_instruction() == instr ||
40 (instr->users().size() == 1 &&
41 instr->users()[0]->opcode() == HloOpcode::kTuple))) ||
42 instr->opcode() == HloOpcode::kTuple ||
43 instr->opcode() == HloOpcode::kParameter ||
44 (instr->opcode() == HloOpcode::kBitcast &&
45 instr->shape().rank() < instr->operand(0)->shape().rank()) ||
46 (instr->opcode() == HloOpcode::kBroadcast &&
47 (instr->dimensions().empty() || // scalar broadcasting
48 (instr->dimensions().size() == 1 && // row broadcasting
49 instr->dimensions()[0] == (instr->shape().rank() - 1))));
50 if (!supported) {
51 VLOG(2) << "NOT SUPPORTED: " << instr->ToString();
52 return false;
53 }
54
55 // If there is an instruction that change the layout, we do not do
56 // anything.
57 if (HloInstruction::IsOpElementwise(instr->opcode()) &&
58 !absl::c_all_of(instr->operands(), [&](HloInstruction* input) {
59 return ShapeUtil::EqualIgnoringElementType(input->shape(),
60 instr->shape());
61 })) {
62 VLOG(2) << "NOT SUPPORTED (instruction change the layout): "
63 << instr->ToString();
64 return false;
65 }
66 }
67 return true;
68 }
69
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)70 StatusOr<bool> FusionBitcastLift::Run(
71 HloModule* module,
72 const absl::flat_hash_set<absl::string_view>& execution_threads) {
73 XLA_VLOG_LINES(2, "FusionBitcastLift::Run(), before:\n" + module->ToString());
74 bool changed = false;
75 for (HloComputation* comp :
76 module->MakeNonfusionComputations(execution_threads)) {
77 // Copy the instruction list as we modify the HloComputation.
78 std::vector<HloInstruction*> comp_instruction(comp->instructions().begin(),
79 comp->instructions().end());
80 for (HloInstruction* instr : comp_instruction) {
81 // 1) Is this a fusion that we want to modify.
82 if (auto* fusion = DynCast<HloFusionInstruction>(instr)) {
83 // 1.1) We only support kInput fusion and some operations.
84 if (fusion->fusion_kind() != HloInstruction::FusionKind::kInput ||
85 !AreInstructionSupported(
86 fusion->fused_instructions_computation())) {
87 continue;
88 }
89 // 1.2) Check if there is a bitcast that we lift. Currently
90 // we do not lift(merge) bitcast above(with) broadcast.
91 if (!std::any_of(
92 fusion->fused_instructions().begin(),
93 fusion->fused_instructions().end(), [](HloInstruction* inner) {
94 return inner->opcode() == HloOpcode::kBitcast &&
95 inner->operand(0)->opcode() != HloOpcode::kBroadcast;
96 })) {
97 continue;
98 }
99
100 // 1.3) Check that all the bitcast have the same shape pattern.
101 // Multiple bitcast pattern isn't supported/tested.
102 HloInstruction* bitcast = nullptr;
103 bool same_shape_pattern = true;
104 for (HloInstruction* fused_instr : fusion->fused_instructions()) {
105 if (fused_instr->opcode() == HloOpcode::kBitcast &&
106 fused_instr->shape().rank() <
107 fused_instr->operand(0)->shape().rank()) {
108 if (bitcast != nullptr &&
109 (!ShapeUtil::Equal(fused_instr->shape(), bitcast->shape()) ||
110 !ShapeUtil::Equal(bitcast->operand(0)->shape(),
111 fused_instr->operand(0)->shape()))) {
112 same_shape_pattern = false;
113 break;
114 }
115 bitcast = fused_instr;
116 }
117 }
118 if (bitcast == nullptr || !same_shape_pattern) {
119 VLOG(2) << "NOT SUPPORTED: Multiple rank-reducing bitcast pattern.";
120 continue;
121 }
122
123 // 2) Now that we have found a fusion that we want to modify,
124 // create the new fusion. We do so by:
125 // a) Cloning the old fusion.
126 // b) Recursively walk the graph from the root and lift the
127 // bitcast up across one instruction at a time.
128 std::unique_ptr<HloInstruction> cloned_fusion =
129 fusion->Clone("bitcast");
130 // The following stack and set always contain the same data.
131 // The stack is used for the order of traversal.
132 // The set is used only as an optimization to search in the set.
133 std::vector<HloInstruction*> stack(
134 {cloned_fusion->fused_expression_root()});
135 absl::flat_hash_set<HloInstruction*> set(
136 {cloned_fusion->fused_expression_root()});
137 bool clone_changed = false;
138 while (!stack.empty()) {
139 HloInstruction* i = stack.back();
140 stack.pop_back();
141 set.erase(i);
142 if (i->opcode() == HloOpcode::kTuple) {
143 stack.insert(stack.end(), i->operands().begin(),
144 i->operands().end());
145 set.insert(i->operands().begin(), i->operands().end());
146 VLOG(3) << "kTuple: " << i->ToString();
147 } else if (i->opcode() == HloOpcode::kParameter &&
148 absl::c_all_of(i->users(), [](HloInstruction* u) {
149 return u->opcode() == HloOpcode::kBitcast;
150 })) {
151 VLOG(3) << "kParameter: " << i->ToString();
152 // Replace the parameter inside the fusion.
153 Shape new_shape = i->users()[0]->shape();
154 int64_t parameter_number = i->parameter_number();
155 std::string name = i->name();
156 auto n = HloInstruction::CreateParameter(parameter_number,
157 new_shape, name);
158 HloInstruction* new_parameter =
159 i->parent()->ReplaceParameter(parameter_number, std::move(n));
160 // Remove the old inner bitcast.
161 auto old_users = new_parameter->users();
162 for (HloInstruction* param_user : old_users) {
163 DCHECK(param_user->opcode() == HloOpcode::kBitcast)
164 << "Expected a bitcast";
165 TF_RETURN_IF_ERROR(
166 param_user->parent()->ReplaceInstructionWithDifferentShape(
167 param_user, new_parameter));
168 }
169 // Replace the corresponding fusion operands with a new bitcast.
170 HloInstruction* old_outer_parameter =
171 cloned_fusion->mutable_operand(parameter_number);
172 HloInstruction* new_op =
173 old_outer_parameter->parent()->AddInstruction(
174 HloInstruction::CreateBitcast(new_shape,
175 old_outer_parameter));
176 TF_RETURN_IF_ERROR(cloned_fusion->ReplaceOperandWithDifferentShape(
177 parameter_number, new_op));
178 clone_changed = true;
179 changed = true;
180 } else if (i->opcode() == HloOpcode::kBroadcast) {
181 VLOG(3) << "kBroadcast: " << i->ToString();
182 // For now, do nothing. Later we can merge the broadcast
183 // and the bitcast, but this doesn't bring benefit in my
184 // current case.
185 if (set.insert(i->mutable_operand(0)).second) {
186 stack.push_back(i->mutable_operand(0));
187 }
188 } else if (i->opcode() == HloOpcode::kConstant &&
189 !i->users().empty() &&
190 absl::c_all_of(i->users(), [](HloInstruction* u) {
191 return u->opcode() == HloOpcode::kBitcast;
192 })) {
193 // Handling this case is optional for correctness, but
194 // handling it clean up the graph.
195 VLOG(3) << "kConstant: " << i->ToString();
196 Shape new_shape = i->users()[0]->shape();
197 TF_RETURN_IF_ERROR(i->parent()->ReplaceWithNewInstruction(
198 i, i->CloneWithNewOperands(new_shape, {})));
199 clone_changed = true;
200 changed = true;
201 } else if (!i->users().empty() &&
202 // If 0 operands, we can't lift the bitcast. It
203 // must be handled manually as kConstant and
204 // kParameter.
205 !i->operands().empty() &&
206 absl::c_all_of(i->users(), [](HloInstruction* u) {
207 return u->opcode() == HloOpcode::kBitcast;
208 })) {
209 VLOG(3) << "All User bitcast: " << i->ToString();
210 // All users are bitcast, so lift the bitcast.
211 Shape new_shape = i->users()[0]->shape();
212 std::vector<HloInstruction*> new_operands;
213 for (HloInstruction* opnd : i->operands()) {
214 Shape dtyped_new_shape = ShapeUtil::ChangeElementType(
215 new_shape, opnd->shape().element_type());
216 HloInstruction* new_opnd = opnd->parent()->AddInstruction(
217 HloInstruction::CreateBitcast(dtyped_new_shape, opnd));
218 new_operands.push_back(new_opnd);
219 // Handle the operand right before the inserted bitcast now.
220 if (set.insert(opnd).second) {
221 stack.push_back(opnd);
222 }
223 }
224 Shape dtyped_new_shape = ShapeUtil::ChangeElementType(
225 new_shape, i->shape().element_type());
226 HloInstruction* cloned_i = i->parent()->AddInstruction(
227 i->CloneWithNewOperands(dtyped_new_shape, new_operands));
228 // Replace the old bitcasts with the new instruction to
229 // remove it.
230 // Copy the vector as it will be modified while we iterate on it.
231 const std::vector<HloInstruction*> users = i->users();
232 for (HloInstruction* user : users) {
233 TF_RETURN_IF_ERROR(
234 i->parent()->ReplaceInstructionWithDifferentShape(user,
235 cloned_i));
236 }
237 clone_changed = true;
238 changed = true;
239 } else {
240 VLOG(3) << "Else: " << i->ToString();
241 for (auto* opnd : i->operands()) {
242 if (set.insert(opnd).second) {
243 stack.push_back(opnd);
244 }
245 }
246 }
247 } // while
248 DCHECK(clone_changed) << "We should have changed the fusion!";
249 auto opts = HloVerifierOpts{}.MakeLayoutSensitive();
250 auto shape_verifier = std::make_unique<ShapeVerifier>(opts);
251 if (clone_changed) {
252 Status status =
253 cloned_fusion->fused_instructions_computation()->Accept(
254 shape_verifier.get());
255 if (status.ok()) {
256 // 3) Replace the old fusion with the new fusion.
257 TF_RETURN_IF_ERROR(fusion->parent()->ReplaceWithNewInstruction(
258 fusion, std::move(cloned_fusion)));
259 } else {
260 VLOG(2) << "Not lifting due to shape problem: "
261 << cloned_fusion->ToString();
262 }
263 }
264 } // if fusion
265 }
266 }
267 XLA_VLOG_LINES(2, "FusionBitcastLift::Run(), after:\n" + module->ToString());
268 return changed;
269 }
270
271 } // namespace gpu
272 } // namespace xla
273