xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tensor_tracer_report.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Tensor Tracer report generation utilities."""
16
17import collections
18import hashlib
19import os
20
21from tensorflow.python.platform import gfile
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.tpu import tensor_tracer_pb2
24
25_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
26_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
27_MARKER_SECTION_END = '!!!!!!! section-end:'
28
29_SECTION_NAME_CONFIG = 'configuration'
30_SECTION_NAME_REASON = 'reason'
31_SECTION_NAME_OP_LIST = 'op-list'
32_SECTION_NAME_TENSOR_LIST = 'tensor-list'
33_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
34_SECTION_NAME_GRAPH = 'graph'
35_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
36
37_FIELD_NAME_VERSION = 'version:'
38_FIELD_NAME_DEVICE = 'device:'
39_FIELD_NAME_TRACE_MODE = 'trace-mode:'
40_FIELD_NAME_SUBMODE = 'submode:'
41_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
42_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
43_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
44_FIELD_NAME_NUM_OPS = 'number-of-ops:'
45_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
46_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
47_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
48
49_CURRENT_VERSION = 'use-outside-compilation'
50
51_TT_REPORT_PROTO = 'tensor_tracer_report.report_pb'
52
53
54def topological_sort(g):
55  """Performs topological sort on the given graph.
56
57  Args:
58     g: the graph.
59
60  Returns:
61     A pair where the first element indicates if the topological
62     sort succeeded (True if there is no cycle found; False if a
63     cycle is found) and the second element is either the sorted
64     list of nodes or the cycle of nodes found.
65  """
66  def _is_loop_edge(op):
67    """Returns true if the op is the end of a while-loop creating a cycle."""
68    return op.type in ['NextIteration']
69
70  def _in_op_degree(op):
71    """Returns the number of incoming edges to the given op.
72
73    The edge calculation skips the edges that come from 'NextIteration' ops.
74    NextIteration creates a cycle in the graph. We break cycles by treating
75    this op as 'sink' and ignoring all outgoing edges from it.
76    Args:
77      op: Tf.Operation
78    Returns:
79      the number of incoming edges.
80    """
81    count = 0
82    for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
83      if not _is_loop_edge(op):
84        count += 1
85    return count
86
87  sorted_ops = []
88  op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
89
90  frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
91  frontier.sort(key=lambda op: op.name)
92  while frontier:
93    op = frontier.pop()
94    # Remove the op from graph, and remove its outgoing edges.
95    sorted_ops.append(op)
96    if _is_loop_edge(op):
97      continue
98    # pylint: disable=protected-access
99    consumers = list(op._control_outputs)
100    # pylint: enable=protected-access
101    for out_tensor in op.outputs:
102      consumers += [consumer_op for consumer_op in out_tensor.consumers()]
103    consumers.sort(key=lambda op: op.name)
104    for consumer in consumers:
105      # For each deleted edge shift the bucket of the vertex.
106      op_in_degree[consumer] -= 1
107      if op_in_degree[consumer] == 0:
108        frontier.append(consumer)
109      if op_in_degree[consumer] < 0:
110        raise ValueError('consumer:%s degree mismatch'%consumer.name)
111
112  left_ops = set(op for (op, degree) in op_in_degree.items() if degree > 0)
113  if left_ops:
114    return (True, left_ops)
115  else:
116    assert len(g.get_operations()) == len(sorted_ops)
117    return (False, sorted_ops)
118
119
120class TensorTracerConfig(object):
121  """Tensor Tracer config object."""
122
123  def __init__(self):
124    self.version = _CURRENT_VERSION
125    self.device_type = None
126    self.num_replicas = None
127    self.num_replicas_per_host = None
128    self.num_hosts = None
129
130
131class TensorTraceOrder(object):
132  """Class that is responsible from storing the trace-id of the tensors."""
133
134  def __init__(self, graph_order, traced_tensors):
135    self.graph_order = graph_order
136    self.traced_tensors = traced_tensors
137    self._create_tensor_maps()
138
139  def _create_tensor_maps(self):
140    """Creates tensor to cache id maps."""
141    self.tensorname_to_cache_idx = {}
142    self.cache_idx_to_tensor_idx = []
143    for out_tensor in self.traced_tensors:
144      tensor_name = out_tensor.name
145      if tensor_name in self.tensorname_to_cache_idx:
146        raise ValueError('Tensor name {} should not be already in '
147                         'tensorname_to_cache_idx'.format(tensor_name))
148      if tensor_name not in self.graph_order.tensor_to_idx:
149        raise ValueError(
150            'Tensor name {} is not in the tensor_to_idx, tensor_to_idx={} '
151            .format(tensor_name, self.graph_order.tensor_to_idx))
152      tensor_idx = self.graph_order.tensor_to_idx[tensor_name]
153      cache_idx = len(self.tensorname_to_cache_idx)
154      self.tensorname_to_cache_idx[tensor_name] = cache_idx
155      self.cache_idx_to_tensor_idx.append(tensor_idx)
156      if len(self.tensorname_to_cache_idx) != len(
157          self.cache_idx_to_tensor_idx):
158        raise RuntimeError(
159            'len(self.tensorname_to_cache_idx) must equal'
160            'len(self.cache_idx_to_tensor_idx), got '
161            'len(self.tensorname_to_cache_idx)={}, '
162            'len(self.cache_idx_to_tensor_idx)={}'
163            .format(
164                len(self.tensorname_to_cache_idx),
165                len(self.cache_idx_to_tensor_idx)))
166
167
168def sort_tensors_and_ops(graph):
169  """Returns a wrapper that has consistent tensor and op orders."""
170  graph_wrapper = collections.namedtuple('GraphWrapper',
171                                         ['graph', 'operations', 'op_to_idx',
172                                          'tensors', 'tensor_to_idx',
173                                          'contains_cycle',
174                                          'topological_order_or_cycle'])
175  contains_cycle, topological_order_or_cycle = topological_sort(graph)
176  if not contains_cycle:
177    operations = topological_order_or_cycle
178  else:
179    operations = graph.get_operations()
180  op_to_idx = {op.name: index for index, op
181               in enumerate(operations)}
182  tensors = []
183  for op in operations:
184    tensors.extend(op.outputs)
185  tensor_to_idx = {tensor.name: index for index, tensor in
186                   enumerate(tensors)}
187  return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx,
188                       tensors=tensors, tensor_to_idx=tensor_to_idx,
189                       contains_cycle=contains_cycle,
190                       topological_order_or_cycle=topological_order_or_cycle)
191
192
193class OpenReportFile(object):
194  """Context manager for writing report file."""
195
196  def __init__(self, tt_parameters):
197    if not tt_parameters.report_file_path:
198      self._report_file = None
199      return
200    try:
201      self._report_file = gfile.Open(tt_parameters.report_file_path, 'w')
202    except IOError as e:
203      raise e
204
205  def __enter__(self):
206    return self._report_file
207
208  def __exit__(self, unused_type, unused_value, unused_traceback):
209    if self._report_file:
210      self._report_file.close()
211
212
213def proto_fingerprint(message_proto):
214  serialized_message = message_proto.SerializeToString()
215  hasher = hashlib.sha256(serialized_message)
216  return hasher.hexdigest()
217
218
219class TTReportHandle(object):
220  """Utility class responsible from creating a tensor tracer report."""
221
222  def __init__(self):
223    self.instrument_records = {}
224    self._report_file = None
225
226  def instrument(self, name, explanation):
227    self.instrument_records[name] = explanation
228
229  def instrument_op(self, op, explanation):
230    self.instrument(op.name, explanation)
231
232  def instrument_tensor(self, tensor, explanation):
233    self.instrument(tensor.name, explanation)
234
235  def create_report_proto(self, tt_config, tt_parameters, tensor_trace_order,
236                          tensor_trace_points, collected_signature_types):
237    """Creates and returns a proto that stores tensor tracer configuration.
238
239    Args:
240      tt_config: TensorTracerConfig object holding information about the run
241        environment (device, # cores, # hosts), and tensor tracer version
242        information.
243      tt_parameters: TTParameters objects storing the user provided parameters
244        for tensor tracer.
245      tensor_trace_order: TensorTraceOrder object storing a topological order of
246        the graph.
247      tensor_trace_points: Progromatically added trace_points/checkpoints.
248      collected_signature_types: The signature types collected, e,g, norm,
249        max, min, mean...
250    Returns:
251      TensorTracerReport proto.
252    """
253    report = tensor_tracer_pb2.TensorTracerReport()
254    report.config.version = tt_config.version
255    report.config.device = tt_config.device_type
256    report.config.num_cores = tt_config.num_replicas
257    report.config.num_hosts = tt_config.num_hosts
258    report.config.num_cores_per_host = tt_config.num_replicas_per_host
259    report.config.submode = tt_parameters.submode
260    report.config.trace_mode = tt_parameters.trace_mode
261
262    for signature_name, _ in sorted(collected_signature_types.items(),
263                                    key=lambda x: x[1]):
264      report.config.signatures.append(signature_name)
265
266    for tensor in tensor_trace_order.graph_order.tensors:
267      tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef()
268      tensor_def.name = tensor.name
269      if tensor.name in tensor_trace_order.tensorname_to_cache_idx:
270        tensor_def.is_traced = True
271        tensor_def.cache_index = (
272            tensor_trace_order.tensorname_to_cache_idx[tensor.name])
273      else:
274        # To prevent small changes affecting the fingerprint calculation, avoid
275        # writing the untraced tensors to metadata. Fingerprints will be
276        # different only when the list of the traced tensors are different.
277        if tt_parameters.use_fingerprint_subdir:
278          continue
279        tensor_def.is_traced = False
280
281      if tensor.name in tensor_trace_points:
282        tensor_def.trace_point_name = tensor_trace_points[tensor.name]
283      if tensor.name in self.instrument_records:
284        tensor_def.explanation = self.instrument_records[tensor.name]
285      elif tensor.op.name in self.instrument_records:
286        tensor_def.explanation = self.instrument_records[tensor.op.name]
287      report.tensordef[tensor.name].CopyFrom(tensor_def)
288    report.fingerprint = proto_fingerprint(report)
289    logging.info('TensorTracerProto fingerprint is %s.',
290                 report.fingerprint)
291    tf_graph = tensor_trace_order.graph_order.graph
292    report.graphdef.CopyFrom(tf_graph.as_graph_def())
293    return report
294
295  def report_proto_path(self, trace_dir, summary_tag_name):
296    """Returns the path where report proto should be written.
297
298    Args:
299      trace_dir: String denoting the trace directory.
300      summary_tag_name: Name of the unique tag that relates to
301                        the report.
302    Returns:
303      A string denoting the path to the report proto.
304    """
305    filename = _TT_REPORT_PROTO + '.' + summary_tag_name.replace('/', '_')
306    return os.path.join(trace_dir, filename)
307
308  def write_report_proto(self, report_path, report_proto, tt_parameters):
309    """Writes the given report proto under trace_dir."""
310    gfile.MakeDirs(tt_parameters.trace_dir)
311    with gfile.GFile(report_path, 'wb') as f:
312      f.write(report_proto.SerializeToString())
313
314  def create_report(self, tt_config, tt_parameters,
315                    tensor_trace_order, tensor_trace_points):
316    """Creates a report file and writes the trace information."""
317    with OpenReportFile(tt_parameters) as self._report_file:
318      self._write_config_section(tt_config, tt_parameters)
319      self._write_op_list_section(tensor_trace_order.graph_order)
320      self._write_tensor_list_section(tensor_trace_order.graph_order)
321      self._write_trace_points(tensor_trace_points)
322      self._write_cache_index_map_section(tensor_trace_order)
323      self._write_reason_section()
324      self._write_graph_section(tensor_trace_order.graph_order)
325
326  def _write_trace_points(self, tensor_trace_points):
327    """Writes the list of checkpoints."""
328    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
329                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
330    for (tensor, checkpoint_name) in tensor_trace_points:
331      self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
332    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
333                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
334
335  def _write_report(self, content):
336    """Writes the given content to the report."""
337
338    line = '%s %s'%(_TRACER_LOG_PREFIX, content)
339    if self._report_file:
340      self._report_file.write(line)
341    else:
342      logging.info(line)
343
344  def _write_config_section(self, tt_config, tt_parameters):
345    """Writes the config section of the report."""
346
347    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
348    self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version))
349    self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type))
350    self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
351                                  tt_parameters.trace_mode))
352    self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
353                                  tt_parameters.submode))
354    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
355                                  tt_config.num_replicas))
356    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
357                                  tt_config.num_replicas_per_host))
358    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts))
359    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
360
361  def _write_reason_section(self):
362    """Writes the reason section of the report."""
363
364    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
365    for key in sorted(self.instrument_records):
366      self._write_report('"%s" %s\n'%(key, self.instrument_records[key]))
367    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
368
369  def _write_op_list_section(self, graph_order):
370    """Writes the Op-list section of the report."""
371
372    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
373    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS,
374                                  len(graph_order.operations)))
375    for i in range(0, len(graph_order.operations)):
376      op = graph_order.operations[i]
377      line = '%d "%s" %s'%(i, op.name, op.type)
378      for out_tensor in op.outputs:
379        if out_tensor.name not in graph_order.tensor_to_idx:
380          raise ValueError(
381              'out_tensor is not in tensor_to_idx. out_tensor={}, '
382              'tensor_to_idx={}'
383              .format(out_tensor.name, graph_order.tensor_to_idx))
384        line += ' %d'%graph_order.tensor_to_idx[out_tensor.name]
385      line += '\n'
386      self._write_report(line)
387    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
388
389  def _write_tensor_list_section(self, graph_order):
390    """Writes the tensor-list section of the report."""
391
392    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
393                                  _SECTION_NAME_TENSOR_LIST))
394    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS,
395                                  len(graph_order.tensors)))
396    for i in range(0, len(graph_order.tensors)):
397      tensor = graph_order.tensors[i]
398      line = '%d "%s"'%(i, tensor.name)
399      consumers = tensor.consumers()
400      consumers.sort(key=lambda op: op.name)
401      for consumer_op in consumers:
402        if consumer_op.name not in graph_order.op_to_idx:
403          raise ValueError(
404              'consumer_op is not in op_to_idx.  '
405              'got consumer_op={}, op_to_idx={}'
406              .format(consumer_op.name, graph_order.op_to_idx))
407        line += ' %d'%graph_order.op_to_idx[consumer_op.name]
408      line += '\n'
409      self._write_report(line)
410    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
411                                  _SECTION_NAME_TENSOR_LIST))
412
413  def _write_cache_index_map_section(self, tensor_trace_order):
414    """Writes the mapping from cache index to tensor index to the report."""
415    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
416                                  _SECTION_NAME_CACHE_INDEX_MAP))
417    self._write_report('%s %d\n'%(
418        _FIELD_NAME_NUM_CACHE_INDICES,
419        len(tensor_trace_order.cache_idx_to_tensor_idx)))
420    for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)):
421      tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx]
422      line = '%d %d\n'%(cache_idx, tensor_idx)
423      self._write_report(line)
424    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
425                                  _SECTION_NAME_CACHE_INDEX_MAP))
426
427  def _write_graph_section(self, graph_order):
428    """Writes the graph section of the report."""
429
430    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
431    self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
432                                  not graph_order.contains_cycle))
433    l = list(graph_order.topological_order_or_cycle)
434    for i in range(0, len(l)):
435      self._write_report('%d "%s"\n'%(i, l[i].name))
436    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
437