xref: /aosp_15_r20/external/tensorflow/tensorflow/python/profiler/tfprof_logger.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logging tensorflow::tfprof::OpLogProto.
16
17OpLogProto is used to add extra model information for offline analysis.
18"""
19import os
20import sys
21
22import six
23from tensorflow.core.profiler import tfprof_log_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.platform import gfile
28from tensorflow.python.profiler.internal import flops_registry  # pylint: disable=unused-import
29from tensorflow.python.util.tf_export import tf_export
30
31TRAINABLE_VARIABLES = '_trainable_variables'
32REGISTERED_FLOP_STATS = 'flops'
33
34
35def _fill_missing_graph_shape(graph, run_meta):
36  """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
37  for dev_stat in run_meta.step_stats.dev_stats:
38    for node_stat in dev_stat.node_stats:
39      if not node_stat.output:
40        continue
41      try:
42        op = graph.get_operation_by_name(node_stat.node_name)
43      except KeyError as e:
44        # Graph doesn't contains the node_stat, usually RecvTensor.
45        continue
46      if len(node_stat.output) != len(op.outputs):
47        # For example, conditional op has only 1 output at run time.
48        continue
49      for (i, node_stat_out) in enumerate(node_stat.output):
50        if op.outputs[i].get_shape().is_fully_defined():
51          continue
52        node_stat_dims = node_stat_out.tensor_description.shape.dim
53        node_stat_shape = tensor_shape.TensorShape(
54            [d.size for d in node_stat_dims])
55        try:
56          op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
57              node_stat_shape))
58        except ValueError as e:
59          sys.stderr.write('Node %s incompatible shapes: %s.\n' %
60                           (node_stat.node_name, e))
61  return graph
62
63
64def _str_id(s, str_to_id):
65  """Maps string to id."""
66  num = str_to_id.get(s, None)
67  if num is None:
68    num = len(str_to_id)
69    str_to_id[s] = num
70  return num
71
72
73def _get_logged_ops(graph, run_meta=None, add_trace=True,
74                    add_trainable_var=True):
75  """Extract trainable model parameters and FLOPs for ops from a Graph.
76
77  Args:
78    graph: tf.Graph.
79    run_meta: RunMetadata proto used to complete shape information.
80    add_trace: Whether to add op trace information.
81    add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
82      type '_trainable_variables'.
83  Returns:
84    logged_ops: dict mapping from op_name to OpLogEntry.
85    string_to_id: dict mapping from string to id.
86  """
87  if run_meta:
88    graph = _fill_missing_graph_shape(graph, run_meta)
89
90  op_missing_shape = 0
91  logged_ops = {}
92  string_to_id = {}
93  string_to_id['none'] = len(string_to_id)
94  # TODO(xpan): Work with Profiler more efficiently.
95  for op in graph.get_operations():
96    try:
97      stats = ops.get_stats_for_node_def(
98          graph, op.node_def, REGISTERED_FLOP_STATS)
99    except ValueError:
100      # Catch Exception When shape is incomplete. Skip it.
101      op_missing_shape += 1
102      stats = None
103
104    entry = tfprof_log_pb2.OpLogEntry()
105    entry.name = op.name
106    add_entry = False
107    if stats and stats.value:
108      entry.float_ops = int(stats.value)
109      add_entry = True
110
111    if add_trace:
112      if op.traceback:
113        for filename, lineno, funcname, line in op.traceback:
114          trace = entry.code_def.traces.add()
115          trace.file_id = _str_id(filename, string_to_id) if filename else 0
116          trace.lineno = lineno if lineno else -1
117          trace.function_id = _str_id(funcname, string_to_id) if funcname else 0
118          trace.line_id = _str_id(line, string_to_id) if line else 0
119          # TODO(slebedev): remove this unused field from the proto.
120          trace.func_start_line = -1
121      add_entry = True
122
123    if add_entry:
124      logged_ops[entry.name] = entry
125
126  if add_trainable_var:
127    for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
128      if v.op.name not in logged_ops:
129        entry = tfprof_log_pb2.OpLogEntry()
130        entry.name = v.op.name
131        entry.types.append(TRAINABLE_VARIABLES)
132        logged_ops[entry.name] = entry
133      else:
134        logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
135
136  if op_missing_shape > 0 and not run_meta:
137    sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
138                     op_missing_shape)
139  return logged_ops, string_to_id
140
141
142def merge_default_with_oplog(graph, op_log=None, run_meta=None,
143                             add_trace=True, add_trainable_var=True):
144  """Merge the tfprof default extra info with caller's op_log.
145
146  Args:
147    graph: tf.Graph. If None and eager execution is not enabled, use
148        default graph.
149    op_log: OpLogProto proto.
150    run_meta: RunMetadata proto used to complete shape information.
151    add_trace: Whether to add op trace information.
152    add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
153      type '_trainable_variables'.
154  Returns:
155    tmp_op_log: Merged OpLogProto proto.
156  """
157  if not graph and not context.executing_eagerly():
158    graph = ops.get_default_graph()
159
160  tmp_op_log = tfprof_log_pb2.OpLogProto()
161  if not graph:
162    return tmp_op_log
163
164  logged_ops, string_to_id = _get_logged_ops(
165      graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
166
167  if not op_log:
168    tmp_op_log.log_entries.extend(logged_ops.values())
169  else:
170    all_ops = {}
171    for entry in op_log.log_entries:
172      all_ops[entry.name] = entry
173    for op_name, entry in six.iteritems(logged_ops):
174      if op_name in all_ops:
175        all_ops[op_name].types.extend(entry.types)
176        if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
177          all_ops[op_name].float_ops = entry.float_ops
178        if entry.code_def.traces and not all_ops[op_name].code_def.traces:
179          all_ops[op_name].code_def.MergeFrom(entry.code_def)
180      else:
181        all_ops[op_name] = entry
182    tmp_op_log.log_entries.extend(all_ops.values())
183
184  for s, i in six.iteritems(string_to_id):
185    tmp_op_log.id_to_string[i] = s
186  return tmp_op_log
187
188
189@tf_export(v1=['profiler.write_op_log'])
190def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
191  """Log provided 'op_log', and add additional model information below.
192
193    The API also assigns ops in tf.compat.v1.trainable_variables() an op type
194    called '_trainable_variables'.
195    The API also logs 'flops' statistics for ops with op.RegisterStatistics()
196    defined. flops calculation depends on Tensor shapes defined in 'graph',
197    which might not be complete. 'run_meta', if provided, completes the shape
198    information with best effort.
199
200  Args:
201    graph: tf.Graph. If None and eager execution is not enabled, use
202        default graph.
203    log_dir: directory to write the log file.
204    op_log: (Optional) OpLogProto proto to be written. If not provided, an new
205        one is created.
206    run_meta: (Optional) RunMetadata proto that helps flops computation using
207        run time shape information.
208    add_trace: Whether to add python code trace information.
209        Used to support "code" view.
210  """
211  if not graph and not context.executing_eagerly():
212    graph = ops.get_default_graph()
213  op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)
214
215  with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
216    log.write(op_log.SerializeToString())
217