1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h"
16
17 #include <string>
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "llvm/Support/SourceMgr.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/Dialect.h" // from @llvm-project
27 #include "mlir/Parser/Parser.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h"
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30
31 namespace tflite {
32
CreateRuntimeMetadata()33 std::string CreateRuntimeMetadata() {
34 flatbuffers::FlatBufferBuilder fb_builder;
35
36 std::vector<flatbuffers::Offset<flatbuffers::String>> device_names = {
37 fb_builder.CreateString("GPU"), fb_builder.CreateString("CPU")};
38
39 const auto hardwares =
40 CreateHardwareMetadata(fb_builder, fb_builder.CreateVector(device_names));
41 const auto ops = {
42 CreateOpMetadata(fb_builder, 0, 0,
43 fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
44 CreateOpMetadata(fb_builder, 1, 0,
45 fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
46 CreateOpMetadata(fb_builder, 2, 0,
47 fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
48 CreateOpMetadata(
49 fb_builder, 3, 1,
50 fb_builder.CreateVector(std::vector<float>({-1.0, 2.0}))),
51 };
52 const auto subgraphs = {CreateSubgraphMetadata(
53 fb_builder, fb_builder.CreateVector(ops.begin(), ops.size()))};
54
55 const auto metadata = CreateRuntimeMetadata(
56 fb_builder, hardwares,
57 fb_builder.CreateVector(subgraphs.begin(), subgraphs.size()));
58 fb_builder.Finish(metadata);
59
60 return std::string(
61 reinterpret_cast<const char*>(fb_builder.GetBufferPointer()),
62 fb_builder.GetSize());
63 }
64
Verify(const RuntimeMetadata * result,const RuntimeMetadata * expected)65 void Verify(const RuntimeMetadata* result, const RuntimeMetadata* expected) {
66 EXPECT_EQ(result->subgraph_metadata()->size(),
67 expected->subgraph_metadata()->size());
68 for (int i = 0; i < result->subgraph_metadata()->size(); ++i) {
69 auto result_subgraph_metadata =
70 result->subgraph_metadata()->GetAs<SubgraphMetadata>(i);
71 auto expected_subgraph_metadata =
72 expected->subgraph_metadata()->GetAs<SubgraphMetadata>(i);
73 if (expected_subgraph_metadata->op_metadata() == nullptr &&
74 result_subgraph_metadata->op_metadata() == nullptr) {
75 return;
76 }
77 ASSERT_EQ(expected_subgraph_metadata->op_metadata()->size(),
78 result_subgraph_metadata->op_metadata()->size());
79 for (int j = 0; j < expected_subgraph_metadata->op_metadata()->size();
80 ++j) {
81 auto result_op_metadata =
82 result_subgraph_metadata->op_metadata()->GetAs<OpMetadata>(j);
83 auto expected_op_metadata =
84 expected_subgraph_metadata->op_metadata()->GetAs<OpMetadata>(j);
85 EXPECT_EQ(result_op_metadata->index(), expected_op_metadata->index());
86 EXPECT_EQ(result_op_metadata->hardware(),
87 expected_op_metadata->hardware());
88
89 EXPECT_EQ(result_op_metadata->op_costs()->size(),
90 expected_op_metadata->op_costs()->size());
91 for (int i = 0; i < result_op_metadata->op_costs()->size(); ++i) {
92 EXPECT_FLOAT_EQ(result_op_metadata->op_costs()->Get(i),
93 expected_op_metadata->op_costs()->Get(i));
94 }
95 }
96 }
97 }
98
TEST(ExporterTest,Valid)99 TEST(ExporterTest, Valid) {
100 const std::string kMLIR = R"(
101 func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>) -> tensor<2x1xf32> {
102 %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6", per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
103 %1 = "tfl.mul"(%0, %arg2) {fused_activation_function = "RELU6", per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
104 %2 = "tfl.add"(%arg0, %arg3) {fused_activation_function = "RELU6", per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
105 %3 = "tfl.pack"(%1, %2) {axis = 0 : i32, per_device_costs = {CPU = 2.0 : f32, GPU = -1.0 : f32}, values_count = 2 : i32, tac.device = "CPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
106 func.return %3 : tensor<2x1xf32>
107 })";
108 const std::string kExpectedFB = CreateRuntimeMetadata();
109 mlir::DialectRegistry registry;
110 registry.insert<mlir::TFL::TensorFlowLiteDialect,
111 mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>();
112 mlir::MLIRContext context(registry);
113 auto module = mlir::OwningOpRef<mlir::ModuleOp>(
114 mlir::parseSourceString<mlir::ModuleOp>(kMLIR, &context));
115 auto module_op = module.get();
116 auto serialized_result_fb = ExportRuntimeMetadata(module_op);
117 const auto* result =
118 GetRuntimeMetadata(serialized_result_fb.getValue().c_str());
119 const auto* expected = GetRuntimeMetadata(kExpectedFB.c_str());
120 ASSERT_TRUE(result != nullptr);
121 ASSERT_TRUE(result->subgraph_metadata() != nullptr);
122 ASSERT_TRUE(expected->subgraph_metadata() != nullptr);
123 Verify(result, expected);
124 }
125
126 } // namespace tflite
127