xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/segment/segment.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/segment/segment.h"
17 
18 #include <algorithm>
19 #include <fstream>
20 #include <map>
21 #include <numeric>
22 #include <queue>
23 #include <tuple>
24 #include <unordered_map>
25 #include <utility>
26 
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
31 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/graph/algorithm.h"
34 #include "tensorflow/core/graph/graph.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/lib/strings/str_util.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/env_var.h"
41 
42 #if GOOGLE_CUDA && GOOGLE_TENSORRT
43 
44 namespace tensorflow {
45 namespace tensorrt {
46 namespace segment {
47 namespace {
48 using absl::StrAppend;
49 using absl::StrAppendFormat;
50 using absl::StrCat;
51 using absl::StrJoin;
52 
53 // A simple graph representation to mirror Graph. This structure
54 // helps saving memory since segmenter modifies the graph in place, preventing
55 // the need to create a copy of the graph. It is composed of edges and nodes.
56 // Nodes keep pointers to original TF nodes.
57 class SimpleNode;
58 class SimpleGraph;
59 class SimpleEdge {
60  public:
SimpleEdge(int id,SimpleNode * src,int src_port,SimpleNode * dst,int dst_port,bool is_control=false)61   SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
62              int dst_port, bool is_control = false)
63       : id_(id),
64         src_(src),
65         src_port_(src_port),
66         dst_(dst),
67         dst_port_(dst_port),
68         control_(is_control) {}
~SimpleEdge()69   ~SimpleEdge() {}
70 
src() const71   SimpleNode* src() const { return src_; }
dst() const72   SimpleNode* dst() const { return dst_; }
src_output() const73   int src_output() const { return src_port_; }
dst_input() const74   int dst_input() const { return dst_port_; }
id() const75   int id() const { return id_; }
IsControlEdge() const76   bool IsControlEdge() const { return control_; }
77 
78  private:
79   int id_;
80   SimpleNode* src_;
81   int src_port_;
82   SimpleNode* dst_;
83   int dst_port_;
84   bool control_;
85 };
86 
87 class SimpleNode {
88  public:
89   SimpleNode(const Node* node, const int id);
90 
in_edges() const91   const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
out_edges() const92   const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
93 
in_nodes() const94   std::vector<SimpleNode*> in_nodes() const {
95     std::vector<SimpleNode*> res;
96     res.reserve(in_edges_.size());
97     for (const auto e : in_edges_) {
98       if (e) res.push_back(e->src());
99     }
100     return res;
101   }
102 
out_nodes() const103   std::vector<SimpleNode*> out_nodes() const {
104     std::vector<SimpleNode*> res;
105     res.reserve(out_edges_.size());
106     for (const auto e : out_edges_) {
107       if (e) res.push_back(e->dst());
108     }
109     return res;
110   }
111 
name() const112   const string& name() const { return node_->name(); }
tf_node() const113   const Node* tf_node() const { return node_; }
id() const114   int id() const { return id_; }
115 
116  private:
117   const Node* node_;
118   std::vector<SimpleEdge*> in_edges_;
119   std::vector<SimpleEdge*> out_edges_;
120   int id_;
121 
122   friend class SimpleGraph;
123 };
124 
125 class SimpleGraph {
126  public:
127   explicit SimpleGraph(const Graph* g);
128   ~SimpleGraph();
129 
130   void AddControlEdge(SimpleNode* src, SimpleNode* dst);
131   void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
132   void RemoveEdge(const SimpleEdge*);
FindNodeId(int node_id)133   SimpleNode* FindNodeId(int node_id) {
134     if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
135       return nullptr;
136     }
137     return nodes_[node_id];
138   }
num_node_ids() const139   int num_node_ids() const { return nodes_.size(); }
source_node() const140   const SimpleNode* source_node() const { return nodes_[Graph::kSourceId]; }
sink_node() const141   const SimpleNode* sink_node() const { return nodes_[Graph::kSinkId]; }
142 
143  private:
144   const Graph* g_;
145   std::vector<SimpleNode*> nodes_;
146   std::vector<SimpleEdge*> edges_;
147   // free_edge_ids_ and free_node_ids_ contain freed indices.
148   std::set<int> free_edge_ids_;
149   std::set<int> free_node_ids_;
150 };
151 
SimpleNode(const Node * node,const int id)152 SimpleNode::SimpleNode(const Node* node, const int id) : node_(node), id_(id) {
153   if (node_) {
154     in_edges_.reserve(node_->in_edges().size());
155     out_edges_.reserve(node_->out_edges().size());
156   }
157 }
158 
SimpleGraph(const Graph * g)159 SimpleGraph::SimpleGraph(const Graph* g) : g_(g) {
160   int n_nodes = g_->num_node_ids();
161   nodes_.resize(n_nodes, nullptr);
162   nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
163   nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
164   int n_edges = g->num_edge_ids();
165   edges_.resize(n_edges, nullptr);
166   for (int i = 2; i < n_nodes; i++) {
167     const auto n = g->FindNodeId(i);
168     if (n) {
169       nodes_[i] = new SimpleNode(n, i);
170     } else {
171       free_node_ids_.insert(i);
172     }
173   }
174   for (int i = 0; i < n_edges; i++) {
175     const auto e = g->FindEdgeId(i);
176     if (e) {
177       const auto tfsrc = e->src();
178       const auto tfdst = e->dst();
179       bool is_control = e->IsControlEdge();
180       auto src = nodes_[tfsrc->id()];
181       auto dst = nodes_[tfdst->id()];
182       auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
183                                  is_control);
184       edges_[i] = edge;
185       src->out_edges_.push_back(edge);
186       dst->in_edges_.push_back(edge);
187     } else {
188       free_edge_ids_.insert(i);
189     }
190   }
191 }
192 
AddEdge(SimpleNode * src,int out_port,SimpleNode * dst,int in_port)193 void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
194                           int in_port) {
195   int i = edges_.size();
196   if (!free_edge_ids_.empty()) {
197     auto it = free_edge_ids_.begin();
198     i = *it;
199     free_edge_ids_.erase(it);
200   } else {
201     edges_.push_back(nullptr);
202   }
203   bool is_control = (out_port == Graph::kControlSlot);
204   is_control |= (in_port == Graph::kControlSlot);
205   auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
206   edges_[i] = edge;
207   src->out_edges_.push_back(edge);
208   dst->in_edges_.push_back(edge);
209 }
210 
AddControlEdge(SimpleNode * src,SimpleNode * dst)211 void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
212   AddEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
213 }
214 
RemoveEdge(const SimpleEdge * edge)215 void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
216   auto src = edge->src();
217   auto dst = edge->dst();
218   for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
219     if (*it == edge) {
220       src->out_edges_.erase(it);
221       break;
222     }
223   }
224   for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
225     if (*it == edge) {
226       dst->in_edges_.erase(it);
227       break;
228     }
229   }
230 }
231 
~SimpleGraph()232 SimpleGraph::~SimpleGraph() {
233   for (auto x : nodes_) delete x;
234   for (auto x : edges_) delete x;
235 }
236 
237 // Define comparison functions for std::set with pointer keys so that behavior
238 // is deterministic. When using std::set with pointer key types, the items are
239 // sorted by pointer address which is non-deterministic. This can cause issues
240 // for INT8 mode because the graph is converted twice and non-determinism may
241 // cause a mismatch between the calibration tables of the conversions.
242 struct SimpleEdgePtrCompare {
operator ()tensorflow::tensorrt::segment::__anon2febd5fa0111::SimpleEdgePtrCompare243   bool operator()(const SimpleEdge* lhs, const SimpleEdge* rhs) const {
244     return lhs->id() < rhs->id();
245   }
246 };
247 
248 // Copied from TF ReverseDFS, which only works for Graph.
StableDFS(const SimpleGraph & g,bool reverse,const std::vector<const SimpleNode * > & start,const std::function<bool (const SimpleNode *)> & enter,const std::function<bool (const SimpleNode *)> & leave)249 void StableDFS(const SimpleGraph& g, bool reverse,
250                const std::vector<const SimpleNode*>& start,
251                const std::function<bool(const SimpleNode*)>& enter,
252                const std::function<bool(const SimpleNode*)>& leave) {
253   // Stack of work to do.
254   struct Work {
255     const SimpleNode* node;
256     bool leave;  // Are we entering or leaving n?
257   };
258   std::vector<Work> stack(start.size());
259   for (int i = 0; i < start.size(); ++i) {
260     stack[i] = Work{start[i], false};
261   }
262 
263   auto get_nodes = [reverse](const SimpleNode* n) {
264     return reverse ? n->in_nodes() : n->out_nodes();
265   };
266   std::vector<bool> visited(g.num_node_ids(), false);
267   while (!stack.empty()) {
268     Work w = stack.back();
269     stack.pop_back();
270 
271     auto n = w.node;
272     if (w.leave) {
273       if (leave && !leave(n)) return;
274       continue;
275     }
276 
277     if (visited[n->id()]) continue;
278     visited[n->id()] = true;
279     if (enter && !enter(n)) return;
280 
281     // Arrange to call leave(n) when all done with descendants.
282     if (leave) stack.push_back(Work{n, true});
283 
284     auto nodes = get_nodes(n);
285     std::vector<const SimpleNode*> nodes_sorted(nodes.begin(), nodes.end());
286     std::sort(nodes_sorted.begin(), nodes_sorted.end(),
287               [](const SimpleNode* lhs, const SimpleNode* rhs) {
288                 return lhs->name() < rhs->name();
289               });
290     for (const SimpleNode* node : nodes_sorted) {
291       if (!visited[node->id()]) {
292         stack.push_back(Work{node, false});
293       }
294     }
295   }
296 }
297 
CanContractEdge(const SimpleEdge * edge,const std::unique_ptr<SimpleGraph> & graph)298 bool CanContractEdge(const SimpleEdge* edge,
299                      const std::unique_ptr<SimpleGraph>& graph) {
300   const auto src = edge->src();
301   const auto dst = edge->dst();
302 
303   // Can't contract edge if doing so would cause a cycle in the
304   // graph. So, if there is a directed path from 'src' to 'dst', other
305   // than 'edge' (or any other direct edge from 'src' to 'dst'), then
306   // combining 'src' and 'dst' will cause a cycle along that path.
307   //
308   // In practice, to avoid modifying the graph and to take advantage
309   // of existing graph functions, we perform an equivalent.
310   //   1. Get all nodes incoming to 'dst', excluding 'src'
311   //   2. Reverse DFS from those nodes
312   //   3. If reverse DFS reaches 'src' then we have a cycle
313   //
314   // TODO(aaroey): there are several problems with the current approach:
315   // 1. src->dst->src, this is not detected but it should be;
316   // 2. src->dst->...(any node sequence that doesn't contain src)...->dst, this
317   //    is detected but it should not be.
318   //
319   // Note that it's fine that dst connects back to src indirectly (i.e. through
320   // a path with length > 1 that consists of intermedia nodes other than src).
321   // While loops is one example.
322   //
323   // The goal is to make sure that the trt subgraph:
324   // 1. has no loops (i.e. is a DAG), and
325   // 2. if there is a path in the subgraph from X to Y (X and Y are both nodes
326   //    in the subgraph), then all paths from X to Y are in the subgraph.
327   //
328   // To achieve this goal, the correct way seems to be:
329   // 1. remove any direct edge from src->dst;
330   // 2. detect if src can reach dst, if so they cannot be merged.
331   std::vector<const SimpleNode*> dfs_start_nodes;
332   for (const SimpleNode* node : dst->in_nodes()) {
333     if (node != src) {
334       dfs_start_nodes.push_back(node);
335     }
336   }
337   bool has_cycle = false;
338   StableDFS(*graph, /*reverse=*/true, dfs_start_nodes, /*enter=*/nullptr,
339             [&has_cycle, src](const SimpleNode* n) {
340               if (n == src) {
341                 has_cycle = true;
342                 return false;
343               }
344               return true;
345             });
346   return !has_cycle;
347 }
348 
349 // TODO(bixia): put this to a common utility file.
TensorPropertiesToString(const OpInfo::TensorProperties & prop)350 string TensorPropertiesToString(const OpInfo::TensorProperties& prop) {
351   string s = StrCat(DataTypeString(prop.dtype()), ": ");
352   StrAppend(&s, "[");
353   if (prop.shape().unknown_rank()) {
354     StrAppend(&s, "?");
355   } else {
356     StrAppend(&s, StrJoin(prop.shape().dim(), ",",
357                           [](string* out, const TensorShapeProto_Dim& d) {
358                             StrAppendFormat(out, "%d", d.size());
359                           }));
360   }
361   StrAppend(&s, "]");
362   return s;
363 }
364 
TensorPropertiesToString(const std::vector<OpInfo::TensorProperties> & properties)365 string TensorPropertiesToString(
366     const std::vector<OpInfo::TensorProperties>& properties) {
367   return StrJoin(properties, "; ",
368                  [](string* out, const OpInfo::TensorProperties& prop) {
369                    StrAppend(out, TensorPropertiesToString(prop));
370                  });
371 }
372 
373 // From the given list of input properties, returns the leading shape, which is
374 // the shape that determines the batch size of the operation. The leading shape
375 // is selected from the group of input shapes with the highest rank as follows:
376 //  . If all of those shapes have non-negative values for the batch dimension,
377 //    the leading shape is the one with the largest value for the batch
378 //    dimension.
379 //  . If some or all of those shapes have negative values for the batch
380 //    dimension, and the rest of those shapes have 1 for the batch dimension,
381 //    the leading shape is the first of those shapes with a negative value for
382 //    the batch dimension.
383 //  . Otherwise, we can't determine the leading shape for the operation and
384 //    have to exclude the operation from TRT.
385 //
386 // Examples:
387 //    case-1: a[1,3,4] + b[2,3,4] => leading shape [2,3,4]
388 //    case-2: a[2,3,4] + b[scalar] => leading shape [2,3,4]
389 //    case-3: a[-1,3,4] + b[1,3,4] => leading shape [-1,3,4]
390 //    case-4: a[-1,3,4] + b[2,3,4] => no leading shape
391 //
392 // We have to return "no leading shape" for case-4 to exclude such operation
393 // from being translated for this reason:
394 //   The actually input for "a" have to be in the shape of [2,3,4] for the
395 //   operation to be valid. On the other hand, if we translate the operation
396 //   to implicit batch mode, it will becomes a[3,4]+b[3,4] which is valid for
397 //   any input shape of "a".
398 //
399 // This routine assumes the input program is valid. For example, we shouldn't
400 // see invalid operation like a[2,3,4] + b[3,3,4]. It also assumes the input
401 // properties is not empty and all input have known shapes.
402 //
403 // TODO(bixia): find a way to share this knowledge with the converter.
404 // TODO(bixia): investigate the use of symbolic shape analysis to improve
405 //   segmentation, such as by requiring the dynamic dimensions to have the same
406 //   negative value.
FindLeadingShape(absl::Span<const OpInfo::TensorProperties> properties)407 std::optional<const TensorShapeProto*> FindLeadingShape(
408     absl::Span<const OpInfo::TensorProperties> properties) {
409   DCHECK(!properties.empty());
410   const TensorShapeProto* result;
411   int max_batch_dim_value;
412   auto choose_shape_with_higher_rank = [&](const TensorShapeProto* s) {
413     result = s;
414     max_batch_dim_value = s->dim_size() < 1 ? 1 : s->dim(0).size();
415   };
416 
417   DCHECK(!properties[0].shape().unknown_rank());
418   choose_shape_with_higher_rank(&properties[0].shape());
419 
420   for (const OpInfo::TensorProperties& p : properties.subspan(1)) {
421     DCHECK(!p.shape().unknown_rank());
422     if (p.shape().dim_size() < result->dim_size()) continue;
423 
424     if (p.shape().dim_size() > result->dim_size()) {
425       choose_shape_with_higher_rank(&p.shape());
426       continue;
427     }
428 
429     // Among the shapes with the same rank, choose the one with a dynamic batch
430     // size. If no shapes have a dynamic batch size, choose the one with the
431     // largest size.
432     if (result->dim_size() < 1) continue;
433 
434     if (p.shape().dim(0).size() < 0 || result->dim(0).size() < 0) {
435       if (p.shape().dim(0).size() < 0 && result->dim(0).size() >= 0) {
436         result = &p.shape();
437       } else {
438         max_batch_dim_value =
439             std::max<int>(max_batch_dim_value, p.shape().dim(0).size());
440       }
441 
442       continue;
443     }
444 
445     if (p.shape().dim(0).size() > result->dim(0).size()) {
446       result = &p.shape();
447       max_batch_dim_value = result->dim(0).size();
448     }
449   }
450 
451   if (result->dim_size() > 0 && result->dim(0).size() < 0) {
452     // dynamic batch size
453     if (max_batch_dim_value <= 1) {
454       return result;
455     } else {
456       return std::nullopt;
457     }
458   }
459 
460   return result;
461 }
462 
463 // Returns the inputs that are relevant to determinate the batch size of the
464 // operation. This routine handles the following cases:
465 //   . Operations that support implicit boradcasting, such as operation mul.
466 //     In this case, we need to inspect all the inputs in order to determine the
467 //     batch size of the operation.
468 //   . Special cases. Such as "Conv2DBackpropInput", "Conv3DBackpropInputV2".
469 //   . The batch size of a operation is determined by the first input of the
470 //     operation.
GetInputsToDeterminateBatchSize(const Node * node,const std::vector<OpInfo::TensorProperties> & all_inputs)471 absl::Span<const OpInfo::TensorProperties> GetInputsToDeterminateBatchSize(
472     const Node* node, const std::vector<OpInfo::TensorProperties>& all_inputs) {
473   // TODO(bixia): Find a way to share this knowledge with the converter.
474   static std::set<string> broadcast_supporting_ops = {
475       // ops corresponding to ConvertBinary in the converter
476       "Add",
477       "AddV2",
478       "Mul",
479       "Sub",
480       "Div",
481       "FloorDiv",
482       "RealDiv",
483       "Minimum",
484       "Maximum",
485       "Pow",
486       // other ops that need to need GetTrtBroadcastShape to convert
487       "BiasAdd",
488       "SquaredDifference",
489       "BatchMatMul",
490       "BatchMatMulV2",
491   };
492   const string& op = node->def().op();
493 
494   if (op == "Conv2DBackpropInput" || op == "Conv3DBackpropInputV2") {
495     DCHECK_EQ(all_inputs.size(), 3);
496     return absl::MakeSpan(all_inputs).subspan(2, 1);
497   }
498 
499   if (broadcast_supporting_ops.count(op)) {
500     return absl::MakeSpan(all_inputs);
501   }
502 
503   // This is the common case for the operations that don't support implicit
504   // broadcasting: the first operand determines its batch size. All otherwise
505   // cases are handled before reaching here.
506   return absl::MakeSpan(all_inputs).subspan(0, 1);
507 }
508 
509 // Returns true if the operation we can remove the implicit batch of the
510 // operation.
511 //
512 // In particular, if the input shape has dynamic rank or the input shape rank
513 // is less than 2, we can't remove the implicit batch dimension and generate
514 // a new operation for TRT translation.
OperationCanBeTranslatedToImplicitBatch(const grappler::GraphProperties * graph_properties,const Node * node)515 bool OperationCanBeTranslatedToImplicitBatch(
516     const grappler::GraphProperties* graph_properties, const Node* node) {
517   VLOG(3) << "process node " << node->name();
518   if (node->num_inputs() == 0) return true;
519   if (!graph_properties || !graph_properties->HasInputProperties(node->name()))
520     return false;
521 
522   VLOG(3) << "input shapes "
523           << TensorPropertiesToString(
524                  graph_properties->GetInputProperties(node->name()));
525 
526   const std::vector<OpInfo::TensorProperties>& all_input_properties =
527       graph_properties->GetInputProperties(node->name());
528   absl::Span<const OpInfo::TensorProperties> input_properties =
529       GetInputsToDeterminateBatchSize(node, all_input_properties);
530   if (absl::c_any_of(input_properties, [](const OpInfo::TensorProperties& p) {
531         return p.shape().unknown_rank();
532       })) {
533     return false;
534   }
535 
536   std::optional<const TensorShapeProto*> leading_shape =
537       FindLeadingShape(input_properties);
538   return leading_shape.has_value() && leading_shape.value()->dim_size() >= 2;
539 }
540 
541 // Returns true if we can't be sure that the operand with the given properties
542 // won't have negative values for non-batch dimensions.
543 //
HasDynamicNonBatchDimension(const OpInfo::TensorProperties & prop)544 bool HasDynamicNonBatchDimension(const OpInfo::TensorProperties& prop) {
545   const TensorShapeProto& shape = prop.shape();
546   if (shape.unknown_rank()) return true;
547 
548   // Scalar is a well specified shape, and TRT supports implicit broadcasting
549   // from scalar to other shapes.
550   if (shape.dim_size() == 0) return false;
551   for (int i = 1; i < shape.dim_size(); ++i) {
552     // The value of a dynamic dimension can be other negative values besides
553     // -1, representing the symbolic group of the dimension.
554     if (shape.dim(i).size() <= -1) {
555       return true;
556     }
557   }
558   return false;
559 }
560 
561 // Returns true if we can't be sure that the operation won't have dynamic
562 // non-batch dimension involved. We only check the shape of the first output
563 // assuming shape inference already propagates the shapes.
OperationHasDynamicNonBatchDimension(const grappler::GraphProperties * graph_properties,const Node * node)564 bool OperationHasDynamicNonBatchDimension(
565     const grappler::GraphProperties* graph_properties, const Node* node) {
566   VLOG(3) << "process node " << node->name();
567   // If the node doesn't have any input or output, not computation is involved.
568   if (node->num_inputs() == 0 || node->num_outputs() == 0) return false;
569 
570   // If the node doesn't have output properties, return true to be conservative.
571   if (!graph_properties->HasOutputProperties(node->name())) return true;
572   VLOG(3) << "output shapes "
573           << TensorPropertiesToString(
574                  graph_properties->GetOutputProperties(node->name()));
575   return HasDynamicNonBatchDimension(
576       graph_properties->GetOutputProperties(node->name()).at(0));
577 }
578 
ContractEdge(SimpleEdge * edge,SimpleGraph * graph,std::vector<const SimpleEdge * > * remove_edges)579 void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
580                   std::vector<const SimpleEdge*>* remove_edges) {
581   // Transfer all inputs and outputs of 'dst' to 'src' except edges
582   // connecting the two.
583   auto src = edge->src();
584   auto dst = edge->dst();
585 
586   // We can use '0' for input/output index because we don't need them
587   // to be accurate for the way we are using the graph.
588   std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
589                                           dst->in_edges().end());
590   for (const SimpleEdge* in_edge : in_edges) {
591     if (in_edge->IsControlEdge()) {
592       if (in_edge->src() != src) {
593         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
594         graph->AddControlEdge(e->src(), src);
595       }
596     } else {
597       if (in_edge->src() != src) {
598         SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
599         if (e->src() == graph->source_node()) {
600           graph->AddEdge(e->src(), e->src_output(), src, Graph::kControlSlot);
601         } else {
602           graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
603         }
604       }
605     }
606   }
607 
608   std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
609                                            dst->out_edges().end());
610   for (const SimpleEdge* out_edge : out_edges) {
611     if (out_edge->IsControlEdge()) {
612       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
613       graph->AddControlEdge(src, e->dst());
614     } else {
615       SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
616       if (e->dst() == graph->sink_node()) {
617         VLOG(1) << " edge to sink node " << src->name() << " -> "
618                 << e->dst()->name();
619         graph->AddEdge(src, Graph::kControlSlot, e->dst(), e->dst_input());
620       } else {
621         graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
622       }
623     }
624   }
625 
626   // Return the edges that must be removed to disconnect 'dst' from
627   // the graph. We don't actually remove 'dst' since the caller holds
628   // references to all the nodes.
629   for (const auto& in_edge : dst->in_edges()) {
630     remove_edges->push_back(in_edge);
631   }
632   for (const auto& out_edge : dst->out_edges()) {
633     remove_edges->push_back(out_edge);
634   }
635 }
636 
637 // Returns a batch size representation for a segment that only contains the
638 // given node.
GetClusterBatchSizeForNode(const grappler::GraphProperties * graph_properties,const Node * node,bool use_implicit_batch)639 ClusterBatchSize GetClusterBatchSizeForNode(
640     const grappler::GraphProperties* graph_properties, const Node* node,
641     bool use_implicit_batch) {
642   ClusterBatchSize cluster_batch_size;
643   if (!use_implicit_batch || !node || node->num_inputs() == 0) {
644     return cluster_batch_size;
645   }
646 
647   const NodeDef& node_def = node->def();
648   if (node_def.attr().count(kTftrtOpMaxBatchSizeAttr)) {
649     cluster_batch_size.SetMaxBatchSize(
650         node_def.attr().at(kTftrtOpMaxBatchSizeAttr).i());
651   }
652 
653   // As shape inference cannot provide any useful information about the batch
654   // size, we keep it as missing.
655   if (!graph_properties ||
656       !graph_properties->HasInputProperties(node->name())) {
657     VLOG(3) << "doesn't have input property";
658     return cluster_batch_size;
659   }
660 
661   const std::vector<OpInfo::TensorProperties>& input_properties =
662       graph_properties->GetInputProperties(node->name());
663   std::optional<const TensorShapeProto*> optional_leading_shape =
664       FindLeadingShape(GetInputsToDeterminateBatchSize(node, input_properties));
665   DCHECK(optional_leading_shape.has_value());
666   const TensorShapeProto* leading_shape = optional_leading_shape.value();
667   DCHECK(!leading_shape->unknown_rank() && leading_shape->dim_size() >= 2);
668   VLOG(3) << "set batch size as " << leading_shape->dim(0).size();
669   return cluster_batch_size.SetBatchSize(leading_shape->dim(0).size());
670 }
671 
AddSegmentForNode(const grappler::GraphProperties * graph_properties,std::vector<UnionFind<SimpleNode * >> * segments,SimpleNode * node,const DeviceNameUtils::ParsedName & device_name,bool use_implicit_batch)672 void AddSegmentForNode(const grappler::GraphProperties* graph_properties,
673                        std::vector<UnionFind<SimpleNode*>>* segments,
674                        SimpleNode* node,
675                        const DeviceNameUtils::ParsedName& device_name,
676                        bool use_implicit_batch) {
677   ClusterProperty property(
678       GetClusterBatchSizeForNode(graph_properties,
679                                  node == nullptr ? nullptr : node->tf_node(),
680                                  use_implicit_batch),
681       device_name);
682   segments->emplace_back(node, std::move(property));
683 }
684 
685 }  // namespace
686 
ExportNonConversionReportToCSV(string filename,std::map<string,std::map<string,int>> & nonconverted_ops_map,string sep="|")687 Status ExportNonConversionReportToCSV(
688     string filename,
689     std::map<string, std::map<string, int>>& nonconverted_ops_map,
690     string sep = "|") {
691   std::fstream csv_file(filename, std::fstream::out | std::fstream::trunc);
692 
693   if (!csv_file || !csv_file.good()) {
694     return errors::Internal("Failed to open output file: `", filename, "`");
695   }
696 
697   LOG(WARNING) << "TF-TRT Non-Conversion Report saved at: `" << filename << "`";
698 
699   csv_file << "OP Name" << sep << "Reason" << sep << "Count" << std::endl;
700 
701   for (auto& op_details : nonconverted_ops_map) {
702     auto op_name = op_details.first;
703     auto op_data = op_details.second;
704 
705     for (auto& reject_data : op_data) {
706       auto reason = reject_data.first;
707       auto count = reject_data.second;
708       csv_file << op_name << sep << reason << sep << count << std::endl;
709     }
710   }
711 
712   csv_file.close();
713 
714   if (csv_file.bad() || csv_file.fail()) {
715     return errors::Internal("Error closing the file `", filename,
716                             "`. The file might be corrupted.");
717   }
718 
719   return Status::OK();
720 }
721 
GenerateNonConversionReport(std::map<string,std::map<string,int>> & nonconverted_ops_map)722 string GenerateNonConversionReport(
723     std::map<string, std::map<string, int>>& nonconverted_ops_map) {
724   // Fetch whether to print a detailed version of the TF-TRT conversion report.
725   // TF_TRT_SHOW_DETAILED_REPORT triggers three possible behaviors:
726   // - If Number >= 1:      Print detailed non-conversion report on stdout.
727   //                        Usage: TF_TRT_SHOW_DETAILED_REPORT=1
728   // - If non empty string: Exports the non-conversion report in CSV format at
729   //                        the path defined by the environment variable.
730   //                        This will also print the detailed non-conversion
731   //                        report on stdout.
732   //                        Usage: TF_TRT_SHOW_DETAILED_REPORT=/path/to/file.csv
733   // - Else:                Print normal (undetailed) non-conversion report on
734   //                        stdout.
735 
736   string detailed_report_var;
737   TF_CHECK_OK(ReadStringFromEnvVar("TF_TRT_SHOW_DETAILED_REPORT",
738                                    /*default_value=*/"", &detailed_report_var));
739 
740   bool show_detailed_conversion_report = false;
741 
742   if (detailed_report_var != "") {
743     // Checking if `TF_TRT_SHOW_DETAILED_REPORT` env var is a string or a number
744     if (detailed_report_var.find_first_not_of("-0123456789") != string::npos) {
745       const Status status = ExportNonConversionReportToCSV(
746           detailed_report_var, nonconverted_ops_map);
747 
748       if (!status.ok()) {
749         // Log the error in case of issue, however do not stop execution.
750         LOG(ERROR) << "Problem encountered while generating the TF-TRT "
751                    << "Non-Conversion Report in CSV Format:\n"
752                    << status.error_message();
753       }
754       show_detailed_conversion_report = true;
755     } else if (std::stoi(detailed_report_var) >= 1) {
756       show_detailed_conversion_report = true;
757     }
758   }
759 
760   string unsupported_op_report =
761       StrCat("\n\n", string(80, '#'), "\n",
762              "TensorRT unsupported/non-converted OP Report:");
763   int total_nonconverted_ops{0};
764 
765   // <Reason, Count for this reason>
766   using ReasonCounterVector = std::vector<std::pair<string, int>>;
767   // <OP Name, Total Non-Converted for OP, <Reason, Count for this reason>>>
768   using NotConvertedOPTuple = std::tuple<string, int, ReasonCounterVector>;
769 
770   std::vector<NotConvertedOPTuple> nonconverted_ops_vec;
771 
772   // Populate the vector from the map
773   for (auto& nonconverted_op_data : nonconverted_ops_map) {
774     int total_nonconverted_op{0};
775     ReasonCounterVector reason_occurances_vect;
776 
777     auto op_name = nonconverted_op_data.first;
778     auto op_data = nonconverted_op_data.second;
779 
780     for (auto& notconversion_reason_data : op_data) {
781       auto reason_count = notconversion_reason_data.second;
782       total_nonconverted_op += reason_count;
783       reason_occurances_vect.push_back(notconversion_reason_data);
784     }
785 
786     // Sort in descending number of occurances for the reasons why a given
787     // TensorFlow OP was not converted.
788     std::sort(reason_occurances_vect.begin(), reason_occurances_vect.end(),
789               [](const std::pair<string, int>& a,
790                  const std::pair<string, int>& b) -> bool {
791                 return a.second > b.second;
792               });
793 
794     nonconverted_ops_vec.push_back(std::make_tuple(
795         op_name, total_nonconverted_op, reason_occurances_vect));
796   }
797 
798   // Sort the vector by descending OP names.
799   std::sort(nonconverted_ops_vec.begin(), nonconverted_ops_vec.end(),
800             [](const NotConvertedOPTuple& a, const NotConvertedOPTuple& b) {
801               return std::get<1>(a) > std::get<1>(b);
802             });
803 
804   for (auto& notconverted_op_detail : nonconverted_ops_vec) {
805     auto& op_name = std::get<0>(notconverted_op_detail);
806     auto& op_total_nonconverted = std::get<1>(notconverted_op_detail);
807     total_nonconverted_ops += op_total_nonconverted;
808 
809     unsupported_op_report = StrCat(unsupported_op_report, "\n\t- ", op_name,
810                                    " -> ", op_total_nonconverted, "x");
811 
812     if (show_detailed_conversion_report) {
813       auto& nonconverted_ops_details = std::get<2>(notconverted_op_detail);
814 
815       for (auto& nonconversion_details : nonconverted_ops_details) {
816         auto& reason = nonconversion_details.first;
817         auto& reason_count = nonconversion_details.second;
818         if (reason_count == 0) {
819           continue;
820         }
821 
822         unsupported_op_report = StrCat(unsupported_op_report, "\n\t\t- ",
823                                        "[Count: ", reason_count, "x] ", reason);
824       }
825       unsupported_op_report = StrCat(unsupported_op_report, "\n");
826     }
827   }
828 
829   unsupported_op_report =
830       StrCat(unsupported_op_report, "\n", string(80, '-'),
831              "\n\t- Total nonconverted OPs: ", total_nonconverted_ops,
832              "\n\t- Total nonconverted OP Types: ", nonconverted_ops_map.size(),
833              "\nFor more information see https://docs.nvidia.com/deeplearning",
834              "/frameworks/tf-trt-user-guide/index.html#supported-ops.", "\n",
835              string(80, '#'), "\n");
836 
837   return unsupported_op_report;
838 }
839 
SegmentGraph(const Graph * tf_graph,const grappler::GraphProperties * graph_properties,const std::function<Status (const Node *)> & candidate_fn,const std::function<bool (const Edge *)> & input_candidate_fn,const std::function<bool (const Edge *)> & output_candidate_fn,const SegmentOptions & options,SegmentVector * segments)840 Status SegmentGraph(const Graph* tf_graph,
841                     const grappler::GraphProperties* graph_properties,
842                     const std::function<Status(const Node*)>& candidate_fn,
843                     const std::function<bool(const Edge*)>& input_candidate_fn,
844                     const std::function<bool(const Edge*)>& output_candidate_fn,
845                     const SegmentOptions& options, SegmentVector* segments) {
846   if (!options.use_implicit_batch && !options.allow_dynamic_non_batch_dim) {
847     return errors::Internal(
848         "Explicit batch mode should allow dynamic non-batch dimensions");
849   }
850 
851   if (options.use_implicit_batch && !options.maximum_batch_size.has_value()) {
852     return errors::Internal("Implicit batch mode requires maximum_batch_size");
853   }
854 
855   if (!options.allow_dynamic_non_batch_dim && !graph_properties) {
856     return errors::Internal(
857         "Need graph propertities to disallow dynamic non-batch dimensions");
858   }
859 
860   // Steps:
861   // 1. run the segmentation algorithm to find all the segments, which uses
862   //    candidate_fn to determine the candidates segment nodes;
863   // 2. for each segments, remove the nodes that are inputs/outputs of the
864   //    segment but are not eligible, using input/output_candidate_fn to
865   //    determine the eligibilities;
866   // 3. convert the segment into expected return format and return the result.
867 
868   // --------------------------------- Step 1 ---------------------------------
869   auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
870 
871   // Fetch the user-provide TF operations denylisted for conversion by TF-TRT.
872   const absl::flat_hash_set<string> tftrt_op_denylist = [] {
873     string tftrt_op_denylist_str;
874     TF_CHECK_OK(ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", /*default_value=*/"",
875                                      &tftrt_op_denylist_str));
876     absl::flat_hash_set<string> tftrt_op_denylist{};
877     for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) {
878       tftrt_op_denylist.insert(x);
879     }
880     // Force a rehash of the flat hash set
881     tftrt_op_denylist.rehash(0);
882     return tftrt_op_denylist;
883   }();
884 
885   // Use a union-find to collect the nodes that belong to the same
886   // segment. A node value of nullptr indicates that the node is not a candidate
887   // for TRT.
888 
889   std::map<string, std::map<string, int>> nonconverted_ops_map = {};
890 
891   // Parsing each node of the graph
892   std::vector<UnionFind<SimpleNode*>> node_segments;
893   for (int i = 0; i < graph->num_node_ids(); ++i) {
894     SimpleNode* node = graph->FindNodeId(i);
895 
896     if (!node) {
897       VLOG(3) << "Node " << i << " doesn't exist in the graph";
898       continue;
899     }
900 
901     const string node_op_type{node->tf_node()->type_string()};
902 
903     auto exclude_node = [&](absl::string_view reason) {
904       VLOG(1) << "Not a TF-TRT candidate, "
905               << "(Op type: " << node_op_type << "), "
906               << "(Op name: " << node->name() << "), "
907               << "(Reason: " << reason << ")";
908       nonconverted_ops_map[node_op_type][string(reason)]++;
909       node = nullptr;
910     };
911     std::optional<DeviceNameUtils::ParsedName> device_name =
912         GetDeviceParsedName(node->tf_node());
913     // GetDeviceParseName capitalizes the device type.
914     if (!device_name.has_value() ||
915         (device_name->has_type && device_name->type != "GPU")) {
916       exclude_node("node can't be placed on GPU");
917     } else if (options.exclude_node_list.count(node->name()) != 0) {
918       exclude_node(
919           "excluded by segmenter option. Most likely an input or "
920           "output node.");
921     } else if (options.use_implicit_batch &&
922                !OperationCanBeTranslatedToImplicitBatch(graph_properties,
923                                                         node->tf_node())) {
924       exclude_node(
925           "implicit batch mode requires input shape with at least two "
926           "dimensions");
927     } else if (!options.allow_dynamic_non_batch_dim &&
928                OperationHasDynamicNonBatchDimension(graph_properties,
929                                                     node->tf_node())) {
930       exclude_node("dynamic non-batch dimensions not allowed");
931     } else {
932       const Status status = candidate_fn(node->tf_node());
933       if (!status.ok()) {
934         exclude_node(status.error_message());
935       } else if (tftrt_op_denylist.contains(node->tf_node()->type_string())) {
936         // WARNING verbosity since the user explicitly requests this behavior.
937         LOG_WARNING_WITH_PREFIX
938             << "Denylisted as TF-TRT candidate, "
939             << "(Op type: " << node->tf_node()->type_string() << "), "
940             << "(Op name: " << node->name() << ")";
941         exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST");
942       } else {
943         VLOG(2) << "Accepted as a TF-TRT candidate, "
944                 << "(Op type: " << node->tf_node()->type_string() << "), "
945                 << "(Op name: " << node->name();
946       }
947     }
948     AddSegmentForNode(graph_properties, &node_segments, node, *device_name,
949                       options.use_implicit_batch);
950   }
951 
952   LOG(WARNING) << GenerateNonConversionReport(nonconverted_ops_map);
953 
954   // The segmentation algorithm below visits nodes in reverse topological order
955   // and attempts to merge nodes along output edges. That means that subgraphs
956   // grow from the output-side of the network towards the inputs.
957   //
958   // In general this is not guaranteed to produce a globally optimal
959   // segmentation. For example, consider graph with node {A, B, C, D} and edges
960   // {A->B, A->C, B->D, C->D), where A, B, D are trt compatible but C is not, so
961   // in theory we can choose to contract either A, B or B, D but not both, but
962   // here it always choose to contract B, D.
963   //
964   // In the future if we have a measure of how beneficial it is to include a
965   // given node in a TRT subgraph then we can revisit this algorithm to take
966   // advantage of that information.
967   std::vector<const SimpleNode*> order;
968   order.reserve(graph->num_node_ids());
969   StableDFS(*graph, /*reverse=*/false, {graph->source_node()},
970             /*enter=*/nullptr, [&order](const SimpleNode* n) {
971               order.push_back(n);
972               return true;
973             });
974   for (const SimpleNode* node : order) {
975     // All output nodes of 'node' have been visited.
976     VLOG(3) << "Trying node " << node->name() << " id=" << node->id();
977     // 'node' must be a TRT candidate.
978     if (node_segments[node->id()].Value() == nullptr) {
979       VLOG(3) << "... not a TRT candidate";
980       continue;
981     }
982     // Contract output edges to combine 'node' with output nodes. Repeat this
983     // step until no output edges can be further contracted. This is because
984     // contracting an output edge may unblock new edges for contracting.
985     ClusterBatchSize expected_batch_size =
986         node_segments[node->id()].Property().BatchSize();
987     DeviceNameUtils::ParsedName expected_device_name =
988         node_segments[node->id()].Property().DeviceName();
989     VLOG(3) << "batch size " << expected_batch_size;
990     while (true) {
991       std::set<const SimpleEdge*, SimpleEdgePtrCompare> contract_edges;
992       // TODO(bixia): consider merging the loop to find the edges and the loop
993       // to contract the edges.
994       for (const SimpleEdge* out_edge : node->out_edges()) {
995         VLOG(3) << "... out node " << out_edge->dst()->name() << " ( "
996                 << out_edge->dst()->id() << " <- " << node->id() << " )";
997         if (out_edge->IsControlEdge()) {
998           VLOG(3) << "... ... Control Edge, Skipping";
999           continue;
1000         }
1001         UnionFind<SimpleNode*>* out_cluster =
1002             &node_segments[out_edge->dst()->id()];
1003         // Out node must be a TRT candidate.
1004         if (out_cluster->Value() == nullptr) {
1005           VLOG(3) << "... ... not a TRT candidate";
1006           continue;
1007         }
1008         // Out node must have compatible batch size.
1009         ClusterBatchSize out_batch_size = out_cluster->Property().BatchSize();
1010         ClusterBatchSize merged_batch_size = expected_batch_size;
1011         if (!merged_batch_size.MergeIfCompatible(out_batch_size)) {
1012           VLOG(3) << "... ... incompatible batch sizes "
1013                   << expected_batch_size.ToString() << " "
1014                   << out_batch_size.ToString();
1015           continue;
1016         }
1017 
1018         const DeviceNameUtils::ParsedName& out_device_name =
1019             out_cluster->Property().DeviceName();
1020         std::optional<DeviceNameUtils::ParsedName> merged_device_name =
1021             MergeIfCompatible(expected_device_name, out_device_name);
1022         if (!merged_device_name.has_value()) {
1023           VLOG(3) << "... ... incompatible device names "
1024                   << expected_device_name << " " << out_device_name;
1025           continue;
1026         }
1027 
1028         if (CanContractEdge(out_edge, graph)) {
1029           VLOG(3) << "... ... can contract. new batch size "
1030                   << merged_batch_size.ToString();
1031           contract_edges.insert(out_edge);
1032           expected_batch_size = merged_batch_size;
1033           expected_device_name = *merged_device_name;
1034         } else {
1035           VLOG(3) << "... ... cannot contract, would form cycle";
1036         }
1037       }
1038       if (contract_edges.empty()) {
1039         break;
1040       }
1041       // Contract edges and collect the adjacent nodes into the same
1042       // segment/subgraph.
1043       while (!contract_edges.empty()) {
1044         const SimpleEdge* contract_edge = *contract_edges.begin();
1045         const SimpleNode* src = contract_edge->src();
1046         const SimpleNode* dst = contract_edge->dst();
1047 
1048         VLOG(3) << "Merge " << src->name() << " <- " << dst->name() << " ("
1049                 << src->id() << " <- " << dst->id();
1050         TF_RETURN_IF_ERROR(
1051             node_segments[src->id()].Merge(&node_segments[dst->id()]));
1052 
1053         // Contracting the edge leaves disconnected graph edges.
1054         // Remove these from the graph and from 'contract_edges' so we
1055         // don't visit them again.
1056         SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
1057         std::vector<const SimpleEdge*> remove_edges;
1058         ContractEdge(e, graph.get(), &remove_edges);
1059 
1060         for (const SimpleEdge* r : remove_edges) {
1061           contract_edges.erase(r);
1062           graph->RemoveEdge(r);
1063         }
1064       }
1065       if (expected_batch_size !=
1066           node_segments[node->id()].Property().BatchSize()) {
1067         return errors::Internal(
1068             "expected batch size is not the same as the actual batch size");
1069       }
1070       if (expected_device_name !=
1071           node_segments[node->id()].Property().DeviceName()) {
1072         return errors::Internal(
1073             "expected device name is not the same as the actual device name");
1074       }
1075     }
1076   }
1077 
1078   // Collect the segments/subgraphs. Each subgraph is represented by a
1079   // set of the names of the nodes in that subgraph.
1080 
1081   // A map from the segment identifier (currently the name of the root node of
1082   // the segment tree) to the segment nodes set.
1083   std::map<string, Segment> sg_map;
1084 
1085   for (auto& u : node_segments) {
1086     if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
1087       sg_map[u.ParentValue()->name()].nodes.insert(u.Value()->tf_node());
1088     }
1089     if ((u.Value() != nullptr) && (u.ParentValue() == u.Value())) {
1090       sg_map[u.Value()->name()].property = u.Property();
1091     }
1092   }
1093 
1094   // --------------------------------- Step 2 ---------------------------------
1095   // Remove ineligible input/output nodes.
1096   for (auto& itr : sg_map) {
1097     std::set<const Node*, NodePtrCompare>& segment_nodes = itr.second.nodes;
1098     VLOG(1) << "Segment original size: " << segment_nodes.size();
1099     while (true) {
1100       std::deque<const Node*> in_nodes_que, out_nodes_que;
1101       // Find an input node that is not eligible and add it to the queue.
1102       // Nodes that has no incoming edges should not be treated as "input",
1103       // as there are really no inputs to them. Similar for output nodes.
1104       for (auto node : segment_nodes) {
1105         bool added = false;
1106         for (const Edge* edge : node->in_edges()) {
1107           if (!edge->IsControlEdge() && !edge->src()->IsSource() &&
1108               !segment_nodes.count(edge->src())) {  // 'node' is an input node.
1109             if (!input_candidate_fn(edge)) {
1110               in_nodes_que.push_back(node);
1111               added = true;
1112               break;
1113             }
1114           }
1115         }
1116         if (added) continue;  // Only adding the node once to either queue.
1117         for (const Edge* edge : node->out_edges()) {
1118           if (!edge->dst()->IsSink() && !edge->IsControlEdge() &&
1119               !segment_nodes.count(edge->dst())) {  // 'node' is an output node.
1120             if (!output_candidate_fn(edge)) {
1121               out_nodes_que.push_back(node);
1122               break;
1123             }
1124           }
1125         }
1126       }
1127       if (in_nodes_que.empty() && out_nodes_que.empty()) {
1128         // No more ineligible input/output nodes.
1129         break;
1130       }
1131       // Now for each ineligible node, remove all of its inputs or outputs from
1132       // the subgraph.
1133       //
1134       // It can be proven that, if the original subgraph:
1135       // 1. is a DAG, and
1136       // 2. all paths between two nodes in the subgraph are all inside the
1137       //    subgraph
1138       // then after doing this operation the resulting subgraph will keep the
1139       // same properties 1 and 2.
1140       //
1141       // For simplicity we use heuristics: for input and const output nodes
1142       // remove all their inputs, and for non-const output nodes remove all
1143       // their outputs. In this way, for common cases the number of removed
1144       // nodes should be minimum.
1145       auto remove_nodes = [&segment_nodes](bool is_input_nodes,
1146                                            std::deque<const Node*>* que) {
1147         // Run a BFS on the queue to find all the input/output nodes.
1148         std::set<const Node*, NodePtrCompare> visited;
1149         std::set<const Node*, NodePtrCompare> logged(que->begin(), que->end());
1150         while (!que->empty()) {
1151           auto node = que->front();
1152           que->pop_front();
1153           if (!visited.insert(node).second) continue;
1154           segment_nodes.erase(node);
1155           for (auto in : (is_input_nodes || node->type_string() == "Const")
1156                              ? node->in_nodes()
1157                              : node->out_nodes()) {
1158             if (segment_nodes.count(in)) {
1159               que->push_back(in);
1160               if (VLOG_IS_ON(2)) {
1161                 if (!logged.count(in)) {
1162                   VLOG(2) << "----> Need to remove node " << in->name()
1163                           << " because one of its "
1164                           << (is_input_nodes ? "output" : "input")
1165                           << " nodes in the graph was removed: "
1166                           << node->name();
1167                   logged.insert(in);
1168                 }
1169               }
1170             }
1171           }
1172         }
1173       };
1174       remove_nodes(true, &in_nodes_que);
1175       remove_nodes(false, &out_nodes_que);
1176     }
1177     VLOG(1) << "Segment new size: " << segment_nodes.size();
1178   }
1179 
1180   // --------------------------------- Step 3 ---------------------------------
1181   // Convert the segments into the expected return format
1182   std::vector<int> effective_nodes_counts;
1183   for (const auto& itr : sg_map) {
1184     const string& segment_root = itr.first;
1185     // Return format does not require set comparator.
1186     std::set<const Node*, NodePtrCompare> segment_nodes(
1187         itr.second.nodes.begin(), itr.second.nodes.end());
1188     if (VLOG_IS_ON(1) && !segment_nodes.empty()) {
1189       string s;
1190       for (auto node : segment_nodes) {
1191         StrAppend(&s, "\n[Op type: ", node->type_string(), "] ", node->name());
1192       }
1193       VLOG(1) << "Nodes in segment " << segments->size()
1194               << " with parent=" << segment_root << ":" << s;
1195     }
1196 
1197     const int num_effective_nodes = std::count_if(
1198         segment_nodes.begin(), segment_nodes.end(), [](const Node* node) {
1199           static auto noops =
1200               new std::set<string>{"Identity", "Snapshot", "StopGradient"};
1201           return noops->count(node->type_string()) == 0;
1202         });
1203 
1204     // Don't use segments whose number of effective nodes is small.
1205     if (num_effective_nodes == 0 ||
1206         num_effective_nodes < options.minimum_segment_size) {
1207       VLOG(1) << "Segment " << segments->size() << " has only "
1208               << num_effective_nodes << " effective nodes, dropping";
1209       continue;
1210     }
1211     segments->emplace_back(itr.second.property, segment_nodes);
1212     effective_nodes_counts.push_back(num_effective_nodes);
1213   }
1214 
1215   // --------------------------------- Step 4 ---------------------------------
1216   // If the number of segments exceeds max_engines, prune the smallest ones.
1217 
1218   int64_t max_trt_engine_ops;
1219   TF_CHECK_OK(ReadInt64FromEnvVar("TF_TRT_MAX_ALLOWED_ENGINES",
1220                                   /*default_value=*/20, &max_trt_engine_ops));
1221 
1222   if (max_trt_engine_ops <= 0) {
1223     LOG(WARNING) << "The environment variable TF_TRT_MAX_ALLOWED_ENGINES is "
1224                  << "<= 0. TF-TRT did not limit the number of TensorRT engines "
1225                  << "created.";
1226 
1227   } else {
1228     if (segments->size() > max_trt_engine_ops) {
1229       LOG(WARNING) << "A total of " << segments->size() << " segments with at "
1230                    << "least minimum_segment_size="
1231                    << options.minimum_segment_size << " nodes have been found. "
1232                    << "TF-TRT will only convert the " << max_trt_engine_ops
1233                    << " largest segments. You can change this behavior by "
1234                    << "modifying the environment variable "
1235                    << "TF_TRT_MAX_ALLOWED_ENGINES=" << max_trt_engine_ops;
1236 
1237       // Stable sort of the segment indices according to their effective sizes.
1238       std::vector<int> indices(segments->size());
1239       std::iota(indices.begin(), indices.end(), 0);
1240 
1241       std::stable_sort(indices.begin(), indices.end(),
1242                        [&effective_nodes_counts](int i1, int i2) {
1243                          return effective_nodes_counts[i1] >
1244                                 effective_nodes_counts[i2];
1245                        });
1246 
1247       // Create a mask of segments to keep.
1248       std::vector<bool> mask = std::vector<bool>(segments->size(), false);
1249 
1250       for (int i = 0; i < max_trt_engine_ops; i++) {
1251         mask[indices[i]] = true;
1252       }
1253 
1254       // Gather the masked elements at the start of the array, in place.
1255       int j = 0;
1256       VLOG(1) << "The following segments have been accepted by TF-TRT:";
1257       for (int i = 0; i < segments->size(); i++) {
1258         if (mask[i]) {
1259           VLOG(1) << "[*] Segment " << i
1260                   << " [node count: " << effective_nodes_counts[i]
1261                   << "] accepted. Re-assigned "
1262                   << "segment id=" << j;
1263           segments->at(j) = segments->at(i);
1264           j++;
1265         }
1266       }
1267 
1268       VLOG(1) << "The following segments have been rejected by TF-TRT:";
1269       for (int i = 0; i < segments->size(); i++) {
1270         if (!mask[i]) {
1271           VLOG(1) << "[*] Segment " << i
1272                   << " [node count: " << effective_nodes_counts[i]
1273                   << "] rejected.";
1274         }
1275       }
1276 
1277       // Resize the array.
1278       segments->resize(max_trt_engine_ops);
1279     } else {
1280       LOG(WARNING) << "The environment variable TF_TRT_MAX_ALLOWED_ENGINES="
1281                    << max_trt_engine_ops << " has no effect since there are "
1282                    << "only " << segments->size() << " TRT Engines with  at "
1283                    << "least minimum_segment_size="
1284                    << options.minimum_segment_size << " nodes.";
1285     }
1286   }
1287 
1288   return Status::OK();
1289 }
1290 
1291 }  // namespace segment
1292 }  // namespace tensorrt
1293 }  // namespace tensorflow
1294 
1295 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
1296