xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
16 
17 #include <cstddef>
18 #include <set>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include <gmock/gmock.h>
24 #include <gtest/gtest.h>
25 #include "llvm/Support/MemoryBuffer.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/Parser/Parser.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassManager.h"  // from @llvm-project
33 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/lite/metrics/types_util.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/core/platform/resource_loader.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39 
40 namespace mlir {
41 namespace TFL {
42 namespace {
43 using stream_executor::port::StatusOr;
44 
45 // MockSuccessPass reports errors but doesn't fail.
46 class MockSuccessPass
47     : public PassWrapper<MockSuccessPass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const48   void getDependentDialects(DialectRegistry& registry) const override {
49     registry.insert<TF::TensorFlowDialect>();
50   }
51 
52  public:
53   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockSuccessPass)
54 
MockSuccessPass()55   explicit MockSuccessPass() {}
56 
57  private:
runOnOperation()58   void runOnOperation() override {
59     getOperation().walk([](Operation* nestedOp) {
60       nestedOp->emitError()
61           << "Error at " << nestedOp->getName().getStringRef().str() << " op";
62     });
63   };
64 };
65 
66 // MockFailurePass reports errors and fails.
67 class MockFailurePass
68     : public PassWrapper<MockFailurePass, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const69   void getDependentDialects(DialectRegistry& registry) const override {
70     registry.insert<TF::TensorFlowDialect>();
71   }
72 
73  public:
74   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MockFailurePass)
75 
MockFailurePass()76   explicit MockFailurePass() {}
77 
78  private:
runOnOperation()79   void runOnOperation() override {
80     getOperation().walk([](Operation* nestedOp) {
81       if (nestedOp->getName().getStringRef().str().rfind("tf.") != -1) {
82         AttachErrorCode(
83             nestedOp->emitError()
84                 << "Failed at " << nestedOp->getName().getStringRef().str()
85                 << " op",
86             tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
87       }
88     });
89     signalPassFailure();
90   };
91 };
92 
LoadModule(MLIRContext * context,const std::string & file_name)93 StatusOr<OwningOpRef<mlir::ModuleOp>> LoadModule(MLIRContext* context,
94                                                  const std::string& file_name) {
95   std::string error_message;
96   auto file = openInputFile(file_name, &error_message);
97   if (!file) {
98     return tensorflow::errors::InvalidArgument("fail to open input file");
99   }
100 
101   llvm::SourceMgr source_mgr;
102   source_mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
103   return OwningOpRef<mlir::ModuleOp>(
104       parseSourceFile<mlir::ModuleOp>(source_mgr, context));
105 }
106 
TEST(ErrorCollectorTest,TessSuccessPass)107 TEST(ErrorCollectorTest, TessSuccessPass) {
108   std::string input_file = tensorflow::GetDataDependencyFilepath(
109       "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir");
110   MLIRContext context;
111   context.getOrLoadDialect<mlir::func::FuncDialect>();
112   context.allowUnregisteredDialects();
113   context.enableMultithreading();
114 
115   auto module = LoadModule(&context, input_file);
116   EXPECT_EQ(module.ok(), true);
117 
118   PassManager pm(&context, OpPassManager::Nesting::Implicit);
119   pm.addPass(std::make_unique<MockSuccessPass>());
120 
121   pm.addInstrumentation(
122       std::make_unique<ErrorCollectorInstrumentation>(&context));
123   EXPECT_EQ(succeeded(pm.run(module.ValueOrDie().get())), true);
124 
125   auto collected_errors =
126       ErrorCollector::GetErrorCollector()->CollectedErrors();
127   EXPECT_EQ(collected_errors.size(), 0);
128 }
129 
TEST(ErrorCollectorTest,TessFailurePass)130 TEST(ErrorCollectorTest, TessFailurePass) {
131   using tflite::metrics::ConverterErrorData;
132   MLIRContext context;
133   context.getOrLoadDialect<mlir::func::FuncDialect>();
134   const std::string input_file =
135       "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir";
136   auto input_file_id = StringAttr::get(&context, input_file);
137 
138   context.allowUnregisteredDialects();
139   context.enableMultithreading();
140 
141   auto module =
142       LoadModule(&context, tensorflow::GetDataDependencyFilepath(input_file));
143   EXPECT_EQ(module.ok(), true);
144 
145   PassManager pm(&context, OpPassManager::Nesting::Implicit);
146   pm.addPass(std::make_unique<MockSuccessPass>());
147   pm.addPass(std::make_unique<MockFailurePass>());
148 
149   pm.addInstrumentation(
150       std::make_unique<ErrorCollectorInstrumentation>(&context));
151   EXPECT_EQ(succeeded(pm.run(module.ValueOrDie().get())), false);
152 
153   auto collected_errors =
154       ErrorCollector::GetErrorCollector()->CollectedErrors();
155 
156   EXPECT_EQ(collected_errors.size(), 3);
157   EXPECT_EQ(collected_errors.count(NewConverterErrorData(
158                 "MockFailurePass",
159                 "Failed at tf.Const op\nsee current operation: %0 = "
160                 "\"tf.Const\"() {value = dense<1> : tensor<4xi32>} : () -> "
161                 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
162                 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
163                 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
164             1);
165   EXPECT_EQ(collected_errors.count(NewConverterErrorData(
166                 "MockFailurePass",
167                 "Failed at tf.Const op\nsee current operation: %1 = "
168                 "\"tf.Const\"() {value = dense<0> : tensor<4xi32>} : () -> "
169                 "tensor<4xi32>\nError code: ERROR_NEEDS_FLEX_OPS",
170                 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.Const",
171                 mlir::FileLineColLoc::get(input_file_id, 2, 9))),
172             1);
173   EXPECT_EQ(collected_errors.count(NewConverterErrorData(
174                 "MockFailurePass",
175                 "Failed at tf.StridedSlice op\nsee current operation: %2 = "
176                 "\"tf.StridedSlice\"(%arg0, %1, %1, %0) {begin_mask = 11 : "
177                 "i64, device = \"\", ellipsis_mask = 0 : i64, end_mask = 11 : "
178                 "i64, new_axis_mask = 4 : i64, shrink_axis_mask = 0 : i64} : "
179                 "(tensor<*xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) "
180                 "-> tensor<*xf32>\nError code: ERROR_NEEDS_FLEX_OPS",
181                 ConverterErrorData::ERROR_NEEDS_FLEX_OPS, "tf.StridedSlice",
182                 mlir::FileLineColLoc::get(input_file_id, 4, 10))),
183             1);
184 
185   // Check the location information.
186   std::vector<std::string> locations;
187   for (const auto& error : collected_errors) {
188     EXPECT_TRUE(error.has_location());
189     locations.push_back(error.location().DebugString());
190   }
191 
192   EXPECT_THAT(locations, Each(testing::HasSubstr("CALLSITELOC")));
193   EXPECT_THAT(locations, Each(testing::HasSubstr(input_file)));
194   EXPECT_THAT(locations, Contains(testing::HasSubstr("line: 2")));
195   EXPECT_THAT(locations, Contains(testing::HasSubstr("column: 9")));
196   EXPECT_THAT(locations, Contains(testing::HasSubstr("line: 4")));
197   EXPECT_THAT(locations, Contains(testing::HasSubstr("column: 10")));
198 }
199 }  // namespace
200 }  // namespace TFL
201 }  // namespace mlir
202