xref: /aosp_15_r20/external/tensorflow/tensorflow/python/client/timeline.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Timeline visualization for TensorFlow using Chrome Trace Format."""
16
17import collections
18import copy
19import json
20import re
21
22# The timeline target is usually imported as part of BUILD target
23# "platform_test", which includes also includes the "platform"
24# dependency.  This is why the logging import here is okay.
25from tensorflow.python.platform import build_info
26from tensorflow.python.platform import tf_logging as logging
27
28
29class AllocationMaximum(collections.namedtuple(
30    'AllocationMaximum', ('timestamp', 'num_bytes', 'tensors'))):
31  """Stores the maximum allocation for a given allocator within the timelne.
32
33  Parameters:
34    timestamp: `tensorflow::Env::NowMicros()` when this maximum was reached.
35    num_bytes: the total memory used at this time.
36    tensors: the set of tensors allocated at this time.
37  """
38  pass
39
40
41class StepStatsAnalysis(collections.namedtuple(
42    'StepStatsAnalysis', ('chrome_trace', 'allocator_maximums'))):
43  """Stores the step stats analysis output.
44
45  Parameters:
46    chrome_trace: A dict containing the chrome trace analysis.
47    allocator_maximums: A dict mapping allocator names to AllocationMaximum.
48  """
49  pass
50
51
52class _ChromeTraceFormatter(object):
53  """A helper class for generating traces in Chrome Trace Format."""
54
55  def __init__(self, show_memory=False):
56    """Constructs a new Chrome Trace formatter."""
57    self._show_memory = show_memory
58    self._events = []
59    self._metadata = []
60
61  def _create_event(self, ph, category, name, pid, tid, timestamp):
62    """Creates a new Chrome Trace event.
63
64    For details of the file format, see:
65    https://github.com/catapult-project/catapult/blob/master/tracing/README.md
66
67    Args:
68      ph:  The type of event - usually a single character.
69      category: The event category as a string.
70      name:  The event name as a string.
71      pid:  Identifier of the process generating this event as an integer.
72      tid:  Identifier of the thread generating this event as an integer.
73      timestamp:  The timestamp of this event as a long integer.
74
75    Returns:
76      A JSON compatible event object.
77    """
78    event = {}
79    event['ph'] = ph
80    event['cat'] = category
81    event['name'] = name
82    event['pid'] = pid
83    event['tid'] = tid
84    event['ts'] = timestamp
85    return event
86
87  def emit_pid(self, name, pid):
88    """Adds a process metadata event to the trace.
89
90    Args:
91      name:  The process name as a string.
92      pid:  Identifier of the process as an integer.
93    """
94    event = {}
95    event['name'] = 'process_name'
96    event['ph'] = 'M'
97    event['pid'] = pid
98    event['args'] = {'name': name}
99    self._metadata.append(event)
100
101  def emit_tid(self, name, pid, tid):
102    """Adds a thread metadata event to the trace.
103
104    Args:
105      name:  The thread name as a string.
106      pid:  Identifier of the process as an integer.
107      tid:  Identifier of the thread as an integer.
108    """
109    event = {}
110    event['name'] = 'thread_name'
111    event['ph'] = 'M'
112    event['pid'] = pid
113    event['tid'] = tid
114    event['args'] = {'name': name}
115    self._metadata.append(event)
116
117  def emit_region(self, timestamp, duration, pid, tid, category, name, args):
118    """Adds a region event to the trace.
119
120    Args:
121      timestamp:  The start timestamp of this region as a long integer.
122      duration:  The duration of this region as a long integer.
123      pid:  Identifier of the process generating this event as an integer.
124      tid:  Identifier of the thread generating this event as an integer.
125      category: The event category as a string.
126      name:  The event name as a string.
127      args:  A JSON-compatible dictionary of event arguments.
128    """
129    event = self._create_event('X', category, name, pid, tid, timestamp)
130    event['dur'] = duration
131    event['args'] = args
132    self._events.append(event)
133
134  def emit_obj_create(self, category, name, timestamp, pid, tid, object_id):
135    """Adds an object creation event to the trace.
136
137    Args:
138      category: The event category as a string.
139      name:  The event name as a string.
140      timestamp:  The timestamp of this event as a long integer.
141      pid:  Identifier of the process generating this event as an integer.
142      tid:  Identifier of the thread generating this event as an integer.
143      object_id: Identifier of the object as an integer.
144    """
145    event = self._create_event('N', category, name, pid, tid, timestamp)
146    event['id'] = object_id
147    self._events.append(event)
148
149  def emit_obj_delete(self, category, name, timestamp, pid, tid, object_id):
150    """Adds an object deletion event to the trace.
151
152    Args:
153      category: The event category as a string.
154      name:  The event name as a string.
155      timestamp:  The timestamp of this event as a long integer.
156      pid:  Identifier of the process generating this event as an integer.
157      tid:  Identifier of the thread generating this event as an integer.
158      object_id: Identifier of the object as an integer.
159    """
160    event = self._create_event('D', category, name, pid, tid, timestamp)
161    event['id'] = object_id
162    self._events.append(event)
163
164  def emit_obj_snapshot(self, category, name, timestamp, pid, tid, object_id,
165                        snapshot):
166    """Adds an object snapshot event to the trace.
167
168    Args:
169      category: The event category as a string.
170      name:  The event name as a string.
171      timestamp:  The timestamp of this event as a long integer.
172      pid:  Identifier of the process generating this event as an integer.
173      tid:  Identifier of the thread generating this event as an integer.
174      object_id: Identifier of the object as an integer.
175      snapshot:  A JSON-compatible representation of the object.
176    """
177    event = self._create_event('O', category, name, pid, tid, timestamp)
178    event['id'] = object_id
179    event['args'] = {'snapshot': snapshot}
180    self._events.append(event)
181
182  def emit_flow_start(self, name, timestamp, pid, tid, flow_id):
183    """Adds a flow start event to the trace.
184
185    When matched with a flow end event (with the same 'flow_id') this will
186    cause the trace viewer to draw an arrow between the start and end events.
187
188    Args:
189      name:  The event name as a string.
190      timestamp:  The timestamp of this event as a long integer.
191      pid:  Identifier of the process generating this event as an integer.
192      tid:  Identifier of the thread generating this event as an integer.
193      flow_id: Identifier of the flow as an integer.
194    """
195    event = self._create_event('s', 'DataFlow', name, pid, tid, timestamp)
196    event['id'] = flow_id
197    self._events.append(event)
198
199  def emit_flow_end(self, name, timestamp, pid, tid, flow_id):
200    """Adds a flow end event to the trace.
201
202    When matched with a flow start event (with the same 'flow_id') this will
203    cause the trace viewer to draw an arrow between the start and end events.
204
205    Args:
206      name:  The event name as a string.
207      timestamp:  The timestamp of this event as a long integer.
208      pid:  Identifier of the process generating this event as an integer.
209      tid:  Identifier of the thread generating this event as an integer.
210      flow_id: Identifier of the flow as an integer.
211    """
212    event = self._create_event('t', 'DataFlow', name, pid, tid, timestamp)
213    event['id'] = flow_id
214    self._events.append(event)
215
216  def emit_counter(self, category, name, pid, timestamp, counter, value):
217    """Emits a record for a single counter.
218
219    Args:
220      category: The event category as a string.
221      name:  The event name as a string.
222      pid:  Identifier of the process generating this event as an integer.
223      timestamp:  The timestamp of this event as a long integer.
224      counter: Name of the counter as a string.
225      value:  Value of the counter as an integer.
226    """
227    event = self._create_event('C', category, name, pid, 0, timestamp)
228    event['args'] = {counter: value}
229    self._events.append(event)
230
231  def emit_counters(self, category, name, pid, timestamp, counters):
232    """Emits a counter record for the dictionary 'counters'.
233
234    Args:
235      category: The event category as a string.
236      name:  The event name as a string.
237      pid:  Identifier of the process generating this event as an integer.
238      timestamp:  The timestamp of this event as a long integer.
239      counters: Dictionary of counter values.
240    """
241    event = self._create_event('C', category, name, pid, 0, timestamp)
242    event['args'] = counters.copy()
243    self._events.append(event)
244
245  def format_to_string(self, pretty=False):
246    """Formats the chrome trace to a string.
247
248    Args:
249      pretty: (Optional.)  If True, produce human-readable JSON output.
250
251    Returns:
252      A JSON-formatted string in Chrome Trace format.
253    """
254    trace = {}
255    trace['traceEvents'] = self._metadata + self._events
256    if pretty:
257      return json.dumps(trace, indent=4, separators=(',', ': '))
258    else:
259      return json.dumps(trace, separators=(',', ':'))
260
261
262class _TensorTracker(object):
263  """An internal class to track the lifetime of a Tensor."""
264
265  def __init__(self, name, object_id, timestamp, pid, allocator, num_bytes):
266    """Creates an object to track tensor references.
267
268    This class is not thread safe and is intended only for internal use by
269    the 'Timeline' class in this file.
270
271    Args:
272      name:  The name of the Tensor as a string.
273      object_id:  Chrome Trace object identifier assigned for this Tensor.
274      timestamp:  The creation timestamp of this event as a long integer.
275      pid:  Process identifier of the associated device, as an integer.
276      allocator:  Name of the allocator used to create the Tensor.
277      num_bytes:  Number of bytes allocated (long integer).
278
279    Returns:
280      A 'TensorTracker' object.
281    """
282    self._name = name
283    self._pid = pid
284    self._object_id = object_id
285    self._create_time = timestamp
286    self._allocator = allocator
287    self._num_bytes = num_bytes
288    self._ref_times = []
289    self._unref_times = []
290
291  @property
292  def name(self):
293    """Name of this tensor."""
294    return self._name
295
296  @property
297  def pid(self):
298    """ID of the process which created this tensor (an integer)."""
299    return self._pid
300
301  @property
302  def create_time(self):
303    """Timestamp when this tensor was created (long integer)."""
304    return self._create_time
305
306  @property
307  def object_id(self):
308    """Returns the object identifier of this tensor (integer)."""
309    return self._object_id
310
311  @property
312  def num_bytes(self):
313    """Size of this tensor in bytes (long integer)."""
314    return self._num_bytes
315
316  @property
317  def allocator(self):
318    """Name of the allocator used to create this tensor (string)."""
319    return self._allocator
320
321  @property
322  def last_unref(self):
323    """Last unreference timestamp of this tensor (long integer)."""
324    return max(self._unref_times)
325
326  def add_ref(self, timestamp):
327    """Adds a reference to this tensor with the specified timestamp.
328
329    Args:
330      timestamp:  Timestamp of object reference as an integer.
331    """
332    self._ref_times.append(timestamp)
333
334  def add_unref(self, timestamp):
335    """Adds an unref to this tensor with the specified timestamp.
336
337    Args:
338      timestamp:  Timestamp of object unreference as an integer.
339    """
340    self._unref_times.append(timestamp)
341
342
343class Timeline(object):
344  """A class for visualizing execution timelines of TensorFlow steps."""
345
346  def __init__(self, step_stats, graph=None):
347    """Constructs a new Timeline.
348
349    A 'Timeline' is used for visualizing the execution of a TensorFlow
350    computation.  It shows the timings and concurrency of execution at
351    the granularity of TensorFlow Ops.
352    This class is not thread safe.
353
354    Args:
355      step_stats: The 'StepStats' proto recording execution times.
356      graph: (Optional) The 'Graph' that was executed.
357    """
358
359    self._origin_step_stats = step_stats
360    self._step_stats = None
361    self._graph = graph
362    self._chrome_trace = _ChromeTraceFormatter()
363    self._next_pid = 0
364    self._device_pids = {}  # device name -> pid for compute activity.
365    self._tensor_pids = {}  # device name -> pid for tensors.
366    self._tensors = {}  # tensor_name -> TensorTracker
367    self._next_flow_id = 0
368    self._flow_starts = {}  # tensor_name -> (timestamp, pid, tid)
369    self._alloc_times = {}  # tensor_name -> ( time, allocator, size )
370    self._allocator_maximums = {}  # allocator name => maximum bytes long
371
372  def _alloc_pid(self):
373    """Allocate a process Id."""
374    pid = self._next_pid
375    self._next_pid += 1
376    return pid
377
378  def _alloc_flow_id(self):
379    """Allocate a flow Id."""
380    flow_id = self._next_flow_id
381    self._next_flow_id += 1
382    return flow_id
383
384  def _parse_op_label(self, label):
385    """Parses the fields in a node timeline label."""
386    # Expects labels of the form: name = op(arg, arg, ...).
387    match = re.match(r'(.*) = (.*)\((.*)\)', label)
388    if match is None:
389      return 'unknown', 'unknown', []
390    nn, op, inputs = match.groups()
391    if not inputs:
392      inputs = []
393    else:
394      inputs = inputs.split(', ')
395    return nn, op, inputs
396
397  def _parse_kernel_label(self, label, node_name):
398    """Parses the fields in a node timeline label."""
399    # Expects labels of the form: retval (arg) detail @@annotation
400    start = label.find('@@')
401    end = label.find('#')
402    if start >= 0 and end >= 0 and start + 2 < end:
403      node_name = label[start + 2:end]
404    # Node names should always have the form 'name:op'.
405    fields = node_name.split(':') + ['unknown']
406    name, op = fields[:2]
407    return name, op
408
409  def _assign_lanes(self):
410    """Assigns non-overlapping lanes for the activities on each device."""
411    for device_stats in self._step_stats.dev_stats:
412      # TODO(pbar): Genuine thread IDs in NodeExecStats might be helpful.
413      lanes = [0]
414      for ns in device_stats.node_stats:
415        l = -1
416        for (i, lts) in enumerate(lanes):
417          if ns.all_start_micros > lts:
418            l = i
419            lanes[l] = ns.all_start_micros + ns.all_end_rel_micros
420            break
421        if l < 0:
422          l = len(lanes)
423          lanes.append(ns.all_start_micros + ns.all_end_rel_micros)
424        ns.thread_id = l
425
426  def _emit_op(self, nodestats, pid, is_gputrace):
427    """Generates a Chrome Trace event to show Op execution.
428
429    Args:
430      nodestats: The 'NodeExecStats' proto recording op execution.
431      pid: The pid assigned for the device where this op ran.
432      is_gputrace: If True then this op came from the GPUTracer.
433    """
434    node_name = nodestats.node_name
435    start = nodestats.all_start_micros
436    duration = nodestats.all_end_rel_micros
437    tid = nodestats.thread_id
438    inputs = []
439    if is_gputrace:
440      node_name, op = self._parse_kernel_label(nodestats.timeline_label,
441                                               node_name)
442    elif node_name == 'RecvTensor':
443      # RPC tracing does not use the standard timeline_label format.
444      op = 'RecvTensor'
445    else:
446      _, op, inputs = self._parse_op_label(nodestats.timeline_label)
447    args = {'name': node_name, 'op': op}
448    if build_info.build_info['is_rocm_build']:
449      args['kernel'] = nodestats.timeline_label.split('@@')[0]
450    for i, iname in enumerate(inputs):
451      args['input%d' % i] = iname
452    self._chrome_trace.emit_region(start, duration, pid, tid, 'Op', op, args)
453
454  def _emit_tensor_snapshot(self, tensor, timestamp, pid, tid, value):
455    """Generate Chrome Trace snapshot event for a computed Tensor.
456
457    Args:
458      tensor: A 'TensorTracker' object.
459      timestamp:  The timestamp of this snapshot as a long integer.
460      pid: The pid assigned for showing the device where this op ran.
461      tid: The tid of the thread computing the tensor snapshot.
462      value: A JSON-compliant snapshot of the object.
463    """
464    desc = str(value.tensor_description).replace('"', '')
465    snapshot = {'tensor_description': desc}
466    self._chrome_trace.emit_obj_snapshot('Tensor', tensor.name, timestamp, pid,
467                                         tid, tensor.object_id, snapshot)
468
469  def _produce_tensor(self, name, timestamp, tensors_pid, allocator, num_bytes):
470    object_id = len(self._tensors)
471    tensor = _TensorTracker(name, object_id, timestamp, tensors_pid, allocator,
472                            num_bytes)
473    self._tensors[name] = tensor
474    return tensor
475
476  def _is_gputrace_device(self, device_name):
477    """Returns true if this device is part of the GPUTracer logging."""
478    return '/stream:' in device_name or '/memcpy' in device_name
479
480  def _allocate_pids(self):
481    """Allocate fake process ids for each device in the StepStats."""
482    self._allocators_pid = self._alloc_pid()
483    self._chrome_trace.emit_pid('Allocators', self._allocators_pid)
484
485    # Add processes in the Chrome trace to show compute and data activity.
486    for dev_stats in self._step_stats.dev_stats:
487      device_pid = self._alloc_pid()
488      self._device_pids[dev_stats.device] = device_pid
489      tensors_pid = self._alloc_pid()
490      self._tensor_pids[dev_stats.device] = tensors_pid
491      self._chrome_trace.emit_pid(dev_stats.device + ' Compute', device_pid)
492      self._chrome_trace.emit_pid(dev_stats.device + ' Tensors', tensors_pid)
493
494  def _analyze_tensors(self, show_memory):
495    """Analyze tensor references to track dataflow."""
496    for dev_stats in self._step_stats.dev_stats:
497      device_pid = self._device_pids[dev_stats.device]
498      tensors_pid = self._tensor_pids[dev_stats.device]
499      for node_stats in dev_stats.node_stats:
500        tid = node_stats.thread_id
501        node_name = node_stats.node_name
502        start_time = node_stats.all_start_micros
503        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
504        for index, output in enumerate(node_stats.output):
505          if index:
506            output_name = '%s:%d' % (node_name, index)
507          else:
508            output_name = node_name
509
510          allocation = output.tensor_description.allocation_description
511          num_bytes = allocation.requested_bytes
512          allocator_name = allocation.allocator_name
513          tensor = self._produce_tensor(output_name, start_time, tensors_pid,
514                                        allocator_name, num_bytes)
515          tensor.add_ref(start_time)
516          tensor.add_unref(end_time)
517          self._flow_starts[output_name] = (end_time, device_pid, tid)
518
519          if show_memory:
520            self._chrome_trace.emit_obj_create('Tensor', output_name,
521                                               start_time, tensors_pid, tid,
522                                               tensor.object_id)
523            self._emit_tensor_snapshot(tensor, end_time - 1, tensors_pid, tid,
524                                       output)
525
526  def _show_compute(self, show_dataflow):
527    """Visualize the computation activity."""
528    for dev_stats in self._step_stats.dev_stats:
529      device_name = dev_stats.device
530      device_pid = self._device_pids[device_name]
531      is_gputrace = self._is_gputrace_device(device_name)
532
533      for node_stats in dev_stats.node_stats:
534        tid = node_stats.thread_id
535        start_time = node_stats.all_start_micros
536        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
537        self._emit_op(node_stats, device_pid, is_gputrace)
538
539        if is_gputrace or node_stats.node_name == 'RecvTensor':
540          continue
541
542        _, _, inputs = self._parse_op_label(node_stats.timeline_label)
543        for input_name in inputs:
544          if input_name not in self._tensors:
545            # This can happen when partitioning has inserted a Send/Recv.
546            # We remove the numeric suffix so that the dataflow appears to
547            # come from the original node.  Ideally, the StepStats would
548            # contain logging for the Send and Recv nodes.
549            index = input_name.rfind('/_')
550            if index > 0:
551              input_name = input_name[:index]
552
553          if input_name in self._tensors:
554            tensor = self._tensors[input_name]
555            tensor.add_ref(start_time)
556            tensor.add_unref(end_time - 1)
557
558            if show_dataflow:
559              # We use a different flow ID for every graph edge.
560              create_time, create_pid, create_tid = self._flow_starts[
561                  input_name]
562              # Don't add flows when producer and consumer ops are on the same
563              # pid/tid since the horizontal arrows clutter the visualization.
564              if create_pid != device_pid or create_tid != tid:
565                flow_id = self._alloc_flow_id()
566                self._chrome_trace.emit_flow_start(input_name, create_time,
567                                                   create_pid, create_tid,
568                                                   flow_id)
569                self._chrome_trace.emit_flow_end(input_name, start_time,
570                                                 device_pid, tid, flow_id)
571          else:
572            logging.vlog(1, 'Can\'t find tensor %s - removed by CSE?',
573                         input_name)
574
575  def _show_memory_counters(self):
576    """Produce a counter series for each memory allocator."""
577    # Iterate over all tensor trackers to build a list of allocations and
578    # frees for each allocator. Then sort the lists and emit a cumulative
579    # counter series for each allocator.
580    allocations = {}
581    for name in self._tensors:
582      tensor = self._tensors[name]
583      self._chrome_trace.emit_obj_delete('Tensor', name, tensor.last_unref,
584                                         tensor.pid, 0, tensor.object_id)
585      allocator = tensor.allocator
586      if allocator not in allocations:
587        allocations[allocator] = []
588      num_bytes = tensor.num_bytes
589      allocations[allocator].append((tensor.create_time, num_bytes, name))
590      allocations[allocator].append((tensor.last_unref, -num_bytes, name))
591
592    alloc_maxes = {}
593
594    # Generate a counter series showing total allocations for each allocator.
595    for allocator in allocations:
596      alloc_list = allocations[allocator]
597      alloc_list.sort()
598      total_bytes = 0
599      alloc_tensor_set = set()
600      alloc_maxes[allocator] = AllocationMaximum(
601          timestamp=0, num_bytes=0, tensors=set())
602      for time, num_bytes, name in sorted(
603          alloc_list, key=lambda allocation: allocation[0]):
604        total_bytes += num_bytes
605        if num_bytes < 0:
606          alloc_tensor_set.discard(name)
607        else:
608          alloc_tensor_set.add(name)
609
610        if total_bytes > alloc_maxes[allocator].num_bytes:
611          alloc_maxes[allocator] = AllocationMaximum(
612              timestamp=time,
613              num_bytes=total_bytes,
614              tensors=copy.deepcopy(alloc_tensor_set))
615
616        self._chrome_trace.emit_counter('Memory', allocator,
617                                        self._allocators_pid, time, allocator,
618                                        total_bytes)
619    self._allocator_maximums = alloc_maxes
620
621  def _preprocess_op_time(self, op_time):
622    """Update the start and end time of ops in step stats.
623
624    Args:
625    op_time: How the execution time of op is shown in timeline. Possible values
626      are "schedule", "gpu" and "all". "schedule" will show op from the time it
627      is scheduled to the end of the scheduling. Notice by the end of its
628      scheduling its async kernels may not start yet. It is shown using the
629      default value from step_stats. "gpu" will show op with the execution time
630      of its kernels on GPU. "all" will show op from the start of its scheduling
631      to the end of its last kernel.
632    """
633    if op_time == 'schedule':
634      self._step_stats = self._origin_step_stats
635      return
636    self._step_stats = copy.deepcopy(self._origin_step_stats)
637    # Separate job task and gpu tracer stream
638    stream_all_stats = []
639    job_stats = []
640    for stats in self._step_stats.dev_stats:
641      if '/stream:all' in stats.device:
642        stream_all_stats.append(stats)
643      elif '/job' in stats.device:
644        job_stats.append(stats)
645
646    # Record the start time of the first kernel and the end time of
647    # the last gpu kernel for all ops.
648    op_gpu_start = {}
649    op_gpu_end = {}
650    for stats in stream_all_stats:
651      for kernel in stats.node_stats:
652        name, _ = self._parse_kernel_label(kernel.timeline_label,
653                                           kernel.node_name)
654        start = kernel.all_start_micros
655        end = kernel.all_start_micros + kernel.all_end_rel_micros
656        if name in op_gpu_start:
657          op_gpu_start[name] = min(op_gpu_start[name], start)
658          op_gpu_end[name] = max(op_gpu_end[name], end)
659        else:
660          op_gpu_start[name] = start
661          op_gpu_end[name] = end
662
663    # Update the start and end time of each op according to the op_time
664    for stats in job_stats:
665      for op in stats.node_stats:
666        if op.node_name in op_gpu_start:
667          end = max(op_gpu_end[op.node_name],
668                    op.all_start_micros + op.all_end_rel_micros)
669          if op_time == 'gpu':
670            op.all_start_micros = op_gpu_start[op.node_name]
671          op.all_end_rel_micros = end - op.all_start_micros
672
673  def analyze_step_stats(self,
674                         show_dataflow=True,
675                         show_memory=True,
676                         op_time='schedule'):
677    """Analyze the step stats and format it into Chrome Trace Format.
678
679    Args:
680      show_dataflow: (Optional.) If True, add flow events to the trace
681        connecting producers and consumers of tensors.
682      show_memory: (Optional.) If True, add object snapshot events to the trace
683        showing the sizes and lifetimes of tensors.
684      op_time: (Optional.) How the execution time of op is shown in timeline.
685        Possible values are "schedule", "gpu" and "all". "schedule" will show op
686        from the time it is scheduled to the end of the scheduling. Notice by
687        the end of its scheduling its async kernels may not start yet. It is
688        shown using the default value from step_stats. "gpu" will show op with
689        the execution time of its kernels on GPU. "all" will show op from the
690        start of its scheduling to the end of its last kernel.
691
692    Returns:
693      A 'StepStatsAnalysis' object.
694    """
695    self._preprocess_op_time(op_time)
696    self._allocate_pids()
697    self._assign_lanes()
698    self._analyze_tensors(show_memory)
699    self._show_compute(show_dataflow)
700    if show_memory:
701      self._show_memory_counters()
702    return StepStatsAnalysis(
703        chrome_trace=self._chrome_trace,
704        allocator_maximums=self._allocator_maximums)
705
706  def generate_chrome_trace_format(self,
707                                   show_dataflow=True,
708                                   show_memory=False,
709                                   op_time='schedule'):
710    """Produces a trace in Chrome Trace Format.
711
712    Args:
713      show_dataflow: (Optional.) If True, add flow events to the trace
714        connecting producers and consumers of tensors.
715      show_memory: (Optional.) If True, add object snapshot events to the trace
716        showing the sizes and lifetimes of tensors.
717      op_time: (Optional.) How the execution time of op is shown in timeline.
718        Possible values are "schedule", "gpu" and "all".
719        "schedule" will show op from the time it is scheduled to the end of
720          the scheduling.
721          Notice by the end of its scheduling its async kernels may not start
722          yet. It is shown using the default value from step_stats.
723        "gpu" will show op with the execution time of its kernels on GPU.
724        "all" will show op from the start of its scheduling to the end of
725          its last kernel.
726
727    Returns:
728      A JSON formatted string in Chrome Trace format.
729    """
730    step_stats_analysis = self.analyze_step_stats(
731        show_dataflow=show_dataflow, show_memory=show_memory, op_time=op_time)
732
733    return step_stats_analysis.chrome_trace.format_to_string(pretty=True)
734