xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
18 
19 #include <set>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/types/optional.h"
26 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h"
27 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
28 #include "tensorflow/compiler/tf2tensorrt/convert/weights.h"
29 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
30 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
31 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
32 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
33 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/grappler/costs/graph_properties.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39 
40 #if GOOGLE_CUDA && GOOGLE_TENSORRT
41 #include "third_party/tensorrt/NvInfer.h"
42 
43 namespace tensorflow {
44 namespace tensorrt {
45 
46 namespace convert {
47 using ::stream_executor::port::StatusOr;
48 
49 #define TFTRT_INTERNAL_ERROR_AT_NODE(node)                           \
50   do {                                                               \
51     return errors::Internal("TFTRT::", __FUNCTION__, ":", __LINE__,  \
52                             " failed to add TRT layer, at: ", node); \
53   } while (0)
54 
55 #define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \
56   do {                                           \
57     if (ptr == nullptr) {                        \
58       TFTRT_INTERNAL_ERROR_AT_NODE(node);        \
59     }                                            \
60   } while (0)
61 
62 struct EngineConnection {
63   // Constructs a non-control edge.
EngineConnectionEngineConnection64   EngineConnection(const string& outside, int out_id, int out_port,
65                    const string& inside, int in_id, int in_port,
66                    bool input_edge, int port)
67       : outside_node_name(outside),
68         outside_id(out_id),
69         outside_port(out_port),
70         inside_node_name(inside),
71         inside_id(in_id),
72         inside_port(in_port),
73         is_input_edge(input_edge),
74         port_number(port) {}
75 
76   // Constructs a control edge.
EngineConnectionEngineConnection77   EngineConnection(const string& outside, int out_id, const string& inside,
78                    int in_id, bool input_edge)
79       : outside_node_name(outside),
80         outside_id(out_id),
81         outside_port(Graph::kControlSlot),
82         inside_node_name(inside),
83         inside_id(in_id),
84         inside_port(Graph::kControlSlot),
85         is_input_edge(input_edge),
86         port_number(Graph::kControlSlot) {}
87 
is_control_edgeEngineConnection88   bool is_control_edge() const { return port_number == Graph::kControlSlot; }
89 
90   const string outside_node_name;
91   const int outside_id;
92   const int outside_port;
93   PartialTensorShape outside_shape;  // Only set for input edge.
94 
95   const string inside_node_name;
96   const int inside_id;
97   const int inside_port;
98   PartialTensorShape inside_shape;  // Only set for output edge.
99 
100   DataType connection_type;
101   const bool is_input_edge;
102 
103   // The port number of the TRT node connected with this edge.
104   const int port_number;
105 };
106 
107 struct EngineInfo {
EngineInfoEngineInfo108   EngineInfo()
109       : engine_type(EngineType::TRTStatic),
110         max_workspace_size_bytes(0),
111         max_batch_size(std::nullopt),
112         maximum_cached_engines(0),
113         precision_mode(TrtPrecisionMode::FP32),
114         use_calibration(true),
115 
116         allow_build_at_runtime(true),
117         use_explicit_precision(false) {}
118 
119   string engine_name;
120   string device;
121   GraphDef segment_graph_def;
122 
123   // Non-control input connections inside this vector are sorted in a way such
124   // that, the segment nodes connecting to them are topological sorted.
125   // In addition, for non-control connections, there must be no duplicates.
126   std::vector<EngineConnection> connections;
127 
128   enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
129   EngineType engine_type;
130   int64 max_workspace_size_bytes;
131   std::optional<int> max_batch_size;
132   int maximum_cached_engines;
133   TrtPrecisionMode precision_mode;
134   bool use_calibration;
135   bool allow_build_at_runtime;
136   bool use_explicit_precision;
137 };
138 
139 // Constructs a graphdef from the segment in the given graph and stores it to
140 // the engine_info. Adds _Arg nodes for input edges (InputPH_*) and _Retval
141 // nodes for output edges (OutputPH_*). Maintains the topological order of the
142 // non-input/output nodes in the graphdef. This function needs to be called
143 // before TensorRT layers are created because it prepares the original graph
144 // for TensorRT conversion.
145 //
146 // - subgraph_node_names: the node names of the subgraph.
147 // - subgraph_node_ids: the node ids of the subgraph, must be sorted in
148 //   topological order.
149 // - engine_info: a data structure that records the information about the
150 //   engine containing the subgraph.
151 //
152 // TODO(aaroey): add tests to validate these properties.
153 Status ConvertSegmentToGraphDef(
154     const Graph* graph, const grappler::GraphProperties& graph_properties,
155     const std::vector<const Node*>& subgraph_nodes, EngineInfo* engine_info);
156 
157 // Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
158 // 'builder' successfully build the engine. If the result is not ok, 'engine'
159 // will be set to nullptr
160 // Once returned, 'builder' is not needed any more and can be safely destroyed.
161 //
162 // - convert_successfully: indicates whether the conversion to TensorRT network
163 //   is successful. This is different than successfully building the engine:
164 //   building can still fail afterwards.
165 // Note: When 'cluster' is not null, it contains the graph to be converted.
166 //       We may perform additional optimizations to the graph before converting
167 //       the graph.
168 Status ConvertGraphDefToEngine(
169     const GraphDef& gdef, OpKernelContext* ctx, TrtPrecisionMode precision_mode,
170     int max_batch_size, size_t max_workspace_size_bytes,
171     const std::vector<PartialTensorShape>& input_shapes,
172     nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator,
173     TRTInt8Calibrator* calibrator,
174     TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
175     const bool use_implicit_batch, bool* convert_successfully,
176     TrtShapeOptimizationProfile* profiles, absl::string_view engine_name,
177     bool use_explicit_precision,
178     tensorflow::grappler::Cluster* cluster = nullptr);
179 
180 // Helper class for the segmenter to determine whether an output edge from the
181 // TRT segment is valid.
182 class OutputEdgeValidator {
183  public:
184   // Return true if the specified edge is eligible to be an output edge of the
185   // TRT segment.
186   bool operator()(const Edge* out_edge) const;
187 };
188 
189 // Class to verify if specific TF node is supported by TRT.
190 class TrtNodeValidator {
191  public:
192   // 'graph_properties' is the GraphProperties of the graph whose nodes will be
193   // checked by IsTensorRTCandidate() later. It is used to get the shape and
194   // data type information of a tensor for validation purpose.
195   TrtNodeValidator(const grappler::GraphProperties& graph_properties,
196                    TrtPrecisionMode precision_mode, bool use_calibration,
197                    bool use_implicit_batch, bool use_explicit_precision);
198 
199   // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
200   // to TRT subgraph and later converted into TRT engine.
201   Status IsTensorRTCandidate(const Node* node);
202 
203   static const std::set<string>* quantize_ops;
204 
205   // Returns validator by op type. If no validator is registered for
206   // specific op, it means no validation is needed and ValidateNode() will
207   // return OK.
208   StatusOr<OpConverter> GetValidator(const std::string& op);
209 
210  private:
211   // Convert a Const node to a TRT_TensorOrWeights.
212   Status ConvertConstToWeights(const NodeDef& const_node_def,
213                                const std::vector<TRT_TensorOrWeights>& inputs,
214                                TRT_TensorOrWeights* output);
215 
216   // Convert a VariableV2 node to a TRT_TensorOrWeights.
217   Status ConvertVariableToWeights(
218       const NodeDef& const_node_def,
219       const std::vector<TRT_TensorOrWeights>& inputs,
220       TRT_TensorOrWeights* output);
221 
222   // Convert the output tensor at 'output_port' of 'node_def' to a
223   // TRT_TensorOrWeights which will be later used as an input to other nodes and
224   // passed to ValidateNode() below.
225   Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port,
226                                   TRT_TensorOrWeights* tensor_or_weights);
227 
228   // Store the weights added during validation. Some validations (e.g.
229   // validation for Const node) may produce weights.
230   TrtWeightStore weight_store_;
231 
232   // GraphProperties of the graph whose nodes are to be validated by
233   // IsTensorRTCandidate().
234   const grappler::GraphProperties& graph_properties_;
235 
236   // Quantization ops are only converted when using quantized precisions.
237   const TrtPrecisionMode precision_mode_;
238 
239   const bool use_calibration_;
240 
241   const bool use_implicit_batch_;
242 
243   const bool use_explicit_precision_;
244 
245   friend class ValidatorTest;
246   friend class OpConverterTest;
247 };
248 
249 // Class to convert TF nodes to TRT network.
250 class Converter {
251  public:
252   // Used for Converter::RenameAndMarkOutputTensors()
253   struct EngineOutputInfo {
254     // The TRT tensor name which produces the output.
255     string source_tensor_name;
256     // The TensorFlow node name which is receiving the output from the TRT
257     // engine. This should always be the Identity node created in
258     // ConvertSegmentToGraphDef.
259     string dest_node_name;
260     // Output type. TensorRT requires this to be explicitly set for engine
261     // outputs.
262     nvinfer1::DataType trt_dtype;
263   };
264 
265   static StatusOr<std::unique_ptr<Converter>> Create(
266       TrtPrecisionMode precision_mode, bool use_calibration,
267       nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
268       absl::string_view engine_name, bool use_explicit_precision = false,
269       OpKernelContext* ctx = nullptr);
270 
271   //////////////////////////////////////////////////////////////////////////////
272   // Methods used by the TRT engine builder to build a TRT network from a TF
273   // function/subgraph.
274 
275   // Convert the node to TRT network.
276   Status ConvertNode(const NodeDef& node_def);
277 
278   // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
279   // 'batch_size'.
280   Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
281                         const nvinfer1::Dims& dims, int batch_size);
282 
283   // Store the ResourceHandle as a TRT_TensorOrWeights object. This can be
284   // later used as input to other nodes.
285   Status AddInputResource(const string& name, const ResourceHandle& resource);
286 
287   // Mark the tensors with names specified by source_tensor_name as output of
288   // the TRT network, and set their names in the TRT network as dest_node_name.
289   Status RenameAndMarkOutputTensors(
290       const std::vector<EngineOutputInfo>& output_tensors);
291 
292   // Build a TRT engine using the created network.
293   Status BuildCudaEngine(TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
294                          int max_batch_size, size_t max_workspace_size_bytes,
295                          nvinfer1::IGpuAllocator* allocator,
296                          TRTInt8Calibrator* calibrator,
297                          TrtShapeOptimizationProfile* profiles);
298 
299   //////////////////////////////////////////////////////////////////////////////
300   // Methods used by op converters to convert individual TF node and add layers
301   // to the TRT network.
302 
303   // Op converters (e.g. ConvertReshape) need to access the TRT network in order
304   // to add TRT layers.
network()305   nvinfer1::INetworkDefinition* network() { return trt_network_.get(); }
306 
307   // What precision are we targeting?
precision_mode()308   TrtPrecisionMode precision_mode() const { return precision_mode_; }
309 
310   // Variable converters need the context to read variable values.
context()311   OpKernelContext* context() { return ctx_; }
312 
313   // Calibration will be or was previously performed on this network?
use_calibration()314   bool use_calibration() const { return use_calibration_; }
315 
316   // Whether implicit batch mode is enabled
use_implicit_batch()317   bool use_implicit_batch() const { return use_implicit_batch_; }
318 
319   // This function should be called when we know the quantization range of a
320   // tensor from a quantize/dequantize node.
321   void ProvideQuantizationRange(ITensorProxyPtr* tensor, float min_range,
322                                 float max_range);
323 
324   // Should be called when full TRT network has been constructed and before
325   // building the engine.
326   void MaybeApplyQuantizationRanges();
327 
328   // Below are helper methods for op converters to add different layers to the
329   // TRT network.
330 
331   // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to
332   // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch
333   // dimension which should always be 0. If this is for adding a transpose layer
334   // to support the conversion of 'node_def', callers need to provide a
335   // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer
336   // name conflicts.
337   Status TransposeTensor(ITensorProxyPtr input_tensor,
338                          const std::vector<int>& order_with_batch_dim,
339                          ITensorProxyPtr* output_tensor,
340                          const NodeDef& node_def,
341                          absl::string_view sub_op_name = "");
342 
343   // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1,
344   // and/or permuting the dimensions. The new shape is derived from the shape of
345   // the input tensor according to the slices and size_for_added_dims arguments.
346   //
347   // If there would be at most one unknown dimension, we could set the new shape
348   // using IShuffleLayer::setReshapeDimensions, which treats -1 as a special
349   // value (the same way as TF). In general, we can have more than one unknown
350   // dimensions, and we have to manipulate the shape tensors during runtime to
351   // define the new shape. This helper function defines the necessary shape
352   // inference layers and calls reshape using the calculated new shape.
353   //
354   // Example:
355   //
356   // Assume that we want to reshape a tensor from shape {A,B,C,D} to {C,D,A,B}
357   // (no transpose, just change the shape). In dynamic shape mode, the A,B,C,D
358   // values are not necessarily known at conversion time, they can be all -1. We
359   // can only define the new shape at runtime, when the actual shape is already
360   // known. To define the new shape:
361   // - We use an IShapeLayer to retrieve a shape tensor with the {A,B,C,D}
362   //   values.
363   // - Create two slices {C,D} and {A,B} of the shape tensor.
364   // - Concatenate these slices {C,D,A,B},
365   // - Set the {C,D,A,B} shape tensor as an input shape tensor for
366   // IShuffleLayer.
367   //
368   // This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
369   // params).
370   //
371   // Before each slice we can insert new dims if the corresponding
372   // size_for_added_dims element is not negative. The size_for_added_dims array
373   // can have more than slices.size() elements, in order to insert a dimension
374   // after the last slice. For example, to add two leading 1 dimensions, and
375   // three trailing 1 dimensions, call DynamicReshape(input, {{0,nbDims}},
376   // {2, 3}).
377   //
378   // Parameters:
379   // input - input tensor
380   // slices - [start, end) pairs of slices
381   // params - conversion parameters
382   // output - reshaped tensor
383   // size_for_added_dims - size of dimension inserted right before slice[i]. We
384   //   only insert a new dim if size_for_added_dims[i] >= 0.
385   Status DynamicReshape(ITensorProxyPtr input,
386                         std::vector<std::pair<int, int>> slices,
387                         OpConverterParams* params, ITensorProxyPtr* output,
388                         std::vector<int> size_for_added_dims = {},
389                         std::optional<int> op_instance = std::nullopt);
390 
391   // Inserts a singleton dimension at axis for a dynamic shape tensor.
392   Status DynamicExpandDims(ITensorProxyPtr input, const nvinfer1::Dims& dims,
393                            int axis, OpConverterParams* params,
394                            ITensorProxyPtr* output,
395                            std::optional<int> op_instance = std::nullopt);
396 
397   // Helper function to add a squeeze op to the network.
398   //
399   // The input_dims argument stores the TRT dimensions of the input tensor,
400   // where the dimensions to be squeezed are replaced by 0.
401   Status SqueezeTensor(ITensorProxyPtr input, std::vector<int>* input_dims,
402                        OpConverterParams* params, ITensorProxyPtr* output,
403                        std::optional<int> op_instance = std::nullopt);
404 
405   // Creates an IConstantLayer using 'weights' whose dimensions are specified by
406   // 'dims', and returns the output ITensor.
407   ITensorProxyPtr CreateConstantLayer(const TRT_ShapedWeights& weights,
408                                       const nvinfer1::Dims& dims);
409 
410   // Gets the min and max value in a TRT_ShapedWeights
411   Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
412                         float* out_max) const;
413 
414   // Constructs a name and passed it to the TensorRT layer to support xprof.
415   void SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def,
416                     absl::string_view sub_op_name = "",
417                     std::optional<int> sub_op_instance = std::nullopt,
418                     std::optional<std::string> origin_node_name = std::nullopt);
419 
420   void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name,
421                     absl::string_view sub_op_name,
422                     std::optional<int> sub_op_instance = std::nullopt);
423 
TensorsMap()424   std::unordered_map<string, TRT_TensorOrWeights>& TensorsMap() {
425     return trt_tensors_;
426   }
427 
UseExplicitPrecision()428   bool UseExplicitPrecision() const { return use_explicit_precision_; }
429 
430  private:
431   Converter(TrtPrecisionMode precision_mode, bool use_calibration,
432             nvinfer1::ILogger* trt_logger, const bool use_implicit_batch,
433             absl::string_view engine_name, bool use_explicit_precision,
434             OpKernelContext* ctx);
435 
436   Status Init(nvinfer1::ILogger* trt_logger);
437 
438   // Verify the provided batch_size is consistent with batch_size_ and update it
439   // if necessary.
440   Status MaybeUpdateBatchSize(int batch_size);
441 
442   // Add the provided tensor/weights to the map trt_tensors_.
443   Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input);
444 
445   // Get the tensor/weights from trt_tensors_ by 'name'.
446   Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output);
447 
448   // Get the inputs of 'node_def' from trt_tensors_.
449   Status GetInputs(const NodeDef& node_def,
450                    std::vector<TRT_TensorOrWeights>* inputs) const;
451 
452   // Tensors/weights added during construction of trt_network_.
453   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
454 
455   // The TRT builder used to create the network and build the engine. Not owned.
456   TrtUniquePtrType<nvinfer1::IBuilder> trt_builder_;
457 
458   // The TRT network being built.
459   TrtUniquePtrType<nvinfer1::INetworkDefinition> trt_network_;
460 
461   // Store the weights added during construction of trt_network_.
462   TrtWeightStore weight_store_;
463 
464   // Store the context.
465   OpKernelContext* ctx_;
466 
467   // During conversion, this table is populated with quantization ranges per
468   // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
469   // quantization ranges. Since TRT only supports symmetric ranges, we will
470   // store the range as a single float = max(abs(min_range), abs(max_range)).
471   // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
472   // = 6.0f for Relu6.
473   std::unordered_map<ITensorProxyPtr*, float> quantization_ranges_proxy_;
474   std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
475 
476   const TrtPrecisionMode precision_mode_;
477 
478   const bool use_calibration_;
479 
480   // If this is false, all dimensions including the batch dimension are
481   // set explicitely.
482   const bool use_implicit_batch_;
483 
484   // Batch size of inputs to trt_network_ added by AddInputTensor(). During
485   // network construction it will update this, use it to verify the batch
486   // size of all inputs are compatible, and make sure individual TF node is
487   // acceptable by TRT.
488   int batch_size_ = -1;
489 
490   // Assign a ID to each constant layer we create, so that we can assign a
491   // unique name to the layer.
492   int next_constant_layer_id_ = 0;
493 
494   // The name of the TRTEngineOp node.
495   absl::string_view engine_name_;
496 
497   // Indicates whether to use explicit precision in TensorRT (Q/DQ support).
498   bool use_explicit_precision_;
499 
500   friend class ConverterTest;
501   friend class OpConverterTest;
502 };
503 
504 // Converts a TensorFlow tensor to TRT shaped weights.
505 Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store,
506                             TRT_ShapedWeights* weights);
507 
508 // Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims'
509 // (which doesn't contain the batch dimension).
510 //
511 // If validation_only is true, it doesn't do the conversion but only do some
512 // minimum validation for the eligibility of the conversion, and *tensor will
513 // be set to nullptr.
514 // If validation_only is false converter must not be nullptr.
515 Status PrepareTensorForShape(
516     Converter* converter, const TRT_TensorOrWeights& input,
517     const DimsAdapter& dims, const bool validation_only,
518     ITensorProxyPtr* tensor, const NodeDef& node_def,
519     std::optional<int> op_instance = std::nullopt,
520     std::optional<std::string> origin_node_name = std::nullopt);
521 
522 // Return OK if the broadcast scheme is supported and compute the shapes after
523 // broadcasting. check_feasibility can be set to false in cases where dimensions
524 // do not need to match exactly (as in the case of BatchMatMulV2).
525 Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
526                             const TRT_TensorOrWeights& operand_r,
527                             const bool check_feasibility,
528                             const bool use_implicit_batch,
529                             nvinfer1::Dims* operand_l_new_dims,
530                             nvinfer1::Dims* operand_r_new_dims);
531 
532 template <typename T>
533 using OperationMap = std::unordered_map<std::string, T>;
534 
535 // Map from Tensorflow operation names to TensorRT unary operations.
536 using UnaryOperationMapType = OperationMap<nvinfer1::UnaryOperation>;
537 const UnaryOperationMapType* UnaryOperationMap();
538 
539 // Map from Tensorflow boolean operation names to TensorRT unary operations.
540 const UnaryOperationMapType* UnaryBooleanOperationMap();
541 
542 // Map of all supported ActivationTypes.
543 const OperationMap<nvinfer1::ActivationType>* ActivationTypeMap();
544 
545 // Map from Tensorflow binary operation names to TensorRT binary operations
546 // types.
547 using BinaryOperationMapType = OperationMap<nvinfer1::ElementWiseOperation>;
548 const BinaryOperationMapType* BinaryOperationMap();
549 
550 // Map from Tensorflow boolean binary operation names to TensorRT binary
551 // operations types.
552 const BinaryOperationMapType* BinaryBooleanOperationMap();
553 
554 template <typename T>
GetOperationNames(const T & set)555 absl::InlinedVector<std::string, 10> GetOperationNames(const T& set) {
556   absl::InlinedVector<std::string, 10> result;
557   absl::c_transform(set, std::back_inserter(result),
558                     [](const auto x) { return x.first; });
559   return result;
560 }
561 
562 // Adds a matrix multiplication operation to the TensorRT graph. The "params"
563 // pointer is only used to access the TRT network builder. The inputs and
564 // parameters for the op are fully specified by input_[a|b] and transpose_[a|b].
565 StatusOr<ITensorProxyPtr> ConvertMatMulImpl(OpConverterParams* params,
566                                             TRT_TensorOrWeights input_a,
567                                             TRT_TensorOrWeights input_b,
568                                             bool transpose_a, bool transpose_b);
569 
570 std::string convert_range_error_msg(float start, float limit, float delta);
571 std::string convert_range_expected_msg(const NodeDef& node_def);
572 
find_name(const string & name,const std::vector<string> names)573 inline bool find_name(const string& name, const std::vector<string> names) {
574   return std::find(names.begin(), names.end(), name) != names.end();
575 }
576 
577 }  // namespace convert
578 }  // namespace tensorrt
579 }  // namespace tensorflow
580 
581 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
582 
583 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_
584