1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "absl/status/status.h"
24 #include "absl/types/any.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
27 #include "tensorflow/lite/delegates/gpu/common/model.h"
28 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace {
36 
AddQuantParams(absl::optional<QuantizationParams> * params,float min,float max,float scale)37 void AddQuantParams(absl::optional<QuantizationParams>* params, float min,
38                     float max, float scale) {
39   params->emplace();
40   params->value().min = min;
41   params->value().max = max;
42   params->value().scale = scale;
43 }
44 
45 // Scenario:
46 // -> Add ->
47 //
48 // Since there is only one node output with no consumers, no new node should be
49 // added.
TEST(AddQuantAdjustments,OneNode)50 TEST(AddQuantAdjustments, OneNode) {
51   GraphFloat32 graph;
52   auto input = graph.NewValue();
53   input->tensor.shape = BHWC(1, 4, 4, 8);
54   AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
55                  /*scale=*/0.004);
56 
57   Tensor<Linear, DataType::FLOAT32> add_tensor;
58   add_tensor.shape = Linear(8);
59   add_tensor.data.resize(8);
60   ElementwiseAttributes add_attr;
61   add_attr.param = add_tensor;
62   auto add_node = graph.NewNode();
63   add_node->operation.type = ToString(OperationType::ADD);
64   add_node->operation.attributes = add_attr;
65 
66   ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
67 
68   Value* output = nullptr;
69   AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0,
70                  /*scale=*/0.008);
71   ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
72   output->tensor.shape = BHWC(1, 4, 4, 8);
73 
74   ASSERT_EQ(1, graph.nodes().size());
75   ASSERT_EQ(2, graph.values().size());
76 
77   auto transformation = NewAddQuantAdjustments();
78   ModelTransformer transformer(&graph);
79   transformer.Apply("add_quant_adjustments", transformation.get());
80 
81   EXPECT_EQ(1, graph.nodes().size());
82   EXPECT_EQ(2, graph.values().size());
83 }
84 
85 // Scenario:
86 // -> Add -> QuantizeAndDequantize -> Add ->
87 //        |                            ^
88 //        |                            |
89 //        ------------------------------
90 //
91 // A new QuantizeAndDequantize should only be added after the left/first 'Add'
92 // op, and it should connect to both its consumers.
TEST(AddQuantAdjustments,GeneralCase)93 TEST(AddQuantAdjustments, GeneralCase) {
94   GraphFloat32 graph;
95   auto input = graph.NewValue();
96   input->tensor.shape = BHWC(1, 4, 4, 8);
97   AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
98                  /*scale=*/0.004);
99 
100   // First Add.
101   Tensor<Linear, DataType::FLOAT32> add_tensor;
102   add_tensor.shape = Linear(8);
103   add_tensor.data.resize(8);
104   ElementwiseAttributes add_attr;
105   add_attr.param = add_tensor;
106   auto add1_node = graph.NewNode();
107   add1_node->operation.type = ToString(OperationType::ADD);
108   add1_node->operation.attributes = add_attr;
109   // QuantizeAndDequantize.
110   QuantizeAndDequantizeAttributes quant_attr;
111   quant_attr.min = -1.0;
112   quant_attr.max = 1.0;
113   quant_attr.scale = 0.008;
114   auto quant_node = graph.NewNode();
115   quant_node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
116   quant_node->operation.attributes = quant_attr;
117   // Second Add.
118   auto add2_node = graph.NewNode();
119   add2_node->operation.type = ToString(OperationType::ADD);
120 
121   // Connections.
122   ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok());
123   Value* link1 = nullptr;
124   ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok());
125   AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0,
126                  /*scale=*/0.008);
127   link1->tensor.shape = BHWC(1, 4, 4, 8);
128   ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok());
129   Value* link2 = nullptr;
130   ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok());
131   AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0,
132                  /*scale=*/0.008);
133   link2->tensor.shape = BHWC(1, 4, 4, 8);
134   Value* output = nullptr;
135   ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok());
136   AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0,
137                  /*scale=*/0.008);
138   output->tensor.shape = BHWC(1, 4, 4, 8);
139 
140   ASSERT_EQ(3, graph.nodes().size());
141   ASSERT_EQ(4, graph.values().size());
142 
143   auto transformation = NewAddQuantAdjustments();
144   ModelTransformer transformer(&graph);
145   transformer.Apply("add_quant_adjustments", transformation.get());
146 
147   EXPECT_EQ(4, graph.nodes().size());
148   EXPECT_EQ(5, graph.values().size());
149   EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[0]->operation.type);
150   // The new node should be inserted at index 1, just after add1.
151   EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
152             graph.nodes()[1]->operation.type);
153   EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
154             graph.nodes()[2]->operation.type);
155   EXPECT_EQ(quant_node->id, graph.nodes()[2]->id);
156   EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[3]->operation.type);
157   auto new_quant_attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
158       graph.nodes()[1]->operation.attributes);
159   EXPECT_EQ(0.0, new_quant_attr.min);
160   EXPECT_EQ(2.0, new_quant_attr.max);
161   const auto& new_quant_consumers = graph.FindConsumers(graph.values()[4]->id);
162   EXPECT_EQ(2, new_quant_consumers.size());
163   EXPECT_EQ(quant_node, new_quant_consumers[0]);
164   EXPECT_EQ(add2_node, new_quant_consumers[1]);
165 
166   // Transformation should be idempotent.
167   transformer.Apply("add_quant_adjustments", transformation.get());
168   EXPECT_EQ(4, graph.nodes().size());
169   EXPECT_EQ(5, graph.values().size());
170 }
171 
172 }  // namespace
173 }  // namespace gpu
174 }  // namespace tflite
175