xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17 
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
32 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
34 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
36 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
37 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/graph_to_functiondef.h"
41 #include "tensorflow/core/framework/node_def_builder.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/graph.h"
44 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
45 #include "tensorflow/core/grappler/costs/graph_properties.h"
46 #include "tensorflow/core/grappler/devices.h"
47 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
48 #include "tensorflow/core/grappler/utils.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/cleanup.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/protobuf/config.pb.h"  // NOLINT
54 #include "tensorflow/core/protobuf/device_properties.pb.h"  // NOLINT
55 #include "tensorflow/core/protobuf/rewriter_config.pb.h"  // NOLINT
56 #include "tensorflow/core/util/device_name_utils.h"
57 #include "tensorflow/tools/graph_transforms/transform_utils.h"
58 
59 #if GOOGLE_CUDA && GOOGLE_TENSORRT
60 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
61 #include "third_party/tensorrt/NvInfer.h"
62 namespace tensorflow {
63 namespace tensorrt {
64 namespace convert {
65 
66 using absl::StrAppend;
67 using absl::StrCat;
68 using ::tensorflow::tensorrt::segment::ClusterProperty;
69 using ::tensorflow::tensorrt::segment::NodePtrCompare;
70 using ::tensorflow::tensorrt::segment::Segment;
71 
72 namespace {
73 
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)74 Status BuildNodeMap(const Graph& graph,
75                     std::unordered_map<string, Node*>* node_map) {
76   for (auto* node : graph.op_nodes()) {
77     if (!node_map->insert({node->name(), node}).second) {
78       return errors::AlreadyExists("Node name is not unique in graph: " +
79                                    node->name());
80     }
81   }
82   return Status::OK();
83 }
84 
GetEngineType(const TRTOptimizationPass::ConversionParams & params)85 EngineInfo::EngineType GetEngineType(
86     const TRTOptimizationPass::ConversionParams& params) {
87   return (params.is_dynamic_op || params.use_calibration)
88              ? EngineInfo::EngineType::TRTDynamic
89              : EngineInfo::EngineType::TRTStatic;
90 }
91 
92 // Returns true when use_implicit_batch is false or when we are building dynamic
93 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const TRTOptimizationPass::ConversionParams & params)94 bool AllowDynamicNonBatchDimension(
95     const TRTOptimizationPass::ConversionParams& params) {
96   return !params.use_implicit_batch ||
97          GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
98 }
99 
100 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anonbe1b44070111::EdgePtrCompare101   bool operator()(const Edge* lhs, const Edge* rhs) const {
102     return lhs->id() < rhs->id();
103   }
104 };
105 
106 // TODO(laigd): instead of deciding the device here, the converter should accept
107 // a device name as one of the conversion parameter so users can control on
108 // which device they want to run the conversion.
GetFirstValidDeviceId()109 std::pair<TfDeviceId, PlatformDeviceId> GetFirstValidDeviceId() {
110   for (int tf_device_id_value = 0; tf_device_id_value < 100;
111        ++tf_device_id_value) {
112     TfDeviceId tf_device_id(tf_device_id_value);
113     PlatformDeviceId platform_device_id;
114     Status s =
115         GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id);
116     if (s.ok()) {
117       VLOG(1) << "Found TF GPU " << tf_device_id.value() << " at cuda device "
118               << platform_device_id.value();
119       return std::make_pair(tf_device_id, platform_device_id);
120     }
121   }
122   LOG(ERROR) << "Could not find any TF GPUs";
123   return std::make_pair(TfDeviceId(-1), PlatformDeviceId(-1));
124 }
125 
126 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)127 bool ShallKeepControlEdgeFrom(const Node* input_node) {
128   if (!input_node) {
129     LOG(ERROR) << "Node pointer is null, this should not happen";
130     return false;
131   }
132   return input_node->type_string() != "Const";
133 }
134 
135 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)136 Status GetEngineInfo(const Graph* g,
137                      const grappler::GraphProperties& graph_properties,
138                      const Segment& segment,
139                      const std::vector<Node*>& reverse_topo_order,
140                      EngineInfo* info) {
141   std::vector<const Node*> subgraph_nodes;  // Topologically sorted nodes.
142   std::set<const Node*> added_const_nodes;  // Used to prevent double insertion.
143 
144   const ClusterProperty& segment_property = segment.property;
145   const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
146 
147   // The device assignment accumulated from the compatible device assignments
148   // for the nodes in the segment.
149   const DeviceNameUtils::ParsedName segment_device =
150       segment_property.DeviceName();
151   info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
152 
153   // Map from src_node_name+port to the unique port numbers of the TRT op, where
154   // the src_node_name is the name of the source node of the input/output
155   // edge, thus there must not be any duplicates since source nodes of
156   // input/output edges must be in different split of the graph.
157   // TODO(aaroey): consider using node id and port instead.
158   // TODO(aaroey): using topo order instead of reverting reverse topo order.
159   std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
160   for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
161        ++it) {
162     const Node* node = *it;
163     if (segment_nodes.count(node) == 0) continue;
164     subgraph_nodes.push_back(node);
165 
166     const int node_id = node->id();
167     const string& node_name = node->name();
168 
169     // Create input connections. Sort edges first to make deterministic since
170     // in_edges is a set of pointers.
171     std::vector<const Edge*> in_edges(node->in_edges().begin(),
172                                       node->in_edges().end());
173     std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
174     for (const auto edge : in_edges) {
175       auto input_node = edge->src();
176       if (input_node->IsSource() || segment_nodes.count(input_node)) {
177         continue;
178       }
179       if (edge->IsControlEdge()) {
180         if (ShallKeepControlEdgeFrom(input_node)) {
181           // Non-Const control input.
182           info->connections.emplace_back(input_node->name(), input_node->id(),
183                                          node_name, node_id,
184                                          /*input_edge=*/true);
185         }
186       } else if (input_node->type_string() == "Const") {
187         // Add constant data input nodes into the segment graphdef (thus also in
188         // the engine). We don't care if it has other output edges going into
189         // other engines or TF nodes. Since we add it only to the segment
190         // graphdef, not the segment itself, it won't be removed from the graph.
191         // If it doesn't have any edges, TF will prune it out.
192         //
193         // Note that the segmenter already ensure that the constant data input
194         // is valid and supported by the engine.
195         if (!added_const_nodes.insert(input_node).second) {
196           // Already added before.
197           continue;
198         }
199         VLOG(1) << "Adding const node " << input_node->name();
200       } else {
201         // Non-const data input.
202         int port = Graph::kControlSlot - 1;
203         // Use the source non-segment node name/port as key.
204         const string s = StrCat(input_node->name(), ":", edge->src_output());
205         VLOG(1) << "Input edge = " << s;
206         if (input_to_engine_port.count(s)) {
207           port = input_to_engine_port.at(s);
208         } else {
209           port = input_to_engine_port.size();
210           input_to_engine_port.insert({s, port});
211         }
212         info->connections.emplace_back(
213             input_node->name(), input_node->id(), edge->src_output(), node_name,
214             node_id, edge->dst_input(), /*input_edge=*/true, port);
215       }
216     }
217     // Create output connections. Sort edges first to make deterministic since
218     // out_edges is a set of pointers.
219     std::vector<const Edge*> out_edges(node->out_edges().begin(),
220                                        node->out_edges().end());
221     std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
222     for (const auto edge : out_edges) {
223       auto output_node = edge->dst();
224       if (output_node->IsSink() || segment_nodes.count(output_node)) {
225         continue;
226       }
227       if (edge->IsControlEdge()) {
228         // Control output.
229         if (ShallKeepControlEdgeFrom(node)) {
230           info->connections.emplace_back(output_node->name(), output_node->id(),
231                                          node_name, node_id,
232                                          /*input_edge=*/false);
233         }
234       } else {
235         // Data output.
236         int port = Graph::kControlSlot - 1;
237         // Use the source segment node name/port as key.
238         const string s = StrCat(node_name, ":", edge->src_output());
239         VLOG(1) << "Output edge = " << s;
240         if (output_to_engine_port.count(s)) {
241           port = output_to_engine_port.at(s);
242         } else {
243           port = output_to_engine_port.size();
244           output_to_engine_port.insert({s, port});
245         }
246         info->connections.emplace_back(
247             output_node->name(), output_node->id(), edge->dst_input(),
248             node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
249       }
250     }
251   }  // For each segment node in topological order.
252 
253   // Construct the const nodes first.
254   subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
255                         added_const_nodes.end());
256   TF_RETURN_IF_ERROR(
257       ConvertSegmentToGraphDef(g, graph_properties, subgraph_nodes, info));
258   VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
259           << "' to a GraphDef";
260   if (segment_device.has_type) {
261     // If the accumulated device assignment for the segment has a device type,
262     // the segmenter guarantees the device type is GPU. Use the device
263     // assignment in this case.
264     if (segment_device.type != "GPU") {
265       return errors::Internal(
266           "segment device is not GPU: ",
267           DeviceNameUtils::ParsedNameToString(segment_device));
268     }
269     info->device = DeviceNameUtils::ParsedNameToString(segment_device);
270   } else {
271     TfDeviceId tf_device_id;
272     PlatformDeviceId platform_device_id;
273     std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
274     if (tf_device_id.value() >= 0) {
275       DeviceNameUtils::ParsedName parsed_name;
276       parsed_name.type = "GPU";
277       parsed_name.has_type = true;
278       parsed_name.id = tf_device_id.value();
279       parsed_name.has_id = true;
280       info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
281     } else {
282       VLOG(1) << "No device is assigned to the segment. A device will be "
283                  "assigned during graph execution (inference).";
284     }
285   }
286   return Status::OK();
287 }
288 
289 // Helper function to update edge connection from the removed node to the
290 // engine node. If an outside node is gone, it must have been absorbed into
291 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)292 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
293                         const size_t my_engine_id,
294                         const std::vector<Node*>& engine_nodes,
295                         const bool is_input_edge, const string& node_name,
296                         Node** node, int* port) {
297   for (size_t t = 0; t < infos.size(); ++t) {
298     if (t == my_engine_id) {
299       continue;
300     }
301     const auto& info = infos.at(t);
302     for (const auto& eng_conn : info.connections) {
303       // If the connection being updated is an input connection, the source of
304       // the connection must be an output connection of another engine. And vise
305       // versa.
306       if (is_input_edge == eng_conn.is_input_edge) continue;
307       if (eng_conn.inside_node_name == node_name &&
308           eng_conn.inside_port == *port) {
309         *node = CHECK_NOTNULL(engine_nodes[t]);
310         QCHECK_EQ(info.engine_name, (**node).name())
311             << "Engine name mismatch: " << info.engine_name << " vs "
312             << (**node).name();
313         *port = eng_conn.port_number;
314         return;
315       }
316     }
317   }
318   LOG(FATAL) << "Node " << node_name << " not found in any engine.";
319 }
320 
ComputeTRTNodeIOShape(std::vector<PartialTensorShape> & partial_tensorshape_vect,std::vector<tensorflow::TensorShapeProto> & shape_proto_vect,const PartialTensorShape & conn_shape,int port_number)321 tensorflow::TensorShapeProto ComputeTRTNodeIOShape(
322     std::vector<PartialTensorShape>& partial_tensorshape_vect,
323     std::vector<tensorflow::TensorShapeProto>& shape_proto_vect,
324     const PartialTensorShape& conn_shape, int port_number) {
325   tensorflow::TensorShapeProto tmp_shape_proto;
326   conn_shape.AsProto(&tmp_shape_proto);
327 
328   if (partial_tensorshape_vect.size() <= port_number) {
329     shape_proto_vect.resize(port_number + 1);
330     partial_tensorshape_vect.resize(port_number + 1);
331   }
332 
333   return tmp_shape_proto;
334 }
335 
336 // Function to insert a TRT engine node into the graph.
337 // Create engine nodes in the following way:
338 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
339 // 2. When an engine node is created, add it into the graph with necessary
340 //    re-wiring.
341 //    2.1. If the outside connected node is existing, connect the engine
342 //         node to it.
343 //    2.2. If the outside connected node is gone, it must have been absorted
344 //         into another engine node (which was processed before the processing
345 //         one). Connect to the pre-existing engine node instead.
346 // 3. In this way, we ensure the graph is topologically sort-able after each
347 //    invocation of CreateTRTNode().
CreateTRTNode(const TRTOptimizationPass::ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes,grappler::Cluster * cluster)348 Status CreateTRTNode(const TRTOptimizationPass::ConversionParams& params,
349                      const std::vector<EngineInfo>& infos, int pos,
350                      int default_max_batch_size, Graph* graph,
351                      std::vector<Node*>* engine_nodes,
352                      grappler::Cluster* cluster) {
353   const auto& info = infos.at(pos);
354   std::vector<tensorflow::TensorShapeProto> input_shape_protos;
355   std::vector<tensorflow::TensorShapeProto> output_shape_protos;
356   std::vector<PartialTensorShape> input_shapes;
357   std::vector<PartialTensorShape> output_shapes;
358   std::vector<NodeDefBuilder::NodeOut> inputs;
359   std::vector<Node*> input_nodes;
360   std::vector<Node*> control_input_nodes;
361   std::unordered_set<string> control_input_names;
362   std::vector<DataType> out_types;
363 
364   VLOG(1) << "Processing " << info.engine_name;
365   // Collect needed info for creating the engine node in the graph
366   for (const auto& conn : info.connections) {
367     // Control edges
368     if (conn.is_control_edge()) {
369       // Skip control outputs for now. control output info are not needed for
370       // node creation and will be processed later.
371       if (!conn.is_input_edge) continue;
372 
373       // Rewrire control input if it's not found in original graph.
374       Node* input_node = graph->FindNodeId(conn.outside_id);
375       int port = Graph::kControlSlot;
376       if (!input_node) {
377         UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
378                            conn.outside_node_name, &input_node, &port);
379         QCHECK_EQ(Graph::kControlSlot, port);
380       }
381       if (!control_input_names.insert(input_node->name()).second) {
382         continue;
383       }
384       control_input_nodes.push_back(input_node);
385       VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
386               << info.engine_name;
387     } else {
388       // Data edges
389       if (!conn.is_input_edge) {
390         // Set the shapes and data types of the output edge.
391         tensorflow::TensorShapeProto out_shape = ComputeTRTNodeIOShape(
392             /*partial_tensorshape_vect=*/output_shapes,
393             /*shape_proto_vect=*/output_shape_protos,
394             /*conn_shape=*/conn.inside_shape,
395             /*port_number=*/conn.port_number);
396 
397         output_shape_protos.at(conn.port_number) = out_shape;
398         output_shapes.at(conn.port_number) = conn.inside_shape;
399 
400         if (out_types.size() <= conn.port_number) {
401           out_types.resize(conn.port_number + 1);
402         }
403         out_types.at(conn.port_number) = conn.connection_type;
404         VLOG(2) << "Collected output shape "
405                 << output_shape_protos.at(conn.port_number).DebugString();
406       } else {
407         // Set the shapes of the input edge.
408         tensorflow::TensorShapeProto in_shape = ComputeTRTNodeIOShape(
409             /*partial_tensorshape_vect=*/input_shapes,
410             /*shape_proto_vect=*/input_shape_protos,
411             /*conn_shape=*/conn.outside_shape,
412             /*port_number=*/conn.port_number);
413 
414         input_shape_protos.at(conn.port_number) = in_shape;
415         input_shapes.at(conn.port_number) = conn.outside_shape;
416 
417         // Shape must be fully defined (excluding batch dimension) for static
418         // mode.
419         if (params.use_implicit_batch &&
420             info.engine_type == EngineInfo::EngineType::TRTStatic) {
421           for (int i = 1; i < conn.outside_shape.dims(); i++) {
422             if (conn.outside_shape.dim_size(i) <= 0) {
423               return errors::Internal(
424                   "Not fully defined input shape when in static mode which "
425                   "should have been excluded by the segmenter. ");
426             }
427           }
428         }
429 
430         // Rewrire data input if it's not found in original graph.
431         Node* input_node = graph->FindNodeId(conn.outside_id);
432         int port = conn.outside_port;
433         if (!input_node) {
434           UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
435                              conn.outside_node_name, &input_node, &port);
436         }
437         if (std::find_if(
438                 std::begin(inputs), std::end(inputs),
439                 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
440                   return inp.node == input_node->name() && inp.index == port;
441                 }) == std::end(inputs)) {
442           inputs.emplace_back(input_node->name(), port, conn.connection_type);
443           input_nodes.push_back(CHECK_NOTNULL(input_node));
444           VLOG(1) << "Engine Input " << input_node->name() << ":" << port
445                   << " -> " << info.engine_name << ":" << inputs.size() - 1;
446         }
447       }
448     }
449   }
450   // We don't support segments with no inputs. Fall back to native TF here to
451   // avoid crash later. Constant folding should've folded the ops that make up
452   // these segments.
453   if (inputs.empty()) {
454     return errors::Internal(
455         "Segment has no inputs (possible constfold failure)");
456   }
457 
458   // Build the engine and get its serialized representation.
459   string segment_string;
460 
461   int max_batch_size = info.max_batch_size.has_value()
462                            ? info.max_batch_size.value()
463                            : default_max_batch_size;
464 
465   if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
466     TF_RETURN_IF_ERROR(CreateStaticEngine(params, info, max_batch_size,
467                                           input_shapes, nullptr,
468                                           &segment_string, cluster));
469   }
470 
471   string prec_string;
472   TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
473   NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
474   if (!info.device.empty()) node_builder.Device(info.device);
475   if (VLOG_IS_ON(1)) {
476     string ins = StrCat(info.engine_name, " inputs= ");
477     for (const auto& ii : inputs) {
478       StrAppend(&ins, ii.node, ":", ii.index, " ");
479     }
480     VLOG(1) << ins;
481   }
482   node_builder.Input(inputs);
483   for (const string& c : control_input_names) {
484     node_builder.ControlInput(c);
485   }
486 
487   NodeDef trt_node;
488   NameAttrList function;
489   function.set_name(StrCat(info.engine_name, "_native_segment"));
490 
491   node_builder.Attr("input_shapes", input_shape_protos)
492       .Attr("output_shapes", output_shape_protos)
493       .Attr("static_engine",
494             info.engine_type == EngineInfo::EngineType::TRTStatic)
495       .Attr("segment_func", function)
496       .Attr("serialized_segment", segment_string)
497       .Attr("calibration_data", "")
498       .Attr("max_cached_engines_count", info.maximum_cached_engines)
499       .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
500       .Attr("max_batch_size", max_batch_size)
501       .Attr("precision_mode", prec_string)
502       .Attr("use_calibration", info.use_calibration)
503       .Attr("_use_implicit_batch", params.use_implicit_batch)
504       .Attr("use_explicit_precision", params.use_explicit_precision)
505       .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
506       .Attr("OutT", out_types);
507 
508   if (!params.use_implicit_batch) {
509     node_builder.Attr("profile_strategy",
510                       ProfileStrategyToName(params.profile_strategy));
511   }
512 
513   Status status = node_builder.Finalize(&trt_node);
514 
515   if (!status.ok()) {
516     LOG(ERROR) << "Node construction failed with" << status;
517     return status;
518   }
519   VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
520 
521   // Up until this point, graph is not modified. If we return !status.ok() from
522   // here, this segment will be skipped
523   // TODO(aaroey): let it return proper error status for the following logic
524   // instead of checking fail.
525   TF_ASSIGN_OR_RETURN(Node * engine_node, graph->AddNode(trt_node));
526   (*engine_nodes)[pos] = engine_node;
527   // Add control input and input edges to the engine node.
528   for (const auto in : control_input_nodes) {
529     VLOG(1) << "Connecting control edge from " << in->name() << " to "
530             << engine_node->name();
531     graph->AddControlEdge(in, engine_node);
532   }
533   VLOG(1) << "input_nodes size = " << input_nodes.size();
534   for (int i = 0; i < input_nodes.size(); ++i) {
535     Node* n = CHECK_NOTNULL(input_nodes[i]);
536     const auto& in = inputs[i];
537     VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
538             << " to " << engine_node->name() << ":" << i;
539     graph->AddEdge(n, in.index, engine_node, i);
540   }
541 
542   // Updates the inputs of output edges destination nodes, and point them to the
543   // engine node.
544   for (auto& conn : info.connections) {
545     if (conn.is_input_edge) {
546       continue;
547     }
548     Node* output_node = graph->FindNodeId(conn.outside_id);
549     int port = conn.outside_port;
550     if (!output_node) {
551       UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
552                          conn.outside_node_name, &output_node, &port);
553     }
554     if (conn.is_control_edge()) {
555       VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
556               << output_node->name();
557       QCHECK_EQ(Graph::kControlSlot, port);
558       graph->AddControlEdge(engine_node, output_node);
559     } else {
560       VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
561               << conn.port_number << " to " << output_node->name() << ":"
562               << port;
563       // Use UpdateEdge() to avoid adding the same edge multiple times.
564       TF_CHECK_OK(
565           graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
566     }
567   }
568   return Status::OK();
569 }
570 
GetNextGraphSequenceNumber()571 int64 GetNextGraphSequenceNumber() {
572   static std::atomic<int64_t> graph_sequence_num;
573   return graph_sequence_num++;
574 }
575 
576 constexpr char kCastInputTypeAttrName[] = "SrcT";
577 
578 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
579 //   castToFp16 = cast(x, fp16)
580 //   node = cast(castToFp16, fp32)
581 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)582 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
583   if (node_def->op() != "Cast") {
584     return Status::OK();
585   }
586 
587   DataTypeVector input_types;
588   DataTypeVector output_types;
589   TF_RETURN_IF_ERROR(
590       graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
591 
592   if (input_types.size() != 1 || output_types.size() != 1) {
593     return errors::Internal("Bad cast operation");
594   }
595 
596   if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
597     return Status::OK();
598   }
599 
600   VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
601 
602   NodeDef* castToFp16 = graph_def->add_node();
603   for (auto attr_value : node_def->attr()) {
604     (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
605   }
606   castToFp16->set_name(node_def->name() + "_split");
607   castToFp16->set_op("Cast");
608   castToFp16->set_device(node_def->device());
609   castToFp16->add_input(node_def->input(0));
610   (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
611 
612   node_def->set_input(0, castToFp16->name() + ":0");
613   (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
614 
615   VLOG(2) << castToFp16->DebugString();
616   VLOG(2) << node_def->DebugString();
617 
618   return Status::OK();
619 }
620 
621 }  // namespace
622 
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name)623 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
624                                       Graph* graph, const string& engine_name) {
625   Graph segment_graph(graph->flib_def());
626   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
627                                             segment_graph_def, &segment_graph));
628   FunctionDefLibrary library;
629   auto segment_func = library.add_function();
630   TF_RETURN_IF_ERROR(GraphToFunctionDef(
631       segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
632   if (VLOG_IS_ON(7)) {
633     VLOG(7) << engine_name << " Function_Def ";
634     VLOG(7) << segment_func->DebugString();
635   }
636   VLOG(1) << "Adding funcdef " << segment_func->signature().name()
637           << " to graphlib";
638   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
639   return Status::OK();
640 }
641 
GetDeviceAndAllocator(const grappler::Cluster * cluster,const EngineInfo & engine)642 std::pair<int, Allocator*> GetDeviceAndAllocator(
643     const grappler::Cluster* cluster, const EngineInfo& engine) {
644   int cuda_device_id = -1;
645   Allocator* dev_allocator = nullptr;
646   if (cluster == nullptr || cluster->GetDeviceSet() == nullptr ||
647       engine.device.empty()) {
648     // If device is not set, use the first found GPU device for the conversion.
649     TfDeviceId tf_device_id;
650     PlatformDeviceId platform_device_id;
651     std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
652     cuda_device_id = platform_device_id.value();
653     if (cuda_device_id >= 0) {
654       GPUOptions gpu_options;
655       // If the TF to Cuda gpu id mapping exist, the device and corresponding
656       // allocator must have been initialized already, so the
657       // GetGPUAllocator() call won't create a new allocator.
658       dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
659           gpu_options, tf_device_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
660     }
661     return std::make_pair(cuda_device_id, dev_allocator);
662   }
663 
664   // Use the device requested by the engine.
665   auto device_set = cluster->GetDeviceSet();
666   std::vector<Device*> devices;
667   DeviceNameUtils::ParsedName parsed_name;
668   if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
669       parsed_name.has_id) {
670     device_set->FindMatchingDevices(parsed_name, &devices);
671   }
672   if (!devices.empty()) {
673     if (devices.size() > 1) {
674       string msg = "Found multiple matching devices using name '";
675       StrAppend(&msg, engine.device, "': ");
676       for (auto d : devices) StrAppend(&msg, d->name(), ", ");
677       StrAppend(&msg, ". Will get the allocator from first one.");
678       LOG_WARNING_WITH_PREFIX << msg;
679     }
680     AllocatorAttributes alloc_attr;
681     cuda_device_id = devices[0]->tensorflow_accelerator_device_info()->gpu_id;
682     dev_allocator = devices[0]->GetAllocator(alloc_attr);
683     VLOG(1) << "Using allocator " << dev_allocator->Name()
684             << " and cuda_device_id " << cuda_device_id;
685   } else {
686     LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
687                             << "' is not found in the cluster";
688   }
689   return std::make_pair(cuda_device_id, dev_allocator);
690 }
691 
CreateStaticEngine(const TRTOptimizationPass::ConversionParams & params,const EngineInfo & info,int max_batch_size,const std::vector<PartialTensorShape> & input_shapes,TrtShapeOptimizationProfile * profile,string * segment_string,grappler::Cluster * cluster)692 Status CreateStaticEngine(const TRTOptimizationPass::ConversionParams& params,
693                           const EngineInfo& info, int max_batch_size,
694                           const std::vector<PartialTensorShape>& input_shapes,
695                           TrtShapeOptimizationProfile* profile,
696                           string* segment_string, grappler::Cluster* cluster) {
697   std::pair<int, Allocator*> device_allocator =
698       GetDeviceAndAllocator(cluster, info);
699   int cuda_device_id = 0;
700   std::unique_ptr<TRTBaseAllocator> trt_allocator;
701   if (device_allocator.first >= 0) {
702     cuda_device_id = device_allocator.first;
703     trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
704   } else {
705     // The value in trt_allocator is a nullptr and cudamalloc will be used.
706     LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
707                                "device 0 and use cudamalloc as an allocator";
708   }
709   cudaSetDevice(cuda_device_id);
710 
711   auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
712   const bool calibrate_int8 =
713       (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
714 
715   // Create static engines with precision_mode fp32/fp16.
716   TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
717   TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
718       info.segment_graph_def, nullptr,
719       calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
720       max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
721       trt_allocator.get(), /*calibrator=*/nullptr, &engine,
722       info.use_calibration, params.use_implicit_batch,
723       /*convert_successfully=*/nullptr, profile, info.engine_name,
724       /*use_explicit_precision=*/params.use_explicit_precision, cluster));
725   TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
726   *segment_string = string(static_cast<const char*>(engine_data->data()),
727                            engine_data->size());
728   return Status::OK();
729 }
730 
ConvertGraph(const TRTOptimizationPass::ConversionParams & params,grappler::GrapplerItem & grappler_item,const std::vector<string> & input_output_names,grappler::Cluster * cluster,GraphDef * output)731 Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
732                     grappler::GrapplerItem& grappler_item,
733                     const std::vector<string>& input_output_names,
734                     grappler::Cluster* cluster, GraphDef* output) {
735   // Sanity checks.
736   TRT_ENSURE(output != nullptr)
737   if (params.precision_mode != TrtPrecisionMode::INT8 &&
738       params.use_calibration) {
739     return errors::InvalidArgument(
740         "Calibration with FP32 or FP16 is not supported.");
741   }
742 
743   GraphDef& graph_def = grappler_item.graph;
744 
745   // When precision_mode is FP16, transform cast(x, fp32) to
746   // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
747   // included in the TRTEngineOp as an TensorRT Identity layer for performance:
748   //  . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
749   //    precision.
750   //  . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
751   //    moving from the host to the GPU.
752   if (params.precision_mode == TrtPrecisionMode::FP16) {
753     for (int i = 0; i < graph_def.node_size(); i++) {
754       NodeDef* node_def = graph_def.mutable_node(i);
755       TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&graph_def, node_def));
756     }
757   }
758 
759   // Construct a GrapplerItem using the modified graph_def and the input
760   // grappler_item.
761   grappler::GraphProperties static_graph_properties(grappler_item);
762   TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
763 
764   // Convert graphdef to graph.
765   FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
766   Graph graph(flib);
767   TF_RETURN_IF_ERROR(
768       ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
769 
770   // Segment the graph into subgraphs that can be converted to TensorRT
771   segment::SegmentOptions segment_options;
772   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
773   for (const auto& node : input_output_names) {
774     segment_options.exclude_node_list.insert(node);
775   }
776   segment_options.minimum_segment_size = params.minimum_segment_size;
777   segment_options.use_implicit_batch = params.use_implicit_batch;
778   if (segment_options.use_implicit_batch)
779     segment_options.maximum_batch_size = params.max_batch_size;
780   segment_options.allow_dynamic_non_batch_dim =
781       AllowDynamicNonBatchDimension(params);
782 
783   segment::SegmentVector initial_segments;
784   TrtNodeValidator validator(static_graph_properties, params.precision_mode,
785                              params.use_calibration, params.use_implicit_batch,
786                              params.use_explicit_precision);
787   TF_RETURN_IF_ERROR(segment::SegmentGraph(
788       &graph, &static_graph_properties,
789       std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
790                 std::placeholders::_1),
791       // Input validation is already done by TrtNodeValidator, so we don't
792       // need to check the input edges.
793       [](const Edge* edge) { return true; }, OutputEdgeValidator(),
794       segment_options, &initial_segments));
795   LOG(INFO) << "Number of TensorRT candidate segments: "
796             << initial_segments.size();
797 
798   // Get the EngineInfo for each segment.
799   std::unordered_map<string, Node*> node_map;
800   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
801   std::vector<EngineInfo> engine_segments;
802   engine_segments.reserve(initial_segments.size());
803   std::vector<Node*> reverse_topo_order;
804   GetPostOrder(graph, &reverse_topo_order);
805   segment::SegmentVector converted_segments;
806   converted_segments.reserve(initial_segments.size());
807   string engine_name_prefix =
808       StrCat("TRTEngineOp_",
809              absl::StrFormat("%0*d", 3, GetNextGraphSequenceNumber()), "_");
810   for (size_t t = 0; t < initial_segments.size(); t++) {
811     auto& curr_segment = initial_segments.at(t);
812     EngineInfo curr_engine;
813     curr_engine.engine_name =
814         StrCat(engine_name_prefix, absl::StrFormat("%0*d", 3, t));
815 
816     bool int8_no_calib = (!params.use_calibration &&
817                           params.precision_mode == TrtPrecisionMode::INT8);
818     bool has_qdq = false;
819     if (int8_no_calib) {
820       has_qdq = absl::c_any_of(reverse_topo_order, IsQuantizeAndDequantizeOp);
821     }
822 
823     Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
824                                   reverse_topo_order, &curr_engine);
825     if (!status.ok()) {
826       LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
827                               << ": " << status;
828       continue;
829     }
830 
831     curr_engine.engine_type = GetEngineType(params);
832     curr_engine.use_calibration = params.use_calibration;
833     // Building cuda engines for INT8 without calibration and without dynamic
834     // range info cause TRT failure. Avoid this situation by setting the
835     // precision to FP16.
836     if (int8_no_calib && !has_qdq) {
837       LOG(WARNING) << "Set engine precision to FP16 due to missing QDQ OP";
838       curr_engine.precision_mode = TrtPrecisionMode::FP16;
839     } else {
840       curr_engine.precision_mode = params.precision_mode;
841     }
842     curr_engine.maximum_cached_engines = params.max_cached_engines;
843     curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
844     if (!curr_engine.max_batch_size.has_value()) {
845       curr_engine.max_batch_size = params.max_batch_size;
846     }
847 
848     status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
849                                             &graph, curr_engine.engine_name);
850 
851     if (!status.ok()) {
852       LOG_WARNING_WITH_PREFIX
853           << "Failed to register segment graphdef to the library " << t << ": "
854           << status;
855       continue;
856     }
857 
858     engine_segments.push_back(std::move(curr_engine));
859     converted_segments.push_back(std::move(curr_segment));
860 
861     if (VLOG_IS_ON(8)) {
862       string fname = engine_segments.back().engine_name;
863       StrAppend(&fname, ".pb");
864       std::fstream f;
865       f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
866       f << engine_segments.at(t).segment_graph_def.SerializeAsString();
867       f.close();
868     }
869   }
870 
871   // Save the cuda device since we may need to switch to another cuda device to
872   // build static engines.
873   std::optional<int> old_cuda_device = std::nullopt;
874   if (!params.is_dynamic_op) {
875     int cuda_device_id;
876     cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
877     if (cuda_error != cudaSuccess) {
878       LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
879                               << cudaGetErrorString(cuda_error);
880     } else {
881       VLOG(1) << "Current cuda device is " << cuda_device_id;
882       old_cuda_device = cuda_device_id;
883     }
884   }
885 
886   auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
887     if (old_cuda_device.has_value()) {
888       cudaSetDevice(old_cuda_device.value());
889     }
890   });
891 
892   std::vector<Node*> engine_nodes;
893   engine_nodes.resize(engine_segments.size());
894   for (int i = 0; i < engine_segments.size(); ++i) {
895     auto& engine = engine_segments.at(i);
896     // TODO(b/170762693): implement the heuristic to calculate
897     // max_workspace_size_bytes.
898     engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
899     VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
900             << engine.engine_name;
901     auto status =
902         CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
903                       &engine_nodes, cluster);
904 
905     string msg = StrCat("segment ", i, " consisting of ",
906                         converted_segments.at(i).nodes.size(), " nodes by ",
907                         engine.engine_name);
908     if (status.ok()) {
909       LOG(INFO) << "Replaced " << msg << ".";
910     } else {
911       // Graph is not modified.
912       LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
913                               << " reason: " << status.error_message()
914                               << " (keeping original segment).";
915     }
916     if (VLOG_IS_ON(1)) {
917       msg = "Segment consists of nodes: ";
918       for (const Node* node : converted_segments.at(i).nodes) {
919         StrAppend(&msg, node->name(), ", ");
920       }
921       VLOG(1) << msg;
922     }
923 
924     // If status is ok, we successfully added the node to the graph and can
925     // remove segment ops. Otherwise graph is not modified.
926     if (status.ok()) {
927       for (const Node* node : converted_segments.at(i).nodes) {
928         graph.RemoveNode(const_cast<Node*>(node));
929       }
930     }
931   }
932   graph.ToGraphDef(output);
933   return Status::OK();
934 }
935 
936 }  // namespace convert
937 }  // namespace tensorrt
938 }  // namespace tensorflow
939 
940 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
941