xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tensor_tracer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""A utility to trace tensor values on TPU."""
16
17import collections
18import hashlib
19import operator
20import os
21import os.path
22import sys
23
24import numpy as np
25
26from tensorflow.core.framework import summary_pb2
27from tensorflow.python.eager import monitoring
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import func_graph
31from tensorflow.python.framework import function
32from tensorflow.python.framework import graph_io
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.lib.io import file_io
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import control_flow_util
39from tensorflow.python.ops import gen_math_ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import linalg_ops
42from tensorflow.python.ops import logging_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import nn_impl
45from tensorflow.python.ops import state_ops
46from tensorflow.python.ops import string_ops
47from tensorflow.python.ops import summary_ops_v2 as summary
48from tensorflow.python.ops import variable_scope
49from tensorflow.python.platform import analytics
50from tensorflow.python.platform import gfile
51from tensorflow.python.platform import remote_utils
52from tensorflow.python.platform import tf_logging as logging
53from tensorflow.python.summary import summary_iterator
54from tensorflow.python.tpu import tensor_tracer_flags
55from tensorflow.python.tpu import tensor_tracer_report
56from tensorflow.python.tpu import tpu
57from tensorflow.python.tpu.ops import tpu_ops
58from tensorflow.python.training import training_util
59
60_DEVICE_TYPE_TPU = 'tpu'
61_DEVICE_TYPE_CPU = 'cpu'
62_TRACE_MODE_PART_TENSOR_SIZE = 3
63
64_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
65_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
66_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
67_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
68_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow'
69_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
70_REASON_SKIP_SCALAR = 'not-traced-scalar'
71_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
72_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
73_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
74_REASON_SCALAR_GET_TRACED = 'traced-scalar'
75_REASON_TENSOR_GET_TRACED = 'traced-tensor'
76_REASON_USER_INCLUDED = 'traced-user-included'
77_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
78_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
79_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
80_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
81
82_OUTPUT_STREAM_ESCAPE = 'file://'
83_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
84TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
85_TRACE_FILE_NAME = 'trace.all'
86_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
87_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
88_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
89_TT_SNAPSHOT = 'tensor_tracer_snapshot'
90_REPLICA_ID_TAG = '#replica-id: '
91_SKIP_REPORT_FILE = 'None'  # Do not write report proto if --report_file=None
92
93_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
94_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
95_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS
96_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
97_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
98_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
99_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
100_TT_SUMMARY_SPARSITY = tensor_tracer_flags.TT_SUMMARY_SPARSITY
101
102_TT_SUMMARY_TAG = 'tensor_tracer_summary'
103_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
104_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
105_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
106
107_TT_SUMMARY_MAX_QUEUE = 10
108
109tt_gauge = monitoring.BoolGauge('/tensorflow/api/tensor_tracer/v1',
110                                'tensor tracer usage', 'method')
111
112
113def _graph_summary_tag(graph):
114  """Generates and returns a summary tag name for the given graph."""
115
116  if graph is None:
117    raise RuntimeError('graph is None')
118  # The chance of collision with md5 is effectively 0.
119  hash_id = hashlib.md5()
120  hash_id.update(repr(graph).encode('utf-8'))
121  # hexdigest() returns a string.
122  return hash_id.hexdigest()
123
124
125def set_parameters(tensor_tracer_params=None):
126  """Enables tensor tracer and sets its parameters.
127
128  Example usage:
129    tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir',
130                                'trace_mode': 'norm',
131                                'report_file': '/usr/tmp/trace_dir/report.all'}
132    tensor_tracer.set_parameters(tensor_tracer_parameters)
133
134  This sets up the parameters for tensor tracer. A call to tensor tracer as
135  below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be
136  skipped as this call is hooked into tpu.rewrite.
137    tt = tensor_tracer.TensorTracer()
138    loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss)
139
140  Args:
141    tensor_tracer_params: Tensor tracer parameter dictionary. Below gives
142    examples of these parameters: See tensor_tracer_report.py for all
143      parameters.
144        - enable: If set, tensor tracer will be enabled. Calling
145          enable_tensor_tracer automatically adds this parameters.
146        - trace_mode: The trace_mode to be used by tensor tracer. These include:
147          - summary: Collects multiple statistics for traced tensors, and writes
148            them a summary file that can be visualized using tensorboard. This
149            mode currently only works for TPUEstimator. It can be also be used
150            for other models, but outfeed must be handled by the user.
151          - norm: Collects norm of each traced tensor and writes them into a
152            text file pointed by 'trace_dir' flag. (Default mode).
153          - nan-inf: Checks the existince of NaNs and Infs in the tensor, and
154            writes a boolean value to a text file pointed by 'trace_dir' flag.
155            Note that 'norm' mode can also capture this information with more
156            numerical info.
157          - max-abs: Collects the absolute max for each traced tensors and
158            writes it into a text file pointed by 'trace_dir' flag.
159          - full-tensor: Writes the full tensor content of the traced tensors
160            into a text file pointed by 'trace_dir' flag.
161          - part-tensor: Writes a part of the tensor content of the traced
162            tensors into a text file pointed by 'trace_dir' flag.
163          - full_tensor_summary: Writes the full tensors as binary event files.
164            The outputs can be read using: trace =
165              tensor_tracer.read_tensor_tracer_event_file(event_file_path)
166
167        - report_file: Path to the metadata file that is written during graph
168          construction. If not set, metadata will be printed to stdout during
169          graph construction.
170        - trace_dir: Path where the execution traces will be written during the
171          graph execution. If not set, trace will be printed to stderr.
172        - trace_level: Tensor tracer aims to trace everything it can. This
173          introduces some overhead on graph execution and graph compilation
174          times. Using trace_level parameter, it is possible to trace operation
175          based on their priorities. For example, - trace_level=7 is the highest
176          trace_level, in which every op is traced. - trace_level=6 will skip
177          constant operations such as tf.constant. - trace_level=5 will skip
178          less important ops such as tf.identities. - The default trace_level=3,
179          that will skip concat ops, or random number generators. - To reduce
180          the graph compile time overhead, trace_level can be set to 0, that
181          will skip additions, and substractions, and multiplications as well.
182        - excluded_opnames: If set, any matching op name will not be traced.
183          excluded_opnames can be set as a regular expression. E.g,
184          excluded_opnames=.* will exclude everything.
185        - excluded_optypes: If set, any matching op type will not be traced.
186          excluded_optypes can be set as a regular expression. E.g,
187          excluded_optypes=.* will exclude everything. excluded_optypes=MatMul
188          will exclude all MatMul ops from tracing.
189        - included_opnames: If set, any matching op name will be forced to be
190          traced. included_opnames can be set as a regular expression. E.g,
191          '--included_opnames=some_op --excluded_opname=*.' will only trace
192          some_op.
193        - included_optypes: If set, any matching op type will be forced to be
194          traced. included_optypes can be set as a regular expression. E.g,
195          '--included_optypes=some_op_type --excluded_optypes=*.' will trace
196          only the ops with type 'some_op_type'
197        - flush_summaries: If summary mode is used, flush_summaries=1 will
198          flush summaries using outside compilation. Note that, if used with
199          low level APIs, flush_summaries=1 is necessary to obtain results.
200        Advanced Flags:
201        - trace_scalar: Scalar values are not traced by default. If this flag is
202          set, scalar values will also be traced.
203        - op_range: In the form of '%d:%d' that limits the tracing to the ops
204          within this limit. --op_range='5:10' will trace only the ops that have
205            topological order between 5-10.
206        - submode: 'brief' or 'detailed'. If the trace mode is not compact,
207          brief mode will print only the id of each traced tensor to save some
208          space. 'detailed' mode prints the full tensor name.
209        - use_fingerprint_subdirectory: The trace directory will be chosen as
210          using the fingerprint of the trace metadata under the provided
211          trace_dir.
212  """
213  flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
214  if tensor_tracer_params:
215    for key, value in tensor_tracer_params.items():
216      flags += ' --%s=%s' % (key, value)
217  os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags
218
219
220def op_priority(op_type):
221  """Returns the priority of the op.
222
223  If the priority of the op is k, it will be traced if trace_level>=k.
224  Args:
225    op_type: String name of the operation type.
226  Returns:
227    Integer value corresponding the priority of the op.
228  """
229  if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
230                 'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
231    # Lowest priority ops, e.g., constant ops across different steps,
232    # They will be traced only if trace_level>=7
233    return 7
234
235  if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
236                 'PreventGradient', 'Squeeze', 'Gather', 'GatherNd'):
237    # Operations without numerical effects.
238    # They will be only if trace_level>=6
239    return 6
240  if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
241                 'CollectivePermute', 'SplitV', 'DynamicPartition'):
242    # Operations that merge or slice an input, will be traced if trace_level>=5
243    return 5
244  if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
245    # Operations less likely to provide useful information,
246    # will be traced if trace_level>=4
247    return 4
248  if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
249    # Add operations that are less likely create any issues, will be traced
250    # if trace_level>=3 (default=3)
251    return 3
252  if op_type in ('Neg', 'Sub'):
253    # Sub operations that are less likely create any issues, will be traced
254    # trace_level>=2
255    return 2
256  if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
257                 'Maximum', 'Mean', 'Variance', 'Exp', 'Rsqrt'):
258    # Multiplication and some other operations, will be traced if trace_level>=1
259    return 1
260
261  # Unclassified op_types default to being traced at level 2 and above.
262  return 2
263
264
265def read_tensor_tracer_event_file(event_file):
266  """Reads the event file written by tensor tracer.
267
268  This can be used to read the full tensors written into binary event files by
269  by TensorTracer with trace_mode=full_tensor_summary.
270
271  Example usage:
272    result_dict_list = tensor_tracer.read_tensor_tracer_event_file(
273      event_file_path)
274    for result_dict in result_dict_list:
275      for step, tensor_dict in result_dict.items():
276        for tensor_name, full_tensor_content in tensor_dict.items():
277          logging.info(tensor_name, full_tensor_content)
278
279  Args:
280    event_file: Path to the event file that contains only tensor tracer events.
281  Returns:
282    A list of event dictionaries, each of which with the form:
283    {step_number: {tensor_name: tensor_content}}. This is a list instead of
284    a single event dictionary because it is possible that an event file may
285    have multiple event traces, each of them covering the same step ranges.
286  Raises:
287    ValueError: If an unexpected trace is found.
288  """
289
290  # Keeps track of how many times that a step number shows up in these events.
291  step_occurrence_count = collections.defaultdict(int)
292
293  # List of step occurrences.
294  step_occurrence_list = []
295
296  for trace_event in summary_iterator.summary_iterator(event_file):
297    # First event is an event with file_version: "brain.Event:2"
298    if not trace_event.HasField('summary'):
299      continue
300    if len(trace_event.summary.value) != 1:
301      raise ValueError('Single step contains %d summary values,'
302                       ' expected 1.' % len(trace_event.summary.value))
303    step = trace_event.step
304    step_occurrence_count[step] += 1  # a new occurrence for this step.
305
306    occurrence_idx = step_occurrence_count[step] - 1
307    occurrence_size = len(step_occurrence_list)
308
309    if occurrence_idx == occurrence_size:
310      # This particular occurrence isn't yet recorded on step_occurrence_list.
311      # So append this new occurrence to the end of step_occurrence_list.
312      new_occurrence = collections.defaultdict(dict)
313      step_occurrence_list.append(new_occurrence)
314    else:
315      # This particular occurrence must be already recorded on
316      # step_occurrence_list (i.e. occurrence_idx < occurrence_size).
317      if occurrence_idx > occurrence_size:
318        raise ValueError('Unexpected: occurrence_idx (%d) > '
319                         'occurrence_size (%d)' % (occurrence_idx,
320                                                   occurrence_size))
321    tensor_value = trace_event.summary.value[0]
322    tensor_name = tensor_value.tag
323
324    real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
325    tensor_content = np.frombuffer(
326        tensor_value.tensor.tensor_content,
327        dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
328        ).reshape(real_shape)
329    step_occurrence_list[occurrence_idx][step][tensor_name] = tensor_content
330  return step_occurrence_list
331
332
333def trace_tensor(tensor, tracepoint_name=None):
334  """Programmatic interface to trace a tensor with Tensor Tracer.
335
336  Tensor Tracer, by default, traces all tensors in the execution. This function
337  can be used to limit traced tensors. If this function is called for a subset
338  of the tensors, only those will be traced.
339
340  For example, Tensor Traacer will only trace c below.
341    c = tf.MatMul(a, b)
342    tensor_tracer.trace_tensor(c)
343    d = tf.add(c, 1)
344  Args:
345     tensor: the tensor object for which the tracing is requested.
346     tracepoint_name: an optional tensor tracepoint name string. A tracepoint
347       name is an Tensor Tracer internal name for the tensor. It is useful when
348       comparing equivalent traces from different models that have different
349       tensor namings. Equivalent tensors (with different names) can be mapped
350       to each other by assigning a common tracepoint_name.
351
352  Returns:
353    The provided tensor.
354  """
355  if tracepoint_name is None:
356    tracepoint_name = tensor.name
357  tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
358  tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
359                                 (tensor, tracepoint_name))
360  return tensor
361
362
363def keras_layer_tracepoint(layer, checkpoint_name):
364  """An interface for adding the tensor outputs of a keras layer.
365
366  Encapsulates trace_tensor.
367
368  Args:
369     layer: A keras layer.
370     checkpoint_name: a string name for the checkpoint. This name has to be a
371     unique name if used within model comparison. The tensors that have the same
372     checkpoint identifier is compared in model comparison.
373
374  Returns:
375    The provided layer.
376  """
377  try:
378    outputs = layer.output
379    if tensor_util.is_tf_type(outputs):
380      trace_tensor(outputs, '%s' % (checkpoint_name))
381    else:
382      idx = 0
383      for output_tensor in outputs:
384        if tensor_util.is_tf_type(outputs):
385          trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx))
386        idx += 1
387  except AttributeError:
388    pass
389  except RuntimeError:
390    pass
391  return layer
392
393
394class TensorTracer:
395  """A software construct for tracing tensor values in a TF graph.
396
397  This utility is disabled by default. It is hooked into tpu.rewrite, so it can
398  easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as
399  below without a code change.
400    export TENSOR_TRACER_FLAGS="--enable=1"
401
402  Below is the use example to enable it on CPUs or GPUs, or for more advance use
403  cases on TPUs.
404
405    a = x + 1
406    b = a * 2
407    rs = tf.reduce_sum(b)
408    tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir',
409                             'report_file: 'path/to/report/file'})
410    tt = tensor_tracer.TensorTracer()
411    if on_tpu:
412      rs = tt.trace_tpu(tf.get_default_graph(),
413                          tensor_fetches=rs)
414    else:
415      rs = tt.trace_cpu(tf.get_default_graph(),
416                          tensor_fetches=rs)
417    session.run(rs)
418
419  If it is enabled, it will trace the output tensor values of
420  selected Ops in the graph. It has two outputs: (1) the traces and (2)
421  a report. The traces are dumped to a specified directory during the graph
422  execution, while the report is dumped during the graph construction.
423  By passing options via the env variable, users can change:
424     (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
425         full tensor values)
426     (2) which Ops to be traced (via op.name or op.type)
427     (3) output trace file path.
428
429  """
430  # The set of graphs that are rewritten by tensor tracer.
431  _traced_graphs = set()
432
433  @staticmethod
434  def is_enabled():
435    """Returns True if TensorTracer is enabled."""
436    try:
437      enable = tensor_tracer_flags.TTParameters().is_enabled()
438      # Add metrics to determine API usage.
439      if enable: tt_gauge.get_cell('is_enabled').set(True)
440      return enable
441    except (ValueError, RuntimeError) as e:
442      logging.warning(
443          'Tensor Tracer V1 flags processing error encountered in is_enabled '
444          'check. %s', e)
445      # TODO(b/210212559): Find a more robust fix.
446      # Should only produce exception if Tensor Tracer is enabled.
447      return True
448
449  @staticmethod
450  def check_device_type(device_type):
451    """Checks if the given device type is valid."""
452
453    if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
454      raise ValueError('Invalid device_type "%s"'%device_type)
455
456  @staticmethod
457  def check_trace_mode(device_type, trace_mode):
458    """Checks if the given trace mode work on the given device type.
459
460    Args:
461      device_type: Device type, TPU, GPU, CPU.
462      trace_mode: Tensor tracer trace mode.
463    Raises:
464      ValueError: If the given trace mode is not supported for the device.
465    """
466    if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
467      if device_type != _DEVICE_TYPE_TPU:
468        raise ValueError('Device_type "%s" is not yet supported for '
469                         'trace mode "%s"' % (device_type, trace_mode))
470
471  @staticmethod
472  def loop_cond_op(op):
473    return op.type in ('LoopCond', 'RefLoopCond')
474
475  @staticmethod
476  def while_loop_op(op):
477    """Returns true if op is one of the special ops of in a while loop.
478
479    Args:
480       op: A tf.Operation.
481
482    Returns:
483       True if the given op is one of [Switch, Merge, Enter, Exit,
484       NextIteration, LoopCond], which are all building blocks for TF while
485       loops.
486    """
487    return  (control_flow_util.IsLoopSwitch(op) or
488             control_flow_util.IsLoopMerge(op) or
489             control_flow_util.IsLoopEnter(op) or
490             control_flow_util.IsLoopExit(op) or
491             TensorTracer.loop_cond_op(op) or
492             op.type in ('RefNextIteration', 'NextIteration'))
493
494  @staticmethod
495  def control_flow_op(op):
496    """Returns true if op is one of the special ops of in a while loop.
497
498    Args:
499       op: A tf.Operation.
500
501    Returns:
502       True if the given op is one of [Switch, Merge, Enter, Exit,
503       NextIteration, LoopCond], which are all building blocks for TF while
504       loops.
505    """
506    return  (control_flow_util.IsSwitch(op) or
507             control_flow_util.IsMerge(op))
508
509  @staticmethod
510  def unsafe_op(op):
511    """Returns True if this op is not safe to be traced."""
512
513    # Reasons for not including following op types:
514    #    Assign: cause incorrect result with CPU tracing.
515    if op.type == 'Assign':
516      return True
517    return False
518
519  @staticmethod
520  def device_mismatch(device_type, op):
521    if device_type == _DEVICE_TYPE_TPU:
522      # pylint: disable=protected-access
523      return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
524      # pylint: enable=protected-access
525    return False
526
527  @staticmethod
528  def unsafe_scalar_trace(op):
529    """Return true if scalar output tensor from Op is not safe to be traced."""
530
531    # Tracing the following causes cycle in the graph on TPU.
532    if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
533                   'Switch', 'Less', 'ReadVariableOp'):
534      return True
535    # Tracing the following will cause casting-issue
536    # with the norm tracing mode or other compilation issues on CPU.
537    if op.type in ('VarHandleOp', 'IteratorToStringHandle',
538                   'IteratorGetNext', 'OneShotIterator',
539                   'IteratorV2', 'MakeIterator',
540                   'BatchDatasetV2', 'MapDataset',
541                   'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
542                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
543      return True
544    return False
545
546  def _is_interesting_op(self, op):
547    """Returns True if the given op is not an interesting one to be traced."""
548    return op_priority(op.type) <= self._parameters.trace_level
549
550  @staticmethod
551  def reason(op_idx, details):
552    """Returns reason why the Op at op_idx is traced or not."""
553
554    return '%d %s'%(op_idx, details)
555
556  def __init__(self):
557    """Initializes a TensorTracer.
558
559    Sets the various member fields from the flags (if given) or the defaults.
560    """
561    self._replica_id = None
562    self._tt_config = tensor_tracer_report.TensorTracerConfig()
563    self._parameters = None
564    self._host_call_fn = {}
565    # _cache_variables is a dict (key = graph, value = dicts
566    # (key = name, value = tensors))
567    self._cache_variables = {}
568    self._traced_op_names = set()
569    self._report_proto = None
570    # _temp_cache_var is a dict (key = graph, value = [])
571    self._temp_cache_var = {}
572    self._report_proto_path = ''
573    self._outmost_context = None
574
575  def report_proto(self):
576    """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.
577
578    Returns:
579      A tensor_tracer.proto object.
580    Raises:
581      ValueError if called before tracing happens, or when trace mode is not
582      summary or full_tensor_summary.
583    """
584    if self._report_proto:
585      return self._report_proto
586    else:
587      raise ValueError('Call to report_proto must be done after tracing.'
588                       'Report proto only exists for '
589                       'trace_mode=[summary|full_tensor_summary]')
590
591  def report_proto_path(self):
592    """Getter for path where tensor_tracer.proto object should be written.
593
594    Returns:
595      A string path.
596    """
597    return self._report_proto_path
598
599  def _cache_variable_for_graph(self, graph):
600    if graph not in self._cache_variables:
601      self._cache_variables[graph] = {}
602    return self._cache_variables[graph]
603
604  def _create_or_get_tensor_values_cache(self, cache_name, graph,
605                                         shape=None, dtype=dtypes.float32):
606    """Creates a variable as the cache to store intermediate tensor values.
607
608    Args:
609      cache_name: Name to be given to the cache (an instance of tf.variable).
610      graph: Tensorflow graph.
611      shape: A list of dimensions.
612      dtype: Data type of created cache.
613    Returns:
614      A ref to newly created or existing cache with the given dimensions.
615    Raises:
616      ValueError:
617        (1) If graph is None, or
618        (2) shape is None when a new cache needs to be created.
619    """
620
621    def _escape_namescopes(variable_name):
622      # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor"
623      # and "foo_bar/mytensor".
624      return variable_name.replace('/', '_').replace(':', '_')
625
626    if graph is None:
627      raise ValueError('Invalid graph.')
628
629    graph_cache_var = self._cache_variable_for_graph(graph)
630
631    if cache_name not in graph_cache_var:
632      if shape is None:
633        raise ValueError('shape must be provided at cache creation.')
634      if dtype.is_integer:
635        init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
636      else:
637        init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
638
639      # Create in proper graph and base name_scope.
640      with graph.as_default() as g, g.name_scope(None):
641        graph_cache_var[cache_name] = variable_scope.get_variable(
642            _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name),
643            shape=shape, dtype=dtype,
644            initializer=init_ops.constant_initializer(init_val),
645            trainable=False,
646            use_resource=True,
647            collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
648    return graph_cache_var[cache_name]
649
650  def _add_replica_id_to_graph(self):
651    """Adds nodes for computing the replica ID to the graph."""
652
653    if self._tt_config.num_replicas:
654      with ops.control_dependencies(None):
655        # Uses None as dependency to run outside of TPU graph rewrites.
656        self._replica_id = tpu_ops.tpu_replicated_input(
657            list(range(self._tt_config.num_replicas)),
658            name='tt_replica_id')
659    else:
660      self._replica_id = 'unknown'
661
662  def _inside_op_range(self, idx):
663    """Return True if the given index is inside the selected range."""
664
665    if idx < self._parameters.op_range[0]:
666      return False
667    return (self._parameters.op_range[1] < 0 or
668            idx <= self._parameters.op_range[1])
669
670  def _is_user_included_op(self, op):
671    """Checks whether the op is included in the tensor tracer flags.
672
673    Args:
674      op: tf Operation
675    Returns:
676      True, if the op is included.
677      An op is included if:
678      - Its op name is given in included_opnames
679      - Its op type is given in included_optypes
680      - The op is at most _trace_ops_before_included hops before an included op
681      - The op is at most _trace_ops_after_included hops after an included op
682    """
683    for opname_re in self._parameters.included_opname_re_list:
684      if opname_re.match(op.name):
685        return True
686
687    for optype_re in self._parameters.included_optype_re_list:
688      if optype_re.match(op.type):
689        return True
690    return False
691
692  def _is_user_excluded_op(self, op):
693    for opname_re in self._parameters.excluded_opname_re_list:
694      if opname_re.match(op.name):
695        return True
696    for optype_re in self._parameters.excluded_optype_re_list:
697      if optype_re.match(op.type):
698        return True
699    return False
700
701  def _signature_types(self):
702    """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
703    if self._parameters.trace_mode in set([
704        tensor_tracer_flags.TRACE_MODE_NAN_INF,
705        tensor_tracer_flags.TRACE_MODE_NORM,
706        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
707      return {self._parameters.trace_mode: 0}
708    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
709      return self._parameters.summary_signatures
710    return {}
711
712  def _num_signature_dimensions(self):
713    return len(self._signature_types())
714
715  def _use_temp_cache(self):
716    """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable.
717
718    Returns:
719      A boolean, denoting whether to use a temporary cache or not.
720    """
721    # If full tensors need to be stored tf.variables, then do not use temp
722    # variables to store them.
723    if self._use_tensor_buffer():
724      return False
725    if self._use_tensor_values_cache():
726      return self._parameters.use_temp_cache_var
727    else:
728      # Temporary caches only replaces tf.Variables caches. If no cache is used
729      # return False.
730      return False
731
732  def _use_tensor_values_cache(self):
733    """Returns True if immediate tensors should be first saved to a cache."""
734    return self._parameters.use_compact_trace
735
736  def _use_tensor_buffer(self):
737    """Returns true if the whole tensor needs to be cached/buffered in memory."""
738    return (self._parameters.trace_mode ==
739            tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
740
741  def _merge_tensor_signatures(self, signatures):
742    """Returns a tensor that merges the given signatures.
743
744    Args:
745      signatures: A dictionary of the signature updates from signature name to
746      a tensor of dimension [1].
747    Returns:
748      A tensor that concats the signature values in a predefined order.
749    Raises:
750      ValueError: Unable to merge signatures.
751    """
752    sorted_update = []
753    if self._num_signature_dimensions() > 1:
754      signature_indices = self._signature_types()
755      for _, val in sorted(signatures.items(),
756                           key=lambda item: signature_indices[item[0]]):
757        sorted_update.append(val)
758      updates = array_ops.stack(
759          sorted_update, axis=0, name='merge_single_op_signatures')
760    elif self._num_signature_dimensions() == 1:
761      # Avoid stack operation if there is only a single signature.
762      (_, val), = signatures.items()
763      updates = val
764    else:
765      raise ValueError('Cannot merge 0 signatures. Check the value passed for '
766                       'flag --signatures.')
767    return updates
768
769  def _save_tensor_value_to_tmp_cache(self, cache_idx, updates, graph):
770    """Returns an op that will save the given updates to an entry in the cache.
771
772    Args:
773      cache_idx: The cache index of the tensor within the cache.
774      updates: A dictionary of the signature updates from signature name to
775      a tensor of dimension [1].
776      graph: A TensorFlow graph.
777    Raises:
778      RuntimeError:
779        (1) graph is not already in self._temp_cache_var, or
780        (2) cache_idx is out of range.
781    """
782    updates = self._merge_tensor_signatures(updates)
783    updates = array_ops.reshape(updates,
784                                [self._num_signature_dimensions()])
785    if graph not in self._temp_cache_var:
786      raise RuntimeError('graph is not in self._temp_cache_var')
787    if cache_idx >= len(self._temp_cache_var[graph]):
788      raise RuntimeError('cache_idx (%d) is out of range (%d)' % (
789          cache_idx, len(self._temp_cache_var[graph])))
790    self._temp_cache_var[graph][cache_idx] = updates
791
792  def _save_tensor_value_to_cache_op(self, cache_idx, updates, graph):
793    """Returns an op that will save the given updates to an entry in the cache.
794
795    Args:
796      cache_idx: The cache index of the tensor within the cache.
797      updates: A dictionary of the signature updates.
798      graph: A TensorFlow graph.
799    Returns:
800      Cache update operation.
801    """
802    # state_ops.scatter_update allows updates only along the first dimension.
803    # Make a compact array by concatenating different signatures, and update
804    # them all together.
805    updates = self._merge_tensor_signatures(updates)
806    updates = array_ops.reshape(updates,
807                                [1, self._num_signature_dimensions()])
808    indices = constant_op.constant([cache_idx])
809    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
810    return state_ops.scatter_update(cache, indices, updates).op
811
812  def _snapshot_tensor(self, tensor):
813    """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
814
815    Args:
816      tensor: tensor whose values will be stored in a new tf.Variable.
817    Returns:
818      An assignment operation.
819    """
820
821    snapshot_variable = self._create_or_get_tensor_values_cache(
822        tensor.name, tensor.op.graph,
823        tensor.shape.as_list(), tensor.dtype)
824    return state_ops.assign(snapshot_variable, tensor).op
825
826  def _preprocess_traced_tensor(self, tensor):
827    """Computes NAN/Norm/Max on TPUs before sending to CPU.
828
829    Args:
830      tensor: The tensor to be traced.
831    Returns:
832      A tensor that should be input to the trace_function.
833    Raises:
834      RuntimeError: If the signature is invalid.
835    """
836
837    def _detect_nan_inf(tensor):
838      """Trace function for detecting any NaN/Inf in the tensor."""
839
840      if tensor.dtype.is_floating:
841        mask = math_ops.reduce_any(
842            gen_math_ops.logical_or(
843                gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
844        output_tensor = control_flow_ops.cond(
845            mask,
846            lambda: constant_op.constant([1.0]),
847            lambda: constant_op.constant([0.0]))
848      else:
849        output_tensor = constant_op.constant([0.0])
850      return output_tensor
851
852    def _compute_signature(tensor, tf_op, cast_to_f32=True):
853      if cast_to_f32:
854        tensor = math_ops.cast(tensor, dtypes.float32)
855      output_tensor = tf_op(tensor)
856      # Return type should be scalar. Set it if it does not have the
857      # information.
858      if not output_tensor.get_shape().is_fully_defined():
859        output_tensor = array_ops.reshape(output_tensor, [])
860      return output_tensor
861
862    def _show_size(tensor):
863      # In order to check the size of a tensor.
864      # Not all sizes are known at the compile time, also, different replicas
865      # sometimes get different sizes of tensors.
866      # Collect it here to be used in merging replica data.
867      tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
868      # Cast to float32, so that it can be placed into same cache with other
869      # signatures.
870      return math_ops.cast(tsize, dtypes.float32)
871
872    def _show_max(tensor, cast_to_f32=True):
873      # returns -inf for empty tensor
874      return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
875
876    def _show_min(tensor, cast_to_f32=True):
877      # returns inf for empty tensor
878      return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
879
880    def _show_norm(tensor, cast_to_f32=True):
881      # returns 0 for empty tensor
882      return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
883
884    def _show_sparsity(tensor, cast_to_f32=True, tolerance=1e-06):
885      # returns nan for empty tensor and treats nans as non-zero numbers
886      def sparsity_fn(tensor):
887        non_zeros = math_ops.greater_equal(math_ops.abs(tensor), tolerance)
888        nans = math_ops.is_nan(tensor)
889        return nn_impl.zero_fraction(math_ops.logical_or(non_zeros, nans))
890
891      return _compute_signature(tensor, sparsity_fn, cast_to_f32)
892
893    def _show_mean_and_variance(tensor, cast_to_f32=True):
894      """Returns the mean and variance of the given tensor."""
895      if cast_to_f32:
896        tensor = math_ops.cast(tensor, dtypes.float32)
897      # returns nan for empty tensor
898      mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
899      # The shape has to be 1. Set it if it does not have the information.
900      if not mean.get_shape().is_fully_defined():
901        mean = array_ops.reshape(mean, [])
902      if not var.get_shape().is_fully_defined():
903        var = array_ops.reshape(var, [])
904      return mean, var
905
906    def _show_max_abs(tensor, cast_to_f32=True):
907      return _compute_signature(
908          tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32)
909
910    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
911      return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
912    if (self._parameters.trace_mode ==
913        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
914      return {self._parameters.trace_mode: tensor}
915    if (self._parameters.trace_mode in (
916        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
917        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
918      return {self._parameters.trace_mode: tensor}
919    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
920      return {self._parameters.trace_mode: array_ops.reshape(
921          _show_norm(tensor), [1])}
922    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
923      return {self._parameters.trace_mode: _show_max_abs(tensor)}
924
925    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
926      tensor = math_ops.cast(tensor, dtypes.float32)
927      result_dict = {}
928      # Call mean and variance computation here to avoid adding the same nodes
929      # twice.
930      if (_TT_SUMMARY_MEAN in self._signature_types() or
931          _TT_SUMMARY_VAR in self._signature_types()):
932        mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)
933
934      for signature_name, _ in sorted(self._signature_types().items(),
935                                      key=lambda x: x[1]):
936        if signature_name == _TT_SUMMARY_NORM:
937          signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
938        elif signature_name == _TT_SUMMARY_MAX:
939          signature_result_tensor = _show_max(tensor, cast_to_f32=False)
940        elif signature_name == _TT_SUMMARY_MAX_ABS:
941          signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False)
942        elif signature_name == _TT_SUMMARY_MIN:
943          signature_result_tensor = _show_min(tensor, cast_to_f32=False)
944        elif signature_name == _TT_SUMMARY_SPARSITY:
945          signature_result_tensor = _show_sparsity(tensor)
946        elif signature_name == _TT_SUMMARY_SIZE:
947          signature_result_tensor = _show_size(tensor)
948        elif signature_name == _TT_SUMMARY_MEAN:
949          signature_result_tensor = mean
950        elif signature_name == _TT_SUMMARY_VAR:
951          signature_result_tensor = variance
952        else:
953          raise ValueError('Unknown signature type :%s.' % signature_name)
954
955        result_dict[signature_name] = signature_result_tensor
956      return result_dict
957
958    raise RuntimeError(
959        'Unsupported signature for trace mode %s.'
960        % self._parameters.trace_mode)
961
962  def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
963    """Makes the tensor tracing function called by outside compilation.
964
965    Args:
966      tensor_name: name of the tensor being traced.
967      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
968    Returns:
969      A function to be passed as the first argument to outside compilation.
970
971    Raises:
972      RuntimeError: If the trace mode is invalid.
973    """
974
975    def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
976      """Prints a tensor value to a file.
977
978      Args:
979        tensor_name: name of the tensor being traced.
980        num_elements: number of elements to print (-1 means print all).
981        tensor: the tensor needs to be returned.
982        output_tensor: the tensor needs to be printed.
983
984      Returns:
985        The same tensor passed via the "tensor" argument.
986
987      Raises:
988        ValueError: If tensor_name is not already in
989                    tensor_trace_order.tensorname_to_cache_idx.
990      """
991
992      if self._parameters.is_brief_mode():
993        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
994          raise ValueError(
995              'Tensor %s with name %s is not in the tensorname_to_cache_idx' %
996              (tensor, tensor_name))
997        msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name]
998      else:
999        msg = '"%s"' % tensor_name
1000
1001      if self._parameters.trace_dir:
1002        output_path = os.path.join(
1003            self._parameters.trace_dir,
1004            _TRACE_FILE_NAME + self._get_outfile_suffix())
1005        output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1006      else:
1007        output_stream = sys.stderr
1008      return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
1009                                  '@', self._replica_id,
1010                                  '\n', output_tensor, '\n',
1011                                  summarize=num_elements,
1012                                  output_stream=output_stream)
1013
1014    def _show_part_tensor(tensor):
1015      """Trace function for printing part of the tensor."""
1016
1017      return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
1018                           tensor, tensor)
1019
1020    def _show_full_tensor(tensor):
1021      """Trace function for printing the entire tensor."""
1022
1023      return _print_tensor(tensor_name, -1, tensor, tensor)
1024
1025    if (self._parameters.trace_mode ==
1026        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
1027      return _show_part_tensor
1028    # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
1029    # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
1030    # performed within TPUs and only their results are transferred to CPU.
1031    # Simply, print the full tensor for these trace modes.
1032    if self._parameters.trace_mode in (
1033        tensor_tracer_flags.TRACE_MODE_NAN_INF,
1034        tensor_tracer_flags.TRACE_MODE_NORM,
1035        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
1036        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
1037        tensor_tracer_flags.TRACE_MODE_SUMMARY
1038        ):
1039      return _show_full_tensor
1040
1041    raise RuntimeError('Full tensor support is not available with trace mode %s'
1042                       %self._parameters.trace_mode)
1043
1044  def _is_in_control_flow(self, op):
1045    """Returns true if the given op is inside a tf.cond or in tf.while_loop.
1046
1047    Args:
1048      op: A tensorflow op that should be checked whether in control flow or not.
1049    Returns:
1050      A boolean value whether the op is in control flow or not.
1051    """
1052    return control_flow_util.IsInCond(op)
1053
1054  def _is_in_outmost_while_loop(self, op):
1055    """Returns true if the op is at the same level with the training loop.
1056
1057    Returns false if the op is in an inner while loop or if it is outside of the
1058    training loop.
1059    Args:
1060      op: tf.Operation
1061
1062    Returns:
1063      A boolean.
1064    """
1065    ctxt = self._get_op_control_flow_context(op)
1066    outer_while_context = control_flow_util.GetContainingWhileContext(ctxt)
1067    return outer_while_context == control_flow_util.GetContainingWhileContext(
1068        self._outmost_context)
1069
1070  def _should_trace_in_control_flow(self):
1071    """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop."""
1072    # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY
1073    # forces the execution of the traced tensors. We should not trace the ops
1074    # that may not be executed due to control flow.
1075    if self._use_temp_cache():
1076      return False
1077    elif self._tt_config.device_type == _DEVICE_TYPE_TPU:
1078      # On TPUs do not trace in control flow unless we use caches to store
1079      # intermediate values as calling outside compilation within an inner loop
1080      # causes errors.
1081      return self._use_tensor_values_cache() or self._use_tensor_buffer()
1082    return True
1083
1084  def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
1085    """Returns True if we should not trace Op.
1086
1087    Args:
1088      op_id: Topological index of the op.
1089      op: tf.Operation
1090      ops_in_exec_path: Set of operations that are in the execution path.
1091      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1092    Returns:
1093      True if the op should not be traced, false otherwise.
1094    """
1095    if TensorTracer.while_loop_op(op):
1096      report_handler.instrument_op(
1097          op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
1098      return True
1099    if TensorTracer.control_flow_op(op):
1100      report_handler.instrument_op(
1101          op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
1102      return True
1103    if TensorTracer.unsafe_op(op):
1104      report_handler.instrument_op(
1105          op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
1106      return True
1107    if TensorTracer.device_mismatch(self._tt_config.device_type, op):
1108      report_handler.instrument_op(
1109          op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
1110      return True
1111    if op not in ops_in_exec_path:
1112      report_handler.instrument_op(
1113          op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
1114      return True
1115    # TensorTracer will not trace the operations that are in an inner while loop
1116    # or tf.cond when a temporary cache is used. Temporary cache adds direct
1117    # data dependencies to traced operations, and needs a static number of
1118    # traced operations. For these cases,
1119    # - We do not know the number of slots required when there are inner while
1120    # loops. TensorTracer can only trace the result of a while loop.
1121    # - We do not know ahead of time which branch of the tf.cond
1122    # will be taken, so we avoid introducing data dependencies for the
1123    # operations inside a tf.cond.
1124    # - We also cannot have a data dependency to an operation in a different
1125    # while context.
1126    if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op):
1127      if not self._should_trace_in_control_flow():
1128        report_handler.instrument_op(
1129            op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW))
1130        return True
1131    if self._is_user_included_op(op):
1132      report_handler.instrument_op(
1133          op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1134      return False
1135
1136    if not self._inside_op_range(op_id):
1137      report_handler.instrument_op(
1138          op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
1139      return True
1140    if not self._is_interesting_op(op):
1141      report_handler.instrument_op(
1142          op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
1143      return True
1144    if self._is_user_excluded_op(op):
1145      report_handler.instrument_op(
1146          op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1147      return True
1148    return False
1149
1150  def _skip_tensor(self, op_id, out_tensor, report_handler):
1151    """Returns True if we should not trace out_tensor.
1152
1153    Args:
1154      op_id: Topological index of the op producing tensor.
1155      out_tensor: tf.Tensor
1156      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1157    Returns:
1158      True if the tensor should not be traced, false otherwise.
1159    """
1160
1161    # Skips a tensor if the tensor has a non-numeric type.
1162    #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
1163    #         because it also excludes tensors with dtypes, bool, and
1164    #         float32_ref, which we actually want to trace.
1165    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
1166                                    dtypes.string])
1167    if out_tensor.dtype in non_numeric_tensor_types:
1168
1169      report_handler.instrument_tensor(
1170          out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
1171      return True
1172    # Skip a tensor if it feeds a special while loop op.
1173    if [consumer for consumer in out_tensor.consumers() if
1174        TensorTracer.while_loop_op(consumer)]:
1175      report_handler.instrument_tensor(
1176          out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
1177      return True
1178    if self._is_user_included_op(out_tensor.op):
1179      report_handler.instrument_tensor(
1180          out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1181      return False
1182    if self._is_user_excluded_op(out_tensor.op):
1183      report_handler.instrument_tensor(
1184          out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1185      return True
1186    if not out_tensor.get_shape().is_fully_defined():
1187      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
1188      # to a scalar before the outside compilation call.
1189      if self._parameters.trace_mode in (
1190          tensor_tracer_flags.TRACE_MODE_NAN_INF,
1191          tensor_tracer_flags.TRACE_MODE_NORM,
1192          tensor_tracer_flags.TRACE_MODE_MAX_ABS,
1193          tensor_tracer_flags.TRACE_MODE_SUMMARY
1194          ):
1195        report_handler.instrument_tensor(
1196            out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1197        return False
1198      else:
1199        report_handler.instrument_tensor(
1200            out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
1201        return True
1202    rank = len(out_tensor.shape)
1203    if rank < 1:
1204      # scalar
1205      if self._parameters.trace_scalar_ops:
1206        if TensorTracer.unsafe_scalar_trace(out_tensor.op):
1207          report_handler.instrument_tensor(
1208              out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
1209          return True
1210        else:
1211          report_handler.instrument_tensor(
1212              out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
1213          return False
1214      else:
1215        report_handler.instrument_tensor(
1216            out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
1217        return True
1218    else:
1219      # tensor
1220      report_handler.instrument_tensor(
1221          out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1222      return False
1223
1224  def _filter_execution_path_operations(self, operations, fetches):
1225    """Returns the set of ops in the execution path to compute given fetches."""
1226
1227    # If no fetch provided, then return all operations.
1228    if fetches is None:
1229      return set(operations)
1230    # Convert to list, if a single element is provided.
1231    if not isinstance(fetches, (list, tuple)):
1232      fetches = [fetches]
1233    # If a tensor is given as fetch, convert it to op.
1234    op_fetches = []
1235    for fetch in fetches:
1236      if isinstance(fetch, ops.Operation):
1237        op_fetches.append(fetch)
1238      elif isinstance(fetch, ops.Tensor):
1239        op_fetches.append(fetch.op)
1240      else:
1241        raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
1242                           %fetch)
1243
1244    execution_path_operations = set(op_fetches)
1245    traverse_stack = list(op_fetches)
1246    while True:
1247      if not traverse_stack:
1248        break
1249      head_op = traverse_stack.pop()
1250      input_ops = [tensor_input.op for tensor_input in head_op.inputs]
1251      input_ops.extend(head_op.control_inputs)
1252
1253      for input_op in input_ops:
1254        if input_op not in execution_path_operations:
1255          # Filter out loop condition operations, tracing them causes a cycle.
1256          # Trace only the loop-body.
1257          if TensorTracer.loop_cond_op(input_op):
1258            continue
1259          execution_path_operations.add(input_op)
1260          traverse_stack.append(input_op)
1261    return execution_path_operations
1262
1263  def _determine_and_instrument_traced_tensors(self, graph_order,
1264                                               ops_in_exec_path,
1265                                               tensor_trace_points,
1266                                               report_handler):
1267    """Determines the tensors to trace and instruments the trace details.
1268
1269    Args:
1270      graph_order: graph_order tuple containing graph (tf.graph), operations
1271        (list of operations), op_to_idx (op id mapping), (tensors) list of
1272        tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
1273        there is a cycle in the graph), topological_order_or_cycle (list of ops
1274        in topological order or list of ops creating a cycle).
1275      ops_in_exec_path: Set of ops in the execution path.
1276      tensor_trace_points: Collection of programatic tensor trace points.
1277      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1278    Returns:
1279      List of tensors to be traced.
1280    """
1281
1282    traced_tensors = []
1283    checkpoint_operations = set([tensor.op
1284                                 for (tensor, _) in tensor_trace_points])
1285    for op_id, op in enumerate(graph_order.operations):
1286      if checkpoint_operations and op not in checkpoint_operations:
1287        continue
1288      if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
1289        continue
1290      for i in range(len(op.outputs)):
1291        out_tensor = op.outputs[i]
1292        if not self._skip_tensor(op_id, out_tensor, report_handler):
1293          traced_tensors.append(out_tensor)
1294    return traced_tensors
1295
1296  def _check_trace_files(self):
1297    """Checks if any requirements for trace files are satisfied."""
1298
1299    if not self._parameters.trace_dir:
1300      # traces will be written to stderr. No need to check trace files.
1301      return
1302    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
1303      # Output files are handled by tf.summary operations, no need to precreate
1304      # them.
1305      return
1306    if not gfile.Exists(self._parameters.trace_dir):
1307      file_io.recursive_create_dir(self._parameters.trace_dir)
1308      if not gfile.Exists(self._parameters.trace_dir):
1309        raise RuntimeError('Failed to create trace directory at %s' %
1310                           self._parameters.trace_dir)
1311
1312  def _create_temp_cache(self, num_traced_tensors, num_signatures, graph):
1313    """Creates a temporary cache with the given dimensions.
1314
1315    Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops
1316    that have shape of [num_signatures].
1317    Args:
1318      num_traced_tensors: Int, denoting total number of traced tensors.
1319      num_signatures: Int, denoting the number of statistics collected per
1320        tensors.
1321      graph: TensorFlow graph.
1322    """
1323    init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1324                                      dtype=dtypes.float32,
1325                                      shape=[num_signatures])
1326    self._temp_cache_var[graph] = [
1327        init_value for _ in range(num_traced_tensors)]
1328
1329  def _determine_trace_and_create_report(self, graph, ops_in_exec_path,
1330                                         graph_summary_tag):
1331    """Work needs to be done prior to TPU or CPU tracing.
1332
1333    Args:
1334      graph: tf.graph
1335      ops_in_exec_path: Set of operations in the execution path.
1336      graph_summary_tag: the summary tag name for the given graph.
1337    Returns:
1338      An instance of tensor_tracer_report.TensorTraceOrder, containing list of
1339      tensors to be traced with their topological order information.
1340    """
1341
1342    self._check_trace_files()
1343
1344    graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
1345    tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
1346
1347    report_handler = tensor_tracer_report.TTReportHandle()
1348    traced_tensors = self._determine_and_instrument_traced_tensors(
1349        graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
1350    logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors))
1351
1352    tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
1353                                                               traced_tensors)
1354    num_signatures = self._num_signature_dimensions()
1355    # Create a cache variable if compact_tracing is used.
1356    if num_signatures and self._use_tensor_values_cache():
1357      if self._use_temp_cache():
1358        self._create_temp_cache(len(traced_tensors), num_signatures, graph)
1359      else:
1360        self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG,
1361                                                graph,
1362                                                [len(traced_tensors),
1363                                                 num_signatures])
1364    if self._parameters.trace_mode in (
1365        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1366        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
1367      self._report_proto = report_handler.create_report_proto(
1368          self._tt_config, self._parameters, tensor_trace_order,
1369          tensor_trace_points, self._signature_types())
1370      if self._parameters.use_fingerprint_subdir:
1371        self._parameters.trace_dir = os.path.join(
1372            self._parameters.trace_dir, self._report_proto.fingerprint)
1373        logging.info('TensorTracer updating trace_dir to %s',
1374                     self._parameters.trace_dir)
1375      self._report_proto_path = report_handler.report_proto_path(
1376          self._parameters.trace_dir, graph_summary_tag)
1377
1378      if self._parameters.report_file_path != _SKIP_REPORT_FILE:
1379        report_handler.write_report_proto(self._report_proto_path,
1380                                          self._report_proto, self._parameters)
1381    else:
1382      report_handler.create_report(self._tt_config, self._parameters,
1383                                   tensor_trace_order, tensor_trace_points)
1384    return tensor_trace_order
1385
1386  def _create_host_call(self):
1387    return self._parameters.trace_mode in (
1388        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1389        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
1390
1391  def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream,
1392                             tensor_trace_order):
1393    """Generates a print operation to print trace inspection.
1394
1395    Args:
1396      cache: Tensor storing the trace results for the step.
1397      replica_id: Tensor storing the replica id of the running core.
1398      step_num: Step number.
1399      output_stream: Where to print the outputs, e.g., file path, or sys.stderr.
1400      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1401
1402    Returns:
1403      The Op to flush the cache to file.
1404    """
1405    def _inspect_tensor(tensor):
1406      """Returns the text to be printed for inspection output."""
1407      if (self._parameters.trace_mode ==
1408          tensor_tracer_flags.TRACE_MODE_NAN_INF):
1409        return control_flow_ops.cond(
1410            math_ops.greater(tensor, 0.0),
1411            lambda: 'has NaNs/Infs!',
1412            lambda: 'has no NaNs or Infs.')
1413      else:
1414        return tensor
1415
1416    # Check if there are graph operations being profiled.
1417    if not tensor_trace_order.traced_tensors:
1418      logging.warn('Inspect mode has no tensors in the cache to check.')
1419      return control_flow_ops.no_op
1420
1421    # Check if the cache includes any nan or inf
1422    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
1423      # Cache has 1s or 0s if the mode is NaN_INF
1424      step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0)
1425    else:
1426      # Cache has the actual numerics for other modes.
1427      step_has_nan_or_inf = math_ops.reduce_any(
1428          gen_math_ops.logical_or(
1429              gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache)))
1430
1431    # Summarizing message for each step.
1432    step_error_message = control_flow_ops.cond(
1433        step_has_nan_or_inf,
1434        lambda: 'NaNs or Infs in the step!',
1435        lambda: 'No numerical issues have been found for the step.')
1436
1437    # No need to print core numbers if the cache is merged already.
1438    if self._parameters.collect_summary_per_core:
1439      stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->',
1440               step_error_message,
1441               'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1442    else:
1443      stats = ['\n\n', 'step:', step_num, '-->', step_error_message,
1444               'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1445
1446    for tensor_name, cache_idx in sorted(
1447        tensor_trace_order.tensorname_to_cache_idx.items(),
1448        key=lambda item: item[1]):
1449      if self._parameters.collect_summary_per_core:
1450        stats.extend([
1451            '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
1452            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1453      else:
1454        stats.extend([
1455            '\n', 'step:', step_num, ',',
1456            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1457    return logging_ops.print_v2(*stats, summarize=-1,
1458                                output_stream=output_stream)
1459
1460  def _get_outfile_suffix(self):
1461    if remote_utils.is_remote_path(self._parameters.trace_dir):
1462      return remote_utils.get_appendable_file_encoding()
1463    else:
1464      return ''
1465
1466  def _generate_flush_cache_op(self, num_replicas, on_tpu,
1467                               tensor_trace_order, graph):
1468    """Generates an Op that will flush the cache to file.
1469
1470    Args:
1471      num_replicas: total number of replicas.
1472      on_tpu: if the graph is executed on TPU.
1473      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1474      graph: TensorFlow graph.
1475
1476    Returns:
1477      The Op to flush the cache to file.
1478    """
1479
1480    def _flush_fun(cache, replica_id, step_num):
1481      """Flushes the cache to a file corresponding to replica_id."""
1482
1483      def _f(file_index):
1484        """Generates a func that flushes the cache to a file."""
1485        def _print_cache():
1486          """Flushes the cache to a file."""
1487          replica_str = ('%d' % file_index)
1488          if self._parameters.trace_dir:
1489            output_path = (os.path.join(self._parameters.trace_dir,
1490                                        _COMPACT_TRACE_FILE_PREFIX)
1491                           + replica_str + self._get_outfile_suffix())
1492            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1493          else:
1494            output_stream = sys.stderr
1495
1496          new_step_line = _REPLICA_ID_TAG + replica_str
1497          print_ops = []
1498          if self._parameters.inspect_trace:
1499            if self._num_signature_dimensions() > 1:
1500              raise ValueError('Inspecting multi signatures are not supported.')
1501            print_ops.append(self._inspect_summary_cache(
1502                cache=cache, replica_id=replica_id, step_num=step_num,
1503                output_stream=output_stream,
1504                tensor_trace_order=tensor_trace_order))
1505          else:
1506            for i in range(self._num_signature_dimensions()):
1507              print_ops.append(logging_ops.print_v2(
1508                  new_step_line, '\n',
1509                  cache[:, i], '\n',
1510                  summarize=-1,
1511                  output_stream=output_stream))
1512          with ops.control_dependencies(print_ops):
1513            return constant_op.constant(0).op
1514        return _print_cache
1515
1516      def _eq(file_index):
1517        return math_ops.equal(replica_id, file_index)
1518
1519      flush_op_cases = {}
1520      flush_op_cases[_eq(0)] = _f(0)
1521      for i in range(1, num_replicas):
1522        if on_tpu and not self._parameters.collect_summary_per_core:
1523          # If this is the case, the cache is already merged for all cores.
1524          # Only first core flushes the cache.
1525          flush_op_cases[_eq(i)] = control_flow_ops.no_op
1526        else:
1527          flush_op_cases[_eq(i)] = _f(i)
1528      # Each replica needs to determine where to write their output.
1529      # To do this, we check if replica_id is 0, then 1, ..., and then
1530      # num_replicas - 1 statically; and return the corresponding static file
1531      # name. We cannot simply set the file name in python, as replica_id is
1532      # only known during tf runtime, and we cannot create dynamic filenames.
1533      return control_flow_ops.case(flush_op_cases, exclusive=True)
1534
1535    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, graph)
1536    if self._use_temp_cache():
1537      cache_val = cache
1538    else:
1539      cache_val = cache.value()
1540
1541    if on_tpu:
1542      # If we do not need to collect traces for all cores, merge and aggregate
1543      # per core trace.
1544      if not self._parameters.collect_summary_per_core:
1545        cache_val = self.merge_caches_on_tpu(cache_val)
1546        cache_val = self.aggregate_global_cache(cache_val)[0]
1547
1548      flush_op = tpu.outside_compilation(
1549          _flush_fun, cache_val, self._replica_id,
1550          array_ops.identity(training_util.get_or_create_global_step()))
1551    else:
1552      global_step = training_util.get_or_create_global_step()
1553      flush_op = _flush_fun(cache_val, self._replica_id, global_step)
1554
1555    if self._use_temp_cache():
1556      with ops.control_dependencies([flush_op]):
1557        return constant_op.constant(0).op
1558    else:
1559      # Re-initialize the local cache variable.
1560      with ops.control_dependencies([flush_op]):
1561        reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1562                                           dtype=cache.dtype,
1563                                           shape=cache.shape)
1564        assign_op = state_ops.assign(cache, reset_value).op
1565        with ops.control_dependencies([assign_op]):
1566          return constant_op.constant(0).op
1567
1568  def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu,
1569                                 tensor_trace_order, graph):
1570    """Flushes the intermediate tensor values in the graph to the cache.
1571
1572    Args:
1573      tensor_fetches: list of tensor results returned by the model_fn.
1574      op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
1575      on_tpu: if the graph is executed on TPU.
1576      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1577      graph: TensorFlow graph.
1578
1579    Returns:
1580      An identical copy of tensor_fetches.
1581    """
1582    # Add a dependency to op and tensor fetches to make sure that all tracing
1583    # ops are executed before flushing trace results.
1584    if not tensor_trace_order.traced_tensors:
1585      logging.warn('No tensor values being traced. No flush cache op added.')
1586      return tensor_fetches
1587    with ops.control_dependencies(op_fetches +
1588                                  [tensor.op for tensor in tensor_fetches]):
1589      flush_cache_op = self._generate_flush_cache_op(
1590          self._tt_config.num_replicas, on_tpu, tensor_trace_order, graph)
1591      return control_flow_ops.tuple(tensor_fetches,
1592                                    control_inputs=[flush_cache_op])
1593
1594  def _process_tensor_fetches(self, tensor_fetches):
1595    """Check that tensor_fetches is not empty and have valid tensors."""
1596    # If none or empty list.
1597    if tensor_fetches is None:
1598      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1599                         'None.')
1600    if not isinstance(tensor_fetches, (list, tuple)):
1601      tensor_fetches = [tensor_fetches]
1602    elif not tensor_fetches:
1603      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1604                         'empty list.')
1605    fetches = []
1606    for fetch in tensor_fetches:
1607      if isinstance(fetch, ops.Tensor):
1608        fetches.append(fetch)
1609      else:
1610        raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
1611    return fetches
1612
1613  def _process_op_fetches(self, op_fetches):
1614    """Check that op_fetches have valid ops."""
1615    if op_fetches is None:
1616      return []
1617
1618    if not isinstance(op_fetches, (list, tuple)):
1619      op_fetches = [op_fetches]
1620
1621    fetches = []
1622    for fetch in op_fetches:
1623      if isinstance(fetch, ops.Operation):
1624        fetches.append(fetch)
1625      elif isinstance(fetch, ops.Tensor):
1626        fetches.append(fetch.op)
1627      else:
1628        logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
1629                        fetch)
1630    return fetches
1631
1632  def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
1633    """Changes current_fetches' format, so that it matches input_fetches."""
1634    if isinstance(input_fetches, ops.Tensor):
1635      if len(current_fetches) != 1:
1636        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1637      return current_fetches[0]
1638    else:
1639      if len(current_fetches) != len(current_fetches):
1640        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1641      elif isinstance(input_fetches, tuple):
1642        return tuple(current_fetches)
1643      else:
1644        return current_fetches
1645
1646  def _get_op_control_flow_context(self, op):
1647    """Returns the control flow of the given op.
1648
1649    Args:
1650      op: tf.Operation for which the control flow context is requested.
1651    Returns:
1652      op_control_flow_context: which the is control flow context of the given
1653      op. If the operation type is LoopExit, returns the outer control flow
1654      context.
1655    """
1656    # pylint: disable=protected-access
1657    op_control_flow_context = op._control_flow_context
1658    # pylint: enable=protected-access
1659    if control_flow_util.IsLoopExit(op):
1660      op_control_flow_context = op_control_flow_context.outer_context
1661    return op_control_flow_context
1662
1663  def merge_caches_on_tpu(self, local_tpu_cache_tensor):
1664    """Merges the given caches on tpu.
1665
1666    Args:
1667      local_tpu_cache_tensor: A local tensor that needs to be merged
1668        by concanting data from other tpu cores.
1669    Returns:
1670      A merged tf.Tensor.
1671    """
1672    x = array_ops.broadcast_to(
1673        local_tpu_cache_tensor,
1674        shape=[self._tt_config.num_replicas] +
1675        local_tpu_cache_tensor.shape.as_list())
1676    return tpu_ops.all_to_all(
1677        x, concat_dimension=0, split_dimension=0,
1678        split_count=self._tt_config.num_replicas,
1679        group_assignment=[list(range(self._tt_config.num_replicas))])
1680
1681  def aggregate_global_cache(self, global_tt_summary_cache):
1682    """Merges the given caches on tpu.
1683
1684    Args:
1685      global_tt_summary_cache: The global tensor tracer summary cache tensor
1686        with shape (num_cores, num_traced_tensors, num_traced_signatures). First
1687        dimension corresponds to core_id, where global_tpu_cache_tensor[i]
1688        correspond to the local cache from core-i.
1689    Returns:
1690      An aggregated tf.Tensor.
1691    Raises:
1692      RuntimeError: if there is no aggregate function defined for a signature.
1693    """
1694
1695    # Merge only statistics tensor, if it is any other tensor we simply,
1696    # concatenate them.
1697    agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
1698    signature_idx_map = self._signature_types()
1699    aggregation_result = []
1700    for signature, idx in sorted(signature_idx_map.items(),
1701                                 key=operator.itemgetter(1)):
1702      if signature not in agg_fn_map:
1703        raise RuntimeError('No aggregation function is defined for '
1704                           'signature %s.' % signature)
1705      # The dimensions of the statistics tensor is
1706      # num_cores x num_traced_tensors x num_signatures
1707      # value[:,:,idx] will return the portion of the tensor related
1708      # to signature.
1709      signature_tensor = global_tt_summary_cache[:, :, idx]
1710      # Merge it along the first (core) axis.
1711      agg_fn = agg_fn_map[signature]
1712      agg_tensor = agg_fn(signature_tensor, axis=0)
1713      aggregation_result.append(agg_tensor)
1714    # Merge results corresponding to different signatures
1715
1716    merged_signatures = array_ops.stack(aggregation_result)
1717    # merged_signatures has dimensions
1718    # num_signatures x num_traced_tensors, transpose it so that it
1719    # will match with the original structure
1720    # num_traced_tensors x num_signatures.
1721    transposed_signatures = array_ops.transpose(merged_signatures)
1722    # Expand 1 more dimension so that it will match with the expected
1723    # structure num_cores x num_traced_tensors x num_signatures.
1724    return array_ops.expand_dims(transposed_signatures, axis=0)
1725
1726  def _prepare_host_call_fn(self, processed_t_fetches,
1727                            op_fetches, graph, graph_summary_tag):
1728    """Creates a host call function that will write the cache as tb summary.
1729
1730    Args:
1731      processed_t_fetches: List of tensor provided to session.run.
1732      op_fetches: List of operations provided to session.run.
1733      graph: TensorFlow graph.
1734      graph_summary_tag: the summary_tag name for the given graph.
1735    Raises:
1736      ValueError if trace_dir is not set.
1737    """
1738    if self._parameters.trace_dir is None:
1739      raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
1740                       '--trace_dir=/model/dir')
1741
1742    def _write_cache(step, event_file_suffix=None, **kwargs):
1743      """Writes the given caches as tensor summary.
1744
1745      Args:
1746        step: Step tensor with dimension [num_cores].
1747        event_file_suffix: Event filename suffix tensor.
1748        **kwargs: The dictionary of tensors that needs to be written as
1749          summaries. Key and value pairs within kwargs correspond to the tag
1750          name, and tensor content that will be written using summary.write.
1751          The trace_modes that use this function are:
1752            - summary: In summary mode, kwargs includes a single (tag, content)
1753            pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
1754            variable. The dimension of the signature_cache is:
1755              num_cores x num_traced_tensors x num_signatures.
1756            - full_tensor_summary: kwargs will include all traced tensors. Tag
1757            and content correspond to the name of the tensor, and its actual
1758            content.
1759      Returns:
1760        A tf.Operation that needs to be executed for the host call dependencies.
1761      """
1762      file_suffix = _TT_EVENT_FILE_SUFFIX
1763      if event_file_suffix is not None:
1764        file_suffix = string_ops.string_join([file_suffix, event_file_suffix],
1765                                             separator='.')
1766      # TODO(deveci): Parametrize max_queue, so that flushing op can be called
1767      # less frequently.
1768      # Setting max_queue to 100 appears to be safe even when the number of
1769      # iterations are much lower, as the destructor of the writer flushes it.
1770      summary_write_ops = []
1771      summary_writer = summary.create_file_writer_v2(
1772          self._parameters.trace_dir,
1773          filename_suffix=file_suffix,
1774          max_queue=_TT_SUMMARY_MAX_QUEUE)
1775      graph.add_to_collection(
1776          TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
1777
1778      step_value = step[0]
1779      dt = step_value.dtype
1780
1781      # The step parameter to a summary write call must be 64-bit.
1782      if dt.__ne__(dtypes.int64) and dt.__ne__(
1783          dtypes.uint64) and dt.__ne__(dtypes.float64):
1784        step_value = math_ops.cast(step_value, dtypes.int64)
1785
1786      with summary_writer.as_default():
1787        summary_metadata = summary_pb2.SummaryMetadata(
1788            plugin_data=summary_pb2.SummaryMetadata.PluginData(
1789                plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
1790        for key, value in kwargs.items():
1791          # Check whether we need to compute aggregated statistics that merge
1792          # all cores statistics.
1793          if not self._parameters.collect_summary_per_core:
1794            # Merge only statistics tensor, if it is any other tensor we simply,
1795            # concatenate them.
1796            # Also, if there is only a single core (first dim. is 0), then skip
1797            # aggregation.
1798            if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1:
1799              value = self.aggregate_global_cache(value)
1800          with ops.control_dependencies([summary_writer.init()]):
1801            summary_write_ops.append(summary.write(
1802                _TT_SUMMARY_TAG + '/' + key + '.' + graph_summary_tag,
1803                value, metadata=summary_metadata,
1804                step=step_value))
1805      return control_flow_ops.group(summary_write_ops)
1806
1807    global_step = training_util.get_or_create_global_step()
1808    step = array_ops.reshape(global_step, [1])
1809    self._host_call_fn = {}
1810
1811    host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
1812
1813    caches_to_write = {}
1814    with ops.control_dependencies(host_call_deps):
1815      all_caches = self._cache_variable_for_graph(graph)
1816      for cache_name, cache_variable in all_caches.items():
1817        # Increase the cache rank by 1, so that when host call concatenates
1818        # tensors from different replicas, we can identify them with [core_id].
1819        new_cache_shape = [1]
1820        new_cache_shape.extend(cache_variable.shape.as_list())
1821        cache = array_ops.reshape(cache_variable, new_cache_shape)
1822        caches_to_write[cache_name] = cache
1823    # Add step to parameter dictionary.
1824    caches_to_write['step'] = step
1825    # Other options without adding step to parameter dictionary are
1826    #  * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
1827    #    considers caches_to_write as a single parameter, rather than a keyword
1828    #    parameters.
1829    #  * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
1830    #    a syntax error.
1831    self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
1832
1833  def host_call_deps_and_fn(self):
1834    return self._host_call_fn
1835
1836  def get_traced_op_names(self):
1837    """Returns the set of traced op names."""
1838    return self._traced_op_names
1839
1840  def _trace_execution(self, graph,
1841                       tensor_fetches,
1842                       op_fetches=None,
1843                       on_tpu=True):
1844    """Commong tracing function for both CPU and TPUs.
1845
1846    The caller function should set device_type, num_replicas,
1847    num_replicas_per_host, num_hosts and replica_id before calling
1848    _trace_execution.
1849
1850
1851    Args:
1852      graph: the graph of Ops executed on the TPU.
1853      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1854        returned by model_fn given to session.run. Function must be provided
1855        with as least one tensor to fetch.
1856      op_fetches: A list of op fetches returned by model_fn given to
1857        session.run. op_fetches and tensor_fetches are used to determine the
1858        nodes that will be executed. Can be None.
1859      on_tpu: True if executing on TPU.
1860
1861    Returns:
1862      tensor_fetches: an exact copy of tensor_fetches that has additional
1863                      dependencies.
1864    Raises:
1865      RuntimeError: If tensor_fetches is None or empty.
1866    """
1867    def _cast_unsupported_dtypes(tensor):
1868      """Casts tensor to a supported type."""
1869
1870      if tensor.dtype.__eq__(dtypes.int64):
1871        # outside-compilation doesn't support int64 input yet.
1872        return math_ops.cast(tensor, dtypes.int32)
1873      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
1874          dtypes.float16):
1875        # Since host can't handle bf16, convert tensor to f32.
1876        return math_ops.cast(tensor, dtypes.float32)
1877      return tensor
1878
1879    trace_mode = self._parameters.trace_mode
1880    device_type = self._tt_config.device_type
1881    # pylint: disable=protected-access
1882    self._outmost_context = graph._get_control_flow_context()
1883    # pylint: enable=protected-access
1884
1885    analytics.track_usage('tensor_tracer', [trace_mode, device_type])
1886    TensorTracer.check_device_type(device_type)
1887    TensorTracer.check_trace_mode(device_type, trace_mode)
1888    # Check in_tensor_fetches, and op_fetches and convert them to lists.
1889    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
1890    op_fetches = self._process_op_fetches(op_fetches)
1891    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
1892
1893    # Filter out the operations that won't be executed.
1894    # if fetches=None, then ops_in_exec_path = set(operations)
1895    exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
1896                                                         all_fetches)
1897    graph_summary_tag = _graph_summary_tag(graph)
1898
1899    # Write report file, and determine the traced tensors.
1900    tensor_trace_order = self._determine_trace_and_create_report(
1901        graph, exec_op_set, graph_summary_tag)
1902
1903    tensor_fetch_set = set(processed_t_fetches)
1904    tracing_ops = []
1905
1906    sorted_exec_op_list = list(exec_op_set)
1907    sorted_exec_op_list.sort(key=lambda op: op.name)
1908    # Trace ops only if they are in the execution path.
1909    for op in sorted_exec_op_list:
1910      for i in range(len(op.outputs)):
1911        out_tensor = op.outputs[i]
1912        tensor_name = out_tensor.name
1913        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
1914          continue
1915        self._traced_op_names.add(op.name)
1916        # Create the list of consumers before calling _preprocess_traced_tensor.
1917        # Otherwise, adding control input below, will introduce a cycle in the
1918        # graph.
1919        consumers = out_tensor.consumers()
1920        # Not all consumers may be in the exec path. Filter out the consumers
1921        # to keep the graph simpler.
1922        consumers = [cop for cop in consumers if cop in exec_op_set]
1923
1924        # If there is no consumer of the tensor, there is no need to trace it;
1925        # unless the tensor itself is one of the fetches.
1926        is_a_fetched_tensor = out_tensor in tensor_fetch_set
1927        if (not consumers) and (not is_a_fetched_tensor):
1928          continue
1929
1930        op_control_flow_context = self._get_op_control_flow_context(op)
1931        if op_control_flow_context:
1932          # pylint: disable=protected-access
1933          graph._set_control_flow_context(op_control_flow_context)
1934          # pylint: enable=protected-access
1935
1936        processed_tensors = self._preprocess_traced_tensor(out_tensor)
1937
1938        if on_tpu:
1939          for signature in processed_tensors.keys():
1940            processed_tensors[signature] = _cast_unsupported_dtypes(
1941                processed_tensors[signature])
1942
1943        if self._use_tensor_values_cache():
1944          # Use a small cache (either temp cache or tf local variable) to store
1945          # the characteristics of the tensor.
1946          if self._use_temp_cache():
1947            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1948            self._save_tensor_value_to_tmp_cache(cache_idx,
1949                                                 processed_tensors,
1950                                                 graph)
1951            trace_op = None
1952          else:
1953            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1954            trace_op = self._save_tensor_value_to_cache_op(cache_idx,
1955                                                           processed_tensors,
1956                                                           graph)
1957        elif self._use_tensor_buffer():
1958          if len(processed_tensors) != 1:
1959            raise RuntimeError('Multiple stats are only allowed in compact '
1960                               'mode.')
1961          processed_out_tensor = list(processed_tensors.values())[0]
1962          # Store the whole tensor in a buffer.
1963          trace_op = self._snapshot_tensor(processed_out_tensor)
1964        else:
1965
1966          def tpu_wrap_trace_fn(tensor, out_tensor_name):
1967            """Wraps the trace_fn with outside compilation if on TPUs."""
1968            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
1969                                                          tensor_trace_order)
1970            if on_tpu:
1971              return tpu.outside_compilation(tensor_trace_fn, tensor)
1972            else:
1973              return tensor_trace_fn(tensor)
1974
1975          if len(processed_tensors) != 1:
1976            raise RuntimeError('Multiple stats are only allowed in compact '
1977                               'mode.')
1978          # Collecting multiple statistics are only supported in the summary
1979          # mode that uses compact format(self._use_tensor_values_cache = true).
1980          # Non-compact mode currently allows single stat per tensor.
1981          processed_out_tensor = next(iter(processed_tensors.values()))
1982          trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
1983
1984        if op_control_flow_context:
1985          # pylint: disable=protected-access
1986          graph._set_control_flow_context(self._outmost_context)
1987          # pylint: enable=protected-access
1988        if trace_op:
1989          if is_a_fetched_tensor:
1990            tracing_ops.append(trace_op)
1991            continue
1992          # Add it to all consumers, as some consumers may not be executed if
1993          # they are in a control flow.
1994          for consumer_op in consumers:
1995            # pylint: disable=protected-access
1996            consumer_op._add_control_input(trace_op)
1997            # pylint: enable=protected-access
1998
1999    # pylint: disable=protected-access
2000    graph._set_control_flow_context(self._outmost_context)
2001    # pylint: enable=protected-access
2002    if tracing_ops:
2003      # If we are tracing a fetched tensor, their dependency is stored in
2004      # tracing_ops.
2005      processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
2006                                                   control_inputs=tracing_ops)
2007    if self._use_tensor_values_cache() or self._use_tensor_buffer():
2008      if self._use_temp_cache():
2009        # Create the temporary tf cache variable by concantanating all
2010        # statistics.
2011        graph_cache_var = self._cache_variable_for_graph(graph)
2012        if graph not in self._temp_cache_var:
2013          raise RuntimeError('graph is not in self._temp_cache_var')
2014        graph_cache_var[_TT_SUMMARY_TAG] = array_ops.stack(
2015            self._temp_cache_var[graph], axis=0, name='stack_all_op_signatures')
2016      if self._create_host_call():
2017        self._prepare_host_call_fn(processed_t_fetches, op_fetches, graph,
2018                                   graph_summary_tag)
2019        if not on_tpu:
2020          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
2021          cache_write_op = write_cache(**caches_to_write)
2022          processed_t_fetches = control_flow_ops.tuple(
2023              processed_t_fetches, control_inputs=[cache_write_op])
2024          del self._host_call_fn[_TT_HOSTCALL_KEY]
2025        elif self._parameters.flush_summaries_with_outside_compile:
2026          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
2027          if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
2028            step = caches_to_write['step']
2029            tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
2030            tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
2031            if not self._parameters.collect_summary_per_core:
2032              tt_core_summary = self.aggregate_global_cache(tt_core_summary)
2033
2034            def write_if_core_0(step, replica_id, tt_summary):
2035
2036              return control_flow_ops.cond(
2037                  math_ops.equal(replica_id, 0),
2038                  lambda: write_cache(step=step, event_file_suffix=None,  # pylint: disable=g-long-lambda
2039                                      tensor_tracer_summary=tt_summary),
2040                  control_flow_ops.no_op)
2041
2042            write_op = tpu.outside_compilation(write_if_core_0, step=step,
2043                                               replica_id=self._replica_id,
2044                                               tt_summary=tt_core_summary)
2045            processed_t_fetches = control_flow_ops.tuple(
2046                processed_t_fetches, control_inputs=[write_op])
2047            del self._host_call_fn[_TT_HOSTCALL_KEY]
2048          else:
2049            raise ValueError('Outside compiled flush in only supported for '
2050                             'summary mode')
2051      else:
2052        processed_t_fetches = self._flush_tensor_values_cache(
2053            processed_t_fetches, op_fetches, on_tpu=on_tpu,
2054            tensor_trace_order=tensor_trace_order,
2055            graph=graph)
2056
2057    # processed_t_fetches is a list at this point. Convert it to the same
2058    # format as given in tensor_fetches.
2059    return self._convert_fetches_to_input_format(tensor_fetches,
2060                                                 processed_t_fetches)
2061
2062  def trace_tpu(self, graph,
2063                tensor_fetches,
2064                op_fetches=None,
2065                num_replicas=None,
2066                num_replicas_per_host=None,
2067                num_hosts=None):
2068    """Traces the tensors generated by TPU Ops in a TF graph.
2069
2070    Args:
2071      graph: the graph of Ops executed on the TPU.
2072      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
2073        returned by model_fn given to session.run. Function must be provided
2074        with as least one tensor to fetch.
2075      op_fetches: A list of op fetches returned by model_fn given to
2076        session.run. op_fetches and tensor_fetches are used to determine the
2077        nodes that will be executed. Can be None.
2078      num_replicas: number of replicas used on the TPU.
2079      num_replicas_per_host: number of replicas per TPU host.
2080      num_hosts: total number of TPU hosts.
2081
2082    Returns:
2083      tensor_fetches: an exact copy of tensor_fetches that has additional
2084                      dependencies.
2085    """
2086    if isinstance(graph, func_graph.FuncGraph) or isinstance(
2087        graph, function._FuncGraph):  # pylint: disable=protected-access
2088      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
2089                      'Ignoring tracing.')
2090      return tensor_fetches
2091
2092    if graph in TensorTracer._traced_graphs:
2093      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
2094                      'multiple calls.')
2095      return tensor_fetches
2096    else:
2097      TensorTracer._traced_graphs.add(graph)
2098    # Reset the parameters in case parameters are changed.
2099    self._parameters = tensor_tracer_flags.TTParameters()
2100    self._tt_config.device_type = _DEVICE_TYPE_TPU
2101    self._tt_config.num_replicas = num_replicas
2102    self._tt_config.num_replicas_per_host = num_replicas_per_host
2103    self._tt_config.num_hosts = num_hosts
2104    if self._tt_config.num_replicas is not None:
2105      if self._tt_config.num_replicas_per_host is None:
2106        self._tt_config.num_replicas_per_host = 8
2107      if self._tt_config.num_hosts is None:
2108        self._tt_config.num_hosts = (
2109            num_replicas // self._tt_config.num_replicas_per_host +
2110            (num_replicas % self._tt_config.num_replicas_per_host > 0))
2111
2112    if self._parameters.graph_dump_path:
2113      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2114                           'graph_before_tt.pbtxt')
2115    with graph.as_default():
2116      self._add_replica_id_to_graph()
2117      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
2118                                             on_tpu=True)
2119    if self._parameters.graph_dump_path:
2120      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2121                           'graph_after_tt.pbtxt')
2122    return tensor_fetches
2123
2124  def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
2125    """Traces the tensors generated by CPU Ops in a TF graph.
2126
2127    Args:
2128      graph: the graph of Ops executed on the CPU.
2129      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
2130        returned by model_fn given to session.run. Function must be provided
2131        with as least one tensor to fetch.
2132      op_fetches: A list of op fetches returned by model_fn given to
2133        session.run. op_fetches and tensor_fetches are used to determine the
2134        nodes that will be executed. Can be None.
2135
2136    Returns:
2137      tensor_fetches: an exact copy of tensor_fetches that has additional
2138                      dependencies.
2139    """
2140    if isinstance(graph, func_graph.FuncGraph) or isinstance(
2141        graph, function._FuncGraph):  # pylint: disable=protected-access
2142      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
2143                      'Ignoring tracing.')
2144      return tensor_fetches
2145
2146    if graph in TensorTracer._traced_graphs:
2147      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
2148                      'multiple calls.')
2149      return tensor_fetches
2150    else:
2151      TensorTracer._traced_graphs.add(graph)
2152    # Reset the parameters in case parameters are changed.
2153    self._parameters = tensor_tracer_flags.TTParameters()
2154
2155    self._tt_config.device_type = _DEVICE_TYPE_CPU
2156    self._tt_config.num_replicas = 1
2157    self._tt_config.num_replicas_per_host = 1
2158    self._tt_config.num_hosts = 1
2159    self._replica_id = 0
2160    if self._parameters.graph_dump_path:
2161      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2162                           'graph_before_tt.pbtxt')
2163    with graph.as_default():
2164      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
2165                                             on_tpu=False)
2166    if self._parameters.graph_dump_path:
2167      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2168                           'graph_after_tt.pbtxt')
2169    return tensor_fetches
2170