xref: /aosp_15_r20/external/tensorflow/tensorflow/python/compiler/tensorrt/utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Exposes the Python wrapper conversion to trt_graph."""
16
17import collections
18import os
19import re
20
21from packaging import version
22
23from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils
24from tensorflow.core.protobuf import rewriter_config_pb2
25from tensorflow.python.framework import dtypes
26
27
28def disable_non_trt_optimizers_in_rewriter_config(rewriter_config):
29  """Modifies rewriter_config to disable all non-TRT optimizations."""
30  off = rewriter_config_pb2.RewriterConfig.OFF
31
32  rewriter_config.arithmetic_optimization = off
33  rewriter_config.auto_mixed_precision = off
34  rewriter_config.auto_parallel.enable = False
35  rewriter_config.constant_folding = off
36  rewriter_config.debug_stripper = off
37  rewriter_config.dependency_optimization = off
38  # This one needs to be ON to allow TF-TRT
39  rewriter_config.disable_meta_optimizer = False
40  rewriter_config.disable_model_pruning = True
41  rewriter_config.function_optimization = off
42  rewriter_config.implementation_selector = off
43  rewriter_config.layout_optimizer = off
44  rewriter_config.loop_optimization = off
45  rewriter_config.memory_optimization = (
46      rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
47  rewriter_config.min_graph_nodes = -1
48  rewriter_config.pin_to_host_optimization = off
49  rewriter_config.remapping = off
50  rewriter_config.scoped_allocator_optimization = off
51  rewriter_config.shape_optimization = off
52
53
54def version_tuple_to_string(ver_tuple):
55  assert isinstance(ver_tuple, tuple)
56  assert len(ver_tuple) == 3
57
58  ver_tuple = [str(x) for x in ver_tuple]
59  return ".".join(ver_tuple)
60
61
62def _is_tensorrt_version_greater_equal(trt_ver, target_ver):
63  trt_ver = version.Version(version_tuple_to_string(trt_ver))
64  target_ver = version.Version(version_tuple_to_string(target_ver))
65
66  return trt_ver >= target_ver
67
68
69def is_linked_tensorrt_version_greater_equal(major, minor=0, patch=0):
70  ver = _pywrap_py_utils.get_linked_tensorrt_version()
71  return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
72
73
74def is_loaded_tensorrt_version_greater_equal(major, minor=0, patch=0):
75  ver = _pywrap_py_utils.get_loaded_tensorrt_version()
76  return _is_tensorrt_version_greater_equal(ver, (major, minor, patch))
77
78
79def is_experimental_feature_activated(feature_name):
80  """Determines if a TF-TRT experimental feature is enabled.
81
82  This helper function checks if an experimental feature was enabled using
83  the environment variable `TF_TRT_EXPERIMENTAL_FEATURES=feature_1,feature_2`.
84
85  Args:
86    feature_name: Name of the feature being tested for activation.
87  """
88
89  return (feature_name
90          in os.environ.get("TF_TRT_EXPERIMENTAL_FEATURES",
91                            default="").split(","))
92
93
94def _convert_dtype_id_to_str(dtype):
95  """Helper function to convert a dtype id to a corresponding string name."""
96  if isinstance(dtype, int):
97    return dtypes._TYPE_TO_STRING[dtype]
98  else:
99    return [dtypes._TYPE_TO_STRING[d] for d in dtype]
100
101
102def get_node_compute_dtype(node):
103  """Returns the compute DType of a GraphDef Node."""
104  # Note: Order is important, by default TF Node compute dtype is mentioned
105  # under `T` key, unless these nodes are one of ["TRTEngineOP", "Cast", "Plh"].
106  for type_key in [
107      "precision_mode",  # TRTEngineOp
108      "DstT",  # Cast Nodes
109      "dtype",  # Placeholder
110      "T",  # Everything Else
111  ]:
112    try:
113      precision_val = node.attr[type_key]
114      if type_key == "precision_mode":
115        precision_val = precision_val.s.decode("utf-8")
116        if precision_val == "":
117          continue
118        if precision_val == "FP32":
119          return "float32"
120        elif precision_val == "FP16":
121          return "float16"
122        elif precision_val == "INT8":
123          return "int8"
124        else:
125          return "unknown"
126      else:
127        return _convert_dtype_id_to_str(precision_val.type)
128    except Exception as e:
129      continue
130
131
132def get_node_io_shapes(node, key):
133  """Returns the input/output shapes of a GraphDef Node."""
134  out_shape = []
135  for shape in node.attr[key].list.shape:
136    out_shape.append([dim.size for dim in shape.dim])
137  return out_shape
138
139
140def get_trtengineop_io_dtypes(node, key):
141  """Returns the input/output dtypes of a TRTEngineOp."""
142  return _convert_dtype_id_to_str(node.attr[key].list.type)
143
144
145def get_trtengineop_io_nodes_count(node, key):
146  """Returns the number of input/output nodes of a TRTEngineOp."""
147  return len(node.attr[key].list.type)
148
149
150def get_trtengineop_node_op_count(graphdef, node_name):
151  """Counts the number of nodes and OP types of a given TRTEngineOp."""
152  ops_in_engine = collections.defaultdict(int)
153  for func in graphdef.library.function:
154    if f"{node_name}_native_segment" == func.signature.name:
155      node_count = len(func.node_def)
156      for node in func.node_def:
157        ops_in_engine[node.op] += 1
158      break
159  return node_count, ops_in_engine
160
161
162class DTypeIndex(dict):
163  """Helper class to create an index of dtypes with incremental values."""
164
165  def get_dtype_index(self, dtype):
166    if dtype not in self:
167      self[dtype] = len(self) + 1
168    return self[dtype]
169
170
171def draw_graphdef_as_graphviz(graphdef, dot_output_filename):
172  """Exports a GraphDef to GraphViz format.
173
174  - Step 1: Drawing Each Node of the compute GraphDef.
175  - Step 2: Create nodes for each collected dtype in the graph.
176  - Step 3: Creating invisible links to align properly the legend.
177
178  Each node consequently mentions:
179  - Op Type
180  - Compute Dtype
181  - Compute Device
182  """
183
184  dtype_index = DTypeIndex()
185
186  with open(dot_output_filename, "w") as f:
187    print("digraph tftrt_converted_graph {", file=f)
188
189    print("  graph [fontsize=10 fontname=\"Verdana\"];", file=f)
190    # ColorScheme Documentation: https://graphviz.org/doc/info/colors.html
191    print(
192        "  node [style=filled height=0.55 colorscheme=set312 shape=box];",
193        file=f)
194
195    # Step 1: Parsing the graph and drawing OPs one by one.
196    print("\n  subgraph tensorflow_graph {", file=f)
197    print("    node [width=1.35];", file=f)
198    nodes_with_no_inputs = []
199    for node in graphdef.node:
200      output_name = node.name
201
202      node_precision = get_node_compute_dtype(node)
203      color_idx = dtype_index.get_dtype_index(node_precision)
204
205      device_key = node.device.split("/")[-1]
206      if not device_key:
207        device_key = "device:Unspecified"
208
209      if node.op == "TRTEngineOp":
210        node_count, _ = get_trtengineop_node_op_count(graphdef, output_name)
211        node_label = f"{output_name} [{node_count}]"
212      else:
213        node_label = f"{node.op}"
214
215      # Note: double space before <br/> is necessary for formatting.
216      node_label = f"<b>{node_label}</b>  <br/><i>{device_key}</i>"
217
218      print(
219          f"    \"{output_name}\" [label=<{node_label}> "
220          f"fillcolor={color_idx}];",
221          file=f)
222
223      if len(node.input):
224        for input_full_name in node.input:
225          parts = input_full_name.split(":")
226          input_name = re.sub(r"^\^", "", parts[0])
227          print(f"  \"{input_name}\" -> \"{output_name}\";", file=f)
228      else:
229        nodes_with_no_inputs.append(output_name)
230    print("  }", file=f)
231
232    # Step 2: Creating the DType Nodes previously found in Step 1.
233    print("\n  subgraph cluster_legend {", file=f)
234    print("    label=\"Compute Dtype Legend\";", file=f)
235    print("    margin=\"30\";", file=f)
236    print("    node [width=2];", file=f)
237
238    for dtype, color_idx in dtype_index.items():
239      print(
240          f"    {dtype} [fillcolor={color_idx} label=<<b>{dtype}</b>>];",
241          file=f)
242
243    print("  }", file=f)
244
245    # Step 3: Alignement of the legend with the graph.
246    print("\n  edge[style=\"invisible\", dir=\"none\"];", file=f)
247    for dtype in dtype_index.keys():
248      for node_name in nodes_with_no_inputs:
249        print(f"  \"{dtype}\" -> \"{node_name}\"", file=f)
250
251    print("}", file=f)
252
253  print("\n===================================================================")
254  print(f"Graph Visualization Exported to: `{dot_output_filename}`.")
255  print("We recommend using https://edotor.net/ to visualize the .dot file.")
256  print("You can also use `graphviz` utility to convert them to PNG format:")
257  print("  - `sudo apt install -y graphviz`")
258  print("  - `dot -Tpng <input_filename>.dot -o <output_filename>.png`")
259  print("===================================================================\n")
260