xref: /aosp_15_r20/external/tensorflow/tensorflow/python/profiler/model_analyzer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model Analyzer.
16
17Analyze model, including shape, params, time, memory, structure, etc.
18"""
19import sys
20
21import six
22
23from google.protobuf import message
24from tensorflow.core.profiler import tfprof_options_pb2
25from tensorflow.core.profiler import tfprof_output_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.profiler import option_builder
30from tensorflow.python.profiler import tfprof_logger
31from tensorflow.python.util import _pywrap_tfprof as print_mdl
32from tensorflow.python.util.tf_export import tf_export
33
34_DEFAULT_PROFILE_OPTIONS = 0
35_DEFAULT_ADVISE_OPTIONS = 0
36
37# The following options are for 'advise' cmd.
38# Show all advice.
39ALL_ADVICE = {
40    'ExpensiveOperationChecker': {},
41    'AcceleratorUtilizationChecker': {},
42    'JobChecker': {},  # Only available internally.
43    'OperationChecker': {},
44}
45
46
47def _graph_string(graph):
48  """Helper to serialize a graph to string."""
49  if graph:
50    return graph.as_graph_def(add_shapes=True).SerializeToString()
51  else:
52    return b''
53
54
55def _build_options(options):
56  """Build tfprof.OptionsProto.
57
58  Args:
59    options: A dictionary of options.
60
61  Returns:
62    tfprof.OptionsProto.
63  """
64  opts = tfprof_options_pb2.OptionsProto()
65  opts.max_depth = options.get('max_depth', 10)
66  opts.min_bytes = options.get('min_bytes', 0)
67  opts.min_peak_bytes = options.get('min_peak_bytes', 0)
68  opts.min_residual_bytes = options.get('min_residual_bytes', 0)
69  opts.min_output_bytes = options.get('min_output_bytes', 0)
70  opts.min_micros = options.get('min_micros', 0)
71  opts.min_accelerator_micros = options.get('min_accelerator_micros', 0)
72  opts.min_cpu_micros = options.get('min_cpu_micros', 0)
73  opts.min_params = options.get('min_params', 0)
74  opts.min_float_ops = options.get('min_float_ops', 0)
75  opts.min_occurrence = options.get('min_occurrence', 0)
76
77  opts.step = options.get('step', -1)
78
79  opts.order_by = options.get('order_by', 'name')
80
81  for p in options.get('account_type_regexes', []):
82    opts.account_type_regexes.append(p)
83  for p in options.get('start_name_regexes', []):
84    opts.start_name_regexes.append(p)
85  for p in options.get('trim_name_regexes', []):
86    opts.trim_name_regexes.append(p)
87  for p in options.get('show_name_regexes', []):
88    opts.show_name_regexes.append(p)
89  for p in options.get('hide_name_regexes', []):
90    opts.hide_name_regexes.append(p)
91  opts.account_displayed_op_only = options.get('account_displayed_op_only',
92                                               False)
93
94  for p in options.get('select', []):
95    opts.select.append(p)
96
97  opts.output = options.get('output', 'stdout')
98  opts.dump_to_file = options.get('dump_to_file', '')
99
100  return opts
101
102
103def _build_advisor_options(options):
104  """Build tfprof.AdvisorOptionsProto.
105
106  Args:
107    options: A dictionary of options. See ALL_ADVICE example.
108
109  Returns:
110    tfprof.AdvisorOptionsProto.
111  """
112  opts = tfprof_options_pb2.AdvisorOptionsProto()
113  if options is None:
114    return opts
115  for checker, checker_opts in six.iteritems(options):
116    checker_ops_pb = tfprof_options_pb2.AdvisorOptionsProto.CheckerOption()
117    for k, v in six.iteritems(checker_opts):
118      checker_ops_pb[k] = v
119    opts.checkers[checker].MergeFrom(checker_ops_pb)
120  return opts
121
122
123@tf_export(v1=['profiler.Profiler'])
124class Profiler(object):
125  """TensorFlow multi-step profiler.
126
127
128  ```python
129  Typical use case:
130    # Currently we are only allowed to create 1 profiler per process.
131    profiler = Profiler(sess.graph)
132
133    for i in range(total_steps):
134      if i % 10000 == 0:
135        run_meta = tf.compat.v1.RunMetadata()
136        _ = sess.run(...,
137                     options=tf.compat.v1.RunOptions(
138                         trace_level=tf.RunOptions.FULL_TRACE),
139                     run_metadata=run_meta)
140        profiler.add_step(i, run_meta)
141
142        # Profile the parameters of your model.
143        profiler.profile_name_scope(options=(option_builder.ProfileOptionBuilder
144            .trainable_variables_parameter()))
145
146        # Or profile the timing of your model operations.
147        opts = option_builder.ProfileOptionBuilder.time_and_memory()
148        profiler.profile_operations(options=opts)
149
150        # Or you can generate a timeline:
151        opts = (option_builder.ProfileOptionBuilder(
152                option_builder.ProfileOptionBuilder.time_and_memory())
153                .with_step(i)
154                .with_timeline_output(filename).build())
155        profiler.profile_graph(options=opts)
156      else:
157        _ = sess.run(...)
158    # Auto detect problems and generate advice.
159    profiler.advise()
160  ```
161  """
162
163  def __init__(self, graph=None, op_log=None):
164    """Constructor.
165
166    Args:
167      graph: tf.Graph. If None and eager execution is not enabled, use default
168        graph.
169      op_log: optional. tensorflow::tfprof::OpLogProto proto. Used to define
170        extra op types.
171    """
172    if not graph and not context.executing_eagerly():
173      graph = ops.get_default_graph()
174    self._coverage = 0.0
175    self._graph = graph
176    # pylint: disable=protected-access
177    op_log = tfprof_logger.merge_default_with_oplog(self._graph, op_log=op_log)
178    # pylint: enable=protected-access
179    print_mdl.NewProfiler(
180        _graph_string(self._graph), op_log.SerializeToString())
181
182  def __del__(self):
183    print_mdl.DeleteProfiler()
184
185  def add_step(self, step, run_meta):
186    """Add statistics of a step.
187
188    Args:
189      step: int, An id used to group one or more different `run_meta` together.
190        When profiling with the profile_xxx APIs, user can use the `step` id in
191        the `options` to profile these `run_meta` together.
192      run_meta: RunMetadata proto that contains statistics of a session run.
193    """
194    # pylint: disable=protected-access
195    op_log = tfprof_logger.merge_default_with_oplog(
196        self._graph, run_meta=run_meta)
197    # pylint: enable=protected-access
198    # TODO(xpan): P1: Better to find the current graph.
199    self._coverage = print_mdl.AddStep(step, _graph_string(self._graph),
200                                       run_meta.SerializeToString(),
201                                       op_log.SerializeToString())
202
203  def profile_python(self, options):
204    """Profile the statistics of the Python codes.
205
206      By default, it shows the call stack from root. To avoid
207      redundant output, you may use options to filter as below
208        options['show_name_regexes'] = ['.*my_code.py.*']
209
210    Args:
211      options: A dict of options. See core/profiler/g3doc/options.md.
212
213    Returns:
214      a MultiGraphNodeProto that records the results.
215    """
216    opts = _build_options(options)
217    tfprof_node = tfprof_output_pb2.MultiGraphNodeProto()
218    try:
219      tfprof_node.ParseFromString(
220          print_mdl.Profile('code'.encode('utf-8'), opts.SerializeToString()))
221    except message.DecodeError as e:
222      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
223    return tfprof_node
224
225  def profile_operations(self, options):
226    """Profile the statistics of the Operation types (e.g.
227
228    MatMul, Conv2D).
229
230    Args:
231      options: A dict of options. See core/profiler/g3doc/options.md.
232
233    Returns:
234      a MultiGraphNodeProto that records the results.
235    """
236    opts = _build_options(options)
237    tfprof_node = tfprof_output_pb2.MultiGraphNodeProto()
238    try:
239      tfprof_node.ParseFromString(
240          print_mdl.Profile('op'.encode('utf-8'), opts.SerializeToString()))
241    except message.DecodeError as e:
242      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
243    return tfprof_node
244
245  def profile_name_scope(self, options):
246    """Profile the statistics of graph nodes, organized by name scope.
247
248    Args:
249      options: A dict of options. See core/profiler/g3doc/options.md.
250
251    Returns:
252      a GraphNodeProto that records the results.
253    """
254    opts = _build_options(options)
255    tfprof_node = tfprof_output_pb2.GraphNodeProto()
256    try:
257      tfprof_node.ParseFromString(
258          print_mdl.Profile('scope'.encode('utf-8'), opts.SerializeToString()))
259    except message.DecodeError as e:
260      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
261    return tfprof_node
262
263  def profile_graph(self, options):
264    """Profile the statistics of graph nodes, organized by dataflow graph.
265
266    Args:
267      options: A dict of options. See core/profiler/g3doc/options.md.
268
269    Returns:
270      a GraphNodeProto that records the results.
271    """
272    opts = _build_options(options)
273    tfprof_node = tfprof_output_pb2.GraphNodeProto()
274    try:
275      tfprof_node.ParseFromString(
276          print_mdl.Profile('graph'.encode('utf-8'), opts.SerializeToString()))
277    except message.DecodeError as e:
278      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
279    return tfprof_node
280
281  def advise(self, options):
282    """Automatically detect problems and generate reports.
283
284    Args:
285      options: A dict of options. See ALL_ADVICE example above.
286
287    Returns:
288      An Advise proto that contains the reports from all checkers.
289    """
290    advise_pb = tfprof_output_pb2.AdviceProto()
291    opts = _build_advisor_options(options)
292    advise_pb.ParseFromString(
293        print_mdl.Profile('advise'.encode('utf-8'), opts.SerializeToString()))
294    return advise_pb
295
296  def serialize_to_string(self):
297    """Serialize the ProfileProto to a binary string.
298
299      Users can write it to file for offline analysis by tfprof commandline
300      or graphical interface.
301
302    Returns:
303      ProfileProto binary string.
304    """
305    return print_mdl.SerializeToString()
306
307  def _write_profile(self, filename):
308    """Writes the profile to a file."""
309    print_mdl.WriteProfile(filename)
310
311
312@tf_export(v1=['profiler.profile'])
313def profile(graph=None,
314            run_meta=None,
315            op_log=None,
316            cmd='scope',
317            options=_DEFAULT_PROFILE_OPTIONS):
318  """Profile model.
319
320    Tutorials and examples can be found in:
321    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/profiler/g3doc/python_api.md
322
323  Args:
324    graph: tf.Graph. If None and eager execution is not enabled, use default
325      graph.
326    run_meta: optional tensorflow.RunMetadata proto. It is necessary to
327      support run time information profiling, such as time and memory.
328    op_log: tensorflow.tfprof.OpLogProto proto. User can assign "types" to graph
329      nodes with op_log. "types" allow user to flexibly group and account
330      profiles using options['accounted_type_regexes'].
331    cmd: string. Either 'op', 'scope', 'graph' or 'code'. 'op' view organizes
332      profile using operation type. (e.g. MatMul) 'scope' view organizes profile
333      using graph node name scope. 'graph' view organizes profile using graph
334      node inputs/outputs. 'code' view organizes profile using Python call
335      stack.
336    options: A dict of options. See core/profiler/g3doc/options.md.
337
338  Returns:
339    If cmd is 'scope' or 'graph', returns GraphNodeProto proto.
340    If cmd is 'op' or 'code', returns MultiGraphNodeProto proto.
341    Side effect: stdout/file/timeline.json depending on options['output']
342  """
343  if not graph and not context.executing_eagerly():
344    graph = ops.get_default_graph()
345
346  if options == _DEFAULT_PROFILE_OPTIONS:
347    options = (
348        option_builder.ProfileOptionBuilder.trainable_variables_parameter())
349  # pylint: disable=protected-access
350  op_log = tfprof_logger.merge_default_with_oplog(
351      graph, op_log, run_meta, add_trace=cmd == 'code')
352  # pylint: enable=protected-access
353
354  opts = _build_options(options)
355
356  run_meta_str = run_meta.SerializeToString() if run_meta else b''
357
358  graph_str = _graph_string(graph)
359
360  if cmd == 'code' or cmd == 'op':
361    tfprof_node = tfprof_output_pb2.MultiGraphNodeProto()
362    ret = print_mdl.PrintModelAnalysis(graph_str, run_meta_str,
363                                       op_log.SerializeToString(),
364                                       cmd.encode('utf-8'),
365                                       opts.SerializeToString())
366    try:
367      tfprof_node.ParseFromString(ret)
368    except message.DecodeError as e:
369      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
370
371  elif cmd == 'graph' or cmd == 'scope':
372    tfprof_node = tfprof_output_pb2.GraphNodeProto()
373    ret = print_mdl.PrintModelAnalysis(graph_str, run_meta_str,
374                                       op_log.SerializeToString(),
375                                       cmd.encode('utf-8'),
376                                       opts.SerializeToString())
377    try:
378      tfprof_node.ParseFromString(ret)
379    except message.DecodeError as e:
380      sys.stderr.write('Cannot parse returned proto: %s.\n' % e)
381  else:
382    raise errors.InvalidArgumentError(None, None, 'unknown cmd: %s\n' % cmd)
383
384  return tfprof_node
385
386
387@tf_export(v1=['profiler.advise'])
388def advise(graph=None, run_meta=None, options=_DEFAULT_ADVISE_OPTIONS):
389  """Auto profile and advise.
390
391    Builds profiles and automatically check anomalies of various
392    aspects. For more details:
393    https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profiler/README.md
394
395  Args:
396    graph: tf.Graph. If None and eager execution is not enabled, use default
397      graph.
398    run_meta: optional tensorflow.RunMetadata proto. It is necessary to
399      support run time information profiling, such as time and memory.
400    options: see ALL_ADVICE example above. Default checks everything.
401
402  Returns:
403    Returns AdviceProto proto
404  """
405  if not graph and not context.executing_eagerly():
406    graph = ops.get_default_graph()
407
408  if options == _DEFAULT_ADVISE_OPTIONS:
409    options = ALL_ADVICE.copy()
410
411  # pylint: disable=protected-access
412  op_log = tfprof_logger.merge_default_with_oplog(
413      graph, None, run_meta, add_trace=True)
414  # pylint: enable=protected-access
415
416  run_meta_str = run_meta.SerializeToString() if run_meta else b''
417
418  opts = _build_advisor_options(options)
419  ret = tfprof_output_pb2.AdviceProto()
420  ret.ParseFromString(
421      print_mdl.PrintModelAnalysis(
422          _graph_string(graph), run_meta_str, op_log.SerializeToString(),
423          'advise'.encode('utf-8'), opts.SerializeToString()))
424  return ret
425