1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""Tensor Tracer report generation utilities.""" 16 17import collections 18import hashlib 19import os 20 21from tensorflow.python.platform import gfile 22from tensorflow.python.platform import tf_logging as logging 23from tensorflow.python.tpu import tensor_tracer_pb2 24 25_TRACER_LOG_PREFIX = ' [>>>TT>>>]' 26_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' 27_MARKER_SECTION_END = '!!!!!!! section-end:' 28 29_SECTION_NAME_CONFIG = 'configuration' 30_SECTION_NAME_REASON = 'reason' 31_SECTION_NAME_OP_LIST = 'op-list' 32_SECTION_NAME_TENSOR_LIST = 'tensor-list' 33_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map' 34_SECTION_NAME_GRAPH = 'graph' 35_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' 36 37_FIELD_NAME_VERSION = 'version:' 38_FIELD_NAME_DEVICE = 'device:' 39_FIELD_NAME_TRACE_MODE = 'trace-mode:' 40_FIELD_NAME_SUBMODE = 'submode:' 41_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' 42_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:' 43_FIELD_NAME_NUM_HOSTS = 'num-hosts:' 44_FIELD_NAME_NUM_OPS = 'number-of-ops:' 45_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' 46_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:' 47_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' 48 49_CURRENT_VERSION = 'use-outside-compilation' 50 51_TT_REPORT_PROTO = 'tensor_tracer_report.report_pb' 52 53 54def topological_sort(g): 55 """Performs topological sort on the given graph. 56 57 Args: 58 g: the graph. 59 60 Returns: 61 A pair where the first element indicates if the topological 62 sort succeeded (True if there is no cycle found; False if a 63 cycle is found) and the second element is either the sorted 64 list of nodes or the cycle of nodes found. 65 """ 66 def _is_loop_edge(op): 67 """Returns true if the op is the end of a while-loop creating a cycle.""" 68 return op.type in ['NextIteration'] 69 70 def _in_op_degree(op): 71 """Returns the number of incoming edges to the given op. 72 73 The edge calculation skips the edges that come from 'NextIteration' ops. 74 NextIteration creates a cycle in the graph. We break cycles by treating 75 this op as 'sink' and ignoring all outgoing edges from it. 76 Args: 77 op: Tf.Operation 78 Returns: 79 the number of incoming edges. 80 """ 81 count = 0 82 for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]: 83 if not _is_loop_edge(op): 84 count += 1 85 return count 86 87 sorted_ops = [] 88 op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()} 89 90 frontier = [op for (op, degree) in op_in_degree.items() if degree == 0] 91 frontier.sort(key=lambda op: op.name) 92 while frontier: 93 op = frontier.pop() 94 # Remove the op from graph, and remove its outgoing edges. 95 sorted_ops.append(op) 96 if _is_loop_edge(op): 97 continue 98 # pylint: disable=protected-access 99 consumers = list(op._control_outputs) 100 # pylint: enable=protected-access 101 for out_tensor in op.outputs: 102 consumers += [consumer_op for consumer_op in out_tensor.consumers()] 103 consumers.sort(key=lambda op: op.name) 104 for consumer in consumers: 105 # For each deleted edge shift the bucket of the vertex. 106 op_in_degree[consumer] -= 1 107 if op_in_degree[consumer] == 0: 108 frontier.append(consumer) 109 if op_in_degree[consumer] < 0: 110 raise ValueError('consumer:%s degree mismatch'%consumer.name) 111 112 left_ops = set(op for (op, degree) in op_in_degree.items() if degree > 0) 113 if left_ops: 114 return (True, left_ops) 115 else: 116 assert len(g.get_operations()) == len(sorted_ops) 117 return (False, sorted_ops) 118 119 120class TensorTracerConfig(object): 121 """Tensor Tracer config object.""" 122 123 def __init__(self): 124 self.version = _CURRENT_VERSION 125 self.device_type = None 126 self.num_replicas = None 127 self.num_replicas_per_host = None 128 self.num_hosts = None 129 130 131class TensorTraceOrder(object): 132 """Class that is responsible from storing the trace-id of the tensors.""" 133 134 def __init__(self, graph_order, traced_tensors): 135 self.graph_order = graph_order 136 self.traced_tensors = traced_tensors 137 self._create_tensor_maps() 138 139 def _create_tensor_maps(self): 140 """Creates tensor to cache id maps.""" 141 self.tensorname_to_cache_idx = {} 142 self.cache_idx_to_tensor_idx = [] 143 for out_tensor in self.traced_tensors: 144 tensor_name = out_tensor.name 145 if tensor_name in self.tensorname_to_cache_idx: 146 raise ValueError('Tensor name {} should not be already in ' 147 'tensorname_to_cache_idx'.format(tensor_name)) 148 if tensor_name not in self.graph_order.tensor_to_idx: 149 raise ValueError( 150 'Tensor name {} is not in the tensor_to_idx, tensor_to_idx={} ' 151 .format(tensor_name, self.graph_order.tensor_to_idx)) 152 tensor_idx = self.graph_order.tensor_to_idx[tensor_name] 153 cache_idx = len(self.tensorname_to_cache_idx) 154 self.tensorname_to_cache_idx[tensor_name] = cache_idx 155 self.cache_idx_to_tensor_idx.append(tensor_idx) 156 if len(self.tensorname_to_cache_idx) != len( 157 self.cache_idx_to_tensor_idx): 158 raise RuntimeError( 159 'len(self.tensorname_to_cache_idx) must equal' 160 'len(self.cache_idx_to_tensor_idx), got ' 161 'len(self.tensorname_to_cache_idx)={}, ' 162 'len(self.cache_idx_to_tensor_idx)={}' 163 .format( 164 len(self.tensorname_to_cache_idx), 165 len(self.cache_idx_to_tensor_idx))) 166 167 168def sort_tensors_and_ops(graph): 169 """Returns a wrapper that has consistent tensor and op orders.""" 170 graph_wrapper = collections.namedtuple('GraphWrapper', 171 ['graph', 'operations', 'op_to_idx', 172 'tensors', 'tensor_to_idx', 173 'contains_cycle', 174 'topological_order_or_cycle']) 175 contains_cycle, topological_order_or_cycle = topological_sort(graph) 176 if not contains_cycle: 177 operations = topological_order_or_cycle 178 else: 179 operations = graph.get_operations() 180 op_to_idx = {op.name: index for index, op 181 in enumerate(operations)} 182 tensors = [] 183 for op in operations: 184 tensors.extend(op.outputs) 185 tensor_to_idx = {tensor.name: index for index, tensor in 186 enumerate(tensors)} 187 return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx, 188 tensors=tensors, tensor_to_idx=tensor_to_idx, 189 contains_cycle=contains_cycle, 190 topological_order_or_cycle=topological_order_or_cycle) 191 192 193class OpenReportFile(object): 194 """Context manager for writing report file.""" 195 196 def __init__(self, tt_parameters): 197 if not tt_parameters.report_file_path: 198 self._report_file = None 199 return 200 try: 201 self._report_file = gfile.Open(tt_parameters.report_file_path, 'w') 202 except IOError as e: 203 raise e 204 205 def __enter__(self): 206 return self._report_file 207 208 def __exit__(self, unused_type, unused_value, unused_traceback): 209 if self._report_file: 210 self._report_file.close() 211 212 213def proto_fingerprint(message_proto): 214 serialized_message = message_proto.SerializeToString() 215 hasher = hashlib.sha256(serialized_message) 216 return hasher.hexdigest() 217 218 219class TTReportHandle(object): 220 """Utility class responsible from creating a tensor tracer report.""" 221 222 def __init__(self): 223 self.instrument_records = {} 224 self._report_file = None 225 226 def instrument(self, name, explanation): 227 self.instrument_records[name] = explanation 228 229 def instrument_op(self, op, explanation): 230 self.instrument(op.name, explanation) 231 232 def instrument_tensor(self, tensor, explanation): 233 self.instrument(tensor.name, explanation) 234 235 def create_report_proto(self, tt_config, tt_parameters, tensor_trace_order, 236 tensor_trace_points, collected_signature_types): 237 """Creates and returns a proto that stores tensor tracer configuration. 238 239 Args: 240 tt_config: TensorTracerConfig object holding information about the run 241 environment (device, # cores, # hosts), and tensor tracer version 242 information. 243 tt_parameters: TTParameters objects storing the user provided parameters 244 for tensor tracer. 245 tensor_trace_order: TensorTraceOrder object storing a topological order of 246 the graph. 247 tensor_trace_points: Progromatically added trace_points/checkpoints. 248 collected_signature_types: The signature types collected, e,g, norm, 249 max, min, mean... 250 Returns: 251 TensorTracerReport proto. 252 """ 253 report = tensor_tracer_pb2.TensorTracerReport() 254 report.config.version = tt_config.version 255 report.config.device = tt_config.device_type 256 report.config.num_cores = tt_config.num_replicas 257 report.config.num_hosts = tt_config.num_hosts 258 report.config.num_cores_per_host = tt_config.num_replicas_per_host 259 report.config.submode = tt_parameters.submode 260 report.config.trace_mode = tt_parameters.trace_mode 261 262 for signature_name, _ in sorted(collected_signature_types.items(), 263 key=lambda x: x[1]): 264 report.config.signatures.append(signature_name) 265 266 for tensor in tensor_trace_order.graph_order.tensors: 267 tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef() 268 tensor_def.name = tensor.name 269 if tensor.name in tensor_trace_order.tensorname_to_cache_idx: 270 tensor_def.is_traced = True 271 tensor_def.cache_index = ( 272 tensor_trace_order.tensorname_to_cache_idx[tensor.name]) 273 else: 274 # To prevent small changes affecting the fingerprint calculation, avoid 275 # writing the untraced tensors to metadata. Fingerprints will be 276 # different only when the list of the traced tensors are different. 277 if tt_parameters.use_fingerprint_subdir: 278 continue 279 tensor_def.is_traced = False 280 281 if tensor.name in tensor_trace_points: 282 tensor_def.trace_point_name = tensor_trace_points[tensor.name] 283 if tensor.name in self.instrument_records: 284 tensor_def.explanation = self.instrument_records[tensor.name] 285 elif tensor.op.name in self.instrument_records: 286 tensor_def.explanation = self.instrument_records[tensor.op.name] 287 report.tensordef[tensor.name].CopyFrom(tensor_def) 288 report.fingerprint = proto_fingerprint(report) 289 logging.info('TensorTracerProto fingerprint is %s.', 290 report.fingerprint) 291 tf_graph = tensor_trace_order.graph_order.graph 292 report.graphdef.CopyFrom(tf_graph.as_graph_def()) 293 return report 294 295 def report_proto_path(self, trace_dir, summary_tag_name): 296 """Returns the path where report proto should be written. 297 298 Args: 299 trace_dir: String denoting the trace directory. 300 summary_tag_name: Name of the unique tag that relates to 301 the report. 302 Returns: 303 A string denoting the path to the report proto. 304 """ 305 filename = _TT_REPORT_PROTO + '.' + summary_tag_name.replace('/', '_') 306 return os.path.join(trace_dir, filename) 307 308 def write_report_proto(self, report_path, report_proto, tt_parameters): 309 """Writes the given report proto under trace_dir.""" 310 gfile.MakeDirs(tt_parameters.trace_dir) 311 with gfile.GFile(report_path, 'wb') as f: 312 f.write(report_proto.SerializeToString()) 313 314 def create_report(self, tt_config, tt_parameters, 315 tensor_trace_order, tensor_trace_points): 316 """Creates a report file and writes the trace information.""" 317 with OpenReportFile(tt_parameters) as self._report_file: 318 self._write_config_section(tt_config, tt_parameters) 319 self._write_op_list_section(tensor_trace_order.graph_order) 320 self._write_tensor_list_section(tensor_trace_order.graph_order) 321 self._write_trace_points(tensor_trace_points) 322 self._write_cache_index_map_section(tensor_trace_order) 323 self._write_reason_section() 324 self._write_graph_section(tensor_trace_order.graph_order) 325 326 def _write_trace_points(self, tensor_trace_points): 327 """Writes the list of checkpoints.""" 328 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 329 _SECTION_NAME_TENSOR_TRACER_CHECKPOINT)) 330 for (tensor, checkpoint_name) in tensor_trace_points: 331 self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) 332 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 333 _SECTION_NAME_TENSOR_TRACER_CHECKPOINT)) 334 335 def _write_report(self, content): 336 """Writes the given content to the report.""" 337 338 line = '%s %s'%(_TRACER_LOG_PREFIX, content) 339 if self._report_file: 340 self._report_file.write(line) 341 else: 342 logging.info(line) 343 344 def _write_config_section(self, tt_config, tt_parameters): 345 """Writes the config section of the report.""" 346 347 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) 348 self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version)) 349 self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type)) 350 self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, 351 tt_parameters.trace_mode)) 352 self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, 353 tt_parameters.submode)) 354 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, 355 tt_config.num_replicas)) 356 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST, 357 tt_config.num_replicas_per_host)) 358 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts)) 359 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) 360 361 def _write_reason_section(self): 362 """Writes the reason section of the report.""" 363 364 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) 365 for key in sorted(self.instrument_records): 366 self._write_report('"%s" %s\n'%(key, self.instrument_records[key])) 367 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) 368 369 def _write_op_list_section(self, graph_order): 370 """Writes the Op-list section of the report.""" 371 372 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) 373 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, 374 len(graph_order.operations))) 375 for i in range(0, len(graph_order.operations)): 376 op = graph_order.operations[i] 377 line = '%d "%s" %s'%(i, op.name, op.type) 378 for out_tensor in op.outputs: 379 if out_tensor.name not in graph_order.tensor_to_idx: 380 raise ValueError( 381 'out_tensor is not in tensor_to_idx. out_tensor={}, ' 382 'tensor_to_idx={}' 383 .format(out_tensor.name, graph_order.tensor_to_idx)) 384 line += ' %d'%graph_order.tensor_to_idx[out_tensor.name] 385 line += '\n' 386 self._write_report(line) 387 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) 388 389 def _write_tensor_list_section(self, graph_order): 390 """Writes the tensor-list section of the report.""" 391 392 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 393 _SECTION_NAME_TENSOR_LIST)) 394 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, 395 len(graph_order.tensors))) 396 for i in range(0, len(graph_order.tensors)): 397 tensor = graph_order.tensors[i] 398 line = '%d "%s"'%(i, tensor.name) 399 consumers = tensor.consumers() 400 consumers.sort(key=lambda op: op.name) 401 for consumer_op in consumers: 402 if consumer_op.name not in graph_order.op_to_idx: 403 raise ValueError( 404 'consumer_op is not in op_to_idx. ' 405 'got consumer_op={}, op_to_idx={}' 406 .format(consumer_op.name, graph_order.op_to_idx)) 407 line += ' %d'%graph_order.op_to_idx[consumer_op.name] 408 line += '\n' 409 self._write_report(line) 410 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 411 _SECTION_NAME_TENSOR_LIST)) 412 413 def _write_cache_index_map_section(self, tensor_trace_order): 414 """Writes the mapping from cache index to tensor index to the report.""" 415 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 416 _SECTION_NAME_CACHE_INDEX_MAP)) 417 self._write_report('%s %d\n'%( 418 _FIELD_NAME_NUM_CACHE_INDICES, 419 len(tensor_trace_order.cache_idx_to_tensor_idx))) 420 for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)): 421 tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx] 422 line = '%d %d\n'%(cache_idx, tensor_idx) 423 self._write_report(line) 424 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 425 _SECTION_NAME_CACHE_INDEX_MAP)) 426 427 def _write_graph_section(self, graph_order): 428 """Writes the graph section of the report.""" 429 430 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) 431 self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, 432 not graph_order.contains_cycle)) 433 l = list(graph_order.topological_order_or_cycle) 434 for i in range(0, len(l)): 435 self._write_report('%d "%s"\n'%(i, l[i].name)) 436 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) 437