1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17
18 #include <algorithm>
19 #include <fstream>
20 #include <map>
21 #include <numeric>
22 #include <queue>
23 #include <tuple>
24 #include <unordered_map>
25 #include <utility>
26
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/graph.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/env_var.h"
41
42 #if GOOGLE_CUDA && GOOGLE_TENSORRT
43
44 namespace tensorflow {
45 namespace tensorrt {
46 namespace segment {
47 namespace {
48 using absl::StrAppend;
49 using absl::StrAppendFormat;
50 using absl::StrCat;
51 using absl::StrJoin;
52
53 // A simple graph representation to mirror Graph. This structure
54 // helps saving memory since segmenter modifies the graph in place, preventing
55 // the need to create a copy of the graph. It is composed of edges and nodes.
56 // Nodes keep pointers to original TF nodes.
57 class SimpleNode;
58 class SimpleGraph;
59 class SimpleEdge {
60 public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)61 SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
62 int dst_port, bool is_control = false)
63 : id_(id),
64 src_(src),
65 src_port_(src_port),
66 dst_(dst),
67 dst_port_(dst_port),
68 control_(is_control) {}
~SimpleEdge()69 ~SimpleEdge() {}
70
src() const71 SimpleNode* src() const { return src_; }
dst() const72 SimpleNode* dst() const { return dst_; }
src_output() const73 int src_output() const { return src_port_; }
dst_input() const74 int dst_input() const { return dst_port_; }
id() const75 int id() const { return id_; }
IsControlEdge() const76 bool IsControlEdge() const { return control_; }
77
78 private:
79 int id_;
80 SimpleNode* src_;
81 int src_port_;
82 SimpleNode* dst_;
83 int dst_port_;
84 bool control_;
85 };
86
87 class SimpleNode {
88 public:
89 SimpleNode(const Node* node, const int id);
90
in_edges() const91 const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const92 const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
93
in_nodes() const94 std::vector<SimpleNode*> in_nodes() const {
95 std::vector<SimpleNode*> res;
96 res.reserve(in_edges_.size());
97 for (const auto e : in_edges_) {
98 if (e) res.push_back(e->src());
99 }
100 return res;
101 }
102
out_nodes() const103 std::vector<SimpleNode*> out_nodes() const {
104 std::vector<SimpleNode*> res;
105 res.reserve(out_edges_.size());
106 for (const auto e : out_edges_) {
107 if (e) res.push_back(e->dst());
108 }
109 return res;
110 }
111
name() const112 const string& name() const { return node_->name(); }
tf_node() const113 const Node* tf_node() const { return node_; }
id() const114 int id() const { return id_; }
115
116 private:
117 const Node* node_;
118 std::vector<SimpleEdge*> in_edges_;
119 std::vector<SimpleEdge*> out_edges_;
120 int id_;
121
122 friend class SimpleGraph;
123 };
124
125 class SimpleGraph {
126 public:
127 explicit SimpleGraph(const Graph* g);
128 ~SimpleGraph();
129
130 void AddControlEdge(SimpleNode* src, SimpleNode* dst);
131 void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
132 void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)133 SimpleNode* FindNodeId(int node_id) {
134 if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
135 return nullptr;
136 }
137 return nodes_[node_id];
138 }
num_node_ids() const139 int num_node_ids() const { return nodes_.size(); }
source_node() const140 const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const141 const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
142
143 private:
144 const Graph* g_;
145 std::vector<SimpleNode*> nodes_;
146 std::vector<SimpleEdge*> edges_;
147 // free_edge_ids_ and free_node_ids_ contain freed indices.
148 std::set<int> free_edge_ids_;
149 std::set<int> free_node_ids_;
150 };
151
SimpleNode(const Node * node,const int id)152 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
153 if (node_) {
154 in_edges_.reserve(node_->in_edges().size());
155 out_edges_.reserve(node_->out_edges().size());
156 }
157 }
158
SimpleGraph(const Graph * g)159 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
160 int n_nodes = g_->num_node_ids();
161 nodes_.resize(n_nodes, nullptr);
162 nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
163 nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
164 int n_edges = g->num_edge_ids();
165 edges_.resize(n_edges, nullptr);
166 for (int i = 2; i < n_nodes; i++) {
167 const auto n = g->FindNodeId(i);
168 if (n) {
169 nodes_[i] = new SimpleNode(n, i);
170 } else {
171 free_node_ids_.insert(i);
172 }
173 }
174 for (int i = 0; i < n_edges; i++) {
175 const auto e = g->FindEdgeId(i);
176 if (e) {
177 const auto tfsrc = e->src();
178 const auto tfdst = e->dst();
179 bool is_control = e->IsControlEdge();
180 auto src = nodes_[tfsrc->id()];
181 auto dst = nodes_[tfdst->id()];
182 auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
183 is_control);
184 edges_[i] = edge;
185 src->out_edges_.push_back(edge);
186 dst->in_edges_.push_back(edge);
187 } else {
188 free_edge_ids_.insert(i);
189 }
190 }
191 }
192
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)193 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
194 int in_port) {
195 int i = edges_.size();
196 if (!free_edge_ids_.empty()) {
197 auto it = free_edge_ids_.begin();
198 i = *it;
199 free_edge_ids_.erase(it);
200 } else {
201 edges_.push_back(nullptr);
202 }
203 bool is_control = (out_port == Graph::kControlSlot);
204 is_control |= (in_port == Graph::kControlSlot);
205 auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
206 edges_[i] = edge;
207 src->out_edges_.push_back(edge);
208 dst->in_edges_.push_back(edge);
209 }
210
AddControlEdge(SimpleNode * src,SimpleNode * dst)211 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
212 AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
213 }
214
RemoveEdge(const SimpleEdge * edge)215 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
216 auto src = edge->src();
217 auto dst = edge->dst();
218 for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
219 if (*it == edge) {
220 src->out_edges_.erase(it);
221 break;
222 }
223 }
224 for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
225 if (*it == edge) {
226 dst->in_edges_.erase(it);
227 break;
228 }
229 }
230 }
231
~SimpleGraph()232 SimpleGraph::~SimpleGraph() {
233 for (auto x : nodes_) delete x;
234 for (auto x : edges_) delete x;
235 }
236
237 // Define comparison functions for std::set with pointer keys so that behavior
238 // is deterministic. When using std::set with pointer key types, the items are
239 // sorted by pointer address which is non-deterministic. This can cause issues
240 // for INT8 mode because the graph is converted twice and non-determinism may
241 // cause a mismatch between the calibration tables of the conversions.
242 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::__anon2febd5fa0111::SimpleEdgePtrCompare243 bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
244 return lhs->id() < rhs->id();
245 }
246 };
247
248 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)249 void StableDFS(const SimpleGraph& g, bool reverse,
250 const std::vector<const SimpleNode*>& start,
251 const std::function<bool(const SimpleNode*)>& enter,
252 const std::function<bool(const SimpleNode*)>& leave) {
253 // Stack of work to do.
254 struct Work {
255 const SimpleNode* node;
256 bool leave; // Are we entering or leaving n?
257 };
258 std::vector<Work> stack(start.size());
259 for (int i = 0; i < start.size(); ++i) {
260 stack[i] = Work{start[i], false};
261 }
262
263 auto get_nodes = [reverse](const SimpleNode* n) {
264 return reverse ? n->in_nodes() : n->out_nodes();
265 };
266 std::vector<bool> visited(g.num_node_ids(), false);
267 while (!stack.empty()) {
268 Work w = stack.back();
269 stack.pop_back();
270
271 auto n = w.node;
272 if (w.leave) {
273 if (leave && !leave(n)) return;
274 continue;
275 }
276
277 if (visited[n->id()]) continue;
278 visited[n->id()] = true;
279 if (enter && !enter(n)) return;
280
281 // Arrange to call leave(n) when all done with descendants.
282 if (leave) stack.push_back(Work{n, true});
283
284 auto nodes = get_nodes(n);
285 std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
286 std::sort(nodes_sorted.begin(), nodes_sorted.end(),
287 [](const SimpleNode* lhs, const SimpleNode* rhs) {
288 return lhs->name() < rhs->name();
289 });
290 for (const SimpleNode* node : nodes_sorted) {
291 if (!visited[node->id()]) {
292 stack.push_back(Work{node, false});
293 }
294 }
295 }
296 }
297
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)298 bool CanContractEdge(const SimpleEdge* edge,
299 const std::unique_ptr<SimpleGraph>& graph) {
300 const auto src = edge->src();
301 const auto dst = edge->dst();
302
303 // Can't contract edge if doing so would cause a cycle in the
304 // graph. So, if there is a directed path from 'src' to 'dst', other
305 // than 'edge' (or any other direct edge from 'src' to 'dst'), then
306 // combining 'src' and 'dst' will cause a cycle along that path.
307 //
308 // In practice, to avoid modifying the graph and to take advantage
309 // of existing graph functions, we perform an equivalent.
310 // 1. Get all nodes incoming to 'dst', excluding 'src'
311 // 2. Reverse DFS from those nodes
312 // 3. If reverse DFS reaches 'src' then we have a cycle
313 //
314 // TODO(aaroey): there are several problems with the current approach:
315 // 1. src->dst->src, this is not detected but it should be;
316 // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
317 // is detected but it should not be.
318 //
319 // Note that it's fine that dst connects back to src indirectly (i.e. through
320 // a path with length > 1 that consists of intermedia nodes other than src).
321 // While loops is one example.
322 //
323 // The goal is to make sure that the trt subgraph:
324 // 1. has no loops (i.e. is a DAG), and
325 // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
326 // in the subgraph), then all paths from X to Y are in the subgraph.
327 //
328 // To achieve this goal, the correct way seems to be:
329 // 1. remove any direct edge from src->dst;
330 // 2. detect if src can reach dst, if so they cannot be merged.
331 std::vector<const SimpleNode*> dfs_start_nodes;
332 for (const SimpleNode* node : dst->in_nodes()) {
333 if (node != src) {
334 dfs_start_nodes.push_back(node);
335 }
336 }
337 bool has_cycle = false;
338 StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
339 [&has_cycle, src](const SimpleNode* n) {
340 if (n == src) {
341 has_cycle = true;
342 return false;
343 }
344 return true;
345 });
346 return !has_cycle;
347 }
348
349 // TODO(bixia): put this to a common utility file.
TensorPropertiesToString(const OpInfo::TensorProperties & prop)350 string TensorPropertiesToString(const OpInfo::TensorProperties& prop) {
351 string s = StrCat(DataTypeString(prop.dtype()), ": ");
352 StrAppend(&s, "[");
353 if (prop.shape().unknown_rank()) {
354 StrAppend(&s, "?");
355 } else {
356 StrAppend(&s, StrJoin(prop.shape().dim(), ",",
357 [](string* out, const TensorShapeProto_Dim& d) {
358 StrAppendFormat(out, "%d", d.size());
359 }));
360 }
361 StrAppend(&s, "]");
362 return s;
363 }
364
TensorPropertiesToString(const std::vector<OpInfo::TensorProperties> & properties)365 string TensorPropertiesToString(
366 const std::vector<OpInfo::TensorProperties>& properties) {
367 return StrJoin(properties, "; ",
368 [](string* out, const OpInfo::TensorProperties& prop) {
369 StrAppend(out, TensorPropertiesToString(prop));
370 });
371 }
372
373 // From the given list of input properties, returns the leading shape, which is
374 // the shape that determines the batch size of the operation. The leading shape
375 // is selected from the group of input shapes with the highest rank as follows:
376 // . If all of those shapes have non-negative values for the batch dimension,
377 // the leading shape is the one with the largest value for the batch
378 // dimension.
379 // . If some or all of those shapes have negative values for the batch
380 // dimension, and the rest of those shapes have 1 for the batch dimension,
381 // the leading shape is the first of those shapes with a negative value for
382 // the batch dimension.
383 // . Otherwise, we can't determine the leading shape for the operation and
384 // have to exclude the operation from TRT.
385 //
386 // Examples:
387 // case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4]
388 // case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4]
389 // case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4]
390 // case-4: a[-1,3,4] + b[2,3,4] => no leading shape
391 //
392 // We have to return "no leading shape" for case-4 to exclude such operation
393 // from being translated for this reason:
394 // The actually input for "a" have to be in the shape of [2,3,4] for the
395 // operation to be valid. On the other hand, if we translate the operation
396 // to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for
397 // any input shape of "a".
398 //
399 // This routine assumes the input program is valid. For example, we shouldn't
400 // see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input
401 // properties is not empty and all input have known shapes.
402 //
403 // TODO(bixia): find a way to share this knowledge with the converter.
404 // TODO(bixia): investigate the use of symbolic shape analysis to improve
405 // segmentation, such as by requiring the dynamic dimensions to have the same
406 // negative value.
FindLeadingShape(absl::Span<const OpInfo::TensorProperties> properties)407 std::optional<const TensorShapeProto*> FindLeadingShape(
408 absl::Span<const OpInfo::TensorProperties> properties) {
409 DCHECK(!properties.empty());
410 const TensorShapeProto* result;
411 int max_batch_dim_value;
412 auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) {
413 result = s;
414 max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size();
415 };
416
417 DCHECK(!properties[0].shape().unknown_rank());
418 choose_shape_with_higher_rank(&properties[0].shape());
419
420 for (const OpInfo::TensorProperties& p : properties.subspan(1)) {
421 DCHECK(!p.shape().unknown_rank());
422 if (p.shape().dim_size() < result->dim_size()) continue;
423
424 if (p.shape().dim_size() > result->dim_size()) {
425 choose_shape_with_higher_rank(&p.shape());
426 continue;
427 }
428
429 // Among the shapes with the same rank, choose the one with a dynamic batch
430 // size. If no shapes have a dynamic batch size, choose the one with the
431 // largest size.
432 if (result->dim_size() < 1) continue;
433
434 if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) {
435 if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) {
436 result = &p.shape();
437 } else {
438 max_batch_dim_value =
439 std::max<int>(max_batch_dim_value, p.shape().dim(0).size());
440 }
441
442 continue;
443 }
444
445 if (p.shape().dim(0).size() > result->dim(0).size()) {
446 result = &p.shape();
447 max_batch_dim_value = result->dim(0).size();
448 }
449 }
450
451 if (result->dim_size() > 0 && result->dim(0).size() < 0) {
452 // dynamic batch size
453 if (max_batch_dim_value <= 1) {
454 return result;
455 } else {
456 return std::nullopt;
457 }
458 }
459
460 return result;
461 }
462
463 // Returns the inputs that are relevant to determinate the batch size of the
464 // operation. This routine handles the following cases:
465 // . Operations that support implicit boradcasting, such as operation mul.
466 // In this case, we need to inspect all the inputs in order to determine the
467 // batch size of the operation.
468 // . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2".
469 // . The batch size of a operation is determined by the first input of the
470 // operation.
GetInputsToDeterminateBatchSize(const Node * node,const std::vector<OpInfo::TensorProperties> & all_inputs)471 absl::Span<const OpInfo::TensorProperties> GetInputsToDeterminateBatchSize(
472 const Node* node, const std::vector<OpInfo::TensorProperties>& all_inputs) {
473 // TODO(bixia): Find a way to share this knowledge with the converter.
474 static std::set<string> broadcast_supporting_ops = {
475 // ops corresponding to ConvertBinary in the converter
476 "Add",
477 "AddV2",
478 "Mul",
479 "Sub",
480 "Div",
481 "FloorDiv",
482 "RealDiv",
483 "Minimum",
484 "Maximum",
485 "Pow",
486 // other ops that need to need GetTrtBroadcastShape to convert
487 "BiasAdd",
488 "SquaredDifference",
489 "BatchMatMul",
490 "BatchMatMulV2",
491 };
492 const string& op = node->def().op();
493
494 if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") {
495 DCHECK_EQ(all_inputs.size(), 3);
496 return absl::MakeSpan(all_inputs).subspan(2, 1);
497 }
498
499 if (broadcast_supporting_ops.count(op)) {
500 return absl::MakeSpan(all_inputs);
501 }
502
503 // This is the common case for the operations that don't support implicit
504 // broadcasting: the first operand determines its batch size. All otherwise
505 // cases are handled before reaching here.
506 return absl::MakeSpan(all_inputs).subspan(0, 1);
507 }
508
509 // Returns true if the operation we can remove the implicit batch of the
510 // operation.
511 //
512 // In particular, if the input shape has dynamic rank or the input shape rank
513 // is less than 2, we can't remove the implicit batch dimension and generate
514 // a new operation for TRT translation.
OperationCanBeTranslatedToImplicitBatch(const grappler::GraphProperties * graph_properties,const Node * node)515 bool OperationCanBeTranslatedToImplicitBatch(
516 const grappler::GraphProperties* graph_properties, const Node* node) {
517 VLOG(3) << "process node " << node->name();
518 if (node->num_inputs() == 0) return true;
519 if (!graph_properties || !graph_properties->HasInputProperties(node->name()))
520 return false;
521
522 VLOG(3) << "input shapes "
523 << TensorPropertiesToString(
524 graph_properties->GetInputProperties(node->name()));
525
526 const std::vector<OpInfo::TensorProperties>& all_input_properties =
527 graph_properties->GetInputProperties(node->name());
528 absl::Span<const OpInfo::TensorProperties> input_properties =
529 GetInputsToDeterminateBatchSize(node, all_input_properties);
530 if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) {
531 return p.shape().unknown_rank();
532 })) {
533 return false;
534 }
535
536 std::optional<const TensorShapeProto*> leading_shape =
537 FindLeadingShape(input_properties);
538 return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2;
539 }
540
541 // Returns true if we can't be sure that the operand with the given properties
542 // won't have negative values for non-batch dimensions.
543 //
HasDynamicNonBatchDimension(const OpInfo::TensorProperties & prop)544 bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) {
545 const TensorShapeProto& shape = prop.shape();
546 if (shape.unknown_rank()) return true;
547
548 // Scalar is a well specified shape, and TRT supports implicit broadcasting
549 // from scalar to other shapes.
550 if (shape.dim_size() == 0) return false;
551 for (int i = 1; i < shape.dim_size(); ++i) {
552 // The value of a dynamic dimension can be other negative values besides
553 // -1, representing the symbolic group of the dimension.
554 if (shape.dim(i).size() <= -1) {
555 return true;
556 }
557 }
558 return false;
559 }
560
561 // Returns true if we can't be sure that the operation won't have dynamic
562 // non-batch dimension involved. We only check the shape of the first output
563 // assuming shape inference already propagates the shapes.
OperationHasDynamicNonBatchDimension(const grappler::GraphProperties * graph_properties,const Node * node)564 bool OperationHasDynamicNonBatchDimension(
565 const grappler::GraphProperties* graph_properties, const Node* node) {
566 VLOG(3) << "process node " << node->name();
567 // If the node doesn't have any input or output, not computation is involved.
568 if (node->num_inputs() == 0 || node->num_outputs() == 0) return false;
569
570 // If the node doesn't have output properties, return true to be conservative.
571 if (!graph_properties->HasOutputProperties(node->name())) return true;
572 VLOG(3) << "output shapes "
573 << TensorPropertiesToString(
574 graph_properties->GetOutputProperties(node->name()));
575 return HasDynamicNonBatchDimension(
576 graph_properties->GetOutputProperties(node->name()).at(0));
577 }
578
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)579 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
580 std::vector<const SimpleEdge*>* remove_edges) {
581 // Transfer all inputs and outputs of 'dst' to 'src' except edges
582 // connecting the two.
583 auto src = edge->src();
584 auto dst = edge->dst();
585
586 // We can use '0' for input/output index because we don't need them
587 // to be accurate for the way we are using the graph.
588 std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
589 dst->in_edges().end());
590 for (const SimpleEdge* in_edge : in_edges) {
591 if (in_edge->IsControlEdge()) {
592 if (in_edge->src() != src) {
593 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
594 graph->AddControlEdge(e->src(), src);
595 }
596 } else {
597 if (in_edge->src() != src) {
598 SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
599 if (e->src() == graph->source_node()) {
600 graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
601 } else {
602 graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
603 }
604 }
605 }
606 }
607
608 std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
609 dst->out_edges().end());
610 for (const SimpleEdge* out_edge : out_edges) {
611 if (out_edge->IsControlEdge()) {
612 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
613 graph->AddControlEdge(src, e->dst());
614 } else {
615 SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
616 if (e->dst() == graph->sink_node()) {
617 VLOG(1) << " edge to sink node " << src->name() << " -> "
618 << e->dst()->name();
619 graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
620 } else {
621 graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
622 }
623 }
624 }
625
626 // Return the edges that must be removed to disconnect 'dst' from
627 // the graph. We don't actually remove 'dst' since the caller holds
628 // references to all the nodes.
629 for (const auto& in_edge : dst->in_edges()) {
630 remove_edges->push_back(in_edge);
631 }
632 for (const auto& out_edge : dst->out_edges()) {
633 remove_edges->push_back(out_edge);
634 }
635 }
636
637 // Returns a batch size representation for a segment that only contains the
638 // given node.
GetClusterBatchSizeForNode(const grappler::GraphProperties * graph_properties,const Node * node,bool use_implicit_batch)639 ClusterBatchSize GetClusterBatchSizeForNode(
640 const grappler::GraphProperties* graph_properties, const Node* node,
641 bool use_implicit_batch) {
642 ClusterBatchSize cluster_batch_size;
643 if (!use_implicit_batch || !node || node->num_inputs() == 0) {
644 return cluster_batch_size;
645 }
646
647 const NodeDef& node_def = node->def();
648 if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) {
649 cluster_batch_size.SetMaxBatchSize(
650 node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i());
651 }
652
653 // As shape inference cannot provide any useful information about the batch
654 // size, we keep it as missing.
655 if (!graph_properties ||
656 !graph_properties->HasInputProperties(node->name())) {
657 VLOG(3) << "doesn't have input property";
658 return cluster_batch_size;
659 }
660
661 const std::vector<OpInfo::TensorProperties>& input_properties =
662 graph_properties->GetInputProperties(node->name());
663 std::optional<const TensorShapeProto*> optional_leading_shape =
664 FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties));
665 DCHECK(optional_leading_shape.has_value());
666 const TensorShapeProto* leading_shape = optional_leading_shape.value();
667 DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2);
668 VLOG(3) << "set batch size as " << leading_shape->dim(0).size();
669 return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size());
670 }
671
AddSegmentForNode(const grappler::GraphProperties * graph_properties,std::vector<UnionFind<SimpleNode * >> * segments,SimpleNode * node,const DeviceNameUtils::ParsedName & device_name,bool use_implicit_batch)672 void AddSegmentForNode(const grappler::GraphProperties* graph_properties,
673 std::vector<UnionFind<SimpleNode*>>* segments,
674 SimpleNode* node,
675 const DeviceNameUtils::ParsedName& device_name,
676 bool use_implicit_batch) {
677 ClusterProperty property(
678 GetClusterBatchSizeForNode(graph_properties,
679 node == nullptr ? nullptr : node->tf_node(),
680 use_implicit_batch),
681 device_name);
682 segments->emplace_back(node, std::move(property));
683 }
684
685 } // namespace
686
ExportNonConversionReportToCSV(string filename,std::map<string,std::map<string,int>> & nonconverted_ops_map,string sep="|")687 Status ExportNonConversionReportToCSV(
688 string filename,
689 std::map<string, std::map<string, int>>& nonconverted_ops_map,
690 string sep = "|") {
691 std::fstream csv_file(filename, std::fstream::out | std::fstream::trunc);
692
693 if (!csv_file || !csv_file.good()) {
694 return errors::Internal("Failed to open output file: `", filename, "`");
695 }
696
697 LOG(WARNING) << "TF-TRT Non-Conversion Report saved at: `" << filename << "`";
698
699 csv_file << "OP Name" << sep << "Reason" << sep << "Count" << std::endl;
700
701 for (auto& op_details : nonconverted_ops_map) {
702 auto op_name = op_details.first;
703 auto op_data = op_details.second;
704
705 for (auto& reject_data : op_data) {
706 auto reason = reject_data.first;
707 auto count = reject_data.second;
708 csv_file << op_name << sep << reason << sep << count << std::endl;
709 }
710 }
711
712 csv_file.close();
713
714 if (csv_file.bad() || csv_file.fail()) {
715 return errors::Internal("Error closing the file `", filename,
716 "`. The file might be corrupted.");
717 }
718
719 return Status::OK();
720 }
721
GenerateNonConversionReport(std::map<string,std::map<string,int>> & nonconverted_ops_map)722 string GenerateNonConversionReport(
723 std::map<string, std::map<string, int>>& nonconverted_ops_map) {
724 // Fetch whether to print a detailed version of the TF-TRT conversion report.
725 // TF_TRT_SHOW_DETAILED_REPORT triggers three possible behaviors:
726 // - If Number >= 1: Print detailed non-conversion report on stdout.
727 // Usage: TF_TRT_SHOW_DETAILED_REPORT=1
728 // - If non empty string: Exports the non-conversion report in CSV format at
729 // the path defined by the environment variable.
730 // This will also print the detailed non-conversion
731 // report on stdout.
732 // Usage: TF_TRT_SHOW_DETAILED_REPORT=/path/to/file.csv
733 // - Else: Print normal (undetailed) non-conversion report on
734 // stdout.
735
736 string detailed_report_var;
737 TF_CHECK_OK(ReadStringFromEnvVar("TF_TRT_SHOW_DETAILED_REPORT",
738 /*default_value=*/"", &detailed_report_var));
739
740 bool show_detailed_conversion_report = false;
741
742 if (detailed_report_var != "") {
743 // Checking if `TF_TRT_SHOW_DETAILED_REPORT` env var is a string or a number
744 if (detailed_report_var.find_first_not_of("-0123456789") != string::npos) {
745 const Status status = ExportNonConversionReportToCSV(
746 detailed_report_var, nonconverted_ops_map);
747
748 if (!status.ok()) {
749 // Log the error in case of issue, however do not stop execution.
750 LOG(ERROR) << "Problem encountered while generating the TF-TRT "
751 << "Non-Conversion Report in CSV Format:\n"
752 << status.error_message();
753 }
754 show_detailed_conversion_report = true;
755 } else if (std::stoi(detailed_report_var) >= 1) {
756 show_detailed_conversion_report = true;
757 }
758 }
759
760 string unsupported_op_report =
761 StrCat("\n\n", string(80, '#'), "\n",
762 "TensorRT unsupported/non-converted OP Report:");
763 int total_nonconverted_ops{0};
764
765 // <Reason, Count for this reason>
766 using ReasonCounterVector = std::vector<std::pair<string, int>>;
767 // <OP Name, Total Non-Converted for OP, <Reason, Count for this reason>>>
768 using NotConvertedOPTuple = std::tuple<string, int, ReasonCounterVector>;
769
770 std::vector<NotConvertedOPTuple> nonconverted_ops_vec;
771
772 // Populate the vector from the map
773 for (auto& nonconverted_op_data : nonconverted_ops_map) {
774 int total_nonconverted_op{0};
775 ReasonCounterVector reason_occurances_vect;
776
777 auto op_name = nonconverted_op_data.first;
778 auto op_data = nonconverted_op_data.second;
779
780 for (auto& notconversion_reason_data : op_data) {
781 auto reason_count = notconversion_reason_data.second;
782 total_nonconverted_op += reason_count;
783 reason_occurances_vect.push_back(notconversion_reason_data);
784 }
785
786 // Sort in descending number of occurances for the reasons why a given
787 // TensorFlow OP was not converted.
788 std::sort(reason_occurances_vect.begin(), reason_occurances_vect.end(),
789 [](const std::pair<string, int>& a,
790 const std::pair<string, int>& b) -> bool {
791 return a.second > b.second;
792 });
793
794 nonconverted_ops_vec.push_back(std::make_tuple(
795 op_name, total_nonconverted_op, reason_occurances_vect));
796 }
797
798 // Sort the vector by descending OP names.
799 std::sort(nonconverted_ops_vec.begin(), nonconverted_ops_vec.end(),
800 [](const NotConvertedOPTuple& a, const NotConvertedOPTuple& b) {
801 return std::get<1>(a) > std::get<1>(b);
802 });
803
804 for (auto& notconverted_op_detail : nonconverted_ops_vec) {
805 auto& op_name = std::get<0>(notconverted_op_detail);
806 auto& op_total_nonconverted = std::get<1>(notconverted_op_detail);
807 total_nonconverted_ops += op_total_nonconverted;
808
809 unsupported_op_report = StrCat(unsupported_op_report, "\n\t- ", op_name,
810 " -> ", op_total_nonconverted, "x");
811
812 if (show_detailed_conversion_report) {
813 auto& nonconverted_ops_details = std::get<2>(notconverted_op_detail);
814
815 for (auto& nonconversion_details : nonconverted_ops_details) {
816 auto& reason = nonconversion_details.first;
817 auto& reason_count = nonconversion_details.second;
818 if (reason_count == 0) {
819 continue;
820 }
821
822 unsupported_op_report = StrCat(unsupported_op_report, "\n\t\t- ",
823 "[Count: ", reason_count, "x] ", reason);
824 }
825 unsupported_op_report = StrCat(unsupported_op_report, "\n");
826 }
827 }
828
829 unsupported_op_report =
830 StrCat(unsupported_op_report, "\n", string(80, '-'),
831 "\n\t- Total nonconverted OPs: ", total_nonconverted_ops,
832 "\n\t- Total nonconverted OP Types: ", nonconverted_ops_map.size(),
833 "\nFor more information see https://docs.nvidia.com/deeplearning",
834 "/frameworks/tf-trt-user-guide/index.html#supported-ops.", "\n",
835 string(80, '#'), "\n");
836
837 return unsupported_op_report;
838 }
839
SegmentGraph(const Graph * tf_graph,const grappler::GraphProperties * graph_properties,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentVector * segments)840 Status SegmentGraph(const Graph* tf_graph,
841 const grappler::GraphProperties* graph_properties,
842 const std::function<Status(const Node*)>& candidate_fn,
843 const std::function<bool(const Edge*)>& input_candidate_fn,
844 const std::function<bool(const Edge*)>& output_candidate_fn,
845 const SegmentOptions& options, SegmentVector* segments) {
846 if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
847 return errors::Internal(
848 "Explicit batch mode should allow dynamic non-batch dimensions");
849 }
850
851 if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
852 return errors::Internal("Implicit batch mode requires maximum_batch_size");
853 }
854
855 if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
856 return errors::Internal(
857 "Need graph propertities to disallow dynamic non-batch dimensions");
858 }
859
860 // Steps:
861 // 1. run the segmentation algorithm to find all the segments, which uses
862 // candidate_fn to determine the candidates segment nodes;
863 // 2. for each segments, remove the nodes that are inputs/outputs of the
864 // segment but are not eligible, using input/output_candidate_fn to
865 // determine the eligibilities;
866 // 3. convert the segment into expected return format and return the result.
867
868 // --------------------------------- Step 1 ---------------------------------
869 auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
870
871 // Fetch the user-provide TF operations denylisted for conversion by TF-TRT.
872 const absl::flat_hash_set<string> tftrt_op_denylist = [] {
873 string tftrt_op_denylist_str;
874 TF_CHECK_OK(ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", /*default_value=*/"",
875 &tftrt_op_denylist_str));
876 absl::flat_hash_set<string> tftrt_op_denylist{};
877 for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
878 tftrt_op_denylist.insert(x);
879 }
880 // Force a rehash of the flat hash set
881 tftrt_op_denylist.rehash(0);
882 return tftrt_op_denylist;
883 }();
884
885 // Use a union-find to collect the nodes that belong to the same
886 // segment. A node value of nullptr indicates that the node is not a candidate
887 // for TRT.
888
889 std::map<string, std::map<string, int>> nonconverted_ops_map = {};
890
891 // Parsing each node of the graph
892 std::vector<UnionFind<SimpleNode*>> node_segments;
893 for (int i = 0; i < graph->num_node_ids(); ++i) {
894 SimpleNode* node = graph->FindNodeId(i);
895
896 if (!node) {
897 VLOG(3) << "Node " << i << " doesn't exist in the graph";
898 continue;
899 }
900
901 const string node_op_type{node->tf_node()->type_string()};
902
903 auto exclude_node = [&](absl::string_view reason) {
904 VLOG(1) << "Not a TF-TRT candidate, "
905 << "(Op type: " << node_op_type << "), "
906 << "(Op name: " << node->name() << "), "
907 << "(Reason: " << reason << ")";
908 nonconverted_ops_map[node_op_type][string(reason)]++;
909 node = nullptr;
910 };
911 std::optional<DeviceNameUtils::ParsedName> device_name =
912 GetDeviceParsedName(node->tf_node());
913 // GetDeviceParseName capitalizes the device type.
914 if (!device_name.has_value() ||
915 (device_name->has_type && device_name->type != "GPU")) {
916 exclude_node("node can't be placed on GPU");
917 } else if (options.exclude_node_list.count(node->name()) != 0) {
918 exclude_node(
919 "excluded by segmenter option. Most likely an input or "
920 "output node.");
921 } else if (options.use_implicit_batch &&
922 !OperationCanBeTranslatedToImplicitBatch(graph_properties,
923 node->tf_node())) {
924 exclude_node(
925 "implicit batch mode requires input shape with at least two "
926 "dimensions");
927 } else if (!options.allow_dynamic_non_batch_dim &&
928 OperationHasDynamicNonBatchDimension(graph_properties,
929 node->tf_node())) {
930 exclude_node("dynamic non-batch dimensions not allowed");
931 } else {
932 const Status status = candidate_fn(node->tf_node());
933 if (!status.ok()) {
934 exclude_node(status.error_message());
935 } else if (tftrt_op_denylist.contains(node->tf_node()->type_string())) {
936 // WARNING verbosity since the user explicitly requests this behavior.
937 LOG_WARNING_WITH_PREFIX
938 << "Denylisted as TF-TRT candidate, "
939 << "(Op type: " << node->tf_node()->type_string() << "), "
940 << "(Op name: " << node->name() << ")";
941 exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
942 } else {
943 VLOG(2) << "Accepted as a TF-TRT candidate, "
944 << "(Op type: " << node->tf_node()->type_string() << "), "
945 << "(Op name: " << node->name();
946 }
947 }
948 AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
949 options.use_implicit_batch);
950 }
951
952 LOG(WARNING) << GenerateNonConversionReport(nonconverted_ops_map);
953
954 // The segmentation algorithm below visits nodes in reverse topological order
955 // and attempts to merge nodes along output edges. That means that subgraphs
956 // grow from the output-side of the network towards the inputs.
957 //
958 // In general this is not guaranteed to produce a globally optimal
959 // segmentation. For example, consider graph with node {A, B, C, D} and edges
960 // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
961 // in theory we can choose to contract either A, B or B, D but not both, but
962 // here it always choose to contract B, D.
963 //
964 // In the future if we have a measure of how beneficial it is to include a
965 // given node in a TRT subgraph then we can revisit this algorithm to take
966 // advantage of that information.
967 std::vector<const SimpleNode*> order;
968 order.reserve(graph->num_node_ids());
969 StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
970 /*enter=*/nullptr, [&order](const SimpleNode* n) {
971 order.push_back(n);
972 return true;
973 });
974 for (const SimpleNode* node : order) {
975 // All output nodes of 'node' have been visited.
976 VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
977 // 'node' must be a TRT candidate.
978 if (node_segments[node->id()].Value() == nullptr) {
979 VLOG(3) << "... not a TRT candidate";
980 continue;
981 }
982 // Contract output edges to combine 'node' with output nodes. Repeat this
983 // step until no output edges can be further contracted. This is because
984 // contracting an output edge may unblock new edges for contracting.
985 ClusterBatchSize expected_batch_size =
986 node_segments[node->id()].Property().BatchSize();
987 DeviceNameUtils::ParsedName expected_device_name =
988 node_segments[node->id()].Property().DeviceName();
989 VLOG(3) << "batch size " << expected_batch_size;
990 while (true) {
991 std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
992 // TODO(bixia): consider merging the loop to find the edges and the loop
993 // to contract the edges.
994 for (const SimpleEdge* out_edge : node->out_edges()) {
995 VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
996 << out_edge->dst()->id() << " <- " << node->id() << " )";
997 if (out_edge->IsControlEdge()) {
998 VLOG(3) << "... ... Control Edge, Skipping";
999 continue;
1000 }
1001 UnionFind<SimpleNode*>* out_cluster =
1002 &node_segments[out_edge->dst()->id()];
1003 // Out node must be a TRT candidate.
1004 if (out_cluster->Value() == nullptr) {
1005 VLOG(3) << "... ... not a TRT candidate";
1006 continue;
1007 }
1008 // Out node must have compatible batch size.
1009 ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
1010 ClusterBatchSize merged_batch_size = expected_batch_size;
1011 if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
1012 VLOG(3) << "... ... incompatible batch sizes "
1013 << expected_batch_size.ToString() << " "
1014 << out_batch_size.ToString();
1015 continue;
1016 }
1017
1018 const DeviceNameUtils::ParsedName& out_device_name =
1019 out_cluster->Property().DeviceName();
1020 std::optional<DeviceNameUtils::ParsedName> merged_device_name =
1021 MergeIfCompatible(expected_device_name, out_device_name);
1022 if (!merged_device_name.has_value()) {
1023 VLOG(3) << "... ... incompatible device names "
1024 << expected_device_name << " " << out_device_name;
1025 continue;
1026 }
1027
1028 if (CanContractEdge(out_edge, graph)) {
1029 VLOG(3) << "... ... can contract. new batch size "
1030 << merged_batch_size.ToString();
1031 contract_edges.insert(out_edge);
1032 expected_batch_size = merged_batch_size;
1033 expected_device_name = *merged_device_name;
1034 } else {
1035 VLOG(3) << "... ... cannot contract, would form cycle";
1036 }
1037 }
1038 if (contract_edges.empty()) {
1039 break;
1040 }
1041 // Contract edges and collect the adjacent nodes into the same
1042 // segment/subgraph.
1043 while (!contract_edges.empty()) {
1044 const SimpleEdge* contract_edge = *contract_edges.begin();
1045 const SimpleNode* src = contract_edge->src();
1046 const SimpleNode* dst = contract_edge->dst();
1047
1048 VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
1049 << src->id() << " <- " << dst->id();
1050 TF_RETURN_IF_ERROR(
1051 node_segments[src->id()].Merge(&node_segments[dst->id()]));
1052
1053 // Contracting the edge leaves disconnected graph edges.
1054 // Remove these from the graph and from 'contract_edges' so we
1055 // don't visit them again.
1056 SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
1057 std::vector<const SimpleEdge*> remove_edges;
1058 ContractEdge(e, graph.get(), &remove_edges);
1059
1060 for (const SimpleEdge* r : remove_edges) {
1061 contract_edges.erase(r);
1062 graph->RemoveEdge(r);
1063 }
1064 }
1065 if (expected_batch_size !=
1066 node_segments[node->id()].Property().BatchSize()) {
1067 return errors::Internal(
1068 "expected batch size is not the same as the actual batch size");
1069 }
1070 if (expected_device_name !=
1071 node_segments[node->id()].Property().DeviceName()) {
1072 return errors::Internal(
1073 "expected device name is not the same as the actual device name");
1074 }
1075 }
1076 }
1077
1078 // Collect the segments/subgraphs. Each subgraph is represented by a
1079 // set of the names of the nodes in that subgraph.
1080
1081 // A map from the segment identifier (currently the name of the root node of
1082 // the segment tree) to the segment nodes set.
1083 std::map<string, Segment> sg_map;
1084
1085 for (auto& u : node_segments) {
1086 if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
1087 sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
1088 }
1089 if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
1090 sg_map[u.Value()->name()].property = u.Property();
1091 }
1092 }
1093
1094 // --------------------------------- Step 2 ---------------------------------
1095 // Remove ineligible input/output nodes.
1096 for (auto& itr : sg_map) {
1097 std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
1098 VLOG(1) << "Segment original size: " << segment_nodes.size();
1099 while (true) {
1100 std::deque<const Node*> in_nodes_que, out_nodes_que;
1101 // Find an input node that is not eligible and add it to the queue.
1102 // Nodes that has no incoming edges should not be treated as "input",
1103 // as there are really no inputs to them. Similar for output nodes.
1104 for (auto node : segment_nodes) {
1105 bool added = false;
1106 for (const Edge* edge : node->in_edges()) {
1107 if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
1108 !segment_nodes.count(edge->src())) { // 'node' is an input node.
1109 if (!input_candidate_fn(edge)) {
1110 in_nodes_que.push_back(node);
1111 added = true;
1112 break;
1113 }
1114 }
1115 }
1116 if (added) continue; // Only adding the node once to either queue.
1117 for (const Edge* edge : node->out_edges()) {
1118 if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
1119 !segment_nodes.count(edge->dst())) { // 'node' is an output node.
1120 if (!output_candidate_fn(edge)) {
1121 out_nodes_que.push_back(node);
1122 break;
1123 }
1124 }
1125 }
1126 }
1127 if (in_nodes_que.empty() && out_nodes_que.empty()) {
1128 // No more ineligible input/output nodes.
1129 break;
1130 }
1131 // Now for each ineligible node, remove all of its inputs or outputs from
1132 // the subgraph.
1133 //
1134 // It can be proven that, if the original subgraph:
1135 // 1. is a DAG, and
1136 // 2. all paths between two nodes in the subgraph are all inside the
1137 // subgraph
1138 // then after doing this operation the resulting subgraph will keep the
1139 // same properties 1 and 2.
1140 //
1141 // For simplicity we use heuristics: for input and const output nodes
1142 // remove all their inputs, and for non-const output nodes remove all
1143 // their outputs. In this way, for common cases the number of removed
1144 // nodes should be minimum.
1145 auto remove_nodes = [&segment_nodes](bool is_input_nodes,
1146 std::deque<const Node*>* que) {
1147 // Run a BFS on the queue to find all the input/output nodes.
1148 std::set<const Node*, NodePtrCompare> visited;
1149 std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
1150 while (!que->empty()) {
1151 auto node = que->front();
1152 que->pop_front();
1153 if (!visited.insert(node).second) continue;
1154 segment_nodes.erase(node);
1155 for (auto in : (is_input_nodes || node->type_string() == "Const")
1156 ? node->in_nodes()
1157 : node->out_nodes()) {
1158 if (segment_nodes.count(in)) {
1159 que->push_back(in);
1160 if (VLOG_IS_ON(2)) {
1161 if (!logged.count(in)) {
1162 VLOG(2) << "----> Need to remove node " << in->name()
1163 << " because one of its "
1164 << (is_input_nodes ? "output" : "input")
1165 << " nodes in the graph was removed: "
1166 << node->name();
1167 logged.insert(in);
1168 }
1169 }
1170 }
1171 }
1172 }
1173 };
1174 remove_nodes(true, &in_nodes_que);
1175 remove_nodes(false, &out_nodes_que);
1176 }
1177 VLOG(1) << "Segment new size: " << segment_nodes.size();
1178 }
1179
1180 // --------------------------------- Step 3 ---------------------------------
1181 // Convert the segments into the expected return format
1182 std::vector<int> effective_nodes_counts;
1183 for (const auto& itr : sg_map) {
1184 const string& segment_root = itr.first;
1185 // Return format does not require set comparator.
1186 std::set<const Node*, NodePtrCompare> segment_nodes(
1187 itr.second.nodes.begin(), itr.second.nodes.end());
1188 if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
1189 string s;
1190 for (auto node : segment_nodes) {
1191 StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
1192 }
1193 VLOG(1) << "Nodes in segment " << segments->size()
1194 << " with parent=" << segment_root << ":" << s;
1195 }
1196
1197 const int num_effective_nodes = std::count_if(
1198 segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
1199 static auto noops =
1200 new std::set<string>{"Identity", "Snapshot", "StopGradient"};
1201 return noops->count(node->type_string()) == 0;
1202 });
1203
1204 // Don't use segments whose number of effective nodes is small.
1205 if (num_effective_nodes == 0 ||
1206 num_effective_nodes < options.minimum_segment_size) {
1207 VLOG(1) << "Segment " << segments->size() << " has only "
1208 << num_effective_nodes << " effective nodes, dropping";
1209 continue;
1210 }
1211 segments->emplace_back(itr.second.property, segment_nodes);
1212 effective_nodes_counts.push_back(num_effective_nodes);
1213 }
1214
1215 // --------------------------------- Step 4 ---------------------------------
1216 // If the number of segments exceeds max_engines, prune the smallest ones.
1217
1218 int64_t max_trt_engine_ops;
1219 TF_CHECK_OK(ReadInt64FromEnvVar("TF_TRT_MAX_ALLOWED_ENGINES",
1220 /*default_value=*/20, &max_trt_engine_ops));
1221
1222 if (max_trt_engine_ops <= 0) {
1223 LOG(WARNING) << "The environment variable TF_TRT_MAX_ALLOWED_ENGINES is "
1224 << "<= 0. TF-TRT did not limit the number of TensorRT engines "
1225 << "created.";
1226
1227 } else {
1228 if (segments->size() > max_trt_engine_ops) {
1229 LOG(WARNING) << "A total of " << segments->size() << " segments with at "
1230 << "least minimum_segment_size="
1231 << options.minimum_segment_size << " nodes have been found. "
1232 << "TF-TRT will only convert the " << max_trt_engine_ops
1233 << " largest segments. You can change this behavior by "
1234 << "modifying the environment variable "
1235 << "TF_TRT_MAX_ALLOWED_ENGINES=" << max_trt_engine_ops;
1236
1237 // Stable sort of the segment indices according to their effective sizes.
1238 std::vector<int> indices(segments->size());
1239 std::iota(indices.begin(), indices.end(), 0);
1240
1241 std::stable_sort(indices.begin(), indices.end(),
1242 [&effective_nodes_counts](int i1, int i2) {
1243 return effective_nodes_counts[i1] >
1244 effective_nodes_counts[i2];
1245 });
1246
1247 // Create a mask of segments to keep.
1248 std::vector<bool> mask = std::vector<bool>(segments->size(), false);
1249
1250 for (int i = 0; i < max_trt_engine_ops; i++) {
1251 mask[indices[i]] = true;
1252 }
1253
1254 // Gather the masked elements at the start of the array, in place.
1255 int j = 0;
1256 VLOG(1) << "The following segments have been accepted by TF-TRT:";
1257 for (int i = 0; i < segments->size(); i++) {
1258 if (mask[i]) {
1259 VLOG(1) << "[*] Segment " << i
1260 << " [node count: " << effective_nodes_counts[i]
1261 << "] accepted. Re-assigned "
1262 << "segment id=" << j;
1263 segments->at(j) = segments->at(i);
1264 j++;
1265 }
1266 }
1267
1268 VLOG(1) << "The following segments have been rejected by TF-TRT:";
1269 for (int i = 0; i < segments->size(); i++) {
1270 if (!mask[i]) {
1271 VLOG(1) << "[*] Segment " << i
1272 << " [node count: " << effective_nodes_counts[i]
1273 << "] rejected.";
1274 }
1275 }
1276
1277 // Resize the array.
1278 segments->resize(max_trt_engine_ops);
1279 } else {
1280 LOG(WARNING) << "The environment variable TF_TRT_MAX_ALLOWED_ENGINES="
1281 << max_trt_engine_ops << " has no effect since there are "
1282 << "only " << segments->size() << " TRT Engines with at "
1283 << "least minimum_segment_size="
1284 << options.minimum_segment_size << " nodes.";
1285 }
1286 }
1287
1288 return Status::OK();
1289 }
1290
1291 } // namespace segment
1292 } // namespace tensorrt
1293 } // namespace tensorflow
1294
1295 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
1296