1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logging tensorflow::tfprof::OpLogProto. 16 17OpLogProto is used to add extra model information for offline analysis. 18""" 19import os 20import sys 21 22import six 23from tensorflow.core.profiler import tfprof_log_pb2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.platform import gfile 28from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import 29from tensorflow.python.util.tf_export import tf_export 30 31TRAINABLE_VARIABLES = '_trainable_variables' 32REGISTERED_FLOP_STATS = 'flops' 33 34 35def _fill_missing_graph_shape(graph, run_meta): 36 """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'.""" 37 for dev_stat in run_meta.step_stats.dev_stats: 38 for node_stat in dev_stat.node_stats: 39 if not node_stat.output: 40 continue 41 try: 42 op = graph.get_operation_by_name(node_stat.node_name) 43 except KeyError as e: 44 # Graph doesn't contains the node_stat, usually RecvTensor. 45 continue 46 if len(node_stat.output) != len(op.outputs): 47 # For example, conditional op has only 1 output at run time. 48 continue 49 for (i, node_stat_out) in enumerate(node_stat.output): 50 if op.outputs[i].get_shape().is_fully_defined(): 51 continue 52 node_stat_dims = node_stat_out.tensor_description.shape.dim 53 node_stat_shape = tensor_shape.TensorShape( 54 [d.size for d in node_stat_dims]) 55 try: 56 op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with( 57 node_stat_shape)) 58 except ValueError as e: 59 sys.stderr.write('Node %s incompatible shapes: %s.\n' % 60 (node_stat.node_name, e)) 61 return graph 62 63 64def _str_id(s, str_to_id): 65 """Maps string to id.""" 66 num = str_to_id.get(s, None) 67 if num is None: 68 num = len(str_to_id) 69 str_to_id[s] = num 70 return num 71 72 73def _get_logged_ops(graph, run_meta=None, add_trace=True, 74 add_trainable_var=True): 75 """Extract trainable model parameters and FLOPs for ops from a Graph. 76 77 Args: 78 graph: tf.Graph. 79 run_meta: RunMetadata proto used to complete shape information. 80 add_trace: Whether to add op trace information. 81 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 82 type '_trainable_variables'. 83 Returns: 84 logged_ops: dict mapping from op_name to OpLogEntry. 85 string_to_id: dict mapping from string to id. 86 """ 87 if run_meta: 88 graph = _fill_missing_graph_shape(graph, run_meta) 89 90 op_missing_shape = 0 91 logged_ops = {} 92 string_to_id = {} 93 string_to_id['none'] = len(string_to_id) 94 # TODO(xpan): Work with Profiler more efficiently. 95 for op in graph.get_operations(): 96 try: 97 stats = ops.get_stats_for_node_def( 98 graph, op.node_def, REGISTERED_FLOP_STATS) 99 except ValueError: 100 # Catch Exception When shape is incomplete. Skip it. 101 op_missing_shape += 1 102 stats = None 103 104 entry = tfprof_log_pb2.OpLogEntry() 105 entry.name = op.name 106 add_entry = False 107 if stats and stats.value: 108 entry.float_ops = int(stats.value) 109 add_entry = True 110 111 if add_trace: 112 if op.traceback: 113 for filename, lineno, funcname, line in op.traceback: 114 trace = entry.code_def.traces.add() 115 trace.file_id = _str_id(filename, string_to_id) if filename else 0 116 trace.lineno = lineno if lineno else -1 117 trace.function_id = _str_id(funcname, string_to_id) if funcname else 0 118 trace.line_id = _str_id(line, string_to_id) if line else 0 119 # TODO(slebedev): remove this unused field from the proto. 120 trace.func_start_line = -1 121 add_entry = True 122 123 if add_entry: 124 logged_ops[entry.name] = entry 125 126 if add_trainable_var: 127 for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): 128 if v.op.name not in logged_ops: 129 entry = tfprof_log_pb2.OpLogEntry() 130 entry.name = v.op.name 131 entry.types.append(TRAINABLE_VARIABLES) 132 logged_ops[entry.name] = entry 133 else: 134 logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) 135 136 if op_missing_shape > 0 and not run_meta: 137 sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % 138 op_missing_shape) 139 return logged_ops, string_to_id 140 141 142def merge_default_with_oplog(graph, op_log=None, run_meta=None, 143 add_trace=True, add_trainable_var=True): 144 """Merge the tfprof default extra info with caller's op_log. 145 146 Args: 147 graph: tf.Graph. If None and eager execution is not enabled, use 148 default graph. 149 op_log: OpLogProto proto. 150 run_meta: RunMetadata proto used to complete shape information. 151 add_trace: Whether to add op trace information. 152 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 153 type '_trainable_variables'. 154 Returns: 155 tmp_op_log: Merged OpLogProto proto. 156 """ 157 if not graph and not context.executing_eagerly(): 158 graph = ops.get_default_graph() 159 160 tmp_op_log = tfprof_log_pb2.OpLogProto() 161 if not graph: 162 return tmp_op_log 163 164 logged_ops, string_to_id = _get_logged_ops( 165 graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) 166 167 if not op_log: 168 tmp_op_log.log_entries.extend(logged_ops.values()) 169 else: 170 all_ops = {} 171 for entry in op_log.log_entries: 172 all_ops[entry.name] = entry 173 for op_name, entry in six.iteritems(logged_ops): 174 if op_name in all_ops: 175 all_ops[op_name].types.extend(entry.types) 176 if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: 177 all_ops[op_name].float_ops = entry.float_ops 178 if entry.code_def.traces and not all_ops[op_name].code_def.traces: 179 all_ops[op_name].code_def.MergeFrom(entry.code_def) 180 else: 181 all_ops[op_name] = entry 182 tmp_op_log.log_entries.extend(all_ops.values()) 183 184 for s, i in six.iteritems(string_to_id): 185 tmp_op_log.id_to_string[i] = s 186 return tmp_op_log 187 188 189@tf_export(v1=['profiler.write_op_log']) 190def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): 191 """Log provided 'op_log', and add additional model information below. 192 193 The API also assigns ops in tf.compat.v1.trainable_variables() an op type 194 called '_trainable_variables'. 195 The API also logs 'flops' statistics for ops with op.RegisterStatistics() 196 defined. flops calculation depends on Tensor shapes defined in 'graph', 197 which might not be complete. 'run_meta', if provided, completes the shape 198 information with best effort. 199 200 Args: 201 graph: tf.Graph. If None and eager execution is not enabled, use 202 default graph. 203 log_dir: directory to write the log file. 204 op_log: (Optional) OpLogProto proto to be written. If not provided, an new 205 one is created. 206 run_meta: (Optional) RunMetadata proto that helps flops computation using 207 run time shape information. 208 add_trace: Whether to add python code trace information. 209 Used to support "code" view. 210 """ 211 if not graph and not context.executing_eagerly(): 212 graph = ops.get_default_graph() 213 op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) 214 215 with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: 216 log.write(op_log.SerializeToString()) 217