xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/fusion_bitcast_lift.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/core/platform/errors.h"
23 
24 namespace xla {
25 namespace gpu {
26 
27 // Returns true if all instructions are supported operations.
AreInstructionSupported(HloComputation * comp)28 static bool AreInstructionSupported(HloComputation* comp) {
29   for (HloInstruction* instr : comp->instructions()) {
30     bool supported =
31         HloInstruction::IsOpElementwise(instr->opcode()) ||
32         instr->opcode() == HloOpcode::kConstant ||
33         // We only support reduction when they are at the root or when
34         // in a MOF, at the end. This should always be true for now,
35         // but if we implement reduction epilog fusion in the future,
36         // this optimization need to be updated. So disable it just for
37         // future safety.
38         (instr->opcode() == HloOpcode::kReduce &&
39          (comp->root_instruction() == instr ||
40           (instr->users().size() == 1 &&
41            instr->users()[0]->opcode() == HloOpcode::kTuple))) ||
42         instr->opcode() == HloOpcode::kTuple ||
43         instr->opcode() == HloOpcode::kParameter ||
44         (instr->opcode() == HloOpcode::kBitcast &&
45          instr->shape().rank() < instr->operand(0)->shape().rank()) ||
46         (instr->opcode() == HloOpcode::kBroadcast &&
47          (instr->dimensions().empty() ||       // scalar broadcasting
48           (instr->dimensions().size() == 1 &&  // row broadcasting
49            instr->dimensions()[0] == (instr->shape().rank() - 1))));
50     if (!supported) {
51       VLOG(2) << "NOT SUPPORTED: " << instr->ToString();
52       return false;
53     }
54 
55     // If there is an instruction that change the layout, we do not do
56     // anything.
57     if (HloInstruction::IsOpElementwise(instr->opcode()) &&
58         !absl::c_all_of(instr->operands(), [&](HloInstruction* input) {
59           return ShapeUtil::EqualIgnoringElementType(input->shape(),
60                                                      instr->shape());
61         })) {
62       VLOG(2) << "NOT SUPPORTED (instruction change the layout): "
63               << instr->ToString();
64       return false;
65     }
66   }
67   return true;
68 }
69 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)70 StatusOr<bool> FusionBitcastLift::Run(
71     HloModule* module,
72     const absl::flat_hash_set<absl::string_view>& execution_threads) {
73   XLA_VLOG_LINES(2, "FusionBitcastLift::Run(), before:\n" + module->ToString());
74   bool changed = false;
75   for (HloComputation* comp :
76        module->MakeNonfusionComputations(execution_threads)) {
77     // Copy the instruction list as we modify the HloComputation.
78     std::vector<HloInstruction*> comp_instruction(comp->instructions().begin(),
79                                                   comp->instructions().end());
80     for (HloInstruction* instr : comp_instruction) {
81       // 1) Is this a fusion that we want to modify.
82       if (auto* fusion = DynCast<HloFusionInstruction>(instr)) {
83         // 1.1) We only support kInput fusion and some operations.
84         if (fusion->fusion_kind() != HloInstruction::FusionKind::kInput ||
85             !AreInstructionSupported(
86                 fusion->fused_instructions_computation())) {
87           continue;
88         }
89         // 1.2) Check if there is a bitcast that we lift. Currently
90         //      we do not lift(merge) bitcast above(with) broadcast.
91         if (!std::any_of(
92                 fusion->fused_instructions().begin(),
93                 fusion->fused_instructions().end(), [](HloInstruction* inner) {
94                   return inner->opcode() == HloOpcode::kBitcast &&
95                          inner->operand(0)->opcode() != HloOpcode::kBroadcast;
96                 })) {
97           continue;
98         }
99 
100         // 1.3) Check that all the bitcast have the same shape pattern.
101         //      Multiple bitcast pattern isn't supported/tested.
102         HloInstruction* bitcast = nullptr;
103         bool same_shape_pattern = true;
104         for (HloInstruction* fused_instr : fusion->fused_instructions()) {
105           if (fused_instr->opcode() == HloOpcode::kBitcast &&
106               fused_instr->shape().rank() <
107                   fused_instr->operand(0)->shape().rank()) {
108             if (bitcast != nullptr &&
109                 (!ShapeUtil::Equal(fused_instr->shape(), bitcast->shape()) ||
110                  !ShapeUtil::Equal(bitcast->operand(0)->shape(),
111                                    fused_instr->operand(0)->shape()))) {
112               same_shape_pattern = false;
113               break;
114             }
115             bitcast = fused_instr;
116           }
117         }
118         if (bitcast == nullptr || !same_shape_pattern) {
119           VLOG(2) << "NOT SUPPORTED: Multiple rank-reducing bitcast pattern.";
120           continue;
121         }
122 
123         // 2) Now that we have found a fusion that we want to modify,
124         //    create the new fusion. We do so by:
125         //    a) Cloning the old fusion.
126         //    b) Recursively walk the graph from the root and lift the
127         //       bitcast up across one instruction at a time.
128         std::unique_ptr<HloInstruction> cloned_fusion =
129             fusion->Clone("bitcast");
130         // The following stack and set always contain the same data.
131         // The stack is used for the order of traversal.
132         // The set is used only as an optimization to search in the set.
133         std::vector<HloInstruction*> stack(
134             {cloned_fusion->fused_expression_root()});
135         absl::flat_hash_set<HloInstruction*> set(
136             {cloned_fusion->fused_expression_root()});
137         bool clone_changed = false;
138         while (!stack.empty()) {
139           HloInstruction* i = stack.back();
140           stack.pop_back();
141           set.erase(i);
142           if (i->opcode() == HloOpcode::kTuple) {
143             stack.insert(stack.end(), i->operands().begin(),
144                          i->operands().end());
145             set.insert(i->operands().begin(), i->operands().end());
146             VLOG(3) << "kTuple: " << i->ToString();
147           } else if (i->opcode() == HloOpcode::kParameter &&
148                      absl::c_all_of(i->users(), [](HloInstruction* u) {
149                        return u->opcode() == HloOpcode::kBitcast;
150                      })) {
151             VLOG(3) << "kParameter: " << i->ToString();
152             // Replace the parameter inside the fusion.
153             Shape new_shape = i->users()[0]->shape();
154             int64_t parameter_number = i->parameter_number();
155             std::string name = i->name();
156             auto n = HloInstruction::CreateParameter(parameter_number,
157                                                      new_shape, name);
158             HloInstruction* new_parameter =
159                 i->parent()->ReplaceParameter(parameter_number, std::move(n));
160             // Remove the old inner bitcast.
161             auto old_users = new_parameter->users();
162             for (HloInstruction* param_user : old_users) {
163               DCHECK(param_user->opcode() == HloOpcode::kBitcast)
164                   << "Expected a bitcast";
165               TF_RETURN_IF_ERROR(
166                   param_user->parent()->ReplaceInstructionWithDifferentShape(
167                       param_user, new_parameter));
168             }
169             // Replace the corresponding fusion operands with a new bitcast.
170             HloInstruction* old_outer_parameter =
171                 cloned_fusion->mutable_operand(parameter_number);
172             HloInstruction* new_op =
173                 old_outer_parameter->parent()->AddInstruction(
174                     HloInstruction::CreateBitcast(new_shape,
175                                                   old_outer_parameter));
176             TF_RETURN_IF_ERROR(cloned_fusion->ReplaceOperandWithDifferentShape(
177                 parameter_number, new_op));
178             clone_changed = true;
179             changed = true;
180           } else if (i->opcode() == HloOpcode::kBroadcast) {
181             VLOG(3) << "kBroadcast: " << i->ToString();
182             // For now, do nothing. Later we can merge the broadcast
183             // and the bitcast, but this doesn't bring benefit in my
184             // current case.
185             if (set.insert(i->mutable_operand(0)).second) {
186               stack.push_back(i->mutable_operand(0));
187             }
188           } else if (i->opcode() == HloOpcode::kConstant &&
189                      !i->users().empty() &&
190                      absl::c_all_of(i->users(), [](HloInstruction* u) {
191                        return u->opcode() == HloOpcode::kBitcast;
192                      })) {
193             // Handling this case is optional for correctness, but
194             // handling it clean up the graph.
195             VLOG(3) << "kConstant: " << i->ToString();
196             Shape new_shape = i->users()[0]->shape();
197             TF_RETURN_IF_ERROR(i->parent()->ReplaceWithNewInstruction(
198                 i, i->CloneWithNewOperands(new_shape, {})));
199             clone_changed = true;
200             changed = true;
201           } else if (!i->users().empty() &&
202                      // If 0 operands, we can't lift the bitcast.  It
203                      // must be handled manually as kConstant and
204                      // kParameter.
205                      !i->operands().empty() &&
206                      absl::c_all_of(i->users(), [](HloInstruction* u) {
207                        return u->opcode() == HloOpcode::kBitcast;
208                      })) {
209             VLOG(3) << "All User bitcast: " << i->ToString();
210             // All users are bitcast, so lift the bitcast.
211             Shape new_shape = i->users()[0]->shape();
212             std::vector<HloInstruction*> new_operands;
213             for (HloInstruction* opnd : i->operands()) {
214               Shape dtyped_new_shape = ShapeUtil::ChangeElementType(
215                   new_shape, opnd->shape().element_type());
216               HloInstruction* new_opnd = opnd->parent()->AddInstruction(
217                   HloInstruction::CreateBitcast(dtyped_new_shape, opnd));
218               new_operands.push_back(new_opnd);
219               // Handle the operand right before the inserted bitcast now.
220               if (set.insert(opnd).second) {
221                 stack.push_back(opnd);
222               }
223             }
224             Shape dtyped_new_shape = ShapeUtil::ChangeElementType(
225                 new_shape, i->shape().element_type());
226             HloInstruction* cloned_i = i->parent()->AddInstruction(
227                 i->CloneWithNewOperands(dtyped_new_shape, new_operands));
228             // Replace the old bitcasts with the new instruction to
229             // remove it.
230             // Copy the vector as it will be modified while we iterate on it.
231             const std::vector<HloInstruction*> users = i->users();
232             for (HloInstruction* user : users) {
233               TF_RETURN_IF_ERROR(
234                   i->parent()->ReplaceInstructionWithDifferentShape(user,
235                                                                     cloned_i));
236             }
237             clone_changed = true;
238             changed = true;
239           } else {
240             VLOG(3) << "Else: " << i->ToString();
241             for (auto* opnd : i->operands()) {
242               if (set.insert(opnd).second) {
243                 stack.push_back(opnd);
244               }
245             }
246           }
247         }  // while
248         DCHECK(clone_changed) << "We should have changed the fusion!";
249         auto opts = HloVerifierOpts{}.MakeLayoutSensitive();
250         auto shape_verifier = std::make_unique<ShapeVerifier>(opts);
251         if (clone_changed) {
252           Status status =
253               cloned_fusion->fused_instructions_computation()->Accept(
254                   shape_verifier.get());
255           if (status.ok()) {
256             // 3) Replace the old fusion with the new fusion.
257             TF_RETURN_IF_ERROR(fusion->parent()->ReplaceWithNewInstruction(
258                 fusion, std::move(cloned_fusion)));
259           } else {
260             VLOG(2) << "Not lifting due to shape problem: "
261                     << cloned_fusion->ToString();
262           }
263         }
264       }  // if fusion
265     }
266   }
267   XLA_VLOG_LINES(2, "FusionBitcastLift::Run(), after:\n" + module->ToString());
268   return changed;
269 }
270 
271 }  // namespace gpu
272 }  // namespace xla
273