xref: /aosp_15_r20/external/tensorflow/tensorflow/python/profiler/profile_context.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Context that captures profile and performs profiling/dumping.
16"""
17import contextlib
18import os
19import random
20import sys
21import threading
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.platform import gfile
28from tensorflow.python.profiler import model_analyzer
29from tensorflow.python.util import _pywrap_tfprof as print_mdl
30from tensorflow.python.util import compat
31
32WARMUP_STEPS = 10
33MAX_TRACED_STEPS = 100
34
35
36def _profiled_init(self, target='', graph=None, config=None):
37  """Overwrites the session.__init__."""
38  self._profiler_init_internal(target, graph, config)  # pylint: disable=protected-access
39
40
41def _profiled_run(self,
42                  fetches,
43                  feed_dict=None,
44                  options=None,
45                  run_metadata=None):
46  """Overwrites the session.run()."""
47  # pylint: disable=protected-access
48  # Count the session steps.
49  with self.profile_context._new_step() as state:
50    step, locked = state
51    # Fast path if no need for profiling.
52    if locked and not self.profile_context._is_fast_path(step):
53      # Maybe trace this step.
54      if self.profile_context._should_trace(step, self.graph, fetches):
55        if self.profile_context._debug:
56          sys.stderr.write('debug: tracing step: %d\n' % step)
57        # Enable tracing, perform auto profiling or auto dump.
58        if not run_metadata:
59          run_metadata = config_pb2.RunMetadata()
60
61        if not options:
62          options = config_pb2.RunOptions(
63              trace_level=config_pb2.RunOptions.FULL_TRACE)
64          old_trace_level = options.trace_level
65        else:
66          old_trace_level = options.trace_level
67          options.trace_level = config_pb2.RunOptions.FULL_TRACE
68
69        ret = self._profiler_run_internal(
70            fetches, feed_dict, options, run_metadata)
71        if self.profile_context._debug:
72          self.profile_context._dump_file(run_metadata, 'run_meta_%d' % step)
73
74        self.profile_context.profiler._graph = self.graph
75        self.profile_context.profiler.add_step(step, run_metadata)
76        options.trace_level = old_trace_level
77      else:
78        ret = self._profiler_run_internal(fetches, feed_dict, options)
79
80      # Maybe dump profile.
81      self.profile_context._maybe_dump(step)
82
83      # Maybe profile:
84      to_profiles = self.profile_context._profile_candidates()
85      for to_prof in to_profiles:
86        cmd, opts, _ = to_prof
87        saved_views = self.profile_context._views.setdefault(cmd, {})
88        if self.profile_context._debug:
89          sys.stderr.write('debug: profiling %s step: %d\n' % (cmd, step))
90        if cmd == 'graph':
91          saved_views[step] = self.profile_context.profiler.profile_graph(opts)
92        elif cmd == 'scope':
93          saved_views[step] = self.profile_context.profiler.profile_name_scope(
94              opts)
95        elif cmd == 'op':
96          saved_views[step] = self.profile_context.profiler.profile_operations(
97              opts)
98        elif cmd == 'code':
99          saved_views[step] = self.profile_context.profiler.profile_python(opts)
100        else:
101          raise ValueError('Unknown cmd: %s\n' % cmd)
102      return ret
103  # Fast no lock path.
104  return self._profiler_run_internal(
105      fetches, feed_dict, options, run_metadata)
106  # pylint: enable=protected-access
107
108
109class ProfileContext(object):
110  """A Context that captures RunMetadata and performs profiling.
111
112  ```python
113    # Trace steps 100~200, profile at [150, 200] and dump profile at 200.
114    with profile_context.ProfileContext('/tmp/train_dir',
115                                        trace_steps=range(100, 200, 3),
116                                        dump_steps=[200]) as pctx:
117      opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
118      pctx.add_auto_profiling('op', opts, [150, 200])
119      train_loop().
120
121    # Tracing only.
122    with profile_context.tfprof.ProfileContext('/tmp/train_dir') as pctx:
123      # Run train/eval loop for at least few hundred steps. Profiles will be
124      # dumped to train_dir. Use web UI or command line to do profiling.
125      train_loop().
126
127    # When session object is available, do explicit trace, profile and dump.
128    with profile_context.ProfileContext('/tmp/train_dir',
129                                        trace_steps=[],
130                                        dump_steps=[]) as pctx:
131      opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
132      pctx.trace_next_step()
133      _ = session.run(train_op)
134      pctx.profiler.profile_operations(options=opts)
135  ```
136
137  Args:
138    profile_dir: Directory to store profiles.
139    trace_steps: A list of session run steps to trace. If None, use
140        pre-defined steps.
141    dump_steps: A list of steps to dump the profile to `profile_dir`. If None,
142        use pre-defined steps.
143    enabled: If false, everything is disabled with minimal overhead. It allows
144        user to only enable profiling when needed.
145    debug: If true, also dumps the raw trace RunMetadata text file to
146        profile_dir. And print debugging message. Useful for bug report.
147  """
148
149  def __init__(self,
150               profile_dir,
151               trace_steps=None,
152               dump_steps=None,
153               enabled=True,
154               debug=False):
155    self._enabled = enabled
156    if not self._enabled:
157      return
158
159    self._debug = debug
160    if not profile_dir:
161      raise ValueError('Must have a directory for profile.\n')
162    self._profiler_dir = profile_dir
163
164    if trace_steps is None:
165      self._trace_steps = set()
166      self._auto_tracing = True
167    else:
168      if len(trace_steps) > MAX_TRACED_STEPS:
169        raise ValueError('Only support tracing up to 100 steps.\n')
170      self._trace_steps = set(trace_steps[:])
171      self._auto_tracing = False
172
173    if dump_steps is None:
174      self._dump_steps = set([MAX_TRACED_STEPS])
175    else:
176      self._dump_steps = set(dump_steps[:])
177
178    self._rng = random.Random(111)
179    self._fetched = set()
180    self._slow_path_steps = self._dump_steps | self._trace_steps
181    self._trace_next_step = False
182    self._dump_next_step = False
183    self._step = 0
184    self._traced_steps = 0
185    self._auto_profiles = []
186    self._profiler = None
187    self._views = {}
188    self._lock = threading.Lock()
189
190  def get_profiles(self, cmd):
191    """Returns profiling results for each step at which `cmd` was run.
192
193    Args:
194      cmd: string, profiling command used in an `add_auto_profiling` call.
195
196    Returns:
197      dict[int: (MultiGraphNodeProto | GraphNodeProto)]. Keys are steps at which
198      the profiling command was run. Values are the outputs of profiling.
199      For "code" and "op" commands this will be a `MultiGraphNodeProto`, for
200      "scope" and "graph" commands this will be a `GraphNodeProto.
201
202    Raises:
203      ValueError: if `cmd` was never run (either because no session.run call was
204      made or because there was no `add_auto_profiling` call with the specified
205      `cmd`.
206    """
207    if cmd not in self._views:
208      raise ValueError('No autoprofiler for command: {}, was run'.format(cmd))
209    return self._views[cmd]
210
211  def add_auto_profiling(self, cmd, options, profile_steps):
212    """Traces and profiles at some session run steps.
213
214    Args:
215      cmd: The profiling commands. (i.e. scope, op, python, graph)
216      options: The profiling options.
217      profile_steps: A list/set of integers. The profiling command and options
218          will be run automatically at these integer steps. Each step is
219          a session.run.
220    """
221    if not self._enabled:
222      return
223    self._auto_profiles.append((cmd, options, profile_steps[:]))
224    self._slow_path_steps |= set(profile_steps)
225    self._trace_steps |= set(profile_steps)
226
227  @property
228  def profiler(self):
229    """Returns the current profiler object."""
230    if not self._enabled:
231      return None
232    if not self._profiler:
233      self._profiler = model_analyzer.Profiler(ops.get_default_graph())
234    return self._profiler
235
236  def trace_next_step(self):
237    """Enables tracing and adds traces to profiler at next step."""
238    if not self._enabled:
239      return
240    self._trace_next_step = True
241    self._slow_path_steps.add(self._step)
242
243  def dump_next_step(self):
244    """Enable tracing and dump profiles at next step."""
245    if not self._enabled:
246      return
247    self._dump_next_step = True
248    self._slow_path_steps.add(self._step)
249
250  def _is_fast_path(self, step):
251    if step in self._slow_path_steps:
252      return False
253    # When user doesn't set the tracing steps explicitly, auto decide it.
254    if (self._auto_tracing and step > WARMUP_STEPS and
255        self._traced_steps <= MAX_TRACED_STEPS):
256      return False
257    return True
258
259  def _should_trace(self, step, graph, fetches):
260    """Whether should do tracing at current step."""
261    if self._traced_steps > MAX_TRACED_STEPS:
262      return False
263    # Check user-set tracing steps.
264    if step in self._trace_steps or self._trace_next_step:
265      self._traced_steps += 1
266      return True
267
268    # If no user-set tracing steps set and passes warm up steps, auto trace.
269    if self._auto_tracing and step > WARMUP_STEPS:
270      # If the fetches have not been seen before, trace it.
271      with graph.as_default():
272        fetch_names = [f.name for f in
273                       session._FetchMapper.for_fetch(fetches).unique_fetches()]  # pylint: disable=protected-access
274      fetch_name = '-'.join(sorted(fetch_names))
275      if self._debug:
276        sys.stderr.write('debug: trace fetches: %s\n' % fetch_name)
277      if fetch_name not in self._fetched:
278        self._fetched.add(fetch_name)
279        self._traced_steps += 1
280        return True
281      # If the trace coverage is low, does some random tracing.
282      if (self.profiler._coverage < 0.5 and step < MAX_TRACED_STEPS and  # pylint: disable=protected-access
283          self._rng.randint(0, 10) < 2):
284        self._traced_steps += 1
285        return True
286    return False
287
288  def _maybe_dump(self, step):
289    """Maybe dump the profile file."""
290    if not (step in self._dump_steps or self._dump_next_step):
291      return
292    if self._debug:
293      sys.stderr.write('debug: dumping file at step: %d\n' % step)
294    gfile.MakeDirs(self._profiler_dir)
295
296    filename = os.path.join(compat.as_bytes(self._profiler_dir),
297                            compat.as_bytes('profile_%d' % step))
298    self.profiler._write_profile(filename)  # pylint: disable=protected-access
299
300  def _dump_file(self, pb, basename):
301    gfile.MakeDirs(self._profiler_dir)
302    with gfile.Open(os.path.join(self._profiler_dir, basename), 'w') as f:
303      f.write('%s' % pb)
304
305  @contextlib.contextmanager
306  def _new_step(self):
307    acquired = self._lock.acquire(False)  # pylint: disable=assignment-from-no-return
308    yield (self._step, acquired)
309    self._step += 1
310    self._trace_next_step = False
311    self._dump_next_step = False
312    if acquired:
313      self._lock.release()
314
315  def _profile_candidates(self):
316    to_profile = []
317    for auto_prof in self._auto_profiles:
318      _, _, prof_steps = auto_prof
319      if self._step in prof_steps:
320        to_profile.append(auto_prof)
321    return to_profile
322
323  def __enter__(self):
324    if self._enabled:
325      self.old_run = getattr(session.BaseSession, 'run', None)
326      self.old_init = getattr(session.BaseSession, '__init__', None)
327      if not self.old_run:
328        raise errors.InternalError(None, None, 'BaseSession misses run method.')
329      elif not self.old_init:
330        raise errors.InternalError(None, None,
331                                   'BaseSession misses __init__ method.')
332      elif getattr(session.BaseSession, '_profiler_run_internal', None):
333        raise errors.InternalError(None, None,
334                                   'Already in context or context not cleaned.')
335      elif getattr(session.BaseSession, '_profiler_init_internal', None):
336        raise errors.InternalError(None, None,
337                                   'Already in context or context not cleaned.')
338      else:
339        setattr(session.BaseSession, 'run', _profiled_run)
340        setattr(session.BaseSession, '__init__', _profiled_init)
341        setattr(session.BaseSession, '_profiler_run_internal', self.old_run)
342        setattr(session.BaseSession, '_profiler_init_internal', self.old_init)
343        setattr(session.BaseSession, 'profile_context', self)
344        return self
345    else:
346      return self
347
348  def __exit__(self, exec_type, exec_value, exec_tb):
349    if not self._enabled:
350      return
351    print_mdl.DeleteProfiler()
352    setattr(session.BaseSession, 'run', self.old_run)
353    setattr(session.BaseSession, '__init__', self.old_init)
354    setattr(session.BaseSession, '_profiler_run_internal', None)
355    setattr(session.BaseSession, '_profiler_init_internal', None)
356    setattr(session.BaseSession, 'profile_context', None)
357