1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h"
17
18 #include <cstdint>
19 #include <cstring>
20 #include <string>
21
22 #include "llvm/ADT/StringMap.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/ADT/Twine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/IR/Operation.h" // from @llvm-project
28 #include "mlir/IR/Verifier.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
30 #include "tensorflow/core/ir/importexport/graphdef_import.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/path.h"
34 #include "tensorflow/core/util/dump_graph.h"
35
36 namespace tensorflow {
37
38 namespace {
39
40 // Simple raw_ostream that prints to a file (doesn't take ownership).
41 struct WritableFileRawStream : public llvm::raw_ostream {
WritableFileRawStreamtensorflow::__anon102086150111::WritableFileRawStream42 explicit WritableFileRawStream(WritableFile* file) : file(file) {
43 SetUnbuffered();
44 }
45 ~WritableFileRawStream() override = default;
current_postensorflow::__anon102086150111::WritableFileRawStream46 uint64_t current_pos() const override { return 0; }
47
write_impltensorflow::__anon102086150111::WritableFileRawStream48 void write_impl(const char* ptr, size_t size) override {
49 // If an error is encountered, null out the file.
50 if (file) {
51 Status s = file->Append(StringPiece(ptr, size));
52 if (!s.ok()) {
53 LOG(WARNING) << "Write failed: " << s;
54 file = nullptr;
55 }
56 }
57 }
58
59 // The file being written to.
60 WritableFile* file;
61 };
62 } // namespace
63
DumpTextualIRToFile(const MlirDumpConfig & config,const Graph & graph,const FunctionLibraryDefinition * flib_def,WritableFile * file)64 Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
65 const FunctionLibraryDefinition* flib_def,
66 WritableFile* file) {
67 WritableFileRawStream os(std::move(file));
68 mlir::MLIRContext context;
69 mlir::OwningOpRef<mlir::ModuleOp> module;
70 if (flib_def) {
71 flib_def = &graph.flib_def();
72 }
73 auto convert = [&]() -> Status {
74 mlir::StatusScopedDiagnosticHandler status_handler(&context);
75 // TODO(jpienaar): Both the graph debug info and import config should be
76 // specifiable.
77 GraphDebugInfo debug_info;
78 switch (config.dialect) {
79 case MlirDumpConfig::Dialect::kTFG: {
80 TF_ASSIGN_OR_RETURN(module,
81 mlir::tfg::ImportGraphAndFunctionsToMlir(
82 &context, debug_info, graph,
83 flib_def ? *flib_def : graph.flib_def()));
84 break;
85 }
86 }
87 if (failed(mlir::verify(*module))) {
88 return status_handler.ConsumeStatus();
89 }
90 return status_handler.ConsumeStatus();
91 };
92
93 TF_RETURN_IF_ERROR(convert());
94 module->print(os, config.op_printing_flags);
95 return OkStatus();
96 }
97
UseMlirForGraphDump(const MlirDumpConfig & config)98 void UseMlirForGraphDump(const MlirDumpConfig& config) {
99 SetGraphDumper(
100 [config](const Graph& graph, const FunctionLibraryDefinition* flib_def,
101 WritableFile* file) -> Status {
102 return DumpTextualIRToFile(config, graph, flib_def, file);
103 },
104 /*suffix=*/".mlir");
105 }
106
107 } // namespace tensorflow
108