1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h"
17
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include <gtest/gtest.h>
23 #include "absl/status/status.h"
24 #include "absl/types/any.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
27 #include "tensorflow/lite/delegates/gpu/common/model.h"
28 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace {
36
AddQuantParams(absl::optional<QuantizationParams> * params,float min,float max,float scale)37 void AddQuantParams(absl::optional<QuantizationParams>* params, float min,
38 float max, float scale) {
39 params->emplace();
40 params->value().min = min;
41 params->value().max = max;
42 params->value().scale = scale;
43 }
44
45 // Scenario:
46 // -> Add ->
47 //
48 // Since there is only one node output with no consumers, no new node should be
49 // added.
TEST(AddQuantAdjustments,OneNode)50 TEST(AddQuantAdjustments, OneNode) {
51 GraphFloat32 graph;
52 auto input = graph.NewValue();
53 input->tensor.shape = BHWC(1, 4, 4, 8);
54 AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
55 /*scale=*/0.004);
56
57 Tensor<Linear, DataType::FLOAT32> add_tensor;
58 add_tensor.shape = Linear(8);
59 add_tensor.data.resize(8);
60 ElementwiseAttributes add_attr;
61 add_attr.param = add_tensor;
62 auto add_node = graph.NewNode();
63 add_node->operation.type = ToString(OperationType::ADD);
64 add_node->operation.attributes = add_attr;
65
66 ASSERT_TRUE(graph.AddConsumer(add_node->id, input->id).ok());
67
68 Value* output = nullptr;
69 AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/2.0,
70 /*scale=*/0.008);
71 ASSERT_TRUE(AddOutput(&graph, add_node, &output).ok());
72 output->tensor.shape = BHWC(1, 4, 4, 8);
73
74 ASSERT_EQ(1, graph.nodes().size());
75 ASSERT_EQ(2, graph.values().size());
76
77 auto transformation = NewAddQuantAdjustments();
78 ModelTransformer transformer(&graph);
79 transformer.Apply("add_quant_adjustments", transformation.get());
80
81 EXPECT_EQ(1, graph.nodes().size());
82 EXPECT_EQ(2, graph.values().size());
83 }
84
85 // Scenario:
86 // -> Add -> QuantizeAndDequantize -> Add ->
87 // | ^
88 // | |
89 // ------------------------------
90 //
91 // A new QuantizeAndDequantize should only be added after the left/first 'Add'
92 // op, and it should connect to both its consumers.
TEST(AddQuantAdjustments,GeneralCase)93 TEST(AddQuantAdjustments, GeneralCase) {
94 GraphFloat32 graph;
95 auto input = graph.NewValue();
96 input->tensor.shape = BHWC(1, 4, 4, 8);
97 AddQuantParams(&input->quant_params, /*min=*/0.0, /*max=*/1.0,
98 /*scale=*/0.004);
99
100 // First Add.
101 Tensor<Linear, DataType::FLOAT32> add_tensor;
102 add_tensor.shape = Linear(8);
103 add_tensor.data.resize(8);
104 ElementwiseAttributes add_attr;
105 add_attr.param = add_tensor;
106 auto add1_node = graph.NewNode();
107 add1_node->operation.type = ToString(OperationType::ADD);
108 add1_node->operation.attributes = add_attr;
109 // QuantizeAndDequantize.
110 QuantizeAndDequantizeAttributes quant_attr;
111 quant_attr.min = -1.0;
112 quant_attr.max = 1.0;
113 quant_attr.scale = 0.008;
114 auto quant_node = graph.NewNode();
115 quant_node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
116 quant_node->operation.attributes = quant_attr;
117 // Second Add.
118 auto add2_node = graph.NewNode();
119 add2_node->operation.type = ToString(OperationType::ADD);
120
121 // Connections.
122 ASSERT_TRUE(graph.AddConsumer(add1_node->id, input->id).ok());
123 Value* link1 = nullptr;
124 ASSERT_TRUE(ConnectTwoNodes(&graph, add1_node, quant_node, &link1).ok());
125 AddQuantParams(&link1->quant_params, /*min=*/0.0, /*max=*/2.0,
126 /*scale=*/0.008);
127 link1->tensor.shape = BHWC(1, 4, 4, 8);
128 ASSERT_TRUE(graph.AddConsumer(add2_node->id, link1->id).ok());
129 Value* link2 = nullptr;
130 ASSERT_TRUE(ConnectTwoNodes(&graph, quant_node, add2_node, &link2).ok());
131 AddQuantParams(&link2->quant_params, /*min=*/-1.0, /*max=*/1.0,
132 /*scale=*/0.008);
133 link2->tensor.shape = BHWC(1, 4, 4, 8);
134 Value* output = nullptr;
135 ASSERT_TRUE(AddOutput(&graph, add2_node, &output).ok());
136 AddQuantParams(&output->quant_params, /*min=*/-1.0, /*max=*/1.0,
137 /*scale=*/0.008);
138 output->tensor.shape = BHWC(1, 4, 4, 8);
139
140 ASSERT_EQ(3, graph.nodes().size());
141 ASSERT_EQ(4, graph.values().size());
142
143 auto transformation = NewAddQuantAdjustments();
144 ModelTransformer transformer(&graph);
145 transformer.Apply("add_quant_adjustments", transformation.get());
146
147 EXPECT_EQ(4, graph.nodes().size());
148 EXPECT_EQ(5, graph.values().size());
149 EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[0]->operation.type);
150 // The new node should be inserted at index 1, just after add1.
151 EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
152 graph.nodes()[1]->operation.type);
153 EXPECT_EQ(ToString(OperationType::QUANTIZE_AND_DEQUANTIZE),
154 graph.nodes()[2]->operation.type);
155 EXPECT_EQ(quant_node->id, graph.nodes()[2]->id);
156 EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[3]->operation.type);
157 auto new_quant_attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
158 graph.nodes()[1]->operation.attributes);
159 EXPECT_EQ(0.0, new_quant_attr.min);
160 EXPECT_EQ(2.0, new_quant_attr.max);
161 const auto& new_quant_consumers = graph.FindConsumers(graph.values()[4]->id);
162 EXPECT_EQ(2, new_quant_consumers.size());
163 EXPECT_EQ(quant_node, new_quant_consumers[0]);
164 EXPECT_EQ(add2_node, new_quant_consumers[1]);
165
166 // Transformation should be idempotent.
167 transformer.Apply("add_quant_adjustments", transformation.get());
168 EXPECT_EQ(4, graph.nodes().size());
169 EXPECT_EQ(5, graph.values().size());
170 }
171
172 } // namespace
173 } // namespace gpu
174 } // namespace tflite
175