1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/tfg_passes_builder.h" 17 18 #include "mlir/Transforms/Passes.h" // from @llvm-project 19 #include "tensorflow/core/ir/ops.h" 20 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 21 #include "tensorflow/core/transforms/pass_registration.h" 22 #include "tensorflow/core/util/util.h" 23 24 namespace mlir { 25 namespace tfg { 26 27 // The default pipeline only does shape inference now. DefaultGrapplerPipeline(PassManager & manager)28void DefaultGrapplerPipeline(PassManager& manager) { 29 // Turn certain shape attrs into types to give better knowledge for shape 30 // inference. 31 manager.addPass(CreateConsolidateAttributesPass()); 32 // Toposort the graph will bring better performance in some optimizations like 33 // shape inference. 34 manager.addPass(CreateTopoSortPass()); 35 // Infer the shape of operation if possible. The TFG importer doesn't do shape 36 // inference for almost all operations. 37 manager.addPass(CreateShapeInferencePass()); 38 // Contruct the shape attrs back from types. 39 manager.addPass(CreatePrepareAttributesForExportPass()); 40 } 41 42 // Run the consolidate attributes pass. Convert the whole module to region 43 // control-flow and run control-flow sinking. Convert the whole module back to 44 // functional control-flow and prepare the attributes for export. DefaultModuleGrapplerPipeline(PassManager & manager,const tensorflow::RewriterConfig & config)45void DefaultModuleGrapplerPipeline(PassManager& manager, 46 const tensorflow::RewriterConfig& config) { 47 manager.addPass(CreateConsolidateAttributesPass()); 48 manager.addPass(CreateFunctionalToRegionPass()); 49 if (config.experimental_conditional_code_motion() != 50 tensorflow::RewriterConfig::OFF) 51 manager.addNestedPass<GraphFuncOp>(CreateControlFlowSinkPass()); 52 manager.addPass(CreateRegionToFunctionalPass(/*force_control_capture=*/true)); 53 manager.addPass(CreateLiftLegacyCallPass()); 54 manager.addPass(createSymbolPrivatizePass()); 55 manager.addPass(createSymbolDCEPass()); 56 manager.addPass(CreatePrepareAttributesForExportPass()); 57 } 58 RemapperPassBuilder(PassManager & manager)59void RemapperPassBuilder(PassManager& manager) { 60 manager.addPass(CreateConsolidateAttributesPass()); 61 manager.addPass(CreateTopoSortPass()); 62 manager.addPass(CreateShapeInferencePass()); 63 manager.addPass( 64 CreateRemapperPass(/*enable_onednn_patterns=*/tensorflow::IsMKLEnabled(), 65 /*xla_auto_clustering=*/false)); 66 manager.addPass(CreatePrepareAttributesForExportPass()); 67 } 68 69 } // namespace tfg 70 } // namespace mlir 71