1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
17
18 #include <fstream>
19 #include <list>
20 #include <map>
21 #include <set>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/strings/str_cat.h"
28 #include "absl/strings/str_format.h"
29 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
30 #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h"
32 #include "tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h"
33 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
34 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
35 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
36 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
37 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/framework/function.h"
40 #include "tensorflow/core/framework/graph_to_functiondef.h"
41 #include "tensorflow/core/framework/node_def_builder.h"
42 #include "tensorflow/core/graph/algorithm.h"
43 #include "tensorflow/core/graph/graph.h"
44 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
45 #include "tensorflow/core/grappler/costs/graph_properties.h"
46 #include "tensorflow/core/grappler/devices.h"
47 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
48 #include "tensorflow/core/grappler/utils.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/cleanup.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/logging.h"
53 #include "tensorflow/core/protobuf/config.pb.h" // NOLINT
54 #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
55 #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
56 #include "tensorflow/core/util/device_name_utils.h"
57 #include "tensorflow/tools/graph_transforms/transform_utils.h"
58
59 #if GOOGLE_CUDA && GOOGLE_TENSORRT
60 #include "third_party/gpus/cuda/include/cuda_runtime_api.h"
61 #include "third_party/tensorrt/NvInfer.h"
62 namespace tensorflow {
63 namespace tensorrt {
64 namespace convert {
65
66 using absl::StrAppend;
67 using absl::StrCat;
68 using ::tensorflow::tensorrt::segment::ClusterProperty;
69 using ::tensorflow::tensorrt::segment::NodePtrCompare;
70 using ::tensorflow::tensorrt::segment::Segment;
71
72 namespace {
73
BuildNodeMap(const Graph & graph,std::unordered_map<string,Node * > * node_map)74 Status BuildNodeMap(const Graph& graph,
75 std::unordered_map<string, Node*>* node_map) {
76 for (auto* node : graph.op_nodes()) {
77 if (!node_map->insert({node->name(), node}).second) {
78 return errors::AlreadyExists("Node name is not unique in graph: " +
79 node->name());
80 }
81 }
82 return Status::OK();
83 }
84
GetEngineType(const TRTOptimizationPass::ConversionParams & params)85 EngineInfo::EngineType GetEngineType(
86 const TRTOptimizationPass::ConversionParams& params) {
87 return (params.is_dynamic_op || params.use_calibration)
88 ? EngineInfo::EngineType::TRTDynamic
89 : EngineInfo::EngineType::TRTStatic;
90 }
91
92 // Returns true when use_implicit_batch is false or when we are building dynamic
93 // engine, to allow unknown size for dimensions rather than dimension 0.
AllowDynamicNonBatchDimension(const TRTOptimizationPass::ConversionParams & params)94 bool AllowDynamicNonBatchDimension(
95 const TRTOptimizationPass::ConversionParams& params) {
96 return !params.use_implicit_batch ||
97 GetEngineType(params) == EngineInfo::EngineType::TRTDynamic;
98 }
99
100 struct EdgePtrCompare {
operator ()tensorflow::tensorrt::convert::__anonbe1b44070111::EdgePtrCompare101 bool operator()(const Edge* lhs, const Edge* rhs) const {
102 return lhs->id() < rhs->id();
103 }
104 };
105
106 // TODO(laigd): instead of deciding the device here, the converter should accept
107 // a device name as one of the conversion parameter so users can control on
108 // which device they want to run the conversion.
GetFirstValidDeviceId()109 std::pair<TfDeviceId, PlatformDeviceId> GetFirstValidDeviceId() {
110 for (int tf_device_id_value = 0; tf_device_id_value < 100;
111 ++tf_device_id_value) {
112 TfDeviceId tf_device_id(tf_device_id_value);
113 PlatformDeviceId platform_device_id;
114 Status s =
115 GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id);
116 if (s.ok()) {
117 VLOG(1) << "Found TF GPU " << tf_device_id.value() << " at cuda device "
118 << platform_device_id.value();
119 return std::make_pair(tf_device_id, platform_device_id);
120 }
121 }
122 LOG(ERROR) << "Could not find any TF GPUs";
123 return std::make_pair(TfDeviceId(-1), PlatformDeviceId(-1));
124 }
125
126 // Returns false for const nodes (we intend to drop control edges from those).
ShallKeepControlEdgeFrom(const Node * input_node)127 bool ShallKeepControlEdgeFrom(const Node* input_node) {
128 if (!input_node) {
129 LOG(ERROR) << "Node pointer is null, this should not happen";
130 return false;
131 }
132 return input_node->type_string() != "Const";
133 }
134
135 // Function to get subsegment information structure.
GetEngineInfo(const Graph * g,const grappler::GraphProperties & graph_properties,const Segment & segment,const std::vector<Node * > & reverse_topo_order,EngineInfo * info)136 Status GetEngineInfo(const Graph* g,
137 const grappler::GraphProperties& graph_properties,
138 const Segment& segment,
139 const std::vector<Node*>& reverse_topo_order,
140 EngineInfo* info) {
141 std::vector<const Node*> subgraph_nodes; // Topologically sorted nodes.
142 std::set<const Node*> added_const_nodes; // Used to prevent double insertion.
143
144 const ClusterProperty& segment_property = segment.property;
145 const std::set<const Node*, NodePtrCompare>& segment_nodes = segment.nodes;
146
147 // The device assignment accumulated from the compatible device assignments
148 // for the nodes in the segment.
149 const DeviceNameUtils::ParsedName segment_device =
150 segment_property.DeviceName();
151 info->max_batch_size = segment_property.BatchSize().GetOptionalMaxBatchSize();
152
153 // Map from src_node_name+port to the unique port numbers of the TRT op, where
154 // the src_node_name is the name of the source node of the input/output
155 // edge, thus there must not be any duplicates since source nodes of
156 // input/output edges must be in different split of the graph.
157 // TODO(aaroey): consider using node id and port instead.
158 // TODO(aaroey): using topo order instead of reverting reverse topo order.
159 std::unordered_map<string, int> input_to_engine_port, output_to_engine_port;
160 for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
161 ++it) {
162 const Node* node = *it;
163 if (segment_nodes.count(node) == 0) continue;
164 subgraph_nodes.push_back(node);
165
166 const int node_id = node->id();
167 const string& node_name = node->name();
168
169 // Create input connections. Sort edges first to make deterministic since
170 // in_edges is a set of pointers.
171 std::vector<const Edge*> in_edges(node->in_edges().begin(),
172 node->in_edges().end());
173 std::sort(in_edges.begin(), in_edges.end(), EdgePtrCompare());
174 for (const auto edge : in_edges) {
175 auto input_node = edge->src();
176 if (input_node->IsSource() || segment_nodes.count(input_node)) {
177 continue;
178 }
179 if (edge->IsControlEdge()) {
180 if (ShallKeepControlEdgeFrom(input_node)) {
181 // Non-Const control input.
182 info->connections.emplace_back(input_node->name(), input_node->id(),
183 node_name, node_id,
184 /*input_edge=*/true);
185 }
186 } else if (input_node->type_string() == "Const") {
187 // Add constant data input nodes into the segment graphdef (thus also in
188 // the engine). We don't care if it has other output edges going into
189 // other engines or TF nodes. Since we add it only to the segment
190 // graphdef, not the segment itself, it won't be removed from the graph.
191 // If it doesn't have any edges, TF will prune it out.
192 //
193 // Note that the segmenter already ensure that the constant data input
194 // is valid and supported by the engine.
195 if (!added_const_nodes.insert(input_node).second) {
196 // Already added before.
197 continue;
198 }
199 VLOG(1) << "Adding const node " << input_node->name();
200 } else {
201 // Non-const data input.
202 int port = Graph::kControlSlot - 1;
203 // Use the source non-segment node name/port as key.
204 const string s = StrCat(input_node->name(), ":", edge->src_output());
205 VLOG(1) << "Input edge = " << s;
206 if (input_to_engine_port.count(s)) {
207 port = input_to_engine_port.at(s);
208 } else {
209 port = input_to_engine_port.size();
210 input_to_engine_port.insert({s, port});
211 }
212 info->connections.emplace_back(
213 input_node->name(), input_node->id(), edge->src_output(), node_name,
214 node_id, edge->dst_input(), /*input_edge=*/true, port);
215 }
216 }
217 // Create output connections. Sort edges first to make deterministic since
218 // out_edges is a set of pointers.
219 std::vector<const Edge*> out_edges(node->out_edges().begin(),
220 node->out_edges().end());
221 std::sort(out_edges.begin(), out_edges.end(), EdgePtrCompare());
222 for (const auto edge : out_edges) {
223 auto output_node = edge->dst();
224 if (output_node->IsSink() || segment_nodes.count(output_node)) {
225 continue;
226 }
227 if (edge->IsControlEdge()) {
228 // Control output.
229 if (ShallKeepControlEdgeFrom(node)) {
230 info->connections.emplace_back(output_node->name(), output_node->id(),
231 node_name, node_id,
232 /*input_edge=*/false);
233 }
234 } else {
235 // Data output.
236 int port = Graph::kControlSlot - 1;
237 // Use the source segment node name/port as key.
238 const string s = StrCat(node_name, ":", edge->src_output());
239 VLOG(1) << "Output edge = " << s;
240 if (output_to_engine_port.count(s)) {
241 port = output_to_engine_port.at(s);
242 } else {
243 port = output_to_engine_port.size();
244 output_to_engine_port.insert({s, port});
245 }
246 info->connections.emplace_back(
247 output_node->name(), output_node->id(), edge->dst_input(),
248 node_name, node_id, edge->src_output(), /*input_edge=*/false, port);
249 }
250 }
251 } // For each segment node in topological order.
252
253 // Construct the const nodes first.
254 subgraph_nodes.insert(subgraph_nodes.begin(), added_const_nodes.begin(),
255 added_const_nodes.end());
256 TF_RETURN_IF_ERROR(
257 ConvertSegmentToGraphDef(g, graph_properties, subgraph_nodes, info));
258 VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name
259 << "' to a GraphDef";
260 if (segment_device.has_type) {
261 // If the accumulated device assignment for the segment has a device type,
262 // the segmenter guarantees the device type is GPU. Use the device
263 // assignment in this case.
264 if (segment_device.type != "GPU") {
265 return errors::Internal(
266 "segment device is not GPU: ",
267 DeviceNameUtils::ParsedNameToString(segment_device));
268 }
269 info->device = DeviceNameUtils::ParsedNameToString(segment_device);
270 } else {
271 TfDeviceId tf_device_id;
272 PlatformDeviceId platform_device_id;
273 std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
274 if (tf_device_id.value() >= 0) {
275 DeviceNameUtils::ParsedName parsed_name;
276 parsed_name.type = "GPU";
277 parsed_name.has_type = true;
278 parsed_name.id = tf_device_id.value();
279 parsed_name.has_id = true;
280 info->device = DeviceNameUtils::ParsedNameToString(parsed_name);
281 } else {
282 VLOG(1) << "No device is assigned to the segment. A device will be "
283 "assigned during graph execution (inference).";
284 }
285 }
286 return Status::OK();
287 }
288
289 // Helper function to update edge connection from the removed node to the
290 // engine node. If an outside node is gone, it must have been absorbed into
291 // an engine node. Find the engine node.
UpdateToEngineNode(const std::vector<EngineInfo> & infos,const size_t my_engine_id,const std::vector<Node * > & engine_nodes,const bool is_input_edge,const string & node_name,Node ** node,int * port)292 void UpdateToEngineNode(const std::vector<EngineInfo>& infos,
293 const size_t my_engine_id,
294 const std::vector<Node*>& engine_nodes,
295 const bool is_input_edge, const string& node_name,
296 Node** node, int* port) {
297 for (size_t t = 0; t < infos.size(); ++t) {
298 if (t == my_engine_id) {
299 continue;
300 }
301 const auto& info = infos.at(t);
302 for (const auto& eng_conn : info.connections) {
303 // If the connection being updated is an input connection, the source of
304 // the connection must be an output connection of another engine. And vise
305 // versa.
306 if (is_input_edge == eng_conn.is_input_edge) continue;
307 if (eng_conn.inside_node_name == node_name &&
308 eng_conn.inside_port == *port) {
309 *node = CHECK_NOTNULL(engine_nodes[t]);
310 QCHECK_EQ(info.engine_name, (**node).name())
311 << "Engine name mismatch: " << info.engine_name << " vs "
312 << (**node).name();
313 *port = eng_conn.port_number;
314 return;
315 }
316 }
317 }
318 LOG(FATAL) << "Node " << node_name << " not found in any engine.";
319 }
320
ComputeTRTNodeIOShape(std::vector<PartialTensorShape> & partial_tensorshape_vect,std::vector<tensorflow::TensorShapeProto> & shape_proto_vect,const PartialTensorShape & conn_shape,int port_number)321 tensorflow::TensorShapeProto ComputeTRTNodeIOShape(
322 std::vector<PartialTensorShape>& partial_tensorshape_vect,
323 std::vector<tensorflow::TensorShapeProto>& shape_proto_vect,
324 const PartialTensorShape& conn_shape, int port_number) {
325 tensorflow::TensorShapeProto tmp_shape_proto;
326 conn_shape.AsProto(&tmp_shape_proto);
327
328 if (partial_tensorshape_vect.size() <= port_number) {
329 shape_proto_vect.resize(port_number + 1);
330 partial_tensorshape_vect.resize(port_number + 1);
331 }
332
333 return tmp_shape_proto;
334 }
335
336 // Function to insert a TRT engine node into the graph.
337 // Create engine nodes in the following way:
338 // 1. Each invocation of CreateTRTNode creates an engine node for infos[pos]
339 // 2. When an engine node is created, add it into the graph with necessary
340 // re-wiring.
341 // 2.1. If the outside connected node is existing, connect the engine
342 // node to it.
343 // 2.2. If the outside connected node is gone, it must have been absorted
344 // into another engine node (which was processed before the processing
345 // one). Connect to the pre-existing engine node instead.
346 // 3. In this way, we ensure the graph is topologically sort-able after each
347 // invocation of CreateTRTNode().
CreateTRTNode(const TRTOptimizationPass::ConversionParams & params,const std::vector<EngineInfo> & infos,int pos,int default_max_batch_size,Graph * graph,std::vector<Node * > * engine_nodes,grappler::Cluster * cluster)348 Status CreateTRTNode(const TRTOptimizationPass::ConversionParams& params,
349 const std::vector<EngineInfo>& infos, int pos,
350 int default_max_batch_size, Graph* graph,
351 std::vector<Node*>* engine_nodes,
352 grappler::Cluster* cluster) {
353 const auto& info = infos.at(pos);
354 std::vector<tensorflow::TensorShapeProto> input_shape_protos;
355 std::vector<tensorflow::TensorShapeProto> output_shape_protos;
356 std::vector<PartialTensorShape> input_shapes;
357 std::vector<PartialTensorShape> output_shapes;
358 std::vector<NodeDefBuilder::NodeOut> inputs;
359 std::vector<Node*> input_nodes;
360 std::vector<Node*> control_input_nodes;
361 std::unordered_set<string> control_input_names;
362 std::vector<DataType> out_types;
363
364 VLOG(1) << "Processing " << info.engine_name;
365 // Collect needed info for creating the engine node in the graph
366 for (const auto& conn : info.connections) {
367 // Control edges
368 if (conn.is_control_edge()) {
369 // Skip control outputs for now. control output info are not needed for
370 // node creation and will be processed later.
371 if (!conn.is_input_edge) continue;
372
373 // Rewrire control input if it's not found in original graph.
374 Node* input_node = graph->FindNodeId(conn.outside_id);
375 int port = Graph::kControlSlot;
376 if (!input_node) {
377 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
378 conn.outside_node_name, &input_node, &port);
379 QCHECK_EQ(Graph::kControlSlot, port);
380 }
381 if (!control_input_names.insert(input_node->name()).second) {
382 continue;
383 }
384 control_input_nodes.push_back(input_node);
385 VLOG(1) << "Engine Control Input " << input_node->name() << " -> "
386 << info.engine_name;
387 } else {
388 // Data edges
389 if (!conn.is_input_edge) {
390 // Set the shapes and data types of the output edge.
391 tensorflow::TensorShapeProto out_shape = ComputeTRTNodeIOShape(
392 /*partial_tensorshape_vect=*/output_shapes,
393 /*shape_proto_vect=*/output_shape_protos,
394 /*conn_shape=*/conn.inside_shape,
395 /*port_number=*/conn.port_number);
396
397 output_shape_protos.at(conn.port_number) = out_shape;
398 output_shapes.at(conn.port_number) = conn.inside_shape;
399
400 if (out_types.size() <= conn.port_number) {
401 out_types.resize(conn.port_number + 1);
402 }
403 out_types.at(conn.port_number) = conn.connection_type;
404 VLOG(2) << "Collected output shape "
405 << output_shape_protos.at(conn.port_number).DebugString();
406 } else {
407 // Set the shapes of the input edge.
408 tensorflow::TensorShapeProto in_shape = ComputeTRTNodeIOShape(
409 /*partial_tensorshape_vect=*/input_shapes,
410 /*shape_proto_vect=*/input_shape_protos,
411 /*conn_shape=*/conn.outside_shape,
412 /*port_number=*/conn.port_number);
413
414 input_shape_protos.at(conn.port_number) = in_shape;
415 input_shapes.at(conn.port_number) = conn.outside_shape;
416
417 // Shape must be fully defined (excluding batch dimension) for static
418 // mode.
419 if (params.use_implicit_batch &&
420 info.engine_type == EngineInfo::EngineType::TRTStatic) {
421 for (int i = 1; i < conn.outside_shape.dims(); i++) {
422 if (conn.outside_shape.dim_size(i) <= 0) {
423 return errors::Internal(
424 "Not fully defined input shape when in static mode which "
425 "should have been excluded by the segmenter. ");
426 }
427 }
428 }
429
430 // Rewrire data input if it's not found in original graph.
431 Node* input_node = graph->FindNodeId(conn.outside_id);
432 int port = conn.outside_port;
433 if (!input_node) {
434 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/true,
435 conn.outside_node_name, &input_node, &port);
436 }
437 if (std::find_if(
438 std::begin(inputs), std::end(inputs),
439 [input_node, &port](const NodeDefBuilder::NodeOut& inp) {
440 return inp.node == input_node->name() && inp.index == port;
441 }) == std::end(inputs)) {
442 inputs.emplace_back(input_node->name(), port, conn.connection_type);
443 input_nodes.push_back(CHECK_NOTNULL(input_node));
444 VLOG(1) << "Engine Input " << input_node->name() << ":" << port
445 << " -> " << info.engine_name << ":" << inputs.size() - 1;
446 }
447 }
448 }
449 }
450 // We don't support segments with no inputs. Fall back to native TF here to
451 // avoid crash later. Constant folding should've folded the ops that make up
452 // these segments.
453 if (inputs.empty()) {
454 return errors::Internal(
455 "Segment has no inputs (possible constfold failure)");
456 }
457
458 // Build the engine and get its serialized representation.
459 string segment_string;
460
461 int max_batch_size = info.max_batch_size.has_value()
462 ? info.max_batch_size.value()
463 : default_max_batch_size;
464
465 if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
466 TF_RETURN_IF_ERROR(CreateStaticEngine(params, info, max_batch_size,
467 input_shapes, nullptr,
468 &segment_string, cluster));
469 }
470
471 string prec_string;
472 TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
473 NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
474 if (!info.device.empty()) node_builder.Device(info.device);
475 if (VLOG_IS_ON(1)) {
476 string ins = StrCat(info.engine_name, " inputs= ");
477 for (const auto& ii : inputs) {
478 StrAppend(&ins, ii.node, ":", ii.index, " ");
479 }
480 VLOG(1) << ins;
481 }
482 node_builder.Input(inputs);
483 for (const string& c : control_input_names) {
484 node_builder.ControlInput(c);
485 }
486
487 NodeDef trt_node;
488 NameAttrList function;
489 function.set_name(StrCat(info.engine_name, "_native_segment"));
490
491 node_builder.Attr("input_shapes", input_shape_protos)
492 .Attr("output_shapes", output_shape_protos)
493 .Attr("static_engine",
494 info.engine_type == EngineInfo::EngineType::TRTStatic)
495 .Attr("segment_func", function)
496 .Attr("serialized_segment", segment_string)
497 .Attr("calibration_data", "")
498 .Attr("max_cached_engines_count", info.maximum_cached_engines)
499 .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
500 .Attr("max_batch_size", max_batch_size)
501 .Attr("precision_mode", prec_string)
502 .Attr("use_calibration", info.use_calibration)
503 .Attr("_use_implicit_batch", params.use_implicit_batch)
504 .Attr("use_explicit_precision", params.use_explicit_precision)
505 .Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
506 .Attr("OutT", out_types);
507
508 if (!params.use_implicit_batch) {
509 node_builder.Attr("profile_strategy",
510 ProfileStrategyToName(params.profile_strategy));
511 }
512
513 Status status = node_builder.Finalize(&trt_node);
514
515 if (!status.ok()) {
516 LOG(ERROR) << "Node construction failed with" << status;
517 return status;
518 }
519 VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
520
521 // Up until this point, graph is not modified. If we return !status.ok() from
522 // here, this segment will be skipped
523 // TODO(aaroey): let it return proper error status for the following logic
524 // instead of checking fail.
525 TF_ASSIGN_OR_RETURN(Node * engine_node, graph->AddNode(trt_node));
526 (*engine_nodes)[pos] = engine_node;
527 // Add control input and input edges to the engine node.
528 for (const auto in : control_input_nodes) {
529 VLOG(1) << "Connecting control edge from " << in->name() << " to "
530 << engine_node->name();
531 graph->AddControlEdge(in, engine_node);
532 }
533 VLOG(1) << "input_nodes size = " << input_nodes.size();
534 for (int i = 0; i < input_nodes.size(); ++i) {
535 Node* n = CHECK_NOTNULL(input_nodes[i]);
536 const auto& in = inputs[i];
537 VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
538 << " to " << engine_node->name() << ":" << i;
539 graph->AddEdge(n, in.index, engine_node, i);
540 }
541
542 // Updates the inputs of output edges destination nodes, and point them to the
543 // engine node.
544 for (auto& conn : info.connections) {
545 if (conn.is_input_edge) {
546 continue;
547 }
548 Node* output_node = graph->FindNodeId(conn.outside_id);
549 int port = conn.outside_port;
550 if (!output_node) {
551 UpdateToEngineNode(infos, pos, *engine_nodes, /*is_input_edge=*/false,
552 conn.outside_node_name, &output_node, &port);
553 }
554 if (conn.is_control_edge()) {
555 VLOG(1) << "Updating control edge from " << engine_node->name() << " to "
556 << output_node->name();
557 QCHECK_EQ(Graph::kControlSlot, port);
558 graph->AddControlEdge(engine_node, output_node);
559 } else {
560 VLOG(1) << "Updating data edge from " << engine_node->name() << ":"
561 << conn.port_number << " to " << output_node->name() << ":"
562 << port;
563 // Use UpdateEdge() to avoid adding the same edge multiple times.
564 TF_CHECK_OK(
565 graph->UpdateEdge(engine_node, conn.port_number, output_node, port));
566 }
567 }
568 return Status::OK();
569 }
570
GetNextGraphSequenceNumber()571 int64 GetNextGraphSequenceNumber() {
572 static std::atomic<int64_t> graph_sequence_num;
573 return graph_sequence_num++;
574 }
575
576 constexpr char kCastInputTypeAttrName[] = "SrcT";
577
578 // Transforms node = cast(x, fp32) where datatype(x) != fp16 to:
579 // castToFp16 = cast(x, fp16)
580 // node = cast(castToFp16, fp32)
581 //
MaybeRewriteCastToFp32(GraphDef * graph_def,NodeDef * node_def)582 Status MaybeRewriteCastToFp32(GraphDef* graph_def, NodeDef* node_def) {
583 if (node_def->op() != "Cast") {
584 return Status::OK();
585 }
586
587 DataTypeVector input_types;
588 DataTypeVector output_types;
589 TF_RETURN_IF_ERROR(
590 graph_transforms::GetInOutTypes(*node_def, &input_types, &output_types));
591
592 if (input_types.size() != 1 || output_types.size() != 1) {
593 return errors::Internal("Bad cast operation");
594 }
595
596 if (input_types[0] == DT_HALF || output_types[0] != DT_FLOAT) {
597 return Status::OK();
598 }
599
600 VLOG(2) << "Rewriting cast to FP32 " << node_def->DebugString();
601
602 NodeDef* castToFp16 = graph_def->add_node();
603 for (auto attr_value : node_def->attr()) {
604 (*castToFp16->mutable_attr())[attr_value.first] = attr_value.second;
605 }
606 castToFp16->set_name(node_def->name() + "_split");
607 castToFp16->set_op("Cast");
608 castToFp16->set_device(node_def->device());
609 castToFp16->add_input(node_def->input(0));
610 (*castToFp16->mutable_attr())[kCastOutputTypeAttrName].set_type(DT_HALF);
611
612 node_def->set_input(0, castToFp16->name() + ":0");
613 (*node_def->mutable_attr())[kCastInputTypeAttrName].set_type(DT_HALF);
614
615 VLOG(2) << castToFp16->DebugString();
616 VLOG(2) << node_def->DebugString();
617
618 return Status::OK();
619 }
620
621 } // namespace
622
RegisterGraphToFunctionLibrary(const GraphDef & segment_graph_def,Graph * graph,const string & engine_name)623 Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def,
624 Graph* graph, const string& engine_name) {
625 Graph segment_graph(graph->flib_def());
626 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
627 segment_graph_def, &segment_graph));
628 FunctionDefLibrary library;
629 auto segment_func = library.add_function();
630 TF_RETURN_IF_ERROR(GraphToFunctionDef(
631 segment_graph, StrCat(engine_name, "_native_segment"), segment_func));
632 if (VLOG_IS_ON(7)) {
633 VLOG(7) << engine_name << " Function_Def ";
634 VLOG(7) << segment_func->DebugString();
635 }
636 VLOG(1) << "Adding funcdef " << segment_func->signature().name()
637 << " to graphlib";
638 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library));
639 return Status::OK();
640 }
641
GetDeviceAndAllocator(const grappler::Cluster * cluster,const EngineInfo & engine)642 std::pair<int, Allocator*> GetDeviceAndAllocator(
643 const grappler::Cluster* cluster, const EngineInfo& engine) {
644 int cuda_device_id = -1;
645 Allocator* dev_allocator = nullptr;
646 if (cluster == nullptr || cluster->GetDeviceSet() == nullptr ||
647 engine.device.empty()) {
648 // If device is not set, use the first found GPU device for the conversion.
649 TfDeviceId tf_device_id;
650 PlatformDeviceId platform_device_id;
651 std::tie(tf_device_id, platform_device_id) = GetFirstValidDeviceId();
652 cuda_device_id = platform_device_id.value();
653 if (cuda_device_id >= 0) {
654 GPUOptions gpu_options;
655 // If the TF to Cuda gpu id mapping exist, the device and corresponding
656 // allocator must have been initialized already, so the
657 // GetGPUAllocator() call won't create a new allocator.
658 dev_allocator = GPUProcessState::singleton()->GetGPUAllocator(
659 gpu_options, tf_device_id, /*total_bytes=*/1, /*peer_gpu_ids=*/{});
660 }
661 return std::make_pair(cuda_device_id, dev_allocator);
662 }
663
664 // Use the device requested by the engine.
665 auto device_set = cluster->GetDeviceSet();
666 std::vector<Device*> devices;
667 DeviceNameUtils::ParsedName parsed_name;
668 if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) &&
669 parsed_name.has_id) {
670 device_set->FindMatchingDevices(parsed_name, &devices);
671 }
672 if (!devices.empty()) {
673 if (devices.size() > 1) {
674 string msg = "Found multiple matching devices using name '";
675 StrAppend(&msg, engine.device, "': ");
676 for (auto d : devices) StrAppend(&msg, d->name(), ", ");
677 StrAppend(&msg, ". Will get the allocator from first one.");
678 LOG_WARNING_WITH_PREFIX << msg;
679 }
680 AllocatorAttributes alloc_attr;
681 cuda_device_id = devices[0]->tensorflow_accelerator_device_info()->gpu_id;
682 dev_allocator = devices[0]->GetAllocator(alloc_attr);
683 VLOG(1) << "Using allocator " << dev_allocator->Name()
684 << " and cuda_device_id " << cuda_device_id;
685 } else {
686 LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device
687 << "' is not found in the cluster";
688 }
689 return std::make_pair(cuda_device_id, dev_allocator);
690 }
691
CreateStaticEngine(const TRTOptimizationPass::ConversionParams & params,const EngineInfo & info,int max_batch_size,const std::vector<PartialTensorShape> & input_shapes,TrtShapeOptimizationProfile * profile,string * segment_string,grappler::Cluster * cluster)692 Status CreateStaticEngine(const TRTOptimizationPass::ConversionParams& params,
693 const EngineInfo& info, int max_batch_size,
694 const std::vector<PartialTensorShape>& input_shapes,
695 TrtShapeOptimizationProfile* profile,
696 string* segment_string, grappler::Cluster* cluster) {
697 std::pair<int, Allocator*> device_allocator =
698 GetDeviceAndAllocator(cluster, info);
699 int cuda_device_id = 0;
700 std::unique_ptr<TRTBaseAllocator> trt_allocator;
701 if (device_allocator.first >= 0) {
702 cuda_device_id = device_allocator.first;
703 trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second));
704 } else {
705 // The value in trt_allocator is a nullptr and cudamalloc will be used.
706 LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on "
707 "device 0 and use cudamalloc as an allocator";
708 }
709 cudaSetDevice(cuda_device_id);
710
711 auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name);
712 const bool calibrate_int8 =
713 (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
714
715 // Create static engines with precision_mode fp32/fp16.
716 TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
717 TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
718 info.segment_graph_def, nullptr,
719 calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
720 max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger,
721 trt_allocator.get(), /*calibrator=*/nullptr, &engine,
722 info.use_calibration, params.use_implicit_batch,
723 /*convert_successfully=*/nullptr, profile, info.engine_name,
724 /*use_explicit_precision=*/params.use_explicit_precision, cluster));
725 TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
726 *segment_string = string(static_cast<const char*>(engine_data->data()),
727 engine_data->size());
728 return Status::OK();
729 }
730
ConvertGraph(const TRTOptimizationPass::ConversionParams & params,grappler::GrapplerItem & grappler_item,const std::vector<string> & input_output_names,grappler::Cluster * cluster,GraphDef * output)731 Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
732 grappler::GrapplerItem& grappler_item,
733 const std::vector<string>& input_output_names,
734 grappler::Cluster* cluster, GraphDef* output) {
735 // Sanity checks.
736 TRT_ENSURE(output != nullptr)
737 if (params.precision_mode != TrtPrecisionMode::INT8 &&
738 params.use_calibration) {
739 return errors::InvalidArgument(
740 "Calibration with FP32 or FP16 is not supported.");
741 }
742
743 GraphDef& graph_def = grappler_item.graph;
744
745 // When precision_mode is FP16, transform cast(x, fp32) to
746 // cast(cast(x, fp16), fp32). This creates cast(fp16, f32) that can be
747 // included in the TRTEngineOp as an TensorRT Identity layer for performance:
748 // . Avoid cast(fp32, fp16) in the TRT engine implementation for fp16
749 // precision.
750 // . Changing the input to the TRTEngine from fp32 to fp16 may reduce data
751 // moving from the host to the GPU.
752 if (params.precision_mode == TrtPrecisionMode::FP16) {
753 for (int i = 0; i < graph_def.node_size(); i++) {
754 NodeDef* node_def = graph_def.mutable_node(i);
755 TF_RETURN_IF_ERROR(MaybeRewriteCastToFp32(&graph_def, node_def));
756 }
757 }
758
759 // Construct a GrapplerItem using the modified graph_def and the input
760 // grappler_item.
761 grappler::GraphProperties static_graph_properties(grappler_item);
762 TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
763
764 // Convert graphdef to graph.
765 FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
766 Graph graph(flib);
767 TF_RETURN_IF_ERROR(
768 ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph));
769
770 // Segment the graph into subgraphs that can be converted to TensorRT
771 segment::SegmentOptions segment_options;
772 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
773 for (const auto& node : input_output_names) {
774 segment_options.exclude_node_list.insert(node);
775 }
776 segment_options.minimum_segment_size = params.minimum_segment_size;
777 segment_options.use_implicit_batch = params.use_implicit_batch;
778 if (segment_options.use_implicit_batch)
779 segment_options.maximum_batch_size = params.max_batch_size;
780 segment_options.allow_dynamic_non_batch_dim =
781 AllowDynamicNonBatchDimension(params);
782
783 segment::SegmentVector initial_segments;
784 TrtNodeValidator validator(static_graph_properties, params.precision_mode,
785 params.use_calibration, params.use_implicit_batch,
786 params.use_explicit_precision);
787 TF_RETURN_IF_ERROR(segment::SegmentGraph(
788 &graph, &static_graph_properties,
789 std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
790 std::placeholders::_1),
791 // Input validation is already done by TrtNodeValidator, so we don't
792 // need to check the input edges.
793 [](const Edge* edge) { return true; }, OutputEdgeValidator(),
794 segment_options, &initial_segments));
795 LOG(INFO) << "Number of TensorRT candidate segments: "
796 << initial_segments.size();
797
798 // Get the EngineInfo for each segment.
799 std::unordered_map<string, Node*> node_map;
800 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
801 std::vector<EngineInfo> engine_segments;
802 engine_segments.reserve(initial_segments.size());
803 std::vector<Node*> reverse_topo_order;
804 GetPostOrder(graph, &reverse_topo_order);
805 segment::SegmentVector converted_segments;
806 converted_segments.reserve(initial_segments.size());
807 string engine_name_prefix =
808 StrCat("TRTEngineOp_",
809 absl::StrFormat("%0*d", 3, GetNextGraphSequenceNumber()), "_");
810 for (size_t t = 0; t < initial_segments.size(); t++) {
811 auto& curr_segment = initial_segments.at(t);
812 EngineInfo curr_engine;
813 curr_engine.engine_name =
814 StrCat(engine_name_prefix, absl::StrFormat("%0*d", 3, t));
815
816 bool int8_no_calib = (!params.use_calibration &&
817 params.precision_mode == TrtPrecisionMode::INT8);
818 bool has_qdq = false;
819 if (int8_no_calib) {
820 has_qdq = absl::c_any_of(reverse_topo_order, IsQuantizeAndDequantizeOp);
821 }
822
823 Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment,
824 reverse_topo_order, &curr_engine);
825 if (!status.ok()) {
826 LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t
827 << ": " << status;
828 continue;
829 }
830
831 curr_engine.engine_type = GetEngineType(params);
832 curr_engine.use_calibration = params.use_calibration;
833 // Building cuda engines for INT8 without calibration and without dynamic
834 // range info cause TRT failure. Avoid this situation by setting the
835 // precision to FP16.
836 if (int8_no_calib && !has_qdq) {
837 LOG(WARNING) << "Set engine precision to FP16 due to missing QDQ OP";
838 curr_engine.precision_mode = TrtPrecisionMode::FP16;
839 } else {
840 curr_engine.precision_mode = params.precision_mode;
841 }
842 curr_engine.maximum_cached_engines = params.max_cached_engines;
843 curr_engine.allow_build_at_runtime = params.allow_build_at_runtime;
844 if (!curr_engine.max_batch_size.has_value()) {
845 curr_engine.max_batch_size = params.max_batch_size;
846 }
847
848 status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def,
849 &graph, curr_engine.engine_name);
850
851 if (!status.ok()) {
852 LOG_WARNING_WITH_PREFIX
853 << "Failed to register segment graphdef to the library " << t << ": "
854 << status;
855 continue;
856 }
857
858 engine_segments.push_back(std::move(curr_engine));
859 converted_segments.push_back(std::move(curr_segment));
860
861 if (VLOG_IS_ON(8)) {
862 string fname = engine_segments.back().engine_name;
863 StrAppend(&fname, ".pb");
864 std::fstream f;
865 f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
866 f << engine_segments.at(t).segment_graph_def.SerializeAsString();
867 f.close();
868 }
869 }
870
871 // Save the cuda device since we may need to switch to another cuda device to
872 // build static engines.
873 std::optional<int> old_cuda_device = std::nullopt;
874 if (!params.is_dynamic_op) {
875 int cuda_device_id;
876 cudaError_t cuda_error = cudaGetDevice(&cuda_device_id);
877 if (cuda_error != cudaSuccess) {
878 LOG_WARNING_WITH_PREFIX << "Couldn't get current device: "
879 << cudaGetErrorString(cuda_error);
880 } else {
881 VLOG(1) << "Current cuda device is " << cuda_device_id;
882 old_cuda_device = cuda_device_id;
883 }
884 }
885
886 auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] {
887 if (old_cuda_device.has_value()) {
888 cudaSetDevice(old_cuda_device.value());
889 }
890 });
891
892 std::vector<Node*> engine_nodes;
893 engine_nodes.resize(engine_segments.size());
894 for (int i = 0; i < engine_segments.size(); ++i) {
895 auto& engine = engine_segments.at(i);
896 // TODO(b/170762693): implement the heuristic to calculate
897 // max_workspace_size_bytes.
898 engine.max_workspace_size_bytes = params.max_workspace_size_bytes;
899 VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to "
900 << engine.engine_name;
901 auto status =
902 CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph,
903 &engine_nodes, cluster);
904
905 string msg = StrCat("segment ", i, " consisting of ",
906 converted_segments.at(i).nodes.size(), " nodes by ",
907 engine.engine_name);
908 if (status.ok()) {
909 LOG(INFO) << "Replaced " << msg << ".";
910 } else {
911 // Graph is not modified.
912 LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg
913 << " reason: " << status.error_message()
914 << " (keeping original segment).";
915 }
916 if (VLOG_IS_ON(1)) {
917 msg = "Segment consists of nodes: ";
918 for (const Node* node : converted_segments.at(i).nodes) {
919 StrAppend(&msg, node->name(), ", ");
920 }
921 VLOG(1) << msg;
922 }
923
924 // If status is ok, we successfully added the node to the graph and can
925 // remove segment ops. Otherwise graph is not modified.
926 if (status.ok()) {
927 for (const Node* node : converted_segments.at(i).nodes) {
928 graph.RemoveNode(const_cast<Node*>(node));
929 }
930 }
931 }
932 graph.ToGraphDef(output);
933 return Status::OK();
934 }
935
936 } // namespace convert
937 } // namespace tensorrt
938 } // namespace tensorflow
939
940 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
941