1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h"
17 
18 #include "llvm/Support/CommandLine.h"
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/Location.h"  // from @llvm-project
21 #include "mlir/Pass/Pass.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
23 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
24 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
25 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34 
35 #define DEBUG_TYPE "run-tf-graph-optimization"
36 
37 namespace tensorflow {
38 namespace {
39 // Creates a pass to convert MLIR to Graph, run user-specified Graph
40 // Optimization Passes and convert back to MLIR.
41 // Constraints: This pass expects that all operations in the MLIR module either
42 // belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
43 class GraphOptPass
44     : public mlir::PassWrapper<GraphOptPass,
45                                mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const46   void getDependentDialects(mlir::DialectRegistry& registry) const override {
47     mlir::RegisterAllTensorFlowDialects(registry);
48   }
49 
50  public:
51   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GraphOptPass)
52 
GraphOptPass(std::vector<tensorflow::GraphOptimizationPass * > passes)53   explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
54       : passes_(std::move(passes)) {}
55 
56  protected:
57   void runOnOperation() override;
58 
59   // The passes to run on the module.
60   std::vector<GraphOptimizationPass*> passes_;
61 };
62 }  // anonymous namespace
63 
runOnOperation()64 void GraphOptPass::runOnOperation() {
65   mlir::ModuleOp module_in = getOperation();
66   mlir::MLIRContext& ctx = getContext();
67 
68   // Convert MLIR to Graph
69   FunctionLibraryDefinition flib_def(OpRegistry::Global(),
70                                      FunctionDefLibrary());
71   GraphExportConfig confs;
72   auto graph = std::make_unique<Graph>(flib_def);
73   Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
74   if (!status.ok()) {
75     mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message();
76     return signalPassFailure();
77   }
78 
79   // Run each of the passes that were selected.
80   GraphConstructorOptions opts;
81   opts.allow_internal_ops = true;
82   opts.expect_device_spec = false;
83 
84   GraphOptimizationPassOptions options;
85   SessionOptions sess_options;
86   options.graph = &graph;
87   options.flib_def = &flib_def;
88   options.session_options = &sess_options;
89 
90   for (auto pass : passes_) {
91     assert(pass != nullptr);
92     Status status = pass->Run(options);
93     if (!status.ok()) {
94       mlir::emitError(mlir::UnknownLoc::get(&ctx))
95           << pass->name() << ": " << status.error_message();
96       return signalPassFailure();
97     }
98   }
99 
100   // Convert Graph to MLIR
101   GraphDebugInfo debug_info;
102   GraphImportConfig specs;
103   auto module_or_status =
104       ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
105   if (!module_or_status.ok()) {
106     mlir::emitError(mlir::UnknownLoc::get(&ctx))
107         << module_or_status.status().error_message();
108     return signalPassFailure();
109   }
110   auto module_out = std::move(module_or_status).ValueOrDie();
111 
112   // We cannot replace the module in a ModulePass. So we simply copy the
113   // operation list from module_out to module_in.
114   auto& module_in_ops = module_in.getBody()->getOperations();
115   module_in_ops.clear();
116   module_in_ops.splice(module_in_ops.end(),
117                        module_out->getBody()->getOperations());
118 }
119 
120 // Returns a vector of passes from their names. If a pass is not found, then the
121 // corresponding return entry is null.
FindRegisteredPassesByName(const std::vector<std::string> & pass_names)122 static std::vector<GraphOptimizationPass*> FindRegisteredPassesByName(
123     const std::vector<std::string>& pass_names) {
124   std::vector<GraphOptimizationPass*> pass_ids(pass_names.size(), nullptr);
125 
126   for (const auto& group : OptimizationPassRegistry::Global()->groups()) {
127     for (const auto& phase : group.second) {
128       for (const auto& pass : phase.second) {
129         // Iterate over the pass_names_ and insert the pass pointer at all the
130         // corresponding indices in the pass_ids vector.
131         auto iter = pass_names.begin();
132         while ((iter = std::find(iter, pass_names.end(), pass->name())) !=
133                pass_names.end()) {
134           pass_ids[std::distance(pass_names.begin(), iter)] = pass.get();
135           iter++;
136         }
137       }
138     }
139   }
140   return pass_ids;
141 }
142 
143 // TODO(prakalps): Move these flags and pass registration to a header file so
144 // that it is clear that this is a generic pass library and command line is used
145 // for testing only.
146 
147 // NOLINTNEXTLINE
148 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
149 
150 // NOLINTNEXTLINE
151 static llvm::cl::list<std::string> cl_pass_list(
152     "graph-passes", llvm::cl::value_desc("list"),
153     llvm::cl::desc("comma separated list of GraphOptimizationPass to run."),
154     llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
155 
156 class GraphOptByNamePass : public GraphOptPass {
157  public:
GraphOptByNamePass()158   explicit GraphOptByNamePass() : GraphOptByNamePass(cl_pass_list) {}
GraphOptByNamePass(const std::vector<std::string> & pass_names)159   explicit GraphOptByNamePass(const std::vector<std::string>& pass_names)
160       : GraphOptPass(FindRegisteredPassesByName(pass_names)) {}
161 
getArgument() const162   llvm::StringRef getArgument() const final {
163     return "run-tf-graph-optimization";
164   }
165 
getDescription() const166   llvm::StringRef getDescription() const final {
167     return "runs passes registered as tensorflow::GraphOptimizationPass";
168   }
169 
170  private:
runOnOperation()171   void runOnOperation() override {
172     // Verify all passes requested were registered/found.
173     for (auto pass_it : llvm::enumerate(passes_)) {
174       if (pass_it.value() == nullptr) {
175         mlir::emitError(mlir::UnknownLoc::get(&getContext()))
176             << "could not find pass " << cl_pass_list[pass_it.index()];
177         return signalPassFailure();
178       }
179     }
180     return GraphOptPass::runOnOperation();
181   }
182 };
183 
184 }  // namespace tensorflow
185 
186 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(std::vector<tensorflow::GraphOptimizationPass * > tf_passes)187 tensorflow::CreateTensorFlowGraphOptimizationPass(
188     std::vector<tensorflow::GraphOptimizationPass*> tf_passes) {
189   return std::make_unique<GraphOptPass>(std::move(tf_passes));
190 }
191 
192 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(const std::vector<std::string> & pass_names)193 tensorflow::CreateTensorFlowGraphOptimizationPass(
194     const std::vector<std::string>& pass_names) {
195   return std::make_unique<GraphOptByNamePass>(pass_names);
196 }
197 
RegisterGraphOptimizationPasses()198 void tensorflow::RegisterGraphOptimizationPasses() {
199   ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
200     return std::make_unique<GraphOptByNamePass>();
201   });
202 }
203