xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_dce.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dce.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/status.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 
40 namespace {
41 
42 // Checks if the instruction is a removable while given
43 // remove_cross_partition_collective_ops
IsRemovableWhile(HloInstruction * instruction,bool remove_cross_partition_collective_ops)44 bool IsRemovableWhile(HloInstruction* instruction,
45                       bool remove_cross_partition_collective_ops) {
46   if (instruction->opcode() != HloOpcode::kWhile) {
47     return false;
48   }
49   for (HloComputation* computation : instruction->called_computations()) {
50     for (HloInstruction* called_instr : computation->instructions()) {
51       auto maybe_collective_op =
52           DynCast<HloCollectiveInstruction>(called_instr);
53       if (called_instr->HasSideEffect() &&
54           (!remove_cross_partition_collective_ops || !maybe_collective_op ||
55            maybe_collective_op->constrain_layout())) {
56         return false;
57       }
58     }
59   }
60   return true;
61 }
62 }  // namespace
63 
RunOnComputation(HloComputation * computation,bool remove_cross_partition_collective_ops)64 /*static*/ StatusOr<bool> HloDCE::RunOnComputation(
65     HloComputation* computation, bool remove_cross_partition_collective_ops) {
66   bool changed = false;
67   VLOG(3) << "Before dce:";
68   XLA_VLOG_LINES(3, computation->ToString());
69   // Remove any dead roots and their dead transitive operands. Collect them
70   // into a separate list first to avoid problems with iterating through the
71   // computation's instruction while simultaneously removing instructions.
72   std::vector<HloInstruction*> dead_roots;
73   for (auto* instruction : computation->instructions()) {
74     auto maybe_collective_op = DynCast<HloCollectiveInstruction>(instruction);
75     if (instruction->IsDead() && computation->IsSafelyRemovable(instruction) &&
76         (!instruction->HasSideEffect() ||
77          (remove_cross_partition_collective_ops && maybe_collective_op &&
78           !maybe_collective_op->constrain_layout()) ||
79          IsRemovableWhile(instruction,
80                           remove_cross_partition_collective_ops))) {
81       dead_roots.push_back(instruction);
82     }
83   }
84 
85   for (HloInstruction* dead_root : dead_roots) {
86     VLOG(1) << "Removing dead root " << dead_root->ToString()
87             << " and its unused operands";
88     TF_RETURN_IF_ERROR(
89         computation->RemoveInstructionAndUnusedOperands(dead_root));
90     changed = true;
91   }
92   if (changed) {
93     VLOG(3) << "After dce:";
94     XLA_VLOG_LINES(3, computation->ToString());
95   }
96   return changed;
97 }
98 
RecursivelyRemoveDeadComputation(HloModule * module,HloComputation * computation,absl::flat_hash_map<HloComputation *,int> & live_call_counts)99 Status HloDCE::RecursivelyRemoveDeadComputation(
100     HloModule* module, HloComputation* computation,
101     absl::flat_hash_map<HloComputation*, int>& live_call_counts) {
102   // First loops all the sub-instructions/sub-computations.
103   for (HloInstruction* instruction : computation->instructions()) {
104     for (HloComputation* subcomp : instruction->called_computations()) {
105       auto iter = live_call_counts.find(subcomp);
106       if (iter == live_call_counts.end()) {
107         return tensorflow::errors::Internal(
108             "called computation not found in live_call_counts table during "
109             "HloDCE");
110       }
111 
112       // Decrements the live call count and sees if there are no more live
113       // calls to this computation.
114       int live_call_count = --iter->second;
115       CHECK_GE(live_call_count, 0);
116       if (live_call_count == 0) {
117         TF_RETURN_IF_ERROR(RecursivelyRemoveDeadComputation(module, subcomp,
118                                                             live_call_counts));
119       }
120     }
121   }
122   VLOG(1) << "Removing dead computation " << computation->name();
123   // After looping called subcomputations, now safe to delete the computation.
124   return module->RemoveEmbeddedComputation(computation);
125 }
126 
RecursivelyRemoveDeadComputations(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)127 StatusOr<bool> HloDCE::RecursivelyRemoveDeadComputations(
128     HloModule* module,
129     const absl::flat_hash_set<absl::string_view>& execution_threads) {
130   // Tracks whether any dead code is eliminated by this pass.
131   bool module_contains_dead_code = false;
132 
133   // First, collect the computations that are
134   // referenced by some remaining instruction. We need to record this as a
135   // refcount map rather than a set since we cannot guarantee that control
136   // flow flattening has been done and there may be multiple call sites.
137   absl::flat_hash_map<HloComputation*, int> live_computation_call_count;
138   if (HloComputation* entry_computation = module->entry_computation()) {
139     ++live_computation_call_count[entry_computation];
140   }
141   for (auto* computation :
142        module->MakeComputationPostOrder(execution_threads)) {
143     for (auto* instruction : computation->instructions()) {
144       for (auto* subcomp : instruction->called_computations()) {
145         ++live_computation_call_count[subcomp];
146       }
147     }
148   }
149 
150   // Find dead computations.
151   absl::flat_hash_set<HloComputation*> dead_computations;
152   for (auto* computation :
153        module->MakeComputationPostOrder(execution_threads)) {
154     // Finds all "top-level" dead computations not called by any instructions.
155     // contains(comp) = true and live_computation_call_count[comp] = 0 also
156     // implies that the computation is dead, but is nested in other dead
157     // computations. These inner computations are ignored here since they will
158     // be removed recursing through other computations.
159     if (!live_computation_call_count.contains(computation)) {
160       TF_RETURN_IF_ERROR(RecursivelyRemoveDeadComputation(
161           module, computation, live_computation_call_count));
162       module_contains_dead_code = true;
163     }
164   }
165   return module_contains_dead_code;
166 }
167 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)168 StatusOr<bool> HloDCE::Run(
169     HloModule* module,
170     const absl::flat_hash_set<absl::string_view>& execution_threads) {
171   bool changed = false;
172 
173   VLOG(2) << "Before dce:";
174   XLA_VLOG_LINES(2, module->ToString());
175 
176   // Run DCE on each computation.
177   for (auto* computation :
178        module->MakeComputationPostOrder(execution_threads)) {
179     TF_ASSIGN_OR_RETURN(
180         bool changed_for_computation,
181         RunOnComputation(computation, remove_cross_partition_collective_ops_));
182     changed |= changed_for_computation;
183   }
184 
185   // Now DCE HloComputations.  Keep doing passes through the module until no
186   // more computations can be eliminated. The function removes all
187   // subcomputations that can be proved to have no remaining live callers.
188   TF_ASSIGN_OR_RETURN(
189       bool module_contains_dead_code,
190       RecursivelyRemoveDeadComputations(module, execution_threads));
191   changed |= module_contains_dead_code;
192 
193   VLOG(2) << "After dce:";
194   XLA_VLOG_LINES(2, module->ToString());
195 
196   return changed;
197 }
198 
199 }  // namespace xla
200