xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 
32 namespace xla {
33 
34 // DfsHloVisitor with default action based on the HloInstruction being visited.
35 // Users should not use this class directly, but use the type aliases
36 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead.
37 //
38 // Do *not* add an override to this class if the opcode is covered by
39 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to
40 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler
41 // here will break passes which rely on the HandleElementwiseUnary/Binary
42 // handling these opcodes.
43 template <typename HloInstructionPtr>
44 class DfsHloVisitorWithDefaultBase
45     : public DfsHloVisitorBase<HloInstructionPtr> {
46  public:
DfsHloVisitorWithDefaultBase()47   DfsHloVisitorWithDefaultBase() {}
~DfsHloVisitorWithDefaultBase()48   ~DfsHloVisitorWithDefaultBase() override {}
49 
50   // Default action performed on HloInstruction.
51   virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0;
52 
HandleElementwiseUnary(HloInstructionPtr hlo)53   Status HandleElementwiseUnary(HloInstructionPtr hlo) override {
54     return DefaultAction(hlo);
55   }
HandleElementwiseBinary(HloInstructionPtr hlo)56   Status HandleElementwiseBinary(HloInstructionPtr hlo) override {
57     return DefaultAction(hlo);
58   }
59 
HandleBatchNormTraining(HloInstructionPtr hlo)60   Status HandleBatchNormTraining(HloInstructionPtr hlo) override {
61     return DefaultAction(hlo);
62   }
63 
HandleBatchNormInference(HloInstructionPtr hlo)64   Status HandleBatchNormInference(HloInstructionPtr hlo) override {
65     return DefaultAction(hlo);
66   }
67 
HandleBatchNormGrad(HloInstructionPtr hlo)68   Status HandleBatchNormGrad(HloInstructionPtr hlo) override {
69     return DefaultAction(hlo);
70   }
71 
HandleClamp(HloInstructionPtr clamp)72   Status HandleClamp(HloInstructionPtr clamp) override {
73     return DefaultAction(clamp);
74   }
HandleConcatenate(HloInstructionPtr concatenate)75   Status HandleConcatenate(HloInstructionPtr concatenate) override {
76     return DefaultAction(concatenate);
77   }
HandleSelect(HloInstructionPtr select)78   Status HandleSelect(HloInstructionPtr select) override {
79     return DefaultAction(select);
80   }
HandleDot(HloInstructionPtr dot)81   Status HandleDot(HloInstructionPtr dot) override {
82     return DefaultAction(dot);
83   }
HandleConvolution(HloInstructionPtr convolution)84   Status HandleConvolution(HloInstructionPtr convolution) override {
85     return DefaultAction(convolution);
86   }
HandleFft(HloInstructionPtr fft)87   Status HandleFft(HloInstructionPtr fft) override {
88     return DefaultAction(fft);
89   }
HandleTriangularSolve(HloInstructionPtr hlo)90   Status HandleTriangularSolve(HloInstructionPtr hlo) override {
91     return DefaultAction(hlo);
92   }
HandleCholesky(HloInstructionPtr hlo)93   Status HandleCholesky(HloInstructionPtr hlo) override {
94     return DefaultAction(hlo);
95   }
HandleOptimizationBarrier(HloInstructionPtr hlo)96   Status HandleOptimizationBarrier(HloInstructionPtr hlo) override {
97     return DefaultAction(hlo);
98   }
HandleAllGather(HloInstructionPtr crs)99   Status HandleAllGather(HloInstructionPtr crs) override {
100     return DefaultAction(crs);
101   }
HandleAllGatherStart(HloInstructionPtr crs)102   Status HandleAllGatherStart(HloInstructionPtr crs) override {
103     return DefaultAction(crs);
104   }
HandleAllGatherDone(HloInstructionPtr crs)105   Status HandleAllGatherDone(HloInstructionPtr crs) override {
106     return DefaultAction(crs);
107   }
HandleAllReduce(HloInstructionPtr crs)108   Status HandleAllReduce(HloInstructionPtr crs) override {
109     return DefaultAction(crs);
110   }
HandleReduceScatter(HloInstructionPtr hlo)111   Status HandleReduceScatter(HloInstructionPtr hlo) override {
112     return DefaultAction(hlo);
113   }
HandleAllReduceStart(HloInstructionPtr hlo)114   Status HandleAllReduceStart(HloInstructionPtr hlo) override {
115     return DefaultAction(hlo);
116   }
HandleAllReduceDone(HloInstructionPtr hlo)117   Status HandleAllReduceDone(HloInstructionPtr hlo) override {
118     return DefaultAction(hlo);
119   }
HandleAllToAll(HloInstructionPtr hlo)120   Status HandleAllToAll(HloInstructionPtr hlo) override {
121     return DefaultAction(hlo);
122   }
HandleCollectivePermute(HloInstructionPtr hlo)123   Status HandleCollectivePermute(HloInstructionPtr hlo) override {
124     return DefaultAction(hlo);
125   }
HandleCollectivePermuteStart(HloInstructionPtr hlo)126   Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override {
127     return DefaultAction(hlo);
128   }
HandleCollectivePermuteDone(HloInstructionPtr hlo)129   Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override {
130     return DefaultAction(hlo);
131   }
HandleReplicaId(HloInstructionPtr hlo)132   Status HandleReplicaId(HloInstructionPtr hlo) override {
133     return DefaultAction(hlo);
134   }
HandlePartitionId(HloInstructionPtr hlo)135   Status HandlePartitionId(HloInstructionPtr hlo) override {
136     return DefaultAction(hlo);
137   }
HandleRng(HloInstructionPtr random)138   Status HandleRng(HloInstructionPtr random) override {
139     return DefaultAction(random);
140   }
HandleRngBitGenerator(HloInstructionPtr random)141   Status HandleRngBitGenerator(HloInstructionPtr random) override {
142     return DefaultAction(random);
143   }
HandleRngGetAndUpdateState(HloInstructionPtr random)144   Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
145     return DefaultAction(random);
146   }
HandleInfeed(HloInstructionPtr infeed)147   Status HandleInfeed(HloInstructionPtr infeed) override {
148     return DefaultAction(infeed);
149   }
HandleOutfeed(HloInstructionPtr outfeed)150   Status HandleOutfeed(HloInstructionPtr outfeed) override {
151     return DefaultAction(outfeed);
152   }
HandleReverse(HloInstructionPtr reverse)153   Status HandleReverse(HloInstructionPtr reverse) override {
154     return DefaultAction(reverse);
155   }
HandleSort(HloInstructionPtr sort)156   Status HandleSort(HloInstructionPtr sort) override {
157     return DefaultAction(sort);
158   }
HandleConstant(HloInstructionPtr constant)159   Status HandleConstant(HloInstructionPtr constant) override {
160     return DefaultAction(constant);
161   }
HandleIota(HloInstructionPtr iota)162   Status HandleIota(HloInstructionPtr iota) override {
163     return DefaultAction(iota);
164   }
HandleGetTupleElement(HloInstructionPtr get_tuple_element)165   Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override {
166     return DefaultAction(get_tuple_element);
167   }
HandleParameter(HloInstructionPtr parameter)168   Status HandleParameter(HloInstructionPtr parameter) override {
169     return DefaultAction(parameter);
170   }
HandleFusion(HloInstructionPtr fusion)171   Status HandleFusion(HloInstructionPtr fusion) override {
172     return DefaultAction(fusion);
173   }
HandleCall(HloInstructionPtr call)174   Status HandleCall(HloInstructionPtr call) override {
175     return DefaultAction(call);
176   }
HandleCustomCall(HloInstructionPtr custom_call)177   Status HandleCustomCall(HloInstructionPtr custom_call) override {
178     return DefaultAction(custom_call);
179   }
HandleSlice(HloInstructionPtr slice)180   Status HandleSlice(HloInstructionPtr slice) override {
181     return DefaultAction(slice);
182   }
HandleDynamicSlice(HloInstructionPtr dynamic_slice)183   Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override {
184     return DefaultAction(dynamic_slice);
185   }
HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)186   Status HandleDynamicUpdateSlice(
187       HloInstructionPtr dynamic_update_slice) override {
188     return DefaultAction(dynamic_update_slice);
189   }
HandleTuple(HloInstructionPtr tuple)190   Status HandleTuple(HloInstructionPtr tuple) override {
191     return DefaultAction(tuple);
192   }
HandleMap(HloInstructionPtr map)193   Status HandleMap(HloInstructionPtr map) override {
194     return DefaultAction(map);
195   }
HandleReduce(HloInstructionPtr reduce)196   Status HandleReduce(HloInstructionPtr reduce) override {
197     return DefaultAction(reduce);
198   }
HandleReduceWindow(HloInstructionPtr reduce_window)199   Status HandleReduceWindow(HloInstructionPtr reduce_window) override {
200     return DefaultAction(reduce_window);
201   }
HandleSelectAndScatter(HloInstructionPtr select_and_scatter)202   Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override {
203     return DefaultAction(select_and_scatter);
204   }
HandleBitcast(HloInstructionPtr bitcast)205   Status HandleBitcast(HloInstructionPtr bitcast) override {
206     return DefaultAction(bitcast);
207   }
HandleBroadcast(HloInstructionPtr broadcast)208   Status HandleBroadcast(HloInstructionPtr broadcast) override {
209     return DefaultAction(broadcast);
210   }
HandlePad(HloInstructionPtr pad)211   Status HandlePad(HloInstructionPtr pad) override {
212     return DefaultAction(pad);
213   }
HandleDynamicReshape(HloInstructionPtr dynamic_reshape)214   Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override {
215     return DefaultAction(dynamic_reshape);
216   }
HandleReshape(HloInstructionPtr reshape)217   Status HandleReshape(HloInstructionPtr reshape) override {
218     return DefaultAction(reshape);
219   }
HandleTranspose(HloInstructionPtr transpose)220   Status HandleTranspose(HloInstructionPtr transpose) override {
221     return DefaultAction(transpose);
222   }
HandleWhile(HloInstructionPtr xla_while)223   Status HandleWhile(HloInstructionPtr xla_while) override {
224     return DefaultAction(xla_while);
225   }
HandleConditional(HloInstructionPtr conditional)226   Status HandleConditional(HloInstructionPtr conditional) override {
227     return DefaultAction(conditional);
228   }
HandleAsyncStart(HloInstructionPtr async_start)229   Status HandleAsyncStart(HloInstructionPtr async_start) override {
230     return DefaultAction(async_start);
231   }
HandleAsyncUpdate(HloInstructionPtr async_update)232   Status HandleAsyncUpdate(HloInstructionPtr async_update) override {
233     return DefaultAction(async_update);
234   }
HandleAsyncDone(HloInstructionPtr async_done)235   Status HandleAsyncDone(HloInstructionPtr async_done) override {
236     return DefaultAction(async_done);
237   }
HandleCopyStart(HloInstructionPtr copy_start)238   Status HandleCopyStart(HloInstructionPtr copy_start) override {
239     return DefaultAction(copy_start);
240   }
HandleCopyDone(HloInstructionPtr copy_done)241   Status HandleCopyDone(HloInstructionPtr copy_done) override {
242     return DefaultAction(copy_done);
243   }
HandleRecv(HloInstructionPtr recv)244   Status HandleRecv(HloInstructionPtr recv) override {
245     return DefaultAction(recv);
246   }
HandleRecvDone(HloInstructionPtr recv_done)247   Status HandleRecvDone(HloInstructionPtr recv_done) override {
248     return DefaultAction(recv_done);
249   }
HandleSend(HloInstructionPtr send)250   Status HandleSend(HloInstructionPtr send) override {
251     return DefaultAction(send);
252   }
HandleSendDone(HloInstructionPtr send_done)253   Status HandleSendDone(HloInstructionPtr send_done) override {
254     return DefaultAction(send_done);
255   }
HandleGather(HloInstructionPtr gather)256   Status HandleGather(HloInstructionPtr gather) override {
257     return DefaultAction(gather);
258   }
HandleScatter(HloInstructionPtr scatter)259   Status HandleScatter(HloInstructionPtr scatter) override {
260     return DefaultAction(scatter);
261   }
HandleAfterAll(HloInstructionPtr token)262   Status HandleAfterAll(HloInstructionPtr token) override {
263     return DefaultAction(token);
264   }
HandleGetDimensionSize(HloInstructionPtr get_size)265   Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
266     return DefaultAction(get_size);
267   }
HandleSetDimensionSize(HloInstructionPtr get_size)268   Status HandleSetDimensionSize(HloInstructionPtr get_size) override {
269     return DefaultAction(get_size);
270   }
HandleAddDependency(HloInstructionPtr add_dependency)271   Status HandleAddDependency(HloInstructionPtr add_dependency) override {
272     return DefaultAction(add_dependency);
273   }
274 
275   // Invoked to inform the visitor that the traversal has completed, and that
276   // the root was "root".
FinishVisit(HloInstructionPtr)277   Status FinishVisit(HloInstructionPtr /*root*/) override { return OkStatus(); }
278 
279  private:
280   DfsHloVisitorWithDefaultBase(const DfsHloVisitorWithDefaultBase&) = delete;
281   DfsHloVisitorWithDefaultBase& operator=(const DfsHloVisitorWithDefaultBase&) =
282       delete;
283 };
284 
285 // Users should use these type aliases which are only two valid instantiations.
286 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>;
287 using ConstDfsHloVisitorWithDefault =
288     DfsHloVisitorWithDefaultBase<const HloInstruction*>;
289 
290 // A common base class for visitors performing rewriting operation.
291 //
292 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while
293 // visiting.
294 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault {
295  public:
296   // Runs a visitor on the module and returns whether the module has changed.
297   StatusOr<bool> RunOnModule(
298       HloModule* module,
299       const absl::flat_hash_set<absl::string_view>& execution_threads = {}) {
300     for (const auto& computation :
301          module->MakeNonfusionComputations(execution_threads)) {
302       TF_RETURN_IF_ERROR(computation->Accept(this));
303     }
304     return changed();
305   }
306 
307   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)308   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
309     return OkStatus();
310   }
311 
changed()312   bool changed() const { return changed_; }
313 
314  protected:
315   // Replaces the existing HLO instruction old_instruction, with
316   // new_instruction, and marks the optimizer status as changed.
317   // Returns the Status representing the result of the replace operation.
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)318   Status ReplaceWithNewInstruction(
319       HloInstruction* old_instruction,
320       std::unique_ptr<HloInstruction> new_instruction) {
321     VLOG(3) << "Replacing instruction:";
322     VLOG(3) << "  old: " << old_instruction->ToString();
323     VLOG(3) << "  new: " << new_instruction->ToString();
324     TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction(
325         old_instruction, std::move(new_instruction)));
326     changed_ = true;
327     return OkStatus();
328   }
329 
330   // Replaces the existing HLO instruction old_instruction, with
331   // new_instruction, and marks the optimizer status as changed.
332   // Returns the Status representing the result of the replace operation.
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)333   StatusOr<bool> ReplaceInstruction(HloInstruction* old_instruction,
334                                     HloInstruction* new_instruction,
335                                     bool preserve_sharding) {
336     VLOG(3) << "Replacing instruction:";
337     VLOG(3) << "  old: " << old_instruction->ToString();
338     VLOG(3) << "  new: " << new_instruction->ToString();
339     TF_ASSIGN_OR_RETURN(
340         bool changed, old_instruction->parent()->ReplaceInstruction(
341                           old_instruction, new_instruction, preserve_sharding));
342     changed_ |= changed;
343     return changed;
344   }
345 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)346   Status ReplaceInstruction(HloInstruction* old_instruction,
347                             HloInstruction* new_instruction) {
348     TF_ASSIGN_OR_RETURN(bool changed,
349                         ReplaceInstruction(old_instruction, new_instruction,
350                                            /*preserve_sharding=*/false));
351     DCHECK(changed);
352     return OkStatus();
353   }
354 
355   // Mark the computation as having changed.
MarkAsChanged()356   void MarkAsChanged() { changed_ = true; }
357 
358  private:
359   bool changed_ = false;
360 };
361 
362 // (Const)FunctionVisitor lets you transform an
363 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor.
364 //
365 // This is useful if you have code that needs to handle visitors in the form of
366 // both std::function and DfsHloVisitor.  You can wrap the function in a
367 // FunctionVisitor and then treat it like any other DfsHloVisitor.
368 template <typename HloInstructionPtr>
369 class FunctionVisitorBase
370     : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> {
371  public:
FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)372   explicit FunctionVisitorBase(
373       std::function<Status(HloInstructionPtr)> visitor_func)
374       : visitor_func_(std::move(visitor_func)) {}
375 
DefaultAction(HloInstructionPtr hlo_instruction)376   Status DefaultAction(HloInstructionPtr hlo_instruction) override {
377     return visitor_func_(hlo_instruction);
378   }
379 
380  private:
381   FunctionVisitorBase(const FunctionVisitorBase&) = delete;
382   FunctionVisitorBase& operator=(const FunctionVisitorBase&) = delete;
383 
384   std::function<Status(HloInstructionPtr)> visitor_func_;
385 };
386 
387 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>;
388 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>;
389 
390 }  // namespace xla
391 
392 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
393