1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
18
19 #if GOOGLE_CUDA && GOOGLE_TENSORRT
20
21 #include <algorithm>
22 #include <map>
23 #include <numeric>
24 #include <string>
25 #include <vector>
26
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/strings/str_format.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/cc/framework/scope.h"
32 #include "tensorflow/cc/ops/standard_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
34 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
35 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
36 #include "tensorflow/core/framework/node_def.pb.h" // NOLINT
37 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_testutil.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "third_party/tensorrt/NvInfer.h"
43
44 namespace tensorflow {
45 namespace tensorrt {
46 namespace convert {
47 // Creates a node with the given op, inputs, and attributes.
48 NodeDef MakeNodeDef(const std::string& name, const std::string& op,
49 const std::vector<std::string>& inputs,
50 const std::map<std::string, AttrValue> attrs = {});
51
52 // Creates a constant node with the given name and values arranged in the given
53 // shape.
54 template <typename T>
MakeConstNodeDef(const std::string & name,const std::vector<T> & vals,const TensorShape & shape)55 NodeDef MakeConstNodeDef(const std::string& name, const std::vector<T>& vals,
56 const TensorShape& shape) {
57 Scope s = Scope::NewRootScope();
58 Tensor t = test::AsTensor<T>(vals, shape);
59 auto const_op = ops::Const(s.WithOpName(name), t);
60 return const_op.node()->def();
61 }
62
63 // Creates a constant node with the given name and values, assuming a 1-D shape.
64 template <typename T>
MakeConstNodeDef(const std::string & name,const std::vector<T> & vals)65 NodeDef MakeConstNodeDef(const std::string& name, const std::vector<T>& vals) {
66 TensorShape shape;
67 const std::vector<int32> shape_dims = {static_cast<int32>(vals.size())};
68 TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape));
69 return MakeConstNodeDef(name, vals, shape);
70 }
71
72 // Creates an nvinfer1::Dims struct from the given vector.
73 nvinfer1::Dims CreateDims(const std::vector<int>& d);
74
75 // A gmock matcher that check that elements of a float vector match to a given
76 // tolerance.
77 ::testing::Matcher<std::vector<float>> ArrayFloatNear(
78 const std::vector<float>& values, float max_abs_error = 1e-5,
79 bool nan_sensitive = false);
80
81 // nvinfer1::Dims gMock matchers
82
83 // matches nvinfer1::Dims to initializer list or vector of ints
84 // Example: EXPECT_THAT(my_dims, DimsAreArray({1, 2, 3}))
85 MATCHER_P(DimsAreArrayHelper, array_value,
86 absl::StrFormat("%s [%s]", negation ? "are" : "are not",
87 ::testing::PrintToString(array_value))) {
88 if (arg.nbDims != array_value.size()) return false;
89 for (int i = 0; i < arg.nbDims; ++i) {
90 if (arg.d[i] != array_value[i]) {
91 return false;
92 }
93 }
94 return true;
95 }
96 using DimsAreArray = DimsAreArrayHelperMatcherP<std::vector<int>>;
97
98 // nvinfer1::INetworkDefinition gMock matchers
99
100 // Checks that layer names are equal to initializer list or vector of strings.
101 // Example: EXPECT_THAT(my_network, LayerNamesAreArray({"conv1", "conv2"}))
102 MATCHER_P(LayerNamesAreArrayHelper, array_value,
103 absl::StrFormat("layer names %s [%s]", negation ? "are" : "are not",
104 ::testing::PrintToString(array_value))) {
105 if (array_value.size() != arg->getNbLayers()) return false;
106 for (int i = 0; i < arg->getNbLayers(); ++i) {
107 if (arg->getLayer(i)->getName() == nullptr) {
108 return false;
109 }
110 }
111 return true;
112 }
113 using LayerNamesAreArray =
114 LayerNamesAreArrayHelperMatcherP<std::vector<std::string>>;
115
116 // Checks layer names are all non-empty.
117 MATCHER(LayerNamesNonEmpty, "") {
118 for (int i = 0; i < arg->getNbLayers(); ++i) {
119 if (arg->getLayer(i)->getName() == nullptr) {
120 return false;
121 }
122 }
123 return true;
124 }
125
126 // TRT_ShapedWeights gMock matchers.
127
128 // Checks that the weight dimensions are values are equal to the given values.
129 // Example: EXPECT_THAT(my_weights,
130 // ShapedWeightsHasDimsAndValues({1, 2},{1.0f, 2.0f}))
131 MATCHER_P2(ShapedWeightsHasDimsAndValuesHelper, dims_vec, expected_values, "") {
132 DimsAdapter dims(dims_vec);
133 if (arg.Shape() != dims) {
134 return false;
135 }
136 if (arg.count() != expected_values.size()) {
137 return false;
138 }
139 using T = typename decltype(expected_values)::value_type;
140 const T* actual_values = arg.template GetPointer<T>();
141 for (int i = 0; i < expected_values.size(); ++i) {
142 if (expected_values[i] != actual_values[i]) {
143 return false;
144 }
145 }
146 return true;
147 }
148
149 template <typename T>
150 using ShapedWeightsHasDimsAndValues =
151 ShapedWeightsHasDimsAndValuesHelperMatcherP2<std::vector<int>,
152 std::vector<T>>;
153
154 // std::vector convenience utilities.
155
156 // Creates a new vector by casting all values of the given InCType vector to
157 // OutCType.
158 template <typename InCType, typename OutCType>
CastVector(const gtl::ArraySlice<InCType> & vals)159 std::vector<OutCType> CastVector(
160 const gtl::ArraySlice<InCType>& vals) { // non-absl ok
161 std::vector<OutCType> res(vals.size());
162 std::transform(vals.begin(), vals.end(), res.begin(),
163 [](const InCType in_val) -> OutCType {
164 return static_cast<OutCType>(in_val);
165 });
166 return res;
167 }
168
169 // Creates a new vector of the given size and fills it with an increasing
170 // sequence starting from the given start_value using std::iota.
171 template <typename CType>
172 std::vector<CType> CreateVectorIota(int size, CType start_value = CType(0)) {
173 std::vector<CType> res(size);
174 std::iota(res.begin(), res.end(), start_value);
175 return res;
176 }
177
178 } // namespace convert
179 } // namespace tensorrt
180 } // namespace tensorflow
181
182 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
183 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
184