1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Context that captures profile and performs profiling/dumping. 16""" 17import contextlib 18import os 19import random 20import sys 21import threading 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.platform import gfile 28from tensorflow.python.profiler import model_analyzer 29from tensorflow.python.util import _pywrap_tfprof as print_mdl 30from tensorflow.python.util import compat 31 32WARMUP_STEPS = 10 33MAX_TRACED_STEPS = 100 34 35 36def _profiled_init(self, target='', graph=None, config=None): 37 """Overwrites the session.__init__.""" 38 self._profiler_init_internal(target, graph, config) # pylint: disable=protected-access 39 40 41def _profiled_run(self, 42 fetches, 43 feed_dict=None, 44 options=None, 45 run_metadata=None): 46 """Overwrites the session.run().""" 47 # pylint: disable=protected-access 48 # Count the session steps. 49 with self.profile_context._new_step() as state: 50 step, locked = state 51 # Fast path if no need for profiling. 52 if locked and not self.profile_context._is_fast_path(step): 53 # Maybe trace this step. 54 if self.profile_context._should_trace(step, self.graph, fetches): 55 if self.profile_context._debug: 56 sys.stderr.write('debug: tracing step: %d\n' % step) 57 # Enable tracing, perform auto profiling or auto dump. 58 if not run_metadata: 59 run_metadata = config_pb2.RunMetadata() 60 61 if not options: 62 options = config_pb2.RunOptions( 63 trace_level=config_pb2.RunOptions.FULL_TRACE) 64 old_trace_level = options.trace_level 65 else: 66 old_trace_level = options.trace_level 67 options.trace_level = config_pb2.RunOptions.FULL_TRACE 68 69 ret = self._profiler_run_internal( 70 fetches, feed_dict, options, run_metadata) 71 if self.profile_context._debug: 72 self.profile_context._dump_file(run_metadata, 'run_meta_%d' % step) 73 74 self.profile_context.profiler._graph = self.graph 75 self.profile_context.profiler.add_step(step, run_metadata) 76 options.trace_level = old_trace_level 77 else: 78 ret = self._profiler_run_internal(fetches, feed_dict, options) 79 80 # Maybe dump profile. 81 self.profile_context._maybe_dump(step) 82 83 # Maybe profile: 84 to_profiles = self.profile_context._profile_candidates() 85 for to_prof in to_profiles: 86 cmd, opts, _ = to_prof 87 saved_views = self.profile_context._views.setdefault(cmd, {}) 88 if self.profile_context._debug: 89 sys.stderr.write('debug: profiling %s step: %d\n' % (cmd, step)) 90 if cmd == 'graph': 91 saved_views[step] = self.profile_context.profiler.profile_graph(opts) 92 elif cmd == 'scope': 93 saved_views[step] = self.profile_context.profiler.profile_name_scope( 94 opts) 95 elif cmd == 'op': 96 saved_views[step] = self.profile_context.profiler.profile_operations( 97 opts) 98 elif cmd == 'code': 99 saved_views[step] = self.profile_context.profiler.profile_python(opts) 100 else: 101 raise ValueError('Unknown cmd: %s\n' % cmd) 102 return ret 103 # Fast no lock path. 104 return self._profiler_run_internal( 105 fetches, feed_dict, options, run_metadata) 106 # pylint: enable=protected-access 107 108 109class ProfileContext(object): 110 """A Context that captures RunMetadata and performs profiling. 111 112 ```python 113 # Trace steps 100~200, profile at [150, 200] and dump profile at 200. 114 with profile_context.ProfileContext('/tmp/train_dir', 115 trace_steps=range(100, 200, 3), 116 dump_steps=[200]) as pctx: 117 opts = tf.profiler.ProfileOptionBuilder.time_and_memory() 118 pctx.add_auto_profiling('op', opts, [150, 200]) 119 train_loop(). 120 121 # Tracing only. 122 with profile_context.tfprof.ProfileContext('/tmp/train_dir') as pctx: 123 # Run train/eval loop for at least few hundred steps. Profiles will be 124 # dumped to train_dir. Use web UI or command line to do profiling. 125 train_loop(). 126 127 # When session object is available, do explicit trace, profile and dump. 128 with profile_context.ProfileContext('/tmp/train_dir', 129 trace_steps=[], 130 dump_steps=[]) as pctx: 131 opts = tf.profiler.ProfileOptionBuilder.time_and_memory() 132 pctx.trace_next_step() 133 _ = session.run(train_op) 134 pctx.profiler.profile_operations(options=opts) 135 ``` 136 137 Args: 138 profile_dir: Directory to store profiles. 139 trace_steps: A list of session run steps to trace. If None, use 140 pre-defined steps. 141 dump_steps: A list of steps to dump the profile to `profile_dir`. If None, 142 use pre-defined steps. 143 enabled: If false, everything is disabled with minimal overhead. It allows 144 user to only enable profiling when needed. 145 debug: If true, also dumps the raw trace RunMetadata text file to 146 profile_dir. And print debugging message. Useful for bug report. 147 """ 148 149 def __init__(self, 150 profile_dir, 151 trace_steps=None, 152 dump_steps=None, 153 enabled=True, 154 debug=False): 155 self._enabled = enabled 156 if not self._enabled: 157 return 158 159 self._debug = debug 160 if not profile_dir: 161 raise ValueError('Must have a directory for profile.\n') 162 self._profiler_dir = profile_dir 163 164 if trace_steps is None: 165 self._trace_steps = set() 166 self._auto_tracing = True 167 else: 168 if len(trace_steps) > MAX_TRACED_STEPS: 169 raise ValueError('Only support tracing up to 100 steps.\n') 170 self._trace_steps = set(trace_steps[:]) 171 self._auto_tracing = False 172 173 if dump_steps is None: 174 self._dump_steps = set([MAX_TRACED_STEPS]) 175 else: 176 self._dump_steps = set(dump_steps[:]) 177 178 self._rng = random.Random(111) 179 self._fetched = set() 180 self._slow_path_steps = self._dump_steps | self._trace_steps 181 self._trace_next_step = False 182 self._dump_next_step = False 183 self._step = 0 184 self._traced_steps = 0 185 self._auto_profiles = [] 186 self._profiler = None 187 self._views = {} 188 self._lock = threading.Lock() 189 190 def get_profiles(self, cmd): 191 """Returns profiling results for each step at which `cmd` was run. 192 193 Args: 194 cmd: string, profiling command used in an `add_auto_profiling` call. 195 196 Returns: 197 dict[int: (MultiGraphNodeProto | GraphNodeProto)]. Keys are steps at which 198 the profiling command was run. Values are the outputs of profiling. 199 For "code" and "op" commands this will be a `MultiGraphNodeProto`, for 200 "scope" and "graph" commands this will be a `GraphNodeProto. 201 202 Raises: 203 ValueError: if `cmd` was never run (either because no session.run call was 204 made or because there was no `add_auto_profiling` call with the specified 205 `cmd`. 206 """ 207 if cmd not in self._views: 208 raise ValueError('No autoprofiler for command: {}, was run'.format(cmd)) 209 return self._views[cmd] 210 211 def add_auto_profiling(self, cmd, options, profile_steps): 212 """Traces and profiles at some session run steps. 213 214 Args: 215 cmd: The profiling commands. (i.e. scope, op, python, graph) 216 options: The profiling options. 217 profile_steps: A list/set of integers. The profiling command and options 218 will be run automatically at these integer steps. Each step is 219 a session.run. 220 """ 221 if not self._enabled: 222 return 223 self._auto_profiles.append((cmd, options, profile_steps[:])) 224 self._slow_path_steps |= set(profile_steps) 225 self._trace_steps |= set(profile_steps) 226 227 @property 228 def profiler(self): 229 """Returns the current profiler object.""" 230 if not self._enabled: 231 return None 232 if not self._profiler: 233 self._profiler = model_analyzer.Profiler(ops.get_default_graph()) 234 return self._profiler 235 236 def trace_next_step(self): 237 """Enables tracing and adds traces to profiler at next step.""" 238 if not self._enabled: 239 return 240 self._trace_next_step = True 241 self._slow_path_steps.add(self._step) 242 243 def dump_next_step(self): 244 """Enable tracing and dump profiles at next step.""" 245 if not self._enabled: 246 return 247 self._dump_next_step = True 248 self._slow_path_steps.add(self._step) 249 250 def _is_fast_path(self, step): 251 if step in self._slow_path_steps: 252 return False 253 # When user doesn't set the tracing steps explicitly, auto decide it. 254 if (self._auto_tracing and step > WARMUP_STEPS and 255 self._traced_steps <= MAX_TRACED_STEPS): 256 return False 257 return True 258 259 def _should_trace(self, step, graph, fetches): 260 """Whether should do tracing at current step.""" 261 if self._traced_steps > MAX_TRACED_STEPS: 262 return False 263 # Check user-set tracing steps. 264 if step in self._trace_steps or self._trace_next_step: 265 self._traced_steps += 1 266 return True 267 268 # If no user-set tracing steps set and passes warm up steps, auto trace. 269 if self._auto_tracing and step > WARMUP_STEPS: 270 # If the fetches have not been seen before, trace it. 271 with graph.as_default(): 272 fetch_names = [f.name for f in 273 session._FetchMapper.for_fetch(fetches).unique_fetches()] # pylint: disable=protected-access 274 fetch_name = '-'.join(sorted(fetch_names)) 275 if self._debug: 276 sys.stderr.write('debug: trace fetches: %s\n' % fetch_name) 277 if fetch_name not in self._fetched: 278 self._fetched.add(fetch_name) 279 self._traced_steps += 1 280 return True 281 # If the trace coverage is low, does some random tracing. 282 if (self.profiler._coverage < 0.5 and step < MAX_TRACED_STEPS and # pylint: disable=protected-access 283 self._rng.randint(0, 10) < 2): 284 self._traced_steps += 1 285 return True 286 return False 287 288 def _maybe_dump(self, step): 289 """Maybe dump the profile file.""" 290 if not (step in self._dump_steps or self._dump_next_step): 291 return 292 if self._debug: 293 sys.stderr.write('debug: dumping file at step: %d\n' % step) 294 gfile.MakeDirs(self._profiler_dir) 295 296 filename = os.path.join(compat.as_bytes(self._profiler_dir), 297 compat.as_bytes('profile_%d' % step)) 298 self.profiler._write_profile(filename) # pylint: disable=protected-access 299 300 def _dump_file(self, pb, basename): 301 gfile.MakeDirs(self._profiler_dir) 302 with gfile.Open(os.path.join(self._profiler_dir, basename), 'w') as f: 303 f.write('%s' % pb) 304 305 @contextlib.contextmanager 306 def _new_step(self): 307 acquired = self._lock.acquire(False) # pylint: disable=assignment-from-no-return 308 yield (self._step, acquired) 309 self._step += 1 310 self._trace_next_step = False 311 self._dump_next_step = False 312 if acquired: 313 self._lock.release() 314 315 def _profile_candidates(self): 316 to_profile = [] 317 for auto_prof in self._auto_profiles: 318 _, _, prof_steps = auto_prof 319 if self._step in prof_steps: 320 to_profile.append(auto_prof) 321 return to_profile 322 323 def __enter__(self): 324 if self._enabled: 325 self.old_run = getattr(session.BaseSession, 'run', None) 326 self.old_init = getattr(session.BaseSession, '__init__', None) 327 if not self.old_run: 328 raise errors.InternalError(None, None, 'BaseSession misses run method.') 329 elif not self.old_init: 330 raise errors.InternalError(None, None, 331 'BaseSession misses __init__ method.') 332 elif getattr(session.BaseSession, '_profiler_run_internal', None): 333 raise errors.InternalError(None, None, 334 'Already in context or context not cleaned.') 335 elif getattr(session.BaseSession, '_profiler_init_internal', None): 336 raise errors.InternalError(None, None, 337 'Already in context or context not cleaned.') 338 else: 339 setattr(session.BaseSession, 'run', _profiled_run) 340 setattr(session.BaseSession, '__init__', _profiled_init) 341 setattr(session.BaseSession, '_profiler_run_internal', self.old_run) 342 setattr(session.BaseSession, '_profiler_init_internal', self.old_init) 343 setattr(session.BaseSession, 'profile_context', self) 344 return self 345 else: 346 return self 347 348 def __exit__(self, exec_type, exec_value, exec_tb): 349 if not self._enabled: 350 return 351 print_mdl.DeleteProfiler() 352 setattr(session.BaseSession, 'run', self.old_run) 353 setattr(session.BaseSession, '__init__', self.old_init) 354 setattr(session.BaseSession, '_profiler_run_internal', None) 355 setattr(session.BaseSession, '_profiler_init_internal', None) 356 setattr(session.BaseSession, 'profile_context', None) 357