1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 18 19 #include "absl/strings/string_view.h" 20 #include "absl/types/span.h" 21 #include "tensorflow/compiler/xla/literal.h" 22 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 23 #include "tensorflow/compiler/xla/service/hlo_computation.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/hlo_module.h" 26 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/compiler/xla/types.h" 29 #include "tensorflow/compiler/xla/xla_data.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 32 namespace xla { 33 34 // DfsHloVisitor with default action based on the HloInstruction being visited. 35 // Users should not use this class directly, but use the type aliases 36 // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. 37 // 38 // Do *not* add an override to this class if the opcode is covered by 39 // HandleElementwiseUnary/Binary. These opcode handlers dispatch to 40 // HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler 41 // here will break passes which rely on the HandleElementwiseUnary/Binary 42 // handling these opcodes. 43 template <typename HloInstructionPtr> 44 class DfsHloVisitorWithDefaultBase 45 : public DfsHloVisitorBase<HloInstructionPtr> { 46 public: DfsHloVisitorWithDefaultBase()47 DfsHloVisitorWithDefaultBase() {} ~DfsHloVisitorWithDefaultBase()48 ~DfsHloVisitorWithDefaultBase() override {} 49 50 // Default action performed on HloInstruction. 51 virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0; 52 HandleElementwiseUnary(HloInstructionPtr hlo)53 Status HandleElementwiseUnary(HloInstructionPtr hlo) override { 54 return DefaultAction(hlo); 55 } HandleElementwiseBinary(HloInstructionPtr hlo)56 Status HandleElementwiseBinary(HloInstructionPtr hlo) override { 57 return DefaultAction(hlo); 58 } 59 HandleBatchNormTraining(HloInstructionPtr hlo)60 Status HandleBatchNormTraining(HloInstructionPtr hlo) override { 61 return DefaultAction(hlo); 62 } 63 HandleBatchNormInference(HloInstructionPtr hlo)64 Status HandleBatchNormInference(HloInstructionPtr hlo) override { 65 return DefaultAction(hlo); 66 } 67 HandleBatchNormGrad(HloInstructionPtr hlo)68 Status HandleBatchNormGrad(HloInstructionPtr hlo) override { 69 return DefaultAction(hlo); 70 } 71 HandleClamp(HloInstructionPtr clamp)72 Status HandleClamp(HloInstructionPtr clamp) override { 73 return DefaultAction(clamp); 74 } HandleConcatenate(HloInstructionPtr concatenate)75 Status HandleConcatenate(HloInstructionPtr concatenate) override { 76 return DefaultAction(concatenate); 77 } HandleSelect(HloInstructionPtr select)78 Status HandleSelect(HloInstructionPtr select) override { 79 return DefaultAction(select); 80 } HandleDot(HloInstructionPtr dot)81 Status HandleDot(HloInstructionPtr dot) override { 82 return DefaultAction(dot); 83 } HandleConvolution(HloInstructionPtr convolution)84 Status HandleConvolution(HloInstructionPtr convolution) override { 85 return DefaultAction(convolution); 86 } HandleFft(HloInstructionPtr fft)87 Status HandleFft(HloInstructionPtr fft) override { 88 return DefaultAction(fft); 89 } HandleTriangularSolve(HloInstructionPtr hlo)90 Status HandleTriangularSolve(HloInstructionPtr hlo) override { 91 return DefaultAction(hlo); 92 } HandleCholesky(HloInstructionPtr hlo)93 Status HandleCholesky(HloInstructionPtr hlo) override { 94 return DefaultAction(hlo); 95 } HandleOptimizationBarrier(HloInstructionPtr hlo)96 Status HandleOptimizationBarrier(HloInstructionPtr hlo) override { 97 return DefaultAction(hlo); 98 } HandleAllGather(HloInstructionPtr crs)99 Status HandleAllGather(HloInstructionPtr crs) override { 100 return DefaultAction(crs); 101 } HandleAllGatherStart(HloInstructionPtr crs)102 Status HandleAllGatherStart(HloInstructionPtr crs) override { 103 return DefaultAction(crs); 104 } HandleAllGatherDone(HloInstructionPtr crs)105 Status HandleAllGatherDone(HloInstructionPtr crs) override { 106 return DefaultAction(crs); 107 } HandleAllReduce(HloInstructionPtr crs)108 Status HandleAllReduce(HloInstructionPtr crs) override { 109 return DefaultAction(crs); 110 } HandleReduceScatter(HloInstructionPtr hlo)111 Status HandleReduceScatter(HloInstructionPtr hlo) override { 112 return DefaultAction(hlo); 113 } HandleAllReduceStart(HloInstructionPtr hlo)114 Status HandleAllReduceStart(HloInstructionPtr hlo) override { 115 return DefaultAction(hlo); 116 } HandleAllReduceDone(HloInstructionPtr hlo)117 Status HandleAllReduceDone(HloInstructionPtr hlo) override { 118 return DefaultAction(hlo); 119 } HandleAllToAll(HloInstructionPtr hlo)120 Status HandleAllToAll(HloInstructionPtr hlo) override { 121 return DefaultAction(hlo); 122 } HandleCollectivePermute(HloInstructionPtr hlo)123 Status HandleCollectivePermute(HloInstructionPtr hlo) override { 124 return DefaultAction(hlo); 125 } HandleCollectivePermuteStart(HloInstructionPtr hlo)126 Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { 127 return DefaultAction(hlo); 128 } HandleCollectivePermuteDone(HloInstructionPtr hlo)129 Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { 130 return DefaultAction(hlo); 131 } HandleReplicaId(HloInstructionPtr hlo)132 Status HandleReplicaId(HloInstructionPtr hlo) override { 133 return DefaultAction(hlo); 134 } HandlePartitionId(HloInstructionPtr hlo)135 Status HandlePartitionId(HloInstructionPtr hlo) override { 136 return DefaultAction(hlo); 137 } HandleRng(HloInstructionPtr random)138 Status HandleRng(HloInstructionPtr random) override { 139 return DefaultAction(random); 140 } HandleRngBitGenerator(HloInstructionPtr random)141 Status HandleRngBitGenerator(HloInstructionPtr random) override { 142 return DefaultAction(random); 143 } HandleRngGetAndUpdateState(HloInstructionPtr random)144 Status HandleRngGetAndUpdateState(HloInstructionPtr random) override { 145 return DefaultAction(random); 146 } HandleInfeed(HloInstructionPtr infeed)147 Status HandleInfeed(HloInstructionPtr infeed) override { 148 return DefaultAction(infeed); 149 } HandleOutfeed(HloInstructionPtr outfeed)150 Status HandleOutfeed(HloInstructionPtr outfeed) override { 151 return DefaultAction(outfeed); 152 } HandleReverse(HloInstructionPtr reverse)153 Status HandleReverse(HloInstructionPtr reverse) override { 154 return DefaultAction(reverse); 155 } HandleSort(HloInstructionPtr sort)156 Status HandleSort(HloInstructionPtr sort) override { 157 return DefaultAction(sort); 158 } HandleConstant(HloInstructionPtr constant)159 Status HandleConstant(HloInstructionPtr constant) override { 160 return DefaultAction(constant); 161 } HandleIota(HloInstructionPtr iota)162 Status HandleIota(HloInstructionPtr iota) override { 163 return DefaultAction(iota); 164 } HandleGetTupleElement(HloInstructionPtr get_tuple_element)165 Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { 166 return DefaultAction(get_tuple_element); 167 } HandleParameter(HloInstructionPtr parameter)168 Status HandleParameter(HloInstructionPtr parameter) override { 169 return DefaultAction(parameter); 170 } HandleFusion(HloInstructionPtr fusion)171 Status HandleFusion(HloInstructionPtr fusion) override { 172 return DefaultAction(fusion); 173 } HandleCall(HloInstructionPtr call)174 Status HandleCall(HloInstructionPtr call) override { 175 return DefaultAction(call); 176 } HandleCustomCall(HloInstructionPtr custom_call)177 Status HandleCustomCall(HloInstructionPtr custom_call) override { 178 return DefaultAction(custom_call); 179 } HandleSlice(HloInstructionPtr slice)180 Status HandleSlice(HloInstructionPtr slice) override { 181 return DefaultAction(slice); 182 } HandleDynamicSlice(HloInstructionPtr dynamic_slice)183 Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override { 184 return DefaultAction(dynamic_slice); 185 } HandleDynamicUpdateSlice(HloInstructionPtr dynamic_update_slice)186 Status HandleDynamicUpdateSlice( 187 HloInstructionPtr dynamic_update_slice) override { 188 return DefaultAction(dynamic_update_slice); 189 } HandleTuple(HloInstructionPtr tuple)190 Status HandleTuple(HloInstructionPtr tuple) override { 191 return DefaultAction(tuple); 192 } HandleMap(HloInstructionPtr map)193 Status HandleMap(HloInstructionPtr map) override { 194 return DefaultAction(map); 195 } HandleReduce(HloInstructionPtr reduce)196 Status HandleReduce(HloInstructionPtr reduce) override { 197 return DefaultAction(reduce); 198 } HandleReduceWindow(HloInstructionPtr reduce_window)199 Status HandleReduceWindow(HloInstructionPtr reduce_window) override { 200 return DefaultAction(reduce_window); 201 } HandleSelectAndScatter(HloInstructionPtr select_and_scatter)202 Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override { 203 return DefaultAction(select_and_scatter); 204 } HandleBitcast(HloInstructionPtr bitcast)205 Status HandleBitcast(HloInstructionPtr bitcast) override { 206 return DefaultAction(bitcast); 207 } HandleBroadcast(HloInstructionPtr broadcast)208 Status HandleBroadcast(HloInstructionPtr broadcast) override { 209 return DefaultAction(broadcast); 210 } HandlePad(HloInstructionPtr pad)211 Status HandlePad(HloInstructionPtr pad) override { 212 return DefaultAction(pad); 213 } HandleDynamicReshape(HloInstructionPtr dynamic_reshape)214 Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override { 215 return DefaultAction(dynamic_reshape); 216 } HandleReshape(HloInstructionPtr reshape)217 Status HandleReshape(HloInstructionPtr reshape) override { 218 return DefaultAction(reshape); 219 } HandleTranspose(HloInstructionPtr transpose)220 Status HandleTranspose(HloInstructionPtr transpose) override { 221 return DefaultAction(transpose); 222 } HandleWhile(HloInstructionPtr xla_while)223 Status HandleWhile(HloInstructionPtr xla_while) override { 224 return DefaultAction(xla_while); 225 } HandleConditional(HloInstructionPtr conditional)226 Status HandleConditional(HloInstructionPtr conditional) override { 227 return DefaultAction(conditional); 228 } HandleAsyncStart(HloInstructionPtr async_start)229 Status HandleAsyncStart(HloInstructionPtr async_start) override { 230 return DefaultAction(async_start); 231 } HandleAsyncUpdate(HloInstructionPtr async_update)232 Status HandleAsyncUpdate(HloInstructionPtr async_update) override { 233 return DefaultAction(async_update); 234 } HandleAsyncDone(HloInstructionPtr async_done)235 Status HandleAsyncDone(HloInstructionPtr async_done) override { 236 return DefaultAction(async_done); 237 } HandleCopyStart(HloInstructionPtr copy_start)238 Status HandleCopyStart(HloInstructionPtr copy_start) override { 239 return DefaultAction(copy_start); 240 } HandleCopyDone(HloInstructionPtr copy_done)241 Status HandleCopyDone(HloInstructionPtr copy_done) override { 242 return DefaultAction(copy_done); 243 } HandleRecv(HloInstructionPtr recv)244 Status HandleRecv(HloInstructionPtr recv) override { 245 return DefaultAction(recv); 246 } HandleRecvDone(HloInstructionPtr recv_done)247 Status HandleRecvDone(HloInstructionPtr recv_done) override { 248 return DefaultAction(recv_done); 249 } HandleSend(HloInstructionPtr send)250 Status HandleSend(HloInstructionPtr send) override { 251 return DefaultAction(send); 252 } HandleSendDone(HloInstructionPtr send_done)253 Status HandleSendDone(HloInstructionPtr send_done) override { 254 return DefaultAction(send_done); 255 } HandleGather(HloInstructionPtr gather)256 Status HandleGather(HloInstructionPtr gather) override { 257 return DefaultAction(gather); 258 } HandleScatter(HloInstructionPtr scatter)259 Status HandleScatter(HloInstructionPtr scatter) override { 260 return DefaultAction(scatter); 261 } HandleAfterAll(HloInstructionPtr token)262 Status HandleAfterAll(HloInstructionPtr token) override { 263 return DefaultAction(token); 264 } HandleGetDimensionSize(HloInstructionPtr get_size)265 Status HandleGetDimensionSize(HloInstructionPtr get_size) override { 266 return DefaultAction(get_size); 267 } HandleSetDimensionSize(HloInstructionPtr get_size)268 Status HandleSetDimensionSize(HloInstructionPtr get_size) override { 269 return DefaultAction(get_size); 270 } HandleAddDependency(HloInstructionPtr add_dependency)271 Status HandleAddDependency(HloInstructionPtr add_dependency) override { 272 return DefaultAction(add_dependency); 273 } 274 275 // Invoked to inform the visitor that the traversal has completed, and that 276 // the root was "root". FinishVisit(HloInstructionPtr)277 Status FinishVisit(HloInstructionPtr /*root*/) override { return OkStatus(); } 278 279 private: 280 DfsHloVisitorWithDefaultBase(const DfsHloVisitorWithDefaultBase&) = delete; 281 DfsHloVisitorWithDefaultBase& operator=(const DfsHloVisitorWithDefaultBase&) = 282 delete; 283 }; 284 285 // Users should use these type aliases which are only two valid instantiations. 286 using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase<HloInstruction*>; 287 using ConstDfsHloVisitorWithDefault = 288 DfsHloVisitorWithDefaultBase<const HloInstruction*>; 289 290 // A common base class for visitors performing rewriting operation. 291 // 292 // Subclasses call ReplaceWithNewInstruction and ReplaceInstruction while 293 // visiting. 294 class DfsHloRewriteVisitor : public DfsHloVisitorWithDefault { 295 public: 296 // Runs a visitor on the module and returns whether the module has changed. 297 StatusOr<bool> RunOnModule( 298 HloModule* module, 299 const absl::flat_hash_set<absl::string_view>& execution_threads = {}) { 300 for (const auto& computation : 301 module->MakeNonfusionComputations(execution_threads)) { 302 TF_RETURN_IF_ERROR(computation->Accept(this)); 303 } 304 return changed(); 305 } 306 307 // Default visitor action is to do nothing and return OK. DefaultAction(HloInstruction *)308 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { 309 return OkStatus(); 310 } 311 changed()312 bool changed() const { return changed_; } 313 314 protected: 315 // Replaces the existing HLO instruction old_instruction, with 316 // new_instruction, and marks the optimizer status as changed. 317 // Returns the Status representing the result of the replace operation. ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)318 Status ReplaceWithNewInstruction( 319 HloInstruction* old_instruction, 320 std::unique_ptr<HloInstruction> new_instruction) { 321 VLOG(3) << "Replacing instruction:"; 322 VLOG(3) << " old: " << old_instruction->ToString(); 323 VLOG(3) << " new: " << new_instruction->ToString(); 324 TF_RETURN_IF_ERROR(old_instruction->parent()->ReplaceWithNewInstruction( 325 old_instruction, std::move(new_instruction))); 326 changed_ = true; 327 return OkStatus(); 328 } 329 330 // Replaces the existing HLO instruction old_instruction, with 331 // new_instruction, and marks the optimizer status as changed. 332 // Returns the Status representing the result of the replace operation. ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)333 StatusOr<bool> ReplaceInstruction(HloInstruction* old_instruction, 334 HloInstruction* new_instruction, 335 bool preserve_sharding) { 336 VLOG(3) << "Replacing instruction:"; 337 VLOG(3) << " old: " << old_instruction->ToString(); 338 VLOG(3) << " new: " << new_instruction->ToString(); 339 TF_ASSIGN_OR_RETURN( 340 bool changed, old_instruction->parent()->ReplaceInstruction( 341 old_instruction, new_instruction, preserve_sharding)); 342 changed_ |= changed; 343 return changed; 344 } 345 ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)346 Status ReplaceInstruction(HloInstruction* old_instruction, 347 HloInstruction* new_instruction) { 348 TF_ASSIGN_OR_RETURN(bool changed, 349 ReplaceInstruction(old_instruction, new_instruction, 350 /*preserve_sharding=*/false)); 351 DCHECK(changed); 352 return OkStatus(); 353 } 354 355 // Mark the computation as having changed. MarkAsChanged()356 void MarkAsChanged() { changed_ = true; } 357 358 private: 359 bool changed_ = false; 360 }; 361 362 // (Const)FunctionVisitor lets you transform an 363 // std::function<Status((const) HloInstruction*)> into a (Const)DfsHloVisitor. 364 // 365 // This is useful if you have code that needs to handle visitors in the form of 366 // both std::function and DfsHloVisitor. You can wrap the function in a 367 // FunctionVisitor and then treat it like any other DfsHloVisitor. 368 template <typename HloInstructionPtr> 369 class FunctionVisitorBase 370 : public DfsHloVisitorWithDefaultBase<HloInstructionPtr> { 371 public: FunctionVisitorBase(std::function<Status (HloInstructionPtr)> visitor_func)372 explicit FunctionVisitorBase( 373 std::function<Status(HloInstructionPtr)> visitor_func) 374 : visitor_func_(std::move(visitor_func)) {} 375 DefaultAction(HloInstructionPtr hlo_instruction)376 Status DefaultAction(HloInstructionPtr hlo_instruction) override { 377 return visitor_func_(hlo_instruction); 378 } 379 380 private: 381 FunctionVisitorBase(const FunctionVisitorBase&) = delete; 382 FunctionVisitorBase& operator=(const FunctionVisitorBase&) = delete; 383 384 std::function<Status(HloInstructionPtr)> visitor_func_; 385 }; 386 387 using FunctionVisitor = FunctionVisitorBase<HloInstruction*>; 388 using ConstFunctionVisitor = FunctionVisitorBase<const HloInstruction*>; 389 390 } // namespace xla 391 392 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ 393