xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/tfg_passes_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/tfg_passes_builder.h"
17 
18 #include "mlir/Transforms/Passes.h"  // from @llvm-project
19 #include "tensorflow/core/ir/ops.h"
20 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
21 #include "tensorflow/core/transforms/pass_registration.h"
22 #include "tensorflow/core/util/util.h"
23 
24 namespace mlir {
25 namespace tfg {
26 
27 // The default pipeline only does shape inference now.
DefaultGrapplerPipeline(PassManager & manager)28 void DefaultGrapplerPipeline(PassManager& manager) {
29   // Turn certain shape attrs into types to give better knowledge for shape
30   // inference.
31   manager.addPass(CreateConsolidateAttributesPass());
32   // Toposort the graph will bring better performance in some optimizations like
33   // shape inference.
34   manager.addPass(CreateTopoSortPass());
35   // Infer the shape of operation if possible. The TFG importer doesn't do shape
36   // inference for almost all operations.
37   manager.addPass(CreateShapeInferencePass());
38   // Contruct the shape attrs back from types.
39   manager.addPass(CreatePrepareAttributesForExportPass());
40 }
41 
42 // Run the consolidate attributes pass. Convert the whole module to region
43 // control-flow and run control-flow sinking. Convert the whole module back to
44 // functional control-flow and prepare the attributes for export.
DefaultModuleGrapplerPipeline(PassManager & manager,const tensorflow::RewriterConfig & config)45 void DefaultModuleGrapplerPipeline(PassManager& manager,
46                                    const tensorflow::RewriterConfig& config) {
47   manager.addPass(CreateConsolidateAttributesPass());
48   manager.addPass(CreateFunctionalToRegionPass());
49   if (config.experimental_conditional_code_motion() !=
50       tensorflow::RewriterConfig::OFF)
51     manager.addNestedPass<GraphFuncOp>(CreateControlFlowSinkPass());
52   manager.addPass(CreateRegionToFunctionalPass(/*force_control_capture=*/true));
53   manager.addPass(CreateLiftLegacyCallPass());
54   manager.addPass(createSymbolPrivatizePass());
55   manager.addPass(createSymbolDCEPass());
56   manager.addPass(CreatePrepareAttributesForExportPass());
57 }
58 
RemapperPassBuilder(PassManager & manager)59 void RemapperPassBuilder(PassManager& manager) {
60   manager.addPass(CreateConsolidateAttributesPass());
61   manager.addPass(CreateTopoSortPass());
62   manager.addPass(CreateShapeInferencePass());
63   manager.addPass(
64       CreateRemapperPass(/*enable_onednn_patterns=*/tensorflow::IsMKLEnabled(),
65                          /*xla_auto_clustering=*/false));
66   manager.addPass(CreatePrepareAttributesForExportPass());
67 }
68 
69 }  // namespace tfg
70 }  // namespace mlir
71