xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/quantize_weights_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Declare here, so we don't need a public header.
32 Status QuantizeWeights(const GraphDef& input_graph_def,
33                        const TransformFuncContext& context,
34                        GraphDef* output_graph_def);
35 
36 class QuantizeWeightsTest : public ::testing::Test {
37  protected:
BuildGraphDef(const TensorShape & input_shape,std::initializer_list<float> input_values,const TensorShape & weight_shape,std::initializer_list<float> weight_values,GraphDef * original_graph_def)38   void BuildGraphDef(const TensorShape& input_shape,
39                      std::initializer_list<float> input_values,
40                      const TensorShape& weight_shape,
41                      std::initializer_list<float> weight_values,
42                      GraphDef* original_graph_def) {
43     auto root = tensorflow::Scope::DisabledShapeInferenceScope();
44 
45     Tensor input_data(DT_FLOAT, input_shape);
46     test::FillValues<float>(&input_data, input_values);
47     Output input_op =
48         ops::Const(root.WithOpName("input_op"), Input::Initializer(input_data));
49 
50     Tensor weights_data(DT_FLOAT, weight_shape);
51     test::FillValues<float>(&weights_data, weight_values);
52     Output weights_op = ops::Const(root.WithOpName("weights_op"),
53                                    Input::Initializer(weights_data));
54 
55     Output conv_op = ops::Conv2D(root.WithOpName("output"), input_op,
56                                  weights_op, {1, 1, 1, 1}, "VALID");
57 
58     TF_ASSERT_OK(root.ToGraphDef(original_graph_def));
59   }
60 
TestQuantizeWeights()61   void TestQuantizeWeights() {
62     GraphDef original_graph_def;
63     BuildGraphDef({1, 1, 6, 2},
64                   {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
65                    -5.0f, -3.0f, -6.0f},
66                   {1, 2, 2, 10},
67                   {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
68                    3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
69                    0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
70                    0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
71                   &original_graph_def);
72 
73     TransformFuncContext context;
74     context.output_names = {"output"};
75     context.params["minimum_size"] = {"16"};
76     GraphDef quantized_graph_def;
77     TF_ASSERT_OK(
78         QuantizeWeights(original_graph_def, context, &quantized_graph_def));
79 
80     // Verify the structure of the quantized graph.
81     std::map<string, const NodeDef*> node_lookup;
82     MapNamesToNodes(quantized_graph_def, &node_lookup);
83     EXPECT_EQ(1, node_lookup.count("input_op"));
84     const NodeDef* q_input_op = node_lookup.at("input_op");
85     EXPECT_EQ(DT_FLOAT, q_input_op->attr().at("dtype").type());
86     EXPECT_EQ(1, node_lookup.count("weights_op"));
87     const NodeDef* q_weights_op = node_lookup.at("weights_op");
88     EXPECT_EQ("Dequantize", q_weights_op->op());
89     const string& weights_const_name =
90         NodeNameFromInput(q_weights_op->input(0));
91     EXPECT_EQ(1, node_lookup.count(weights_const_name));
92     const NodeDef* q_weights_const = node_lookup.at(weights_const_name);
93     EXPECT_EQ("Const", q_weights_const->op());
94     EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type());
95 
96     // Run the original graph.
97     std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
98     TF_ASSERT_OK(original_session->Create(original_graph_def));
99     std::vector<Tensor> original_outputs;
100     TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
101 
102     // Run the quantized graph.
103     std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
104     TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
105     std::vector<Tensor> quantized_outputs;
106     TF_ASSERT_OK(
107         quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
108 
109     // Compare the results
110     test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
111                                   0.5);
112   }
113 };
114 
TEST_F(QuantizeWeightsTest,TestQuantizeWeights)115 TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); }
116 
TEST_F(QuantizeWeightsTest,RangesAlwaysIncludeZero)117 TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) {
118   GraphDef original_graph_def;
119   BuildGraphDef({1, 1, 4, 4},
120                 {-1.0f, -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f, -1.0f,
121                  -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f},
122                 {1, 2, 2, 10},
123                 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
124                  3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
125                  0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
126                  0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
127                 &original_graph_def);
128   TransformFuncContext context;
129   context.output_names = {"output"};
130   context.params["minimum_size"] = {"16"};
131   GraphDef quantized_graph_def;
132   TF_ASSERT_OK(
133       QuantizeWeights(original_graph_def, context, &quantized_graph_def));
134 
135   std::map<string, const NodeDef*> node_lookup;
136   MapNamesToNodes(quantized_graph_def, &node_lookup);
137 
138   auto expected_tensor = [](float value) {
139     Tensor tensor(DT_FLOAT, TensorShape({}));
140     test::FillValues<float>(&tensor, {value});
141     return tensor;
142   };
143   auto existing_tensor = [&node_lookup](string op) {
144     const NodeDef* node_def = node_lookup.at(op);
145     CHECK(node_def);
146     return GetNodeTensorAttr(*node_def, "value");
147   };
148 
149   // The max of input_op is moved from -1.0 to 0.0.
150   test::ExpectTensorNear<float>(
151       expected_tensor(-5.0), existing_tensor("input_op_quantized_min"), 1e-5);
152   test::ExpectTensorNear<float>(
153       expected_tensor(0.0), existing_tensor("input_op_quantized_max"), 1e-5);
154 
155   // The min of weights_op is moved from 0.1 to 0.0.
156   test::ExpectTensorNear<float>(
157       expected_tensor(0.0), existing_tensor("weights_op_quantized_min"), 1e-5);
158   test::ExpectTensorNear<float>(
159       expected_tensor(4.0), existing_tensor("weights_op_quantized_max"), 1e-5);
160 }
161 
162 }  // namespace graph_transforms
163 }  // namespace tensorflow
164