1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Exposes the Python wrapper conversion to trt_graph.""" 16 17import collections 18import os 19import re 20 21from packaging import version 22 23from tensorflow.compiler.tf2tensorrt import _pywrap_py_utils 24from tensorflow.core.protobuf import rewriter_config_pb2 25from tensorflow.python.framework import dtypes 26 27 28def disable_non_trt_optimizers_in_rewriter_config(rewriter_config): 29 """Modifies rewriter_config to disable all non-TRT optimizations.""" 30 off = rewriter_config_pb2.RewriterConfig.OFF 31 32 rewriter_config.arithmetic_optimization = off 33 rewriter_config.auto_mixed_precision = off 34 rewriter_config.auto_parallel.enable = False 35 rewriter_config.constant_folding = off 36 rewriter_config.debug_stripper = off 37 rewriter_config.dependency_optimization = off 38 # This one needs to be ON to allow TF-TRT 39 rewriter_config.disable_meta_optimizer = False 40 rewriter_config.disable_model_pruning = True 41 rewriter_config.function_optimization = off 42 rewriter_config.implementation_selector = off 43 rewriter_config.layout_optimizer = off 44 rewriter_config.loop_optimization = off 45 rewriter_config.memory_optimization = ( 46 rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) 47 rewriter_config.min_graph_nodes = -1 48 rewriter_config.pin_to_host_optimization = off 49 rewriter_config.remapping = off 50 rewriter_config.scoped_allocator_optimization = off 51 rewriter_config.shape_optimization = off 52 53 54def version_tuple_to_string(ver_tuple): 55 assert isinstance(ver_tuple, tuple) 56 assert len(ver_tuple) == 3 57 58 ver_tuple = [str(x) for x in ver_tuple] 59 return ".".join(ver_tuple) 60 61 62def _is_tensorrt_version_greater_equal(trt_ver, target_ver): 63 trt_ver = version.Version(version_tuple_to_string(trt_ver)) 64 target_ver = version.Version(version_tuple_to_string(target_ver)) 65 66 return trt_ver >= target_ver 67 68 69def is_linked_tensorrt_version_greater_equal(major, minor=0, patch=0): 70 ver = _pywrap_py_utils.get_linked_tensorrt_version() 71 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch)) 72 73 74def is_loaded_tensorrt_version_greater_equal(major, minor=0, patch=0): 75 ver = _pywrap_py_utils.get_loaded_tensorrt_version() 76 return _is_tensorrt_version_greater_equal(ver, (major, minor, patch)) 77 78 79def is_experimental_feature_activated(feature_name): 80 """Determines if a TF-TRT experimental feature is enabled. 81 82 This helper function checks if an experimental feature was enabled using 83 the environment variable `TF_TRT_EXPERIMENTAL_FEATURES=feature_1,feature_2`. 84 85 Args: 86 feature_name: Name of the feature being tested for activation. 87 """ 88 89 return (feature_name 90 in os.environ.get("TF_TRT_EXPERIMENTAL_FEATURES", 91 default="").split(",")) 92 93 94def _convert_dtype_id_to_str(dtype): 95 """Helper function to convert a dtype id to a corresponding string name.""" 96 if isinstance(dtype, int): 97 return dtypes._TYPE_TO_STRING[dtype] 98 else: 99 return [dtypes._TYPE_TO_STRING[d] for d in dtype] 100 101 102def get_node_compute_dtype(node): 103 """Returns the compute DType of a GraphDef Node.""" 104 # Note: Order is important, by default TF Node compute dtype is mentioned 105 # under `T` key, unless these nodes are one of ["TRTEngineOP", "Cast", "Plh"]. 106 for type_key in [ 107 "precision_mode", # TRTEngineOp 108 "DstT", # Cast Nodes 109 "dtype", # Placeholder 110 "T", # Everything Else 111 ]: 112 try: 113 precision_val = node.attr[type_key] 114 if type_key == "precision_mode": 115 precision_val = precision_val.s.decode("utf-8") 116 if precision_val == "": 117 continue 118 if precision_val == "FP32": 119 return "float32" 120 elif precision_val == "FP16": 121 return "float16" 122 elif precision_val == "INT8": 123 return "int8" 124 else: 125 return "unknown" 126 else: 127 return _convert_dtype_id_to_str(precision_val.type) 128 except Exception as e: 129 continue 130 131 132def get_node_io_shapes(node, key): 133 """Returns the input/output shapes of a GraphDef Node.""" 134 out_shape = [] 135 for shape in node.attr[key].list.shape: 136 out_shape.append([dim.size for dim in shape.dim]) 137 return out_shape 138 139 140def get_trtengineop_io_dtypes(node, key): 141 """Returns the input/output dtypes of a TRTEngineOp.""" 142 return _convert_dtype_id_to_str(node.attr[key].list.type) 143 144 145def get_trtengineop_io_nodes_count(node, key): 146 """Returns the number of input/output nodes of a TRTEngineOp.""" 147 return len(node.attr[key].list.type) 148 149 150def get_trtengineop_node_op_count(graphdef, node_name): 151 """Counts the number of nodes and OP types of a given TRTEngineOp.""" 152 ops_in_engine = collections.defaultdict(int) 153 for func in graphdef.library.function: 154 if f"{node_name}_native_segment" == func.signature.name: 155 node_count = len(func.node_def) 156 for node in func.node_def: 157 ops_in_engine[node.op] += 1 158 break 159 return node_count, ops_in_engine 160 161 162class DTypeIndex(dict): 163 """Helper class to create an index of dtypes with incremental values.""" 164 165 def get_dtype_index(self, dtype): 166 if dtype not in self: 167 self[dtype] = len(self) + 1 168 return self[dtype] 169 170 171def draw_graphdef_as_graphviz(graphdef, dot_output_filename): 172 """Exports a GraphDef to GraphViz format. 173 174 - Step 1: Drawing Each Node of the compute GraphDef. 175 - Step 2: Create nodes for each collected dtype in the graph. 176 - Step 3: Creating invisible links to align properly the legend. 177 178 Each node consequently mentions: 179 - Op Type 180 - Compute Dtype 181 - Compute Device 182 """ 183 184 dtype_index = DTypeIndex() 185 186 with open(dot_output_filename, "w") as f: 187 print("digraph tftrt_converted_graph {", file=f) 188 189 print(" graph [fontsize=10 fontname=\"Verdana\"];", file=f) 190 # ColorScheme Documentation: https://graphviz.org/doc/info/colors.html 191 print( 192 " node [style=filled height=0.55 colorscheme=set312 shape=box];", 193 file=f) 194 195 # Step 1: Parsing the graph and drawing OPs one by one. 196 print("\n subgraph tensorflow_graph {", file=f) 197 print(" node [width=1.35];", file=f) 198 nodes_with_no_inputs = [] 199 for node in graphdef.node: 200 output_name = node.name 201 202 node_precision = get_node_compute_dtype(node) 203 color_idx = dtype_index.get_dtype_index(node_precision) 204 205 device_key = node.device.split("/")[-1] 206 if not device_key: 207 device_key = "device:Unspecified" 208 209 if node.op == "TRTEngineOp": 210 node_count, _ = get_trtengineop_node_op_count(graphdef, output_name) 211 node_label = f"{output_name} [{node_count}]" 212 else: 213 node_label = f"{node.op}" 214 215 # Note: double space before <br/> is necessary for formatting. 216 node_label = f"<b>{node_label}</b> <br/><i>{device_key}</i>" 217 218 print( 219 f" \"{output_name}\" [label=<{node_label}> " 220 f"fillcolor={color_idx}];", 221 file=f) 222 223 if len(node.input): 224 for input_full_name in node.input: 225 parts = input_full_name.split(":") 226 input_name = re.sub(r"^\^", "", parts[0]) 227 print(f" \"{input_name}\" -> \"{output_name}\";", file=f) 228 else: 229 nodes_with_no_inputs.append(output_name) 230 print(" }", file=f) 231 232 # Step 2: Creating the DType Nodes previously found in Step 1. 233 print("\n subgraph cluster_legend {", file=f) 234 print(" label=\"Compute Dtype Legend\";", file=f) 235 print(" margin=\"30\";", file=f) 236 print(" node [width=2];", file=f) 237 238 for dtype, color_idx in dtype_index.items(): 239 print( 240 f" {dtype} [fillcolor={color_idx} label=<<b>{dtype}</b>>];", 241 file=f) 242 243 print(" }", file=f) 244 245 # Step 3: Alignement of the legend with the graph. 246 print("\n edge[style=\"invisible\", dir=\"none\"];", file=f) 247 for dtype in dtype_index.keys(): 248 for node_name in nodes_with_no_inputs: 249 print(f" \"{dtype}\" -> \"{node_name}\"", file=f) 250 251 print("}", file=f) 252 253 print("\n===================================================================") 254 print(f"Graph Visualization Exported to: `{dot_output_filename}`.") 255 print("We recommend using https://edotor.net/ to visualize the .dot file.") 256 print("You can also use `graphviz` utility to convert them to PNG format:") 257 print(" - `sudo apt install -y graphviz`") 258 print(" - `dot -Tpng <input_filename>.dot -o <output_filename>.png`") 259 print("===================================================================\n") 260