xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
18 
19 #if GOOGLE_CUDA && GOOGLE_TENSORRT
20 
21 #include <algorithm>
22 #include <map>
23 #include <numeric>
24 #include <string>
25 #include <vector>
26 
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "absl/strings/str_format.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/cc/framework/scope.h"
32 #include "tensorflow/cc/ops/standard_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
34 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
35 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
36 #include "tensorflow/core/framework/node_def.pb.h"  // NOLINT
37 #include "tensorflow/core/framework/tensor.pb.h"    // NOLINT
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/tensor_testutil.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/status_test_util.h"
42 #include "third_party/tensorrt/NvInfer.h"
43 
44 namespace tensorflow {
45 namespace tensorrt {
46 namespace convert {
47 // Creates a node with the given op, inputs, and attributes.
48 NodeDef MakeNodeDef(const std::string& name, const std::string& op,
49                     const std::vector<std::string>& inputs,
50                     const std::map<std::string, AttrValue> attrs = {});
51 
52 // Creates a constant node with the given name and values arranged in the given
53 // shape.
54 template <typename T>
MakeConstNodeDef(const std::string & name,const std::vector<T> & vals,const TensorShape & shape)55 NodeDef MakeConstNodeDef(const std::string& name, const std::vector<T>& vals,
56                          const TensorShape& shape) {
57   Scope s = Scope::NewRootScope();
58   Tensor t = test::AsTensor<T>(vals, shape);
59   auto const_op = ops::Const(s.WithOpName(name), t);
60   return const_op.node()->def();
61 }
62 
63 // Creates a constant node with the given name and values, assuming a 1-D shape.
64 template <typename T>
MakeConstNodeDef(const std::string & name,const std::vector<T> & vals)65 NodeDef MakeConstNodeDef(const std::string& name, const std::vector<T>& vals) {
66   TensorShape shape;
67   const std::vector<int32> shape_dims = {static_cast<int32>(vals.size())};
68   TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape));
69   return MakeConstNodeDef(name, vals, shape);
70 }
71 
72 // Creates an nvinfer1::Dims struct from the given vector.
73 nvinfer1::Dims CreateDims(const std::vector<int>& d);
74 
75 // A gmock matcher that check that elements of a float vector match to a given
76 // tolerance.
77 ::testing::Matcher<std::vector<float>> ArrayFloatNear(
78     const std::vector<float>& values, float max_abs_error = 1e-5,
79     bool nan_sensitive = false);
80 
81 // nvinfer1::Dims gMock matchers
82 
83 // matches nvinfer1::Dims to initializer list or vector of ints
84 // Example: EXPECT_THAT(my_dims, DimsAreArray({1, 2, 3}))
85 MATCHER_P(DimsAreArrayHelper, array_value,
86           absl::StrFormat("%s [%s]", negation ? "are" : "are not",
87                           ::testing::PrintToString(array_value))) {
88   if (arg.nbDims != array_value.size()) return false;
89   for (int i = 0; i < arg.nbDims; ++i) {
90     if (arg.d[i] != array_value[i]) {
91       return false;
92     }
93   }
94   return true;
95 }
96 using DimsAreArray = DimsAreArrayHelperMatcherP<std::vector<int>>;
97 
98 // nvinfer1::INetworkDefinition gMock matchers
99 
100 // Checks that layer names are equal to initializer list or vector of strings.
101 // Example: EXPECT_THAT(my_network, LayerNamesAreArray({"conv1", "conv2"}))
102 MATCHER_P(LayerNamesAreArrayHelper, array_value,
103           absl::StrFormat("layer names %s [%s]", negation ? "are" : "are not",
104                           ::testing::PrintToString(array_value))) {
105   if (array_value.size() != arg->getNbLayers()) return false;
106   for (int i = 0; i < arg->getNbLayers(); ++i) {
107     if (arg->getLayer(i)->getName() == nullptr) {
108       return false;
109     }
110   }
111   return true;
112 }
113 using LayerNamesAreArray =
114     LayerNamesAreArrayHelperMatcherP<std::vector<std::string>>;
115 
116 // Checks layer names are all non-empty.
117 MATCHER(LayerNamesNonEmpty, "") {
118   for (int i = 0; i < arg->getNbLayers(); ++i) {
119     if (arg->getLayer(i)->getName() == nullptr) {
120       return false;
121     }
122   }
123   return true;
124 }
125 
126 // TRT_ShapedWeights gMock matchers.
127 
128 // Checks that the weight dimensions are values are equal to the given values.
129 // Example: EXPECT_THAT(my_weights,
130 //                      ShapedWeightsHasDimsAndValues({1, 2},{1.0f, 2.0f}))
131 MATCHER_P2(ShapedWeightsHasDimsAndValuesHelper, dims_vec, expected_values, "") {
132   DimsAdapter dims(dims_vec);
133   if (arg.Shape() != dims) {
134     return false;
135   }
136   if (arg.count() != expected_values.size()) {
137     return false;
138   }
139   using T = typename decltype(expected_values)::value_type;
140   const T* actual_values = arg.template GetPointer<T>();
141   for (int i = 0; i < expected_values.size(); ++i) {
142     if (expected_values[i] != actual_values[i]) {
143       return false;
144     }
145   }
146   return true;
147 }
148 
149 template <typename T>
150 using ShapedWeightsHasDimsAndValues =
151     ShapedWeightsHasDimsAndValuesHelperMatcherP2<std::vector<int>,
152                                                  std::vector<T>>;
153 
154 // std::vector convenience utilities.
155 
156 // Creates a new vector by casting all values of the given InCType vector to
157 // OutCType.
158 template <typename InCType, typename OutCType>
CastVector(const gtl::ArraySlice<InCType> & vals)159 std::vector<OutCType> CastVector(
160     const gtl::ArraySlice<InCType>& vals) {  // non-absl ok
161   std::vector<OutCType> res(vals.size());
162   std::transform(vals.begin(), vals.end(), res.begin(),
163                  [](const InCType in_val) -> OutCType {
164                    return static_cast<OutCType>(in_val);
165                  });
166   return res;
167 }
168 
169 // Creates a new vector of the given size and fills it with an increasing
170 // sequence starting from the given start_value using std::iota.
171 template <typename CType>
172 std::vector<CType> CreateVectorIota(int size, CType start_value = CType(0)) {
173   std::vector<CType> res(size);
174   std::iota(res.begin(), res.end(), start_value);
175   return res;
176 }
177 
178 }  // namespace convert
179 }  // namespace tensorrt
180 }  // namespace tensorflow
181 
182 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
183 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TEST_UTILS_H_
184