1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h"
16 
17 #include <string>
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "llvm/Support/SourceMgr.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/Dialect.h"  // from @llvm-project
27 #include "mlir/Parser/Parser.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h"
29 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
30 
31 namespace tflite {
32 
CreateRuntimeMetadata()33 std::string CreateRuntimeMetadata() {
34   flatbuffers::FlatBufferBuilder fb_builder;
35 
36   std::vector<flatbuffers::Offset<flatbuffers::String>> device_names = {
37       fb_builder.CreateString("GPU"), fb_builder.CreateString("CPU")};
38 
39   const auto hardwares =
40       CreateHardwareMetadata(fb_builder, fb_builder.CreateVector(device_names));
41   const auto ops = {
42       CreateOpMetadata(fb_builder, 0, 0,
43                        fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
44       CreateOpMetadata(fb_builder, 1, 0,
45                        fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
46       CreateOpMetadata(fb_builder, 2, 0,
47                        fb_builder.CreateVector(std::vector<float>({1.0, 5.0}))),
48       CreateOpMetadata(
49           fb_builder, 3, 1,
50           fb_builder.CreateVector(std::vector<float>({-1.0, 2.0}))),
51   };
52   const auto subgraphs = {CreateSubgraphMetadata(
53       fb_builder, fb_builder.CreateVector(ops.begin(), ops.size()))};
54 
55   const auto metadata = CreateRuntimeMetadata(
56       fb_builder, hardwares,
57       fb_builder.CreateVector(subgraphs.begin(), subgraphs.size()));
58   fb_builder.Finish(metadata);
59 
60   return std::string(
61       reinterpret_cast<const char*>(fb_builder.GetBufferPointer()),
62       fb_builder.GetSize());
63 }
64 
Verify(const RuntimeMetadata * result,const RuntimeMetadata * expected)65 void Verify(const RuntimeMetadata* result, const RuntimeMetadata* expected) {
66   EXPECT_EQ(result->subgraph_metadata()->size(),
67             expected->subgraph_metadata()->size());
68   for (int i = 0; i < result->subgraph_metadata()->size(); ++i) {
69     auto result_subgraph_metadata =
70         result->subgraph_metadata()->GetAs<SubgraphMetadata>(i);
71     auto expected_subgraph_metadata =
72         expected->subgraph_metadata()->GetAs<SubgraphMetadata>(i);
73     if (expected_subgraph_metadata->op_metadata() == nullptr &&
74         result_subgraph_metadata->op_metadata() == nullptr) {
75       return;
76     }
77     ASSERT_EQ(expected_subgraph_metadata->op_metadata()->size(),
78               result_subgraph_metadata->op_metadata()->size());
79     for (int j = 0; j < expected_subgraph_metadata->op_metadata()->size();
80          ++j) {
81       auto result_op_metadata =
82           result_subgraph_metadata->op_metadata()->GetAs<OpMetadata>(j);
83       auto expected_op_metadata =
84           expected_subgraph_metadata->op_metadata()->GetAs<OpMetadata>(j);
85       EXPECT_EQ(result_op_metadata->index(), expected_op_metadata->index());
86       EXPECT_EQ(result_op_metadata->hardware(),
87                 expected_op_metadata->hardware());
88 
89       EXPECT_EQ(result_op_metadata->op_costs()->size(),
90                 expected_op_metadata->op_costs()->size());
91       for (int i = 0; i < result_op_metadata->op_costs()->size(); ++i) {
92         EXPECT_FLOAT_EQ(result_op_metadata->op_costs()->Get(i),
93                         expected_op_metadata->op_costs()->Get(i));
94       }
95     }
96   }
97 }
98 
TEST(ExporterTest,Valid)99 TEST(ExporterTest, Valid) {
100   const std::string kMLIR = R"(
101 func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>) -> tensor<2x1xf32> {
102   %0 = "tfl.add"(%arg0, %arg1) {fused_activation_function = "RELU6",  per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
103   %1 = "tfl.mul"(%0, %arg2) {fused_activation_function = "RELU6", per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
104   %2 = "tfl.add"(%arg0, %arg3) {fused_activation_function = "RELU6", per_device_costs = {CPU = 5.0 : f32, GPU = 1.0 : f32}, tac.device = "GPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
105   %3 = "tfl.pack"(%1, %2) {axis = 0 : i32, per_device_costs = {CPU = 2.0 : f32, GPU = -1.0 : f32}, values_count = 2 : i32, tac.device = "CPU"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
106   func.return %3 : tensor<2x1xf32>
107 })";
108   const std::string kExpectedFB = CreateRuntimeMetadata();
109   mlir::DialectRegistry registry;
110   registry.insert<mlir::TFL::TensorFlowLiteDialect,
111                   mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>();
112   mlir::MLIRContext context(registry);
113   auto module = mlir::OwningOpRef<mlir::ModuleOp>(
114       mlir::parseSourceString<mlir::ModuleOp>(kMLIR, &context));
115   auto module_op = module.get();
116   auto serialized_result_fb = ExportRuntimeMetadata(module_op);
117   const auto* result =
118       GetRuntimeMetadata(serialized_result_fb.getValue().c_str());
119   const auto* expected = GetRuntimeMetadata(kExpectedFB.c_str());
120   ASSERT_TRUE(result != nullptr);
121   ASSERT_TRUE(result->subgraph_metadata() != nullptr);
122   ASSERT_TRUE(expected->subgraph_metadata() != nullptr);
123   Verify(result, expected);
124 }
125 
126 }  // namespace tflite
127