1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/sendrecv_ops.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/platform/test.h"
24 #include "tensorflow/core/platform/test_benchmark.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Declare here, so we don't need a public header.
32 Status QuantizeWeights(const GraphDef& input_graph_def,
33 const TransformFuncContext& context,
34 GraphDef* output_graph_def);
35
36 class QuantizeWeightsTest : public ::testing::Test {
37 protected:
BuildGraphDef(const TensorShape & input_shape,std::initializer_list<float> input_values,const TensorShape & weight_shape,std::initializer_list<float> weight_values,GraphDef * original_graph_def)38 void BuildGraphDef(const TensorShape& input_shape,
39 std::initializer_list<float> input_values,
40 const TensorShape& weight_shape,
41 std::initializer_list<float> weight_values,
42 GraphDef* original_graph_def) {
43 auto root = tensorflow::Scope::DisabledShapeInferenceScope();
44
45 Tensor input_data(DT_FLOAT, input_shape);
46 test::FillValues<float>(&input_data, input_values);
47 Output input_op =
48 ops::Const(root.WithOpName("input_op"), Input::Initializer(input_data));
49
50 Tensor weights_data(DT_FLOAT, weight_shape);
51 test::FillValues<float>(&weights_data, weight_values);
52 Output weights_op = ops::Const(root.WithOpName("weights_op"),
53 Input::Initializer(weights_data));
54
55 Output conv_op = ops::Conv2D(root.WithOpName("output"), input_op,
56 weights_op, {1, 1, 1, 1}, "VALID");
57
58 TF_ASSERT_OK(root.ToGraphDef(original_graph_def));
59 }
60
TestQuantizeWeights()61 void TestQuantizeWeights() {
62 GraphDef original_graph_def;
63 BuildGraphDef({1, 1, 6, 2},
64 {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
65 -5.0f, -3.0f, -6.0f},
66 {1, 2, 2, 10},
67 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
68 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
69 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
70 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
71 &original_graph_def);
72
73 TransformFuncContext context;
74 context.output_names = {"output"};
75 context.params["minimum_size"] = {"16"};
76 GraphDef quantized_graph_def;
77 TF_ASSERT_OK(
78 QuantizeWeights(original_graph_def, context, &quantized_graph_def));
79
80 // Verify the structure of the quantized graph.
81 std::map<string, const NodeDef*> node_lookup;
82 MapNamesToNodes(quantized_graph_def, &node_lookup);
83 EXPECT_EQ(1, node_lookup.count("input_op"));
84 const NodeDef* q_input_op = node_lookup.at("input_op");
85 EXPECT_EQ(DT_FLOAT, q_input_op->attr().at("dtype").type());
86 EXPECT_EQ(1, node_lookup.count("weights_op"));
87 const NodeDef* q_weights_op = node_lookup.at("weights_op");
88 EXPECT_EQ("Dequantize", q_weights_op->op());
89 const string& weights_const_name =
90 NodeNameFromInput(q_weights_op->input(0));
91 EXPECT_EQ(1, node_lookup.count(weights_const_name));
92 const NodeDef* q_weights_const = node_lookup.at(weights_const_name);
93 EXPECT_EQ("Const", q_weights_const->op());
94 EXPECT_EQ(DT_QUINT8, q_weights_const->attr().at("dtype").type());
95
96 // Run the original graph.
97 std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
98 TF_ASSERT_OK(original_session->Create(original_graph_def));
99 std::vector<Tensor> original_outputs;
100 TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
101
102 // Run the quantized graph.
103 std::unique_ptr<Session> quantized_session(NewSession(SessionOptions()));
104 TF_ASSERT_OK(quantized_session->Create(quantized_graph_def));
105 std::vector<Tensor> quantized_outputs;
106 TF_ASSERT_OK(
107 quantized_session->Run({}, {"output"}, {}, &quantized_outputs));
108
109 // Compare the results
110 test::ExpectTensorNear<float>(original_outputs[0], quantized_outputs[0],
111 0.5);
112 }
113 };
114
TEST_F(QuantizeWeightsTest,TestQuantizeWeights)115 TEST_F(QuantizeWeightsTest, TestQuantizeWeights) { TestQuantizeWeights(); }
116
TEST_F(QuantizeWeightsTest,RangesAlwaysIncludeZero)117 TEST_F(QuantizeWeightsTest, RangesAlwaysIncludeZero) {
118 GraphDef original_graph_def;
119 BuildGraphDef({1, 1, 4, 4},
120 {-1.0f, -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f, -1.0f,
121 -4.0f, -2.0f, -5.0f, -1.0f, -4.0f, -2.0f, -5.0f},
122 {1, 2, 2, 10},
123 {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f,
124 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f,
125 0.1f, 0.2f, 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f,
126 0.3f, 0.4f, 1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f},
127 &original_graph_def);
128 TransformFuncContext context;
129 context.output_names = {"output"};
130 context.params["minimum_size"] = {"16"};
131 GraphDef quantized_graph_def;
132 TF_ASSERT_OK(
133 QuantizeWeights(original_graph_def, context, &quantized_graph_def));
134
135 std::map<string, const NodeDef*> node_lookup;
136 MapNamesToNodes(quantized_graph_def, &node_lookup);
137
138 auto expected_tensor = [](float value) {
139 Tensor tensor(DT_FLOAT, TensorShape({}));
140 test::FillValues<float>(&tensor, {value});
141 return tensor;
142 };
143 auto existing_tensor = [&node_lookup](string op) {
144 const NodeDef* node_def = node_lookup.at(op);
145 CHECK(node_def);
146 return GetNodeTensorAttr(*node_def, "value");
147 };
148
149 // The max of input_op is moved from -1.0 to 0.0.
150 test::ExpectTensorNear<float>(
151 expected_tensor(-5.0), existing_tensor("input_op_quantized_min"), 1e-5);
152 test::ExpectTensorNear<float>(
153 expected_tensor(0.0), existing_tensor("input_op_quantized_max"), 1e-5);
154
155 // The min of weights_op is moved from 0.1 to 0.0.
156 test::ExpectTensorNear<float>(
157 expected_tensor(0.0), existing_tensor("weights_op_quantized_min"), 1e-5);
158 test::ExpectTensorNear<float>(
159 expected_tensor(4.0), existing_tensor("weights_op_quantized_max"), 1e-5);
160 }
161
162 } // namespace graph_transforms
163 } // namespace tensorflow
164