xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/region_to_functional/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/region_to_functional/pass.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
25 #include "tensorflow/core/transforms/pass_detail.h"
26 #include "tensorflow/core/transforms/region_to_functional/impl.h"
27 
28 namespace mlir {
29 namespace tfg {
30 
31 namespace {
32 struct RegionToFunctionalPass
33     : public RegionToFunctionalBase<RegionToFunctionalPass> {
RegionToFunctionalPassmlir::tfg::__anond0b0aa330111::RegionToFunctionalPass34   explicit RegionToFunctionalPass(bool force_ctl_capture) {
35     force_control_capture = force_ctl_capture;
36   }
37 
runOnOperationmlir::tfg::__anond0b0aa330111::RegionToFunctionalPass38   void runOnOperation() override {
39     RewritePatternSet patterns(&getContext());
40     SymbolTable table(getOperation());
41     PopulateRegionToFunctionalPatterns(patterns, table, force_control_capture);
42 
43     GreedyRewriteConfig config;
44     // Use top-down traversal for more efficient conversion. Disable region
45     // simplification as all regions are single block.
46     config.useTopDownTraversal = true;
47     config.enableRegionSimplification = false;
48     // Iterate until all regions have been outlined. This is guaranteed to
49     // terminate because the IR can only hold a finite depth of regions.
50     config.maxIterations = GreedyRewriteConfig::kNoIterationLimit;
51     if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
52                                             config))) {
53       getOperation()->emitError(getArgument() + " pass failed");
54       signalPassFailure();
55     }
56   }
57 };
58 }  // namespace
59 
CreateRegionToFunctionalPass(bool force_control_capture)60 std::unique_ptr<Pass> CreateRegionToFunctionalPass(bool force_control_capture) {
61   return std::make_unique<RegionToFunctionalPass>(force_control_capture);
62 }
63 
64 }  // namespace tfg
65 }  // namespace mlir
66