1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
16
17 #include <cstddef>
18 #include <set>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/Parser/Parser.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Pass/PassManager.h" // from @llvm-project
33 #include "mlir/Support/FileUtilities.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/metrics/types_util.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/core/platform/resource_loader.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39
40 namespace mlir {
41 namespace TFL {
42 namespace {
43 using stream_executor::port::StatusOr;
44
45 // MockSuccessPass reports errors but doesn't fail.
46 class MockSuccessPass
47 : public PassWrapper<MockSuccessPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const48 void getDependentDialects(DialectRegistry& registry) const override {
49 registry.insert<TF::TensorFlowDialect>();
50 }
51
52 public:
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockSuccessPass)
54
MockSuccessPass()55 explicit MockSuccessPass() {}
56
57 private:
runOnOperation()58 void runOnOperation() override {
59 getOperation().walk([](Operation* nestedOp) {
60 nestedOp->emitError()
61 << "Error at " << nestedOp->getName().getStringRef().str() << " op";
62 });
63 };
64 };
65
66 // MockFailurePass reports errors and fails.
67 class MockFailurePass
68 : public PassWrapper<MockFailurePass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const69 void getDependentDialects(DialectRegistry& registry) const override {
70 registry.insert<TF::TensorFlowDialect>();
71 }
72
73 public:
74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockFailurePass)
75
MockFailurePass()76 explicit MockFailurePass() {}
77
78 private:
runOnOperation()79 void runOnOperation() override {
80 getOperation().walk([](Operation* nestedOp) {
81 if (nestedOp->getName().getStringRef().str().rfind("tf.") != -1) {
82 AttachErrorCode(
83 nestedOp->emitError()
84 << "Failed at " << nestedOp->getName().getStringRef().str()
85 << " op",
86 tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
87 }
88 });
89 signalPassFailure();
90 };
91 };
92
LoadModule(MLIRContext * context,const std::string & file_name)93 StatusOr<OwningOpRef<mlir::ModuleOp>> LoadModule(MLIRContext* context,
94 const std::string& file_name) {
95 std::string error_message;
96 auto file = openInputFile(file_name, &error_message);
97 if (!file) {
98 return tensorflow::errors::InvalidArgument("fail to open input file");
99 }
100
101 llvm::SourceMgr source_mgr;
102 source_mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
103 return OwningOpRef<mlir::ModuleOp>(
104 parseSourceFile<mlir::ModuleOp>(source_mgr, context));
105 }
106
TEST(ErrorCollectorTest,TessSuccessPass)107 TEST(ErrorCollectorTest, TessSuccessPass) {
108 std::string input_file = tensorflow::GetDataDependencyFilepath(
109 "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir");
110 MLIRContext context;
111 context.getOrLoadDialect<mlir::func::FuncDialect>();
112 context.allowUnregisteredDialects();
113 context.enableMultithreading();
114
115 auto module = LoadModule(&context, input_file);
116 EXPECT_EQ(module.ok(), true);
117
118 PassManager pm(&context, OpPassManager::Nesting::Implicit);
119 pm.addPass(std::make_unique<MockSuccessPass>());
120
121 pm.addInstrumentation(
122 std::make_unique<ErrorCollectorInstrumentation>(&context));
123 EXPECT_EQ(succeeded(pm.run(module.ValueOrDie().get())), true);
124
125 auto collected_errors =
126 ErrorCollector::GetErrorCollector()->CollectedErrors();
127 EXPECT_EQ(collected_errors.size(), 0);
128 }
129
TEST(ErrorCollectorTest,TessFailurePass)130 TEST(ErrorCollectorTest, TessFailurePass) {
131 using tflite::metrics::ConverterErrorData;
132 MLIRContext context;
133 context.getOrLoadDialect<mlir::func::FuncDialect>();
134 const std::string input_file =
135 "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir";
136 auto input_file_id = StringAttr::get(&context, input_file);
137
138 context.allowUnregisteredDialects();
139 context.enableMultithreading();
140
141 auto module =
142 LoadModule(&context, tensorflow::GetDataDependencyFilepath(input_file));
143 EXPECT_EQ(module.ok(), true);
144
145 PassManager pm(&context, OpPassManager::Nesting::Implicit);
146 pm.addPass(std::make_unique<MockSuccessPass>());
147 pm.addPass(std::make_unique<MockFailurePass>());
148
149 pm.addInstrumentation(
150 std::make_unique<ErrorCollectorInstrumentation>(&context));
151 EXPECT_EQ(succeeded(pm.run(module.ValueOrDie().get())), false);
152
153 auto collected_errors =
154 ErrorCollector::GetErrorCollector()->CollectedErrors();
155
156 EXPECT_EQ(collected_errors.size(), 3);
157 EXPECT_EQ(collected_errors.count(NewConverterErrorData(
158 "MockFailurePass",
159 "Failed at tf.Const op\nsee current operation: %0 = "
160 "\"tf.Const\"() {value = dense<1> : tensor<4xi32>} : () -> "
161 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
162 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
163 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
164 1);
165 EXPECT_EQ(collected_errors.count(NewConverterErrorData(
166 "MockFailurePass",
167 "Failed at tf.Const op\nsee current operation: %1 = "
168 "\"tf.Const\"() {value = dense<0> : tensor<4xi32>} : () -> "
169 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
170 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
171 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
172 1);
173 EXPECT_EQ(collected_errors.count(NewConverterErrorData(
174 "MockFailurePass",
175 "Failed at tf.StridedSlice op\nsee current operation: %2 = "
176 "\"tf.StridedSlice\"(%arg0, %1, %1, %0) {begin_mask = 11 : "
177 "i64, device = \"\", ellipsis_mask = 0 : i64, end_mask = 11 : "
178 "i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : "
179 "(tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) "
180 "-> tensor<*xf32>\nError code: ERROR_NEEDS_FLEX_OPS",
181 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.StridedSlice",
182 mlir::FileLineColLoc::get(input_file_id, 4, 10))),
183 1);
184
185 // Check the location information.
186 std::vector<std::string> locations;
187 for (const auto& error : collected_errors) {
188 EXPECT_TRUE(error.has_location());
189 locations.push_back(error.location().DebugString());
190 }
191
192 EXPECT_THAT(locations, Each(testing::HasSubstr("CALLSITELOC")));
193 EXPECT_THAT(locations, Each(testing::HasSubstr(input_file)));
194 EXPECT_THAT(locations, Contains(testing::HasSubstr("line: 2")));
195 EXPECT_THAT(locations, Contains(testing::HasSubstr("column: 9")));
196 EXPECT_THAT(locations, Contains(testing::HasSubstr("line: 4")));
197 EXPECT_THAT(locations, Contains(testing::HasSubstr("column: 10")));
198 }
199 } // namespace
200 } // namespace TFL
201 } // namespace mlir
202