1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""A utility to trace tensor values on TPU.""" 16 17import collections 18import hashlib 19import operator 20import os 21import os.path 22import sys 23 24import numpy as np 25 26from tensorflow.core.framework import summary_pb2 27from tensorflow.python.eager import monitoring 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import func_graph 31from tensorflow.python.framework import function 32from tensorflow.python.framework import graph_io 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_util 35from tensorflow.python.lib.io import file_io 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import control_flow_util 39from tensorflow.python.ops import gen_math_ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import linalg_ops 42from tensorflow.python.ops import logging_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import nn_impl 45from tensorflow.python.ops import state_ops 46from tensorflow.python.ops import string_ops 47from tensorflow.python.ops import summary_ops_v2 as summary 48from tensorflow.python.ops import variable_scope 49from tensorflow.python.platform import analytics 50from tensorflow.python.platform import gfile 51from tensorflow.python.platform import remote_utils 52from tensorflow.python.platform import tf_logging as logging 53from tensorflow.python.summary import summary_iterator 54from tensorflow.python.tpu import tensor_tracer_flags 55from tensorflow.python.tpu import tensor_tracer_report 56from tensorflow.python.tpu import tpu 57from tensorflow.python.tpu.ops import tpu_ops 58from tensorflow.python.training import training_util 59 60_DEVICE_TYPE_TPU = 'tpu' 61_DEVICE_TYPE_CPU = 'cpu' 62_TRACE_MODE_PART_TENSOR_SIZE = 3 63 64_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' 65_REASON_UNSAFE_OP = 'not-traced-unsafe-op' 66_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' 67_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op' 68_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow' 69_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' 70_REASON_SKIP_SCALAR = 'not-traced-scalar' 71_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' 72_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' 73_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' 74_REASON_SCALAR_GET_TRACED = 'traced-scalar' 75_REASON_TENSOR_GET_TRACED = 'traced-tensor' 76_REASON_USER_INCLUDED = 'traced-user-included' 77_REASON_USER_EXCLUDED = 'not-traced-user-excluded' 78_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' 79_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' 80_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' 81 82_OUTPUT_STREAM_ESCAPE = 'file://' 83_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' 84TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers' 85_TRACE_FILE_NAME = 'trace.all' 86_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' 87_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 88_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' 89_TT_SNAPSHOT = 'tensor_tracer_snapshot' 90_REPLICA_ID_TAG = '#replica-id: ' 91_SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None 92 93_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM 94_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX 95_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS 96_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN 97_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN 98_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR 99_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE 100_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY 101 102_TT_SUMMARY_TAG = 'tensor_tracer_summary' 103_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' 104_TT_HOSTCALL_KEY = 'tensor_tracer_host_call' 105_TT_EVENT_FILE_SUFFIX = '.tensor_tracer' 106 107_TT_SUMMARY_MAX_QUEUE = 10 108 109tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1', 110 'tensor tracer usage', 'method') 111 112 113def _graph_summary_tag(graph): 114 """Generates and returns a summary tag name for the given graph.""" 115 116 if graph is None: 117 raise RuntimeError('graph is None') 118 # The chance of collision with md5 is effectively 0. 119 hash_id = hashlib.md5() 120 hash_id.update(repr(graph).encode('utf-8')) 121 # hexdigest() returns a string. 122 return hash_id.hexdigest() 123 124 125def set_parameters(tensor_tracer_params=None): 126 """Enables tensor tracer and sets its parameters. 127 128 Example usage: 129 tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir', 130 'trace_mode': 'norm', 131 'report_file': '/usr/tmp/trace_dir/report.all'} 132 tensor_tracer.set_parameters(tensor_tracer_parameters) 133 134 This sets up the parameters for tensor tracer. A call to tensor tracer as 135 below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be 136 skipped as this call is hooked into tpu.rewrite. 137 tt = tensor_tracer.TensorTracer() 138 loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss) 139 140 Args: 141 tensor_tracer_params: Tensor tracer parameter dictionary. Below gives 142 examples of these parameters: See tensor_tracer_report.py for all 143 parameters. 144 - enable: If set, tensor tracer will be enabled. Calling 145 enable_tensor_tracer automatically adds this parameters. 146 - trace_mode: The trace_mode to be used by tensor tracer. These include: 147 - summary: Collects multiple statistics for traced tensors, and writes 148 them a summary file that can be visualized using tensorboard. This 149 mode currently only works for TPUEstimator. It can be also be used 150 for other models, but outfeed must be handled by the user. 151 - norm: Collects norm of each traced tensor and writes them into a 152 text file pointed by 'trace_dir' flag. (Default mode). 153 - nan-inf: Checks the existince of NaNs and Infs in the tensor, and 154 writes a boolean value to a text file pointed by 'trace_dir' flag. 155 Note that 'norm' mode can also capture this information with more 156 numerical info. 157 - max-abs: Collects the absolute max for each traced tensors and 158 writes it into a text file pointed by 'trace_dir' flag. 159 - full-tensor: Writes the full tensor content of the traced tensors 160 into a text file pointed by 'trace_dir' flag. 161 - part-tensor: Writes a part of the tensor content of the traced 162 tensors into a text file pointed by 'trace_dir' flag. 163 - full_tensor_summary: Writes the full tensors as binary event files. 164 The outputs can be read using: trace = 165 tensor_tracer.read_tensor_tracer_event_file(event_file_path) 166 167 - report_file: Path to the metadata file that is written during graph 168 construction. If not set, metadata will be printed to stdout during 169 graph construction. 170 - trace_dir: Path where the execution traces will be written during the 171 graph execution. If not set, trace will be printed to stderr. 172 - trace_level: Tensor tracer aims to trace everything it can. This 173 introduces some overhead on graph execution and graph compilation 174 times. Using trace_level parameter, it is possible to trace operation 175 based on their priorities. For example, - trace_level=7 is the highest 176 trace_level, in which every op is traced. - trace_level=6 will skip 177 constant operations such as tf.constant. - trace_level=5 will skip 178 less important ops such as tf.identities. - The default trace_level=3, 179 that will skip concat ops, or random number generators. - To reduce 180 the graph compile time overhead, trace_level can be set to 0, that 181 will skip additions, and substractions, and multiplications as well. 182 - excluded_opnames: If set, any matching op name will not be traced. 183 excluded_opnames can be set as a regular expression. E.g, 184 excluded_opnames=.* will exclude everything. 185 - excluded_optypes: If set, any matching op type will not be traced. 186 excluded_optypes can be set as a regular expression. E.g, 187 excluded_optypes=.* will exclude everything. excluded_optypes=MatMul 188 will exclude all MatMul ops from tracing. 189 - included_opnames: If set, any matching op name will be forced to be 190 traced. included_opnames can be set as a regular expression. E.g, 191 '--included_opnames=some_op --excluded_opname=*.' will only trace 192 some_op. 193 - included_optypes: If set, any matching op type will be forced to be 194 traced. included_optypes can be set as a regular expression. E.g, 195 '--included_optypes=some_op_type --excluded_optypes=*.' will trace 196 only the ops with type 'some_op_type' 197 - flush_summaries: If summary mode is used, flush_summaries=1 will 198 flush summaries using outside compilation. Note that, if used with 199 low level APIs, flush_summaries=1 is necessary to obtain results. 200 Advanced Flags: 201 - trace_scalar: Scalar values are not traced by default. If this flag is 202 set, scalar values will also be traced. 203 - op_range: In the form of '%d:%d' that limits the tracing to the ops 204 within this limit. --op_range='5:10' will trace only the ops that have 205 topological order between 5-10. 206 - submode: 'brief' or 'detailed'. If the trace mode is not compact, 207 brief mode will print only the id of each traced tensor to save some 208 space. 'detailed' mode prints the full tensor name. 209 - use_fingerprint_subdirectory: The trace directory will be chosen as 210 using the fingerprint of the trace metadata under the provided 211 trace_dir. 212 """ 213 flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE 214 if tensor_tracer_params: 215 for key, value in tensor_tracer_params.items(): 216 flags += ' --%s=%s' % (key, value) 217 os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags 218 219 220def op_priority(op_type): 221 """Returns the priority of the op. 222 223 If the priority of the op is k, it will be traced if trace_level>=k. 224 Args: 225 op_type: String name of the operation type. 226 Returns: 227 Integer value corresponding the priority of the op. 228 """ 229 if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range', 230 'VariableShape', 'Fill', 'OneHot', 'ShapeN'): 231 # Lowest priority ops, e.g., constant ops across different steps, 232 # They will be traced only if trace_level>=7 233 return 7 234 235 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient', 236 'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'): 237 # Operations without numerical effects. 238 # They will be only if trace_level>=6 239 return 6 240 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile', 241 'CollectivePermute', 'SplitV', 'DynamicPartition'): 242 # Operations that merge or slice an input, will be traced if trace_level>=5 243 return 5 244 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'): 245 # Operations less likely to provide useful information, 246 # will be traced if trace_level>=4 247 return 4 248 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'): 249 # Add operations that are less likely create any issues, will be traced 250 # if trace_level>=3 (default=3) 251 return 3 252 if op_type in ('Neg', 'Sub'): 253 # Sub operations that are less likely create any issues, will be traced 254 # trace_level>=2 255 return 2 256 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select', 257 'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'): 258 # Multiplication and some other operations, will be traced if trace_level>=1 259 return 1 260 261 # Unclassified op_types default to being traced at level 2 and above. 262 return 2 263 264 265def read_tensor_tracer_event_file(event_file): 266 """Reads the event file written by tensor tracer. 267 268 This can be used to read the full tensors written into binary event files by 269 by TensorTracer with trace_mode=full_tensor_summary. 270 271 Example usage: 272 result_dict_list = tensor_tracer.read_tensor_tracer_event_file( 273 event_file_path) 274 for result_dict in result_dict_list: 275 for step, tensor_dict in result_dict.items(): 276 for tensor_name, full_tensor_content in tensor_dict.items(): 277 logging.info(tensor_name, full_tensor_content) 278 279 Args: 280 event_file: Path to the event file that contains only tensor tracer events. 281 Returns: 282 A list of event dictionaries, each of which with the form: 283 {step_number: {tensor_name: tensor_content}}. This is a list instead of 284 a single event dictionary because it is possible that an event file may 285 have multiple event traces, each of them covering the same step ranges. 286 Raises: 287 ValueError: If an unexpected trace is found. 288 """ 289 290 # Keeps track of how many times that a step number shows up in these events. 291 step_occurrence_count = collections.defaultdict(int) 292 293 # List of step occurrences. 294 step_occurrence_list = [] 295 296 for trace_event in summary_iterator.summary_iterator(event_file): 297 # First event is an event with file_version: "brain.Event:2" 298 if not trace_event.HasField('summary'): 299 continue 300 if len(trace_event.summary.value) != 1: 301 raise ValueError('Single step contains %d summary values,' 302 ' expected 1.' % len(trace_event.summary.value)) 303 step = trace_event.step 304 step_occurrence_count[step] += 1 # a new occurrence for this step. 305 306 occurrence_idx = step_occurrence_count[step] - 1 307 occurrence_size = len(step_occurrence_list) 308 309 if occurrence_idx == occurrence_size: 310 # This particular occurrence isn't yet recorded on step_occurrence_list. 311 # So append this new occurrence to the end of step_occurrence_list. 312 new_occurrence = collections.defaultdict(dict) 313 step_occurrence_list.append(new_occurrence) 314 else: 315 # This particular occurrence must be already recorded on 316 # step_occurrence_list (i.e. occurrence_idx < occurrence_size). 317 if occurrence_idx > occurrence_size: 318 raise ValueError('Unexpected: occurrence_idx (%d) > ' 319 'occurrence_size (%d)' % (occurrence_idx, 320 occurrence_size)) 321 tensor_value = trace_event.summary.value[0] 322 tensor_name = tensor_value.tag 323 324 real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim] 325 tensor_content = np.frombuffer( 326 tensor_value.tensor.tensor_content, 327 dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype() 328 ).reshape(real_shape) 329 step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content 330 return step_occurrence_list 331 332 333def trace_tensor(tensor, tracepoint_name=None): 334 """Programmatic interface to trace a tensor with Tensor Tracer. 335 336 Tensor Tracer, by default, traces all tensors in the execution. This function 337 can be used to limit traced tensors. If this function is called for a subset 338 of the tensors, only those will be traced. 339 340 For example, Tensor Traacer will only trace c below. 341 c = tf.MatMul(a, b) 342 tensor_tracer.trace_tensor(c) 343 d = tf.add(c, 1) 344 Args: 345 tensor: the tensor object for which the tracing is requested. 346 tracepoint_name: an optional tensor tracepoint name string. A tracepoint 347 name is an Tensor Tracer internal name for the tensor. It is useful when 348 comparing equivalent traces from different models that have different 349 tensor namings. Equivalent tensors (with different names) can be mapped 350 to each other by assigning a common tracepoint_name. 351 352 Returns: 353 The provided tensor. 354 """ 355 if tracepoint_name is None: 356 tracepoint_name = tensor.name 357 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) 358 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, 359 (tensor, tracepoint_name)) 360 return tensor 361 362 363def keras_layer_tracepoint(layer, checkpoint_name): 364 """An interface for adding the tensor outputs of a keras layer. 365 366 Encapsulates trace_tensor. 367 368 Args: 369 layer: A keras layer. 370 checkpoint_name: a string name for the checkpoint. This name has to be a 371 unique name if used within model comparison. The tensors that have the same 372 checkpoint identifier is compared in model comparison. 373 374 Returns: 375 The provided layer. 376 """ 377 try: 378 outputs = layer.output 379 if tensor_util.is_tf_type(outputs): 380 trace_tensor(outputs, '%s' % (checkpoint_name)) 381 else: 382 idx = 0 383 for output_tensor in outputs: 384 if tensor_util.is_tf_type(outputs): 385 trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx)) 386 idx += 1 387 except AttributeError: 388 pass 389 except RuntimeError: 390 pass 391 return layer 392 393 394class TensorTracer: 395 """A software construct for tracing tensor values in a TF graph. 396 397 This utility is disabled by default. It is hooked into tpu.rewrite, so it can 398 easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as 399 below without a code change. 400 export TENSOR_TRACER_FLAGS="--enable=1" 401 402 Below is the use example to enable it on CPUs or GPUs, or for more advance use 403 cases on TPUs. 404 405 a = x + 1 406 b = a * 2 407 rs = tf.reduce_sum(b) 408 tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir', 409 'report_file: 'path/to/report/file'}) 410 tt = tensor_tracer.TensorTracer() 411 if on_tpu: 412 rs = tt.trace_tpu(tf.get_default_graph(), 413 tensor_fetches=rs) 414 else: 415 rs = tt.trace_cpu(tf.get_default_graph(), 416 tensor_fetches=rs) 417 session.run(rs) 418 419 If it is enabled, it will trace the output tensor values of 420 selected Ops in the graph. It has two outputs: (1) the traces and (2) 421 a report. The traces are dumped to a specified directory during the graph 422 execution, while the report is dumped during the graph construction. 423 By passing options via the env variable, users can change: 424 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or 425 full tensor values) 426 (2) which Ops to be traced (via op.name or op.type) 427 (3) output trace file path. 428 429 """ 430 # The set of graphs that are rewritten by tensor tracer. 431 _traced_graphs = set() 432 433 @staticmethod 434 def is_enabled(): 435 """Returns True if TensorTracer is enabled.""" 436 try: 437 enable = tensor_tracer_flags.TTParameters().is_enabled() 438 # Add metrics to determine API usage. 439 if enable: tt_gauge.get_cell('is_enabled').set(True) 440 return enable 441 except (ValueError, RuntimeError) as e: 442 logging.warning( 443 'Tensor Tracer V1 flags processing error encountered in is_enabled ' 444 'check. %s', e) 445 # TODO(b/210212559): Find a more robust fix. 446 # Should only produce exception if Tensor Tracer is enabled. 447 return True 448 449 @staticmethod 450 def check_device_type(device_type): 451 """Checks if the given device type is valid.""" 452 453 if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU): 454 raise ValueError('Invalid device_type "%s"'%device_type) 455 456 @staticmethod 457 def check_trace_mode(device_type, trace_mode): 458 """Checks if the given trace mode work on the given device type. 459 460 Args: 461 device_type: Device type, TPU, GPU, CPU. 462 trace_mode: Tensor tracer trace mode. 463 Raises: 464 ValueError: If the given trace mode is not supported for the device. 465 """ 466 if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY: 467 if device_type != _DEVICE_TYPE_TPU: 468 raise ValueError('Device_type "%s" is not yet supported for ' 469 'trace mode "%s"' % (device_type, trace_mode)) 470 471 @staticmethod 472 def loop_cond_op(op): 473 return op.type in ('LoopCond', 'RefLoopCond') 474 475 @staticmethod 476 def while_loop_op(op): 477 """Returns true if op is one of the special ops of in a while loop. 478 479 Args: 480 op: A tf.Operation. 481 482 Returns: 483 True if the given op is one of [Switch, Merge, Enter, Exit, 484 NextIteration, LoopCond], which are all building blocks for TF while 485 loops. 486 """ 487 return (control_flow_util.IsLoopSwitch(op) or 488 control_flow_util.IsLoopMerge(op) or 489 control_flow_util.IsLoopEnter(op) or 490 control_flow_util.IsLoopExit(op) or 491 TensorTracer.loop_cond_op(op) or 492 op.type in ('RefNextIteration', 'NextIteration')) 493 494 @staticmethod 495 def control_flow_op(op): 496 """Returns true if op is one of the special ops of in a while loop. 497 498 Args: 499 op: A tf.Operation. 500 501 Returns: 502 True if the given op is one of [Switch, Merge, Enter, Exit, 503 NextIteration, LoopCond], which are all building blocks for TF while 504 loops. 505 """ 506 return (control_flow_util.IsSwitch(op) or 507 control_flow_util.IsMerge(op)) 508 509 @staticmethod 510 def unsafe_op(op): 511 """Returns True if this op is not safe to be traced.""" 512 513 # Reasons for not including following op types: 514 # Assign: cause incorrect result with CPU tracing. 515 if op.type == 'Assign': 516 return True 517 return False 518 519 @staticmethod 520 def device_mismatch(device_type, op): 521 if device_type == _DEVICE_TYPE_TPU: 522 # pylint: disable=protected-access 523 return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr 524 # pylint: enable=protected-access 525 return False 526 527 @staticmethod 528 def unsafe_scalar_trace(op): 529 """Return true if scalar output tensor from Op is not safe to be traced.""" 530 531 # Tracing the following causes cycle in the graph on TPU. 532 if op.type in ('LoopCond', 'Enter', 'Merge', 'Const', 533 'Switch', 'Less', 'ReadVariableOp'): 534 return True 535 # Tracing the following will cause casting-issue 536 # with the norm tracing mode or other compilation issues on CPU. 537 if op.type in ('VarHandleOp', 'IteratorToStringHandle', 538 'IteratorGetNext', 'OneShotIterator', 539 'IteratorV2', 'MakeIterator', 540 'BatchDatasetV2', 'MapDataset', 541 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 542 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'): 543 return True 544 return False 545 546 def _is_interesting_op(self, op): 547 """Returns True if the given op is not an interesting one to be traced.""" 548 return op_priority(op.type) <= self._parameters.trace_level 549 550 @staticmethod 551 def reason(op_idx, details): 552 """Returns reason why the Op at op_idx is traced or not.""" 553 554 return '%d %s'%(op_idx, details) 555 556 def __init__(self): 557 """Initializes a TensorTracer. 558 559 Sets the various member fields from the flags (if given) or the defaults. 560 """ 561 self._replica_id = None 562 self._tt_config = tensor_tracer_report.TensorTracerConfig() 563 self._parameters = None 564 self._host_call_fn = {} 565 # _cache_variables is a dict (key = graph, value = dicts 566 # (key = name, value = tensors)) 567 self._cache_variables = {} 568 self._traced_op_names = set() 569 self._report_proto = None 570 # _temp_cache_var is a dict (key = graph, value = []) 571 self._temp_cache_var = {} 572 self._report_proto_path = '' 573 self._outmost_context = None 574 575 def report_proto(self): 576 """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. 577 578 Returns: 579 A tensor_tracer.proto object. 580 Raises: 581 ValueError if called before tracing happens, or when trace mode is not 582 summary or full_tensor_summary. 583 """ 584 if self._report_proto: 585 return self._report_proto 586 else: 587 raise ValueError('Call to report_proto must be done after tracing.' 588 'Report proto only exists for ' 589 'trace_mode=[summary|full_tensor_summary]') 590 591 def report_proto_path(self): 592 """Getter for path where tensor_tracer.proto object should be written. 593 594 Returns: 595 A string path. 596 """ 597 return self._report_proto_path 598 599 def _cache_variable_for_graph(self, graph): 600 if graph not in self._cache_variables: 601 self._cache_variables[graph] = {} 602 return self._cache_variables[graph] 603 604 def _create_or_get_tensor_values_cache(self, cache_name, graph, 605 shape=None, dtype=dtypes.float32): 606 """Creates a variable as the cache to store intermediate tensor values. 607 608 Args: 609 cache_name: Name to be given to the cache (an instance of tf.variable). 610 graph: Tensorflow graph. 611 shape: A list of dimensions. 612 dtype: Data type of created cache. 613 Returns: 614 A ref to newly created or existing cache with the given dimensions. 615 Raises: 616 ValueError: 617 (1) If graph is None, or 618 (2) shape is None when a new cache needs to be created. 619 """ 620 621 def _escape_namescopes(variable_name): 622 # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor" 623 # and "foo_bar/mytensor". 624 return variable_name.replace('/', '_').replace(':', '_') 625 626 if graph is None: 627 raise ValueError('Invalid graph.') 628 629 graph_cache_var = self._cache_variable_for_graph(graph) 630 631 if cache_name not in graph_cache_var: 632 if shape is None: 633 raise ValueError('shape must be provided at cache creation.') 634 if dtype.is_integer: 635 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) 636 else: 637 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE 638 639 # Create in proper graph and base name_scope. 640 with graph.as_default() as g, g.name_scope(None): 641 graph_cache_var[cache_name] = variable_scope.get_variable( 642 _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name), 643 shape=shape, dtype=dtype, 644 initializer=init_ops.constant_initializer(init_val), 645 trainable=False, 646 use_resource=True, 647 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES]) 648 return graph_cache_var[cache_name] 649 650 def _add_replica_id_to_graph(self): 651 """Adds nodes for computing the replica ID to the graph.""" 652 653 if self._tt_config.num_replicas: 654 with ops.control_dependencies(None): 655 # Uses None as dependency to run outside of TPU graph rewrites. 656 self._replica_id = tpu_ops.tpu_replicated_input( 657 list(range(self._tt_config.num_replicas)), 658 name='tt_replica_id') 659 else: 660 self._replica_id = 'unknown' 661 662 def _inside_op_range(self, idx): 663 """Return True if the given index is inside the selected range.""" 664 665 if idx < self._parameters.op_range[0]: 666 return False 667 return (self._parameters.op_range[1] < 0 or 668 idx <= self._parameters.op_range[1]) 669 670 def _is_user_included_op(self, op): 671 """Checks whether the op is included in the tensor tracer flags. 672 673 Args: 674 op: tf Operation 675 Returns: 676 True, if the op is included. 677 An op is included if: 678 - Its op name is given in included_opnames 679 - Its op type is given in included_optypes 680 - The op is at most _trace_ops_before_included hops before an included op 681 - The op is at most _trace_ops_after_included hops after an included op 682 """ 683 for opname_re in self._parameters.included_opname_re_list: 684 if opname_re.match(op.name): 685 return True 686 687 for optype_re in self._parameters.included_optype_re_list: 688 if optype_re.match(op.type): 689 return True 690 return False 691 692 def _is_user_excluded_op(self, op): 693 for opname_re in self._parameters.excluded_opname_re_list: 694 if opname_re.match(op.name): 695 return True 696 for optype_re in self._parameters.excluded_optype_re_list: 697 if optype_re.match(op.type): 698 return True 699 return False 700 701 def _signature_types(self): 702 """Returns a dictionary holding the order of signatures in the cache for the selected trace mode.""" 703 if self._parameters.trace_mode in set([ 704 tensor_tracer_flags.TRACE_MODE_NAN_INF, 705 tensor_tracer_flags.TRACE_MODE_NORM, 706 tensor_tracer_flags.TRACE_MODE_MAX_ABS]): 707 return {self._parameters.trace_mode: 0} 708 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 709 return self._parameters.summary_signatures 710 return {} 711 712 def _num_signature_dimensions(self): 713 return len(self._signature_types()) 714 715 def _use_temp_cache(self): 716 """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable. 717 718 Returns: 719 A boolean, denoting whether to use a temporary cache or not. 720 """ 721 # If full tensors need to be stored tf.variables, then do not use temp 722 # variables to store them. 723 if self._use_tensor_buffer(): 724 return False 725 if self._use_tensor_values_cache(): 726 return self._parameters.use_temp_cache_var 727 else: 728 # Temporary caches only replaces tf.Variables caches. If no cache is used 729 # return False. 730 return False 731 732 def _use_tensor_values_cache(self): 733 """Returns True if immediate tensors should be first saved to a cache.""" 734 return self._parameters.use_compact_trace 735 736 def _use_tensor_buffer(self): 737 """Returns true if the whole tensor needs to be cached/buffered in memory.""" 738 return (self._parameters.trace_mode == 739 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 740 741 def _merge_tensor_signatures(self, signatures): 742 """Returns a tensor that merges the given signatures. 743 744 Args: 745 signatures: A dictionary of the signature updates from signature name to 746 a tensor of dimension [1]. 747 Returns: 748 A tensor that concats the signature values in a predefined order. 749 Raises: 750 ValueError: Unable to merge signatures. 751 """ 752 sorted_update = [] 753 if self._num_signature_dimensions() > 1: 754 signature_indices = self._signature_types() 755 for _, val in sorted(signatures.items(), 756 key=lambda item: signature_indices[item[0]]): 757 sorted_update.append(val) 758 updates = array_ops.stack( 759 sorted_update, axis=0, name='merge_single_op_signatures') 760 elif self._num_signature_dimensions() == 1: 761 # Avoid stack operation if there is only a single signature. 762 (_, val), = signatures.items() 763 updates = val 764 else: 765 raise ValueError('Cannot merge 0 signatures. Check the value passed for ' 766 'flag --signatures.') 767 return updates 768 769 def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph): 770 """Returns an op that will save the given updates to an entry in the cache. 771 772 Args: 773 cache_idx: The cache index of the tensor within the cache. 774 updates: A dictionary of the signature updates from signature name to 775 a tensor of dimension [1]. 776 graph: A TensorFlow graph. 777 Raises: 778 RuntimeError: 779 (1) graph is not already in self._temp_cache_var, or 780 (2) cache_idx is out of range. 781 """ 782 updates = self._merge_tensor_signatures(updates) 783 updates = array_ops.reshape(updates, 784 [self._num_signature_dimensions()]) 785 if graph not in self._temp_cache_var: 786 raise RuntimeError('graph is not in self._temp_cache_var') 787 if cache_idx >= len(self._temp_cache_var[graph]): 788 raise RuntimeError('cache_idx (%d) is out of range (%d)' % ( 789 cache_idx, len(self._temp_cache_var[graph]))) 790 self._temp_cache_var[graph][cache_idx] = updates 791 792 def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph): 793 """Returns an op that will save the given updates to an entry in the cache. 794 795 Args: 796 cache_idx: The cache index of the tensor within the cache. 797 updates: A dictionary of the signature updates. 798 graph: A TensorFlow graph. 799 Returns: 800 Cache update operation. 801 """ 802 # state_ops.scatter_update allows updates only along the first dimension. 803 # Make a compact array by concatenating different signatures, and update 804 # them all together. 805 updates = self._merge_tensor_signatures(updates) 806 updates = array_ops.reshape(updates, 807 [1, self._num_signature_dimensions()]) 808 indices = constant_op.constant([cache_idx]) 809 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph) 810 return state_ops.scatter_update(cache, indices, updates).op 811 812 def _snapshot_tensor(self, tensor): 813 """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable. 814 815 Args: 816 tensor: tensor whose values will be stored in a new tf.Variable. 817 Returns: 818 An assignment operation. 819 """ 820 821 snapshot_variable = self._create_or_get_tensor_values_cache( 822 tensor.name, tensor.op.graph, 823 tensor.shape.as_list(), tensor.dtype) 824 return state_ops.assign(snapshot_variable, tensor).op 825 826 def _preprocess_traced_tensor(self, tensor): 827 """Computes NAN/Norm/Max on TPUs before sending to CPU. 828 829 Args: 830 tensor: The tensor to be traced. 831 Returns: 832 A tensor that should be input to the trace_function. 833 Raises: 834 RuntimeError: If the signature is invalid. 835 """ 836 837 def _detect_nan_inf(tensor): 838 """Trace function for detecting any NaN/Inf in the tensor.""" 839 840 if tensor.dtype.is_floating: 841 mask = math_ops.reduce_any( 842 gen_math_ops.logical_or( 843 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) 844 output_tensor = control_flow_ops.cond( 845 mask, 846 lambda: constant_op.constant([1.0]), 847 lambda: constant_op.constant([0.0])) 848 else: 849 output_tensor = constant_op.constant([0.0]) 850 return output_tensor 851 852 def _compute_signature(tensor, tf_op, cast_to_f32=True): 853 if cast_to_f32: 854 tensor = math_ops.cast(tensor, dtypes.float32) 855 output_tensor = tf_op(tensor) 856 # Return type should be scalar. Set it if it does not have the 857 # information. 858 if not output_tensor.get_shape().is_fully_defined(): 859 output_tensor = array_ops.reshape(output_tensor, []) 860 return output_tensor 861 862 def _show_size(tensor): 863 # In order to check the size of a tensor. 864 # Not all sizes are known at the compile time, also, different replicas 865 # sometimes get different sizes of tensors. 866 # Collect it here to be used in merging replica data. 867 tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False) 868 # Cast to float32, so that it can be placed into same cache with other 869 # signatures. 870 return math_ops.cast(tsize, dtypes.float32) 871 872 def _show_max(tensor, cast_to_f32=True): 873 # returns -inf for empty tensor 874 return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32) 875 876 def _show_min(tensor, cast_to_f32=True): 877 # returns inf for empty tensor 878 return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32) 879 880 def _show_norm(tensor, cast_to_f32=True): 881 # returns 0 for empty tensor 882 return _compute_signature(tensor, linalg_ops.norm, cast_to_f32) 883 884 def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06): 885 # returns nan for empty tensor and treats nans as non-zero numbers 886 def sparsity_fn(tensor): 887 non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance) 888 nans = math_ops.is_nan(tensor) 889 return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans)) 890 891 return _compute_signature(tensor, sparsity_fn, cast_to_f32) 892 893 def _show_mean_and_variance(tensor, cast_to_f32=True): 894 """Returns the mean and variance of the given tensor.""" 895 if cast_to_f32: 896 tensor = math_ops.cast(tensor, dtypes.float32) 897 # returns nan for empty tensor 898 mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0]) 899 # The shape has to be 1. Set it if it does not have the information. 900 if not mean.get_shape().is_fully_defined(): 901 mean = array_ops.reshape(mean, []) 902 if not var.get_shape().is_fully_defined(): 903 var = array_ops.reshape(var, []) 904 return mean, var 905 906 def _show_max_abs(tensor, cast_to_f32=True): 907 return _compute_signature( 908 tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32) 909 910 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 911 return {self._parameters.trace_mode: _detect_nan_inf(tensor)} 912 if (self._parameters.trace_mode == 913 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 914 return {self._parameters.trace_mode: tensor} 915 if (self._parameters.trace_mode in ( 916 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 917 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)): 918 return {self._parameters.trace_mode: tensor} 919 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM: 920 return {self._parameters.trace_mode: array_ops.reshape( 921 _show_norm(tensor), [1])} 922 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS: 923 return {self._parameters.trace_mode: _show_max_abs(tensor)} 924 925 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 926 tensor = math_ops.cast(tensor, dtypes.float32) 927 result_dict = {} 928 # Call mean and variance computation here to avoid adding the same nodes 929 # twice. 930 if (_TT_SUMMARY_MEAN in self._signature_types() or 931 _TT_SUMMARY_VAR in self._signature_types()): 932 mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False) 933 934 for signature_name, _ in sorted(self._signature_types().items(), 935 key=lambda x: x[1]): 936 if signature_name == _TT_SUMMARY_NORM: 937 signature_result_tensor = _show_norm(tensor, cast_to_f32=False) 938 elif signature_name == _TT_SUMMARY_MAX: 939 signature_result_tensor = _show_max(tensor, cast_to_f32=False) 940 elif signature_name == _TT_SUMMARY_MAX_ABS: 941 signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False) 942 elif signature_name == _TT_SUMMARY_MIN: 943 signature_result_tensor = _show_min(tensor, cast_to_f32=False) 944 elif signature_name == _TT_SUMMARY_SPARSITY: 945 signature_result_tensor = _show_sparsity(tensor) 946 elif signature_name == _TT_SUMMARY_SIZE: 947 signature_result_tensor = _show_size(tensor) 948 elif signature_name == _TT_SUMMARY_MEAN: 949 signature_result_tensor = mean 950 elif signature_name == _TT_SUMMARY_VAR: 951 signature_result_tensor = variance 952 else: 953 raise ValueError('Unknown signature type :%s.' % signature_name) 954 955 result_dict[signature_name] = signature_result_tensor 956 return result_dict 957 958 raise RuntimeError( 959 'Unsupported signature for trace mode %s.' 960 % self._parameters.trace_mode) 961 962 def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order): 963 """Makes the tensor tracing function called by outside compilation. 964 965 Args: 966 tensor_name: name of the tensor being traced. 967 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 968 Returns: 969 A function to be passed as the first argument to outside compilation. 970 971 Raises: 972 RuntimeError: If the trace mode is invalid. 973 """ 974 975 def _print_tensor(tensor_name, num_elements, tensor, output_tensor): 976 """Prints a tensor value to a file. 977 978 Args: 979 tensor_name: name of the tensor being traced. 980 num_elements: number of elements to print (-1 means print all). 981 tensor: the tensor needs to be returned. 982 output_tensor: the tensor needs to be printed. 983 984 Returns: 985 The same tensor passed via the "tensor" argument. 986 987 Raises: 988 ValueError: If tensor_name is not already in 989 tensor_trace_order.tensorname_to_cache_idx. 990 """ 991 992 if self._parameters.is_brief_mode(): 993 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 994 raise ValueError( 995 'Tensor %s with name %s is not in the tensorname_to_cache_idx' % 996 (tensor, tensor_name)) 997 msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name] 998 else: 999 msg = '"%s"' % tensor_name 1000 1001 if self._parameters.trace_dir: 1002 output_path = os.path.join( 1003 self._parameters.trace_dir, 1004 _TRACE_FILE_NAME + self._get_outfile_suffix()) 1005 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 1006 else: 1007 output_stream = sys.stderr 1008 return logging_ops.print_v2(msg, array_ops.shape(output_tensor), 1009 '@', self._replica_id, 1010 '\n', output_tensor, '\n', 1011 summarize=num_elements, 1012 output_stream=output_stream) 1013 1014 def _show_part_tensor(tensor): 1015 """Trace function for printing part of the tensor.""" 1016 1017 return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE, 1018 tensor, tensor) 1019 1020 def _show_full_tensor(tensor): 1021 """Trace function for printing the entire tensor.""" 1022 1023 return _print_tensor(tensor_name, -1, tensor, tensor) 1024 1025 if (self._parameters.trace_mode == 1026 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 1027 return _show_part_tensor 1028 # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF, 1029 # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are 1030 # performed within TPUs and only their results are transferred to CPU. 1031 # Simply, print the full tensor for these trace modes. 1032 if self._parameters.trace_mode in ( 1033 tensor_tracer_flags.TRACE_MODE_NAN_INF, 1034 tensor_tracer_flags.TRACE_MODE_NORM, 1035 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 1036 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 1037 tensor_tracer_flags.TRACE_MODE_SUMMARY 1038 ): 1039 return _show_full_tensor 1040 1041 raise RuntimeError('Full tensor support is not available with trace mode %s' 1042 %self._parameters.trace_mode) 1043 1044 def _is_in_control_flow(self, op): 1045 """Returns true if the given op is inside a tf.cond or in tf.while_loop. 1046 1047 Args: 1048 op: A tensorflow op that should be checked whether in control flow or not. 1049 Returns: 1050 A boolean value whether the op is in control flow or not. 1051 """ 1052 return control_flow_util.IsInCond(op) 1053 1054 def _is_in_outmost_while_loop(self, op): 1055 """Returns true if the op is at the same level with the training loop. 1056 1057 Returns false if the op is in an inner while loop or if it is outside of the 1058 training loop. 1059 Args: 1060 op: tf.Operation 1061 1062 Returns: 1063 A boolean. 1064 """ 1065 ctxt = self._get_op_control_flow_context(op) 1066 outer_while_context = control_flow_util.GetContainingWhileContext(ctxt) 1067 return outer_while_context == control_flow_util.GetContainingWhileContext( 1068 self._outmost_context) 1069 1070 def _should_trace_in_control_flow(self): 1071 """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop.""" 1072 # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY 1073 # forces the execution of the traced tensors. We should not trace the ops 1074 # that may not be executed due to control flow. 1075 if self._use_temp_cache(): 1076 return False 1077 elif self._tt_config.device_type == _DEVICE_TYPE_TPU: 1078 # On TPUs do not trace in control flow unless we use caches to store 1079 # intermediate values as calling outside compilation within an inner loop 1080 # causes errors. 1081 return self._use_tensor_values_cache() or self._use_tensor_buffer() 1082 return True 1083 1084 def _skip_op(self, op_id, op, ops_in_exec_path, report_handler): 1085 """Returns True if we should not trace Op. 1086 1087 Args: 1088 op_id: Topological index of the op. 1089 op: tf.Operation 1090 ops_in_exec_path: Set of operations that are in the execution path. 1091 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1092 Returns: 1093 True if the op should not be traced, false otherwise. 1094 """ 1095 if TensorTracer.while_loop_op(op): 1096 report_handler.instrument_op( 1097 op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP)) 1098 return True 1099 if TensorTracer.control_flow_op(op): 1100 report_handler.instrument_op( 1101 op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP)) 1102 return True 1103 if TensorTracer.unsafe_op(op): 1104 report_handler.instrument_op( 1105 op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP)) 1106 return True 1107 if TensorTracer.device_mismatch(self._tt_config.device_type, op): 1108 report_handler.instrument_op( 1109 op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH)) 1110 return True 1111 if op not in ops_in_exec_path: 1112 report_handler.instrument_op( 1113 op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED)) 1114 return True 1115 # TensorTracer will not trace the operations that are in an inner while loop 1116 # or tf.cond when a temporary cache is used. Temporary cache adds direct 1117 # data dependencies to traced operations, and needs a static number of 1118 # traced operations. For these cases, 1119 # - We do not know the number of slots required when there are inner while 1120 # loops. TensorTracer can only trace the result of a while loop. 1121 # - We do not know ahead of time which branch of the tf.cond 1122 # will be taken, so we avoid introducing data dependencies for the 1123 # operations inside a tf.cond. 1124 # - We also cannot have a data dependency to an operation in a different 1125 # while context. 1126 if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op): 1127 if not self._should_trace_in_control_flow(): 1128 report_handler.instrument_op( 1129 op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW)) 1130 return True 1131 if self._is_user_included_op(op): 1132 report_handler.instrument_op( 1133 op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 1134 return False 1135 1136 if not self._inside_op_range(op_id): 1137 report_handler.instrument_op( 1138 op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE)) 1139 return True 1140 if not self._is_interesting_op(op): 1141 report_handler.instrument_op( 1142 op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP)) 1143 return True 1144 if self._is_user_excluded_op(op): 1145 report_handler.instrument_op( 1146 op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 1147 return True 1148 return False 1149 1150 def _skip_tensor(self, op_id, out_tensor, report_handler): 1151 """Returns True if we should not trace out_tensor. 1152 1153 Args: 1154 op_id: Topological index of the op producing tensor. 1155 out_tensor: tf.Tensor 1156 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1157 Returns: 1158 True if the tensor should not be traced, false otherwise. 1159 """ 1160 1161 # Skips a tensor if the tensor has a non-numeric type. 1162 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) 1163 # because it also excludes tensors with dtypes, bool, and 1164 # float32_ref, which we actually want to trace. 1165 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, 1166 dtypes.string]) 1167 if out_tensor.dtype in non_numeric_tensor_types: 1168 1169 report_handler.instrument_tensor( 1170 out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR)) 1171 return True 1172 # Skip a tensor if it feeds a special while loop op. 1173 if [consumer for consumer in out_tensor.consumers() if 1174 TensorTracer.while_loop_op(consumer)]: 1175 report_handler.instrument_tensor( 1176 out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP)) 1177 return True 1178 if self._is_user_included_op(out_tensor.op): 1179 report_handler.instrument_tensor( 1180 out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 1181 return False 1182 if self._is_user_excluded_op(out_tensor.op): 1183 report_handler.instrument_tensor( 1184 out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 1185 return True 1186 if not out_tensor.get_shape().is_fully_defined(): 1187 # If trace mode is nan-inf, norm or max, then the tensor will be reduced 1188 # to a scalar before the outside compilation call. 1189 if self._parameters.trace_mode in ( 1190 tensor_tracer_flags.TRACE_MODE_NAN_INF, 1191 tensor_tracer_flags.TRACE_MODE_NORM, 1192 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 1193 tensor_tracer_flags.TRACE_MODE_SUMMARY 1194 ): 1195 report_handler.instrument_tensor( 1196 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 1197 return False 1198 else: 1199 report_handler.instrument_tensor( 1200 out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE)) 1201 return True 1202 rank = len(out_tensor.shape) 1203 if rank < 1: 1204 # scalar 1205 if self._parameters.trace_scalar_ops: 1206 if TensorTracer.unsafe_scalar_trace(out_tensor.op): 1207 report_handler.instrument_tensor( 1208 out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR)) 1209 return True 1210 else: 1211 report_handler.instrument_tensor( 1212 out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED)) 1213 return False 1214 else: 1215 report_handler.instrument_tensor( 1216 out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR)) 1217 return True 1218 else: 1219 # tensor 1220 report_handler.instrument_tensor( 1221 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 1222 return False 1223 1224 def _filter_execution_path_operations(self, operations, fetches): 1225 """Returns the set of ops in the execution path to compute given fetches.""" 1226 1227 # If no fetch provided, then return all operations. 1228 if fetches is None: 1229 return set(operations) 1230 # Convert to list, if a single element is provided. 1231 if not isinstance(fetches, (list, tuple)): 1232 fetches = [fetches] 1233 # If a tensor is given as fetch, convert it to op. 1234 op_fetches = [] 1235 for fetch in fetches: 1236 if isinstance(fetch, ops.Operation): 1237 op_fetches.append(fetch) 1238 elif isinstance(fetch, ops.Tensor): 1239 op_fetches.append(fetch.op) 1240 else: 1241 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' 1242 %fetch) 1243 1244 execution_path_operations = set(op_fetches) 1245 traverse_stack = list(op_fetches) 1246 while True: 1247 if not traverse_stack: 1248 break 1249 head_op = traverse_stack.pop() 1250 input_ops = [tensor_input.op for tensor_input in head_op.inputs] 1251 input_ops.extend(head_op.control_inputs) 1252 1253 for input_op in input_ops: 1254 if input_op not in execution_path_operations: 1255 # Filter out loop condition operations, tracing them causes a cycle. 1256 # Trace only the loop-body. 1257 if TensorTracer.loop_cond_op(input_op): 1258 continue 1259 execution_path_operations.add(input_op) 1260 traverse_stack.append(input_op) 1261 return execution_path_operations 1262 1263 def _determine_and_instrument_traced_tensors(self, graph_order, 1264 ops_in_exec_path, 1265 tensor_trace_points, 1266 report_handler): 1267 """Determines the tensors to trace and instruments the trace details. 1268 1269 Args: 1270 graph_order: graph_order tuple containing graph (tf.graph), operations 1271 (list of operations), op_to_idx (op id mapping), (tensors) list of 1272 tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether 1273 there is a cycle in the graph), topological_order_or_cycle (list of ops 1274 in topological order or list of ops creating a cycle). 1275 ops_in_exec_path: Set of ops in the execution path. 1276 tensor_trace_points: Collection of programatic tensor trace points. 1277 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1278 Returns: 1279 List of tensors to be traced. 1280 """ 1281 1282 traced_tensors = [] 1283 checkpoint_operations = set([tensor.op 1284 for (tensor, _) in tensor_trace_points]) 1285 for op_id, op in enumerate(graph_order.operations): 1286 if checkpoint_operations and op not in checkpoint_operations: 1287 continue 1288 if self._skip_op(op_id, op, ops_in_exec_path, report_handler): 1289 continue 1290 for i in range(len(op.outputs)): 1291 out_tensor = op.outputs[i] 1292 if not self._skip_tensor(op_id, out_tensor, report_handler): 1293 traced_tensors.append(out_tensor) 1294 return traced_tensors 1295 1296 def _check_trace_files(self): 1297 """Checks if any requirements for trace files are satisfied.""" 1298 1299 if not self._parameters.trace_dir: 1300 # traces will be written to stderr. No need to check trace files. 1301 return 1302 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 1303 # Output files are handled by tf.summary operations, no need to precreate 1304 # them. 1305 return 1306 if not gfile.Exists(self._parameters.trace_dir): 1307 file_io.recursive_create_dir(self._parameters.trace_dir) 1308 if not gfile.Exists(self._parameters.trace_dir): 1309 raise RuntimeError('Failed to create trace directory at %s' % 1310 self._parameters.trace_dir) 1311 1312 def _create_temp_cache(self, num_traced_tensors, num_signatures, graph): 1313 """Creates a temporary cache with the given dimensions. 1314 1315 Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops 1316 that have shape of [num_signatures]. 1317 Args: 1318 num_traced_tensors: Int, denoting total number of traced tensors. 1319 num_signatures: Int, denoting the number of statistics collected per 1320 tensors. 1321 graph: TensorFlow graph. 1322 """ 1323 init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1324 dtype=dtypes.float32, 1325 shape=[num_signatures]) 1326 self._temp_cache_var[graph] = [ 1327 init_value for _ in range(num_traced_tensors)] 1328 1329 def _determine_trace_and_create_report(self, graph, ops_in_exec_path, 1330 graph_summary_tag): 1331 """Work needs to be done prior to TPU or CPU tracing. 1332 1333 Args: 1334 graph: tf.graph 1335 ops_in_exec_path: Set of operations in the execution path. 1336 graph_summary_tag: the summary tag name for the given graph. 1337 Returns: 1338 An instance of tensor_tracer_report.TensorTraceOrder, containing list of 1339 tensors to be traced with their topological order information. 1340 """ 1341 1342 self._check_trace_files() 1343 1344 graph_order = tensor_tracer_report.sort_tensors_and_ops(graph) 1345 tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION) 1346 1347 report_handler = tensor_tracer_report.TTReportHandle() 1348 traced_tensors = self._determine_and_instrument_traced_tensors( 1349 graph_order, ops_in_exec_path, tensor_trace_points, report_handler) 1350 logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors)) 1351 1352 tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order, 1353 traced_tensors) 1354 num_signatures = self._num_signature_dimensions() 1355 # Create a cache variable if compact_tracing is used. 1356 if num_signatures and self._use_tensor_values_cache(): 1357 if self._use_temp_cache(): 1358 self._create_temp_cache(len(traced_tensors), num_signatures, graph) 1359 else: 1360 self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, 1361 graph, 1362 [len(traced_tensors), 1363 num_signatures]) 1364 if self._parameters.trace_mode in ( 1365 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1366 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY): 1367 self._report_proto = report_handler.create_report_proto( 1368 self._tt_config, self._parameters, tensor_trace_order, 1369 tensor_trace_points, self._signature_types()) 1370 if self._parameters.use_fingerprint_subdir: 1371 self._parameters.trace_dir = os.path.join( 1372 self._parameters.trace_dir, self._report_proto.fingerprint) 1373 logging.info('TensorTracer updating trace_dir to %s', 1374 self._parameters.trace_dir) 1375 self._report_proto_path = report_handler.report_proto_path( 1376 self._parameters.trace_dir, graph_summary_tag) 1377 1378 if self._parameters.report_file_path != _SKIP_REPORT_FILE: 1379 report_handler.write_report_proto(self._report_proto_path, 1380 self._report_proto, self._parameters) 1381 else: 1382 report_handler.create_report(self._tt_config, self._parameters, 1383 tensor_trace_order, tensor_trace_points) 1384 return tensor_trace_order 1385 1386 def _create_host_call(self): 1387 return self._parameters.trace_mode in ( 1388 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1389 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 1390 1391 def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream, 1392 tensor_trace_order): 1393 """Generates a print operation to print trace inspection. 1394 1395 Args: 1396 cache: Tensor storing the trace results for the step. 1397 replica_id: Tensor storing the replica id of the running core. 1398 step_num: Step number. 1399 output_stream: Where to print the outputs, e.g., file path, or sys.stderr. 1400 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1401 1402 Returns: 1403 The Op to flush the cache to file. 1404 """ 1405 def _inspect_tensor(tensor): 1406 """Returns the text to be printed for inspection output.""" 1407 if (self._parameters.trace_mode == 1408 tensor_tracer_flags.TRACE_MODE_NAN_INF): 1409 return control_flow_ops.cond( 1410 math_ops.greater(tensor, 0.0), 1411 lambda: 'has NaNs/Infs!', 1412 lambda: 'has no NaNs or Infs.') 1413 else: 1414 return tensor 1415 1416 # Check if there are graph operations being profiled. 1417 if not tensor_trace_order.traced_tensors: 1418 logging.warn('Inspect mode has no tensors in the cache to check.') 1419 return control_flow_ops.no_op 1420 1421 # Check if the cache includes any nan or inf 1422 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 1423 # Cache has 1s or 0s if the mode is NaN_INF 1424 step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0) 1425 else: 1426 # Cache has the actual numerics for other modes. 1427 step_has_nan_or_inf = math_ops.reduce_any( 1428 gen_math_ops.logical_or( 1429 gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache))) 1430 1431 # Summarizing message for each step. 1432 step_error_message = control_flow_ops.cond( 1433 step_has_nan_or_inf, 1434 lambda: 'NaNs or Infs in the step!', 1435 lambda: 'No numerical issues have been found for the step.') 1436 1437 # No need to print core numbers if the cache is merged already. 1438 if self._parameters.collect_summary_per_core: 1439 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->', 1440 step_error_message, 1441 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 1442 else: 1443 stats = ['\n\n', 'step:', step_num, '-->', step_error_message, 1444 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 1445 1446 for tensor_name, cache_idx in sorted( 1447 tensor_trace_order.tensorname_to_cache_idx.items(), 1448 key=lambda item: item[1]): 1449 if self._parameters.collect_summary_per_core: 1450 stats.extend([ 1451 '\n', 'core:', replica_id, ',', 'step:', step_num, ',', 1452 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 1453 else: 1454 stats.extend([ 1455 '\n', 'step:', step_num, ',', 1456 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 1457 return logging_ops.print_v2(*stats, summarize=-1, 1458 output_stream=output_stream) 1459 1460 def _get_outfile_suffix(self): 1461 if remote_utils.is_remote_path(self._parameters.trace_dir): 1462 return remote_utils.get_appendable_file_encoding() 1463 else: 1464 return '' 1465 1466 def _generate_flush_cache_op(self, num_replicas, on_tpu, 1467 tensor_trace_order, graph): 1468 """Generates an Op that will flush the cache to file. 1469 1470 Args: 1471 num_replicas: total number of replicas. 1472 on_tpu: if the graph is executed on TPU. 1473 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1474 graph: TensorFlow graph. 1475 1476 Returns: 1477 The Op to flush the cache to file. 1478 """ 1479 1480 def _flush_fun(cache, replica_id, step_num): 1481 """Flushes the cache to a file corresponding to replica_id.""" 1482 1483 def _f(file_index): 1484 """Generates a func that flushes the cache to a file.""" 1485 def _print_cache(): 1486 """Flushes the cache to a file.""" 1487 replica_str = ('%d' % file_index) 1488 if self._parameters.trace_dir: 1489 output_path = (os.path.join(self._parameters.trace_dir, 1490 _COMPACT_TRACE_FILE_PREFIX) 1491 + replica_str + self._get_outfile_suffix()) 1492 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 1493 else: 1494 output_stream = sys.stderr 1495 1496 new_step_line = _REPLICA_ID_TAG + replica_str 1497 print_ops = [] 1498 if self._parameters.inspect_trace: 1499 if self._num_signature_dimensions() > 1: 1500 raise ValueError('Inspecting multi signatures are not supported.') 1501 print_ops.append(self._inspect_summary_cache( 1502 cache=cache, replica_id=replica_id, step_num=step_num, 1503 output_stream=output_stream, 1504 tensor_trace_order=tensor_trace_order)) 1505 else: 1506 for i in range(self._num_signature_dimensions()): 1507 print_ops.append(logging_ops.print_v2( 1508 new_step_line, '\n', 1509 cache[:, i], '\n', 1510 summarize=-1, 1511 output_stream=output_stream)) 1512 with ops.control_dependencies(print_ops): 1513 return constant_op.constant(0).op 1514 return _print_cache 1515 1516 def _eq(file_index): 1517 return math_ops.equal(replica_id, file_index) 1518 1519 flush_op_cases = {} 1520 flush_op_cases[_eq(0)] = _f(0) 1521 for i in range(1, num_replicas): 1522 if on_tpu and not self._parameters.collect_summary_per_core: 1523 # If this is the case, the cache is already merged for all cores. 1524 # Only first core flushes the cache. 1525 flush_op_cases[_eq(i)] = control_flow_ops.no_op 1526 else: 1527 flush_op_cases[_eq(i)] = _f(i) 1528 # Each replica needs to determine where to write their output. 1529 # To do this, we check if replica_id is 0, then 1, ..., and then 1530 # num_replicas - 1 statically; and return the corresponding static file 1531 # name. We cannot simply set the file name in python, as replica_id is 1532 # only known during tf runtime, and we cannot create dynamic filenames. 1533 return control_flow_ops.case(flush_op_cases, exclusive=True) 1534 1535 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph) 1536 if self._use_temp_cache(): 1537 cache_val = cache 1538 else: 1539 cache_val = cache.value() 1540 1541 if on_tpu: 1542 # If we do not need to collect traces for all cores, merge and aggregate 1543 # per core trace. 1544 if not self._parameters.collect_summary_per_core: 1545 cache_val = self.merge_caches_on_tpu(cache_val) 1546 cache_val = self.aggregate_global_cache(cache_val)[0] 1547 1548 flush_op = tpu.outside_compilation( 1549 _flush_fun, cache_val, self._replica_id, 1550 array_ops.identity(training_util.get_or_create_global_step())) 1551 else: 1552 global_step = training_util.get_or_create_global_step() 1553 flush_op = _flush_fun(cache_val, self._replica_id, global_step) 1554 1555 if self._use_temp_cache(): 1556 with ops.control_dependencies([flush_op]): 1557 return constant_op.constant(0).op 1558 else: 1559 # Re-initialize the local cache variable. 1560 with ops.control_dependencies([flush_op]): 1561 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1562 dtype=cache.dtype, 1563 shape=cache.shape) 1564 assign_op = state_ops.assign(cache, reset_value).op 1565 with ops.control_dependencies([assign_op]): 1566 return constant_op.constant(0).op 1567 1568 def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu, 1569 tensor_trace_order, graph): 1570 """Flushes the intermediate tensor values in the graph to the cache. 1571 1572 Args: 1573 tensor_fetches: list of tensor results returned by the model_fn. 1574 op_fetches: list of ops that are returned by the model_fn, e.g., train_op. 1575 on_tpu: if the graph is executed on TPU. 1576 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1577 graph: TensorFlow graph. 1578 1579 Returns: 1580 An identical copy of tensor_fetches. 1581 """ 1582 # Add a dependency to op and tensor fetches to make sure that all tracing 1583 # ops are executed before flushing trace results. 1584 if not tensor_trace_order.traced_tensors: 1585 logging.warn('No tensor values being traced. No flush cache op added.') 1586 return tensor_fetches 1587 with ops.control_dependencies(op_fetches + 1588 [tensor.op for tensor in tensor_fetches]): 1589 flush_cache_op = self._generate_flush_cache_op( 1590 self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph) 1591 return control_flow_ops.tuple(tensor_fetches, 1592 control_inputs=[flush_cache_op]) 1593 1594 def _process_tensor_fetches(self, tensor_fetches): 1595 """Check that tensor_fetches is not empty and have valid tensors.""" 1596 # If none or empty list. 1597 if tensor_fetches is None: 1598 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1599 'None.') 1600 if not isinstance(tensor_fetches, (list, tuple)): 1601 tensor_fetches = [tensor_fetches] 1602 elif not tensor_fetches: 1603 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1604 'empty list.') 1605 fetches = [] 1606 for fetch in tensor_fetches: 1607 if isinstance(fetch, ops.Tensor): 1608 fetches.append(fetch) 1609 else: 1610 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) 1611 return fetches 1612 1613 def _process_op_fetches(self, op_fetches): 1614 """Check that op_fetches have valid ops.""" 1615 if op_fetches is None: 1616 return [] 1617 1618 if not isinstance(op_fetches, (list, tuple)): 1619 op_fetches = [op_fetches] 1620 1621 fetches = [] 1622 for fetch in op_fetches: 1623 if isinstance(fetch, ops.Operation): 1624 fetches.append(fetch) 1625 elif isinstance(fetch, ops.Tensor): 1626 fetches.append(fetch.op) 1627 else: 1628 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % 1629 fetch) 1630 return fetches 1631 1632 def _convert_fetches_to_input_format(self, input_fetches, current_fetches): 1633 """Changes current_fetches' format, so that it matches input_fetches.""" 1634 if isinstance(input_fetches, ops.Tensor): 1635 if len(current_fetches) != 1: 1636 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1637 return current_fetches[0] 1638 else: 1639 if len(current_fetches) != len(current_fetches): 1640 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1641 elif isinstance(input_fetches, tuple): 1642 return tuple(current_fetches) 1643 else: 1644 return current_fetches 1645 1646 def _get_op_control_flow_context(self, op): 1647 """Returns the control flow of the given op. 1648 1649 Args: 1650 op: tf.Operation for which the control flow context is requested. 1651 Returns: 1652 op_control_flow_context: which the is control flow context of the given 1653 op. If the operation type is LoopExit, returns the outer control flow 1654 context. 1655 """ 1656 # pylint: disable=protected-access 1657 op_control_flow_context = op._control_flow_context 1658 # pylint: enable=protected-access 1659 if control_flow_util.IsLoopExit(op): 1660 op_control_flow_context = op_control_flow_context.outer_context 1661 return op_control_flow_context 1662 1663 def merge_caches_on_tpu(self, local_tpu_cache_tensor): 1664 """Merges the given caches on tpu. 1665 1666 Args: 1667 local_tpu_cache_tensor: A local tensor that needs to be merged 1668 by concanting data from other tpu cores. 1669 Returns: 1670 A merged tf.Tensor. 1671 """ 1672 x = array_ops.broadcast_to( 1673 local_tpu_cache_tensor, 1674 shape=[self._tt_config.num_replicas] + 1675 local_tpu_cache_tensor.shape.as_list()) 1676 return tpu_ops.all_to_all( 1677 x, concat_dimension=0, split_dimension=0, 1678 split_count=self._tt_config.num_replicas, 1679 group_assignment=[list(range(self._tt_config.num_replicas))]) 1680 1681 def aggregate_global_cache(self, global_tt_summary_cache): 1682 """Merges the given caches on tpu. 1683 1684 Args: 1685 global_tt_summary_cache: The global tensor tracer summary cache tensor 1686 with shape (num_cores, num_traced_tensors, num_traced_signatures). First 1687 dimension corresponds to core_id, where global_tpu_cache_tensor[i] 1688 correspond to the local cache from core-i. 1689 Returns: 1690 An aggregated tf.Tensor. 1691 Raises: 1692 RuntimeError: if there is no aggregate function defined for a signature. 1693 """ 1694 1695 # Merge only statistics tensor, if it is any other tensor we simply, 1696 # concatenate them. 1697 agg_fn_map = self._parameters.get_signature_to_agg_fn_map() 1698 signature_idx_map = self._signature_types() 1699 aggregation_result = [] 1700 for signature, idx in sorted(signature_idx_map.items(), 1701 key=operator.itemgetter(1)): 1702 if signature not in agg_fn_map: 1703 raise RuntimeError('No aggregation function is defined for ' 1704 'signature %s.' % signature) 1705 # The dimensions of the statistics tensor is 1706 # num_cores x num_traced_tensors x num_signatures 1707 # value[:,:,idx] will return the portion of the tensor related 1708 # to signature. 1709 signature_tensor = global_tt_summary_cache[:, :, idx] 1710 # Merge it along the first (core) axis. 1711 agg_fn = agg_fn_map[signature] 1712 agg_tensor = agg_fn(signature_tensor, axis=0) 1713 aggregation_result.append(agg_tensor) 1714 # Merge results corresponding to different signatures 1715 1716 merged_signatures = array_ops.stack(aggregation_result) 1717 # merged_signatures has dimensions 1718 # num_signatures x num_traced_tensors, transpose it so that it 1719 # will match with the original structure 1720 # num_traced_tensors x num_signatures. 1721 transposed_signatures = array_ops.transpose(merged_signatures) 1722 # Expand 1 more dimension so that it will match with the expected 1723 # structure num_cores x num_traced_tensors x num_signatures. 1724 return array_ops.expand_dims(transposed_signatures, axis=0) 1725 1726 def _prepare_host_call_fn(self, processed_t_fetches, 1727 op_fetches, graph, graph_summary_tag): 1728 """Creates a host call function that will write the cache as tb summary. 1729 1730 Args: 1731 processed_t_fetches: List of tensor provided to session.run. 1732 op_fetches: List of operations provided to session.run. 1733 graph: TensorFlow graph. 1734 graph_summary_tag: the summary_tag name for the given graph. 1735 Raises: 1736 ValueError if trace_dir is not set. 1737 """ 1738 if self._parameters.trace_dir is None: 1739 raise ValueError('Provide a trace_dir for tensor tracer in summary mode. ' 1740 '--trace_dir=/model/dir') 1741 1742 def _write_cache(step, event_file_suffix=None, **kwargs): 1743 """Writes the given caches as tensor summary. 1744 1745 Args: 1746 step: Step tensor with dimension [num_cores]. 1747 event_file_suffix: Event filename suffix tensor. 1748 **kwargs: The dictionary of tensors that needs to be written as 1749 summaries. Key and value pairs within kwargs correspond to the tag 1750 name, and tensor content that will be written using summary.write. 1751 The trace_modes that use this function are: 1752 - summary: In summary mode, kwargs includes a single (tag, content) 1753 pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache 1754 variable. The dimension of the signature_cache is: 1755 num_cores x num_traced_tensors x num_signatures. 1756 - full_tensor_summary: kwargs will include all traced tensors. Tag 1757 and content correspond to the name of the tensor, and its actual 1758 content. 1759 Returns: 1760 A tf.Operation that needs to be executed for the host call dependencies. 1761 """ 1762 file_suffix = _TT_EVENT_FILE_SUFFIX 1763 if event_file_suffix is not None: 1764 file_suffix = string_ops.string_join([file_suffix, event_file_suffix], 1765 separator='.') 1766 # TODO(deveci): Parametrize max_queue, so that flushing op can be called 1767 # less frequently. 1768 # Setting max_queue to 100 appears to be safe even when the number of 1769 # iterations are much lower, as the destructor of the writer flushes it. 1770 summary_write_ops = [] 1771 summary_writer = summary.create_file_writer_v2( 1772 self._parameters.trace_dir, 1773 filename_suffix=file_suffix, 1774 max_queue=_TT_SUMMARY_MAX_QUEUE) 1775 graph.add_to_collection( 1776 TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer) 1777 1778 step_value = step[0] 1779 dt = step_value.dtype 1780 1781 # The step parameter to a summary write call must be 64-bit. 1782 if dt.__ne__(dtypes.int64) and dt.__ne__( 1783 dtypes.uint64) and dt.__ne__(dtypes.float64): 1784 step_value = math_ops.cast(step_value, dtypes.int64) 1785 1786 with summary_writer.as_default(): 1787 summary_metadata = summary_pb2.SummaryMetadata( 1788 plugin_data=summary_pb2.SummaryMetadata.PluginData( 1789 plugin_name=_TT_TENSORBOARD_PLUGIN_NAME)) 1790 for key, value in kwargs.items(): 1791 # Check whether we need to compute aggregated statistics that merge 1792 # all cores statistics. 1793 if not self._parameters.collect_summary_per_core: 1794 # Merge only statistics tensor, if it is any other tensor we simply, 1795 # concatenate them. 1796 # Also, if there is only a single core (first dim. is 0), then skip 1797 # aggregation. 1798 if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1: 1799 value = self.aggregate_global_cache(value) 1800 with ops.control_dependencies([summary_writer.init()]): 1801 summary_write_ops.append(summary.write( 1802 _TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag, 1803 value, metadata=summary_metadata, 1804 step=step_value)) 1805 return control_flow_ops.group(summary_write_ops) 1806 1807 global_step = training_util.get_or_create_global_step() 1808 step = array_ops.reshape(global_step, [1]) 1809 self._host_call_fn = {} 1810 1811 host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches] 1812 1813 caches_to_write = {} 1814 with ops.control_dependencies(host_call_deps): 1815 all_caches = self._cache_variable_for_graph(graph) 1816 for cache_name, cache_variable in all_caches.items(): 1817 # Increase the cache rank by 1, so that when host call concatenates 1818 # tensors from different replicas, we can identify them with [core_id]. 1819 new_cache_shape = [1] 1820 new_cache_shape.extend(cache_variable.shape.as_list()) 1821 cache = array_ops.reshape(cache_variable, new_cache_shape) 1822 caches_to_write[cache_name] = cache 1823 # Add step to parameter dictionary. 1824 caches_to_write['step'] = step 1825 # Other options without adding step to parameter dictionary are 1826 # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it 1827 # considers caches_to_write as a single parameter, rather than a keyword 1828 # parameters. 1829 # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with 1830 # a syntax error. 1831 self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write) 1832 1833 def host_call_deps_and_fn(self): 1834 return self._host_call_fn 1835 1836 def get_traced_op_names(self): 1837 """Returns the set of traced op names.""" 1838 return self._traced_op_names 1839 1840 def _trace_execution(self, graph, 1841 tensor_fetches, 1842 op_fetches=None, 1843 on_tpu=True): 1844 """Commong tracing function for both CPU and TPUs. 1845 1846 The caller function should set device_type, num_replicas, 1847 num_replicas_per_host, num_hosts and replica_id before calling 1848 _trace_execution. 1849 1850 1851 Args: 1852 graph: the graph of Ops executed on the TPU. 1853 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1854 returned by model_fn given to session.run. Function must be provided 1855 with as least one tensor to fetch. 1856 op_fetches: A list of op fetches returned by model_fn given to 1857 session.run. op_fetches and tensor_fetches are used to determine the 1858 nodes that will be executed. Can be None. 1859 on_tpu: True if executing on TPU. 1860 1861 Returns: 1862 tensor_fetches: an exact copy of tensor_fetches that has additional 1863 dependencies. 1864 Raises: 1865 RuntimeError: If tensor_fetches is None or empty. 1866 """ 1867 def _cast_unsupported_dtypes(tensor): 1868 """Casts tensor to a supported type.""" 1869 1870 if tensor.dtype.__eq__(dtypes.int64): 1871 # outside-compilation doesn't support int64 input yet. 1872 return math_ops.cast(tensor, dtypes.int32) 1873 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( 1874 dtypes.float16): 1875 # Since host can't handle bf16, convert tensor to f32. 1876 return math_ops.cast(tensor, dtypes.float32) 1877 return tensor 1878 1879 trace_mode = self._parameters.trace_mode 1880 device_type = self._tt_config.device_type 1881 # pylint: disable=protected-access 1882 self._outmost_context = graph._get_control_flow_context() 1883 # pylint: enable=protected-access 1884 1885 analytics.track_usage('tensor_tracer', [trace_mode, device_type]) 1886 TensorTracer.check_device_type(device_type) 1887 TensorTracer.check_trace_mode(device_type, trace_mode) 1888 # Check in_tensor_fetches, and op_fetches and convert them to lists. 1889 processed_t_fetches = self._process_tensor_fetches(tensor_fetches) 1890 op_fetches = self._process_op_fetches(op_fetches) 1891 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] 1892 1893 # Filter out the operations that won't be executed. 1894 # if fetches=None, then ops_in_exec_path = set(operations) 1895 exec_op_set = self._filter_execution_path_operations(graph.get_operations(), 1896 all_fetches) 1897 graph_summary_tag = _graph_summary_tag(graph) 1898 1899 # Write report file, and determine the traced tensors. 1900 tensor_trace_order = self._determine_trace_and_create_report( 1901 graph, exec_op_set, graph_summary_tag) 1902 1903 tensor_fetch_set = set(processed_t_fetches) 1904 tracing_ops = [] 1905 1906 sorted_exec_op_list = list(exec_op_set) 1907 sorted_exec_op_list.sort(key=lambda op: op.name) 1908 # Trace ops only if they are in the execution path. 1909 for op in sorted_exec_op_list: 1910 for i in range(len(op.outputs)): 1911 out_tensor = op.outputs[i] 1912 tensor_name = out_tensor.name 1913 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 1914 continue 1915 self._traced_op_names.add(op.name) 1916 # Create the list of consumers before calling _preprocess_traced_tensor. 1917 # Otherwise, adding control input below, will introduce a cycle in the 1918 # graph. 1919 consumers = out_tensor.consumers() 1920 # Not all consumers may be in the exec path. Filter out the consumers 1921 # to keep the graph simpler. 1922 consumers = [cop for cop in consumers if cop in exec_op_set] 1923 1924 # If there is no consumer of the tensor, there is no need to trace it; 1925 # unless the tensor itself is one of the fetches. 1926 is_a_fetched_tensor = out_tensor in tensor_fetch_set 1927 if (not consumers) and (not is_a_fetched_tensor): 1928 continue 1929 1930 op_control_flow_context = self._get_op_control_flow_context(op) 1931 if op_control_flow_context: 1932 # pylint: disable=protected-access 1933 graph._set_control_flow_context(op_control_flow_context) 1934 # pylint: enable=protected-access 1935 1936 processed_tensors = self._preprocess_traced_tensor(out_tensor) 1937 1938 if on_tpu: 1939 for signature in processed_tensors.keys(): 1940 processed_tensors[signature] = _cast_unsupported_dtypes( 1941 processed_tensors[signature]) 1942 1943 if self._use_tensor_values_cache(): 1944 # Use a small cache (either temp cache or tf local variable) to store 1945 # the characteristics of the tensor. 1946 if self._use_temp_cache(): 1947 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 1948 self._save_tensor_value_to_tmp_cache(cache_idx, 1949 processed_tensors, 1950 graph) 1951 trace_op = None 1952 else: 1953 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 1954 trace_op = self._save_tensor_value_to_cache_op(cache_idx, 1955 processed_tensors, 1956 graph) 1957 elif self._use_tensor_buffer(): 1958 if len(processed_tensors) != 1: 1959 raise RuntimeError('Multiple stats are only allowed in compact ' 1960 'mode.') 1961 processed_out_tensor = list(processed_tensors.values())[0] 1962 # Store the whole tensor in a buffer. 1963 trace_op = self._snapshot_tensor(processed_out_tensor) 1964 else: 1965 1966 def tpu_wrap_trace_fn(tensor, out_tensor_name): 1967 """Wraps the trace_fn with outside compilation if on TPUs.""" 1968 tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name, 1969 tensor_trace_order) 1970 if on_tpu: 1971 return tpu.outside_compilation(tensor_trace_fn, tensor) 1972 else: 1973 return tensor_trace_fn(tensor) 1974 1975 if len(processed_tensors) != 1: 1976 raise RuntimeError('Multiple stats are only allowed in compact ' 1977 'mode.') 1978 # Collecting multiple statistics are only supported in the summary 1979 # mode that uses compact format(self._use_tensor_values_cache = true). 1980 # Non-compact mode currently allows single stat per tensor. 1981 processed_out_tensor = next(iter(processed_tensors.values())) 1982 trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) 1983 1984 if op_control_flow_context: 1985 # pylint: disable=protected-access 1986 graph._set_control_flow_context(self._outmost_context) 1987 # pylint: enable=protected-access 1988 if trace_op: 1989 if is_a_fetched_tensor: 1990 tracing_ops.append(trace_op) 1991 continue 1992 # Add it to all consumers, as some consumers may not be executed if 1993 # they are in a control flow. 1994 for consumer_op in consumers: 1995 # pylint: disable=protected-access 1996 consumer_op._add_control_input(trace_op) 1997 # pylint: enable=protected-access 1998 1999 # pylint: disable=protected-access 2000 graph._set_control_flow_context(self._outmost_context) 2001 # pylint: enable=protected-access 2002 if tracing_ops: 2003 # If we are tracing a fetched tensor, their dependency is stored in 2004 # tracing_ops. 2005 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, 2006 control_inputs=tracing_ops) 2007 if self._use_tensor_values_cache() or self._use_tensor_buffer(): 2008 if self._use_temp_cache(): 2009 # Create the temporary tf cache variable by concantanating all 2010 # statistics. 2011 graph_cache_var = self._cache_variable_for_graph(graph) 2012 if graph not in self._temp_cache_var: 2013 raise RuntimeError('graph is not in self._temp_cache_var') 2014 graph_cache_var[_TT_SUMMARY_TAG] = array_ops.stack( 2015 self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures') 2016 if self._create_host_call(): 2017 self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph, 2018 graph_summary_tag) 2019 if not on_tpu: 2020 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 2021 cache_write_op = write_cache(**caches_to_write) 2022 processed_t_fetches = control_flow_ops.tuple( 2023 processed_t_fetches, control_inputs=[cache_write_op]) 2024 del self._host_call_fn[_TT_HOSTCALL_KEY] 2025 elif self._parameters.flush_summaries_with_outside_compile: 2026 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 2027 if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write): 2028 step = caches_to_write['step'] 2029 tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG] 2030 tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0]) 2031 if not self._parameters.collect_summary_per_core: 2032 tt_core_summary = self.aggregate_global_cache(tt_core_summary) 2033 2034 def write_if_core_0(step, replica_id, tt_summary): 2035 2036 return control_flow_ops.cond( 2037 math_ops.equal(replica_id, 0), 2038 lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda 2039 tensor_tracer_summary=tt_summary), 2040 control_flow_ops.no_op) 2041 2042 write_op = tpu.outside_compilation(write_if_core_0, step=step, 2043 replica_id=self._replica_id, 2044 tt_summary=tt_core_summary) 2045 processed_t_fetches = control_flow_ops.tuple( 2046 processed_t_fetches, control_inputs=[write_op]) 2047 del self._host_call_fn[_TT_HOSTCALL_KEY] 2048 else: 2049 raise ValueError('Outside compiled flush in only supported for ' 2050 'summary mode') 2051 else: 2052 processed_t_fetches = self._flush_tensor_values_cache( 2053 processed_t_fetches, op_fetches, on_tpu=on_tpu, 2054 tensor_trace_order=tensor_trace_order, 2055 graph=graph) 2056 2057 # processed_t_fetches is a list at this point. Convert it to the same 2058 # format as given in tensor_fetches. 2059 return self._convert_fetches_to_input_format(tensor_fetches, 2060 processed_t_fetches) 2061 2062 def trace_tpu(self, graph, 2063 tensor_fetches, 2064 op_fetches=None, 2065 num_replicas=None, 2066 num_replicas_per_host=None, 2067 num_hosts=None): 2068 """Traces the tensors generated by TPU Ops in a TF graph. 2069 2070 Args: 2071 graph: the graph of Ops executed on the TPU. 2072 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 2073 returned by model_fn given to session.run. Function must be provided 2074 with as least one tensor to fetch. 2075 op_fetches: A list of op fetches returned by model_fn given to 2076 session.run. op_fetches and tensor_fetches are used to determine the 2077 nodes that will be executed. Can be None. 2078 num_replicas: number of replicas used on the TPU. 2079 num_replicas_per_host: number of replicas per TPU host. 2080 num_hosts: total number of TPU hosts. 2081 2082 Returns: 2083 tensor_fetches: an exact copy of tensor_fetches that has additional 2084 dependencies. 2085 """ 2086 if isinstance(graph, func_graph.FuncGraph) or isinstance( 2087 graph, function._FuncGraph): # pylint: disable=protected-access 2088 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 2089 'Ignoring tracing.') 2090 return tensor_fetches 2091 2092 if graph in TensorTracer._traced_graphs: 2093 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 2094 'multiple calls.') 2095 return tensor_fetches 2096 else: 2097 TensorTracer._traced_graphs.add(graph) 2098 # Reset the parameters in case parameters are changed. 2099 self._parameters = tensor_tracer_flags.TTParameters() 2100 self._tt_config.device_type = _DEVICE_TYPE_TPU 2101 self._tt_config.num_replicas = num_replicas 2102 self._tt_config.num_replicas_per_host = num_replicas_per_host 2103 self._tt_config.num_hosts = num_hosts 2104 if self._tt_config.num_replicas is not None: 2105 if self._tt_config.num_replicas_per_host is None: 2106 self._tt_config.num_replicas_per_host = 8 2107 if self._tt_config.num_hosts is None: 2108 self._tt_config.num_hosts = ( 2109 num_replicas // self._tt_config.num_replicas_per_host + 2110 (num_replicas % self._tt_config.num_replicas_per_host > 0)) 2111 2112 if self._parameters.graph_dump_path: 2113 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2114 'graph_before_tt.pbtxt') 2115 with graph.as_default(): 2116 self._add_replica_id_to_graph() 2117 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 2118 on_tpu=True) 2119 if self._parameters.graph_dump_path: 2120 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2121 'graph_after_tt.pbtxt') 2122 return tensor_fetches 2123 2124 def trace_cpu(self, graph, tensor_fetches, op_fetches=None): 2125 """Traces the tensors generated by CPU Ops in a TF graph. 2126 2127 Args: 2128 graph: the graph of Ops executed on the CPU. 2129 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 2130 returned by model_fn given to session.run. Function must be provided 2131 with as least one tensor to fetch. 2132 op_fetches: A list of op fetches returned by model_fn given to 2133 session.run. op_fetches and tensor_fetches are used to determine the 2134 nodes that will be executed. Can be None. 2135 2136 Returns: 2137 tensor_fetches: an exact copy of tensor_fetches that has additional 2138 dependencies. 2139 """ 2140 if isinstance(graph, func_graph.FuncGraph) or isinstance( 2141 graph, function._FuncGraph): # pylint: disable=protected-access 2142 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 2143 'Ignoring tracing.') 2144 return tensor_fetches 2145 2146 if graph in TensorTracer._traced_graphs: 2147 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 2148 'multiple calls.') 2149 return tensor_fetches 2150 else: 2151 TensorTracer._traced_graphs.add(graph) 2152 # Reset the parameters in case parameters are changed. 2153 self._parameters = tensor_tracer_flags.TTParameters() 2154 2155 self._tt_config.device_type = _DEVICE_TYPE_CPU 2156 self._tt_config.num_replicas = 1 2157 self._tt_config.num_replicas_per_host = 1 2158 self._tt_config.num_hosts = 1 2159 self._replica_id = 0 2160 if self._parameters.graph_dump_path: 2161 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2162 'graph_before_tt.pbtxt') 2163 with graph.as_default(): 2164 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 2165 on_tpu=False) 2166 if self._parameters.graph_dump_path: 2167 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2168 'graph_after_tt.pbtxt') 2169 return tensor_fetches 2170