xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_dce.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module_dce.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_dce.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
27 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
28 #include "tensorflow/compiler/xla/status.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
37 
38 namespace {
39 
RunWhileDCE(HloModule * module,HloLivenessAnalysis * liveness,const absl::flat_hash_set<absl::string_view> & execution_threads)40 StatusOr<bool> RunWhileDCE(
41     HloModule* module, HloLivenessAnalysis* liveness,
42     const absl::flat_hash_set<absl::string_view>& execution_threads) {
43   bool changed = false;
44   std::vector<HloComputation*> while_body_comps_to_dce;
45   for (auto* computation : module->computations(execution_threads)) {
46     for (auto* instruction : computation->instructions()) {
47       if (instruction->opcode() != HloOpcode::kWhile) {
48         continue;
49       }
50 
51       const auto* xla_while = instruction;
52       auto* while_body_comp = xla_while->while_body();
53       auto* while_body_param = while_body_comp->parameter_instruction(0);
54       auto* while_body_root = while_body_comp->root_instruction();
55 
56       if (!xla_while->shape().IsTuple() ||
57           while_body_root->opcode() != HloOpcode::kTuple) {
58         // Only run DCE on tuple-shaped while loops where body root is Tuple,
59         // with no I/O instructions.
60         VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
61         continue;
62       }
63 
64       // Remove dead tuple elements.
65       const int64_t tuple_element_count =
66           ShapeUtil::TupleElementCount(xla_while->shape());
67       bool modified_while_body_comp = false;
68       for (int64_t i = 0; i < tuple_element_count; ++i) {
69         if (liveness->IsLive(xla_while, {i})) {
70           continue;
71         }
72         VLOG(1) << "WhileDCE Dead while tuple element."
73                 << " while: " << xla_while->name() << " tuple_index: " << i;
74         // Transform while.body computation to make tuple element at
75         // 'shape_index' as simple pass-through parameter (which candidate
76         // be removed later by simplification pass).
77         HloInstruction* pass_thru_gte = while_body_comp->AddInstruction(
78             HloInstruction::CreateGetTupleElement(
79                 while_body_param->shape().tuple_shapes(i), while_body_param,
80                 i));
81         // Replace while.body.root Tuple operand at 'tuple_index' with
82         // 'pass_thru_gte', making prior operand a dead root (to be cleaned
83         // up with a subsequent DCE pass).
84         TF_RETURN_IF_ERROR(
85             while_body_root->ReplaceOperandWith(i, pass_thru_gte));
86         changed = true;
87         modified_while_body_comp = true;
88       }
89       if (modified_while_body_comp) {
90         while_body_comps_to_dce.push_back(while_body_comp);
91       }
92     }
93   }
94 
95   // Run DCE on while body computations that we modified.
96   for (auto* while_body_comp : while_body_comps_to_dce) {
97     TF_ASSIGN_OR_RETURN(bool changed_for_computation,
98                         HloDCE::RunOnComputation(
99                             while_body_comp,
100                             /*remove_cross_partition_collective_ops=*/false));
101     changed |= changed_for_computation;
102   }
103   return changed;
104 }
105 
106 }  // namespace
107 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)108 StatusOr<bool> HloModuleDCE::Run(
109     HloModule* module,
110     const absl::flat_hash_set<absl::string_view>& execution_threads) {
111   VLOG(2) << "Before HloModuleDCE:";
112   XLA_VLOG_LINES(3, module->ToString());
113 
114   std::unique_ptr<HloLivenessAnalysis> liveness;
115   TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module));
116 
117   // Sweep through while instructions, transforming dead while tuple element
118   // computations to pass through tuple values (creating dead roots in while
119   // body computation in the process).
120   TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed,
121                       RunWhileDCE(module, liveness.get(), execution_threads));
122 
123   // Run the while loop simplifier to remove dead tuple elements.
124   WhileLoopSimplifier while_loop_simplifier;
125   TF_ASSIGN_OR_RETURN(bool while_loop_simplifier_changed,
126                       while_loop_simplifier.Run(module, execution_threads));
127 
128   TupleSimplifier tuple_simplifier;
129   TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
130                       tuple_simplifier.Run(module, execution_threads));
131 
132   // Run HloDCE to clean up any dead code created during HloModuleDCE.
133   HloDCE hlo_dce;
134   TF_ASSIGN_OR_RETURN(bool hlo_dce_changed,
135                       hlo_dce.Run(module, execution_threads));
136 
137   VLOG(2) << "After HloModuleDCE:";
138   XLA_VLOG_LINES(3, module->ToString());
139 
140   return hlo_module_dce_changed | hlo_dce_changed | tuple_simplifier_changed |
141          while_loop_simplifier_changed;
142 }
143 
144 }  // namespace xla
145