1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/transforms/region_to_functional/pass.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "mlir/IR/PatternMatch.h" // from @llvm-project 22 #include "mlir/Pass/Pass.h" // from @llvm-project 23 #include "mlir/Pass/PassManager.h" // from @llvm-project 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project 25 #include "tensorflow/core/transforms/pass_detail.h" 26 #include "tensorflow/core/transforms/region_to_functional/impl.h" 27 28 namespace mlir { 29 namespace tfg { 30 31 namespace { 32 struct RegionToFunctionalPass 33 : public RegionToFunctionalBase<RegionToFunctionalPass> { RegionToFunctionalPassmlir::tfg::__anond0b0aa330111::RegionToFunctionalPass34 explicit RegionToFunctionalPass(bool force_ctl_capture) { 35 force_control_capture = force_ctl_capture; 36 } 37 runOnOperationmlir::tfg::__anond0b0aa330111::RegionToFunctionalPass38 void runOnOperation() override { 39 RewritePatternSet patterns(&getContext()); 40 SymbolTable table(getOperation()); 41 PopulateRegionToFunctionalPatterns(patterns, table, force_control_capture); 42 43 GreedyRewriteConfig config; 44 // Use top-down traversal for more efficient conversion. Disable region 45 // simplification as all regions are single block. 46 config.useTopDownTraversal = true; 47 config.enableRegionSimplification = false; 48 // Iterate until all regions have been outlined. This is guaranteed to 49 // terminate because the IR can only hold a finite depth of regions. 50 config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; 51 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), 52 config))) { 53 getOperation()->emitError(getArgument() + " pass failed"); 54 signalPassFailure(); 55 } 56 } 57 }; 58 } // namespace 59 CreateRegionToFunctionalPass(bool force_control_capture)60std::unique_ptr<Pass> CreateRegionToFunctionalPass(bool force_control_capture) { 61 return std::make_unique<RegionToFunctionalPass>(force_control_capture); 62 } 63 64 } // namespace tfg 65 } // namespace mlir 66