xref: /aosp_15_r20/external/tensorflow/tensorflow/python/profiler/internal/model_analyzer_testlib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A test lib that defines some models."""
16import contextlib
17
18from tensorflow.python.framework import dtypes
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import init_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
23from tensorflow.python.ops import nn_ops
24from tensorflow.python.ops import rnn
25from tensorflow.python.ops import rnn_cell
26from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.profiler import model_analyzer
29from tensorflow.python.training import gradient_descent
30from tensorflow.python.util import _pywrap_tfprof as print_mdl
31from tensorflow.python.util import compat
32
33
34def BuildSmallModel():
35  """Build a small forward conv model."""
36  image = array_ops.zeros([2, 6, 6, 3])
37  _ = variable_scope.get_variable(
38      'ScalarW', [],
39      dtypes.float32,
40      initializer=init_ops.random_normal_initializer(stddev=0.001))
41  kernel = variable_scope.get_variable(
42      'DW', [3, 3, 3, 6],
43      dtypes.float32,
44      initializer=init_ops.random_normal_initializer(stddev=0.001))
45  x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
46  kernel = variable_scope.get_variable(
47      'DW2', [2, 2, 6, 12],
48      dtypes.float32,
49      initializer=init_ops.random_normal_initializer(stddev=0.001))
50  x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME')
51  return x
52
53
54def BuildFullModel():
55  """Build the full model with conv,rnn,opt."""
56  seq = []
57  for i in range(4):
58    with variable_scope.variable_scope('inp_%d' % i):
59      seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1]))
60
61  cell = rnn_cell.BasicRNNCell(16)
62  out = rnn.dynamic_rnn(
63      cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0]
64
65  target = array_ops.ones_like(out)
66  loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out))
67  sgd_op = gradient_descent.GradientDescentOptimizer(1e-2)
68  return sgd_op.minimize(loss)
69
70
71def BuildSplittableModel():
72  """Build a small model that can be run partially in each step."""
73  image = array_ops.zeros([2, 6, 6, 3])
74
75  kernel1 = variable_scope.get_variable(
76      'DW', [3, 3, 3, 6],
77      dtypes.float32,
78      initializer=init_ops.random_normal_initializer(stddev=0.001))
79  r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME')
80
81  kernel2 = variable_scope.get_variable(
82      'DW2', [2, 3, 3, 6],
83      dtypes.float32,
84      initializer=init_ops.random_normal_initializer(stddev=0.001))
85  r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME')
86
87  r3 = r1 + r2
88  return r1, r2, r3
89
90
91def SearchTFProfNode(node, name):
92  """Search a node in the tree."""
93  if node.name == name:
94    return node
95  for c in node.children:
96    r = SearchTFProfNode(c, name)
97    if r: return r
98  return None
99
100
101@contextlib.contextmanager
102def ProfilerFromFile(profile_file):
103  """Initialize a profiler from profile file."""
104  print_mdl.ProfilerFromFile(compat.as_bytes(profile_file))
105  profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler)
106  yield profiler
107  print_mdl.DeleteProfiler()
108
109
110def CheckAndRemoveDoc(profile):
111  assert 'Doc:' in profile
112  start_pos = profile.find('Profile:')
113  return profile[start_pos + 9:]
114