1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A test lib that defines some models.""" 16import contextlib 17 18from tensorflow.python.framework import dtypes 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import init_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 23from tensorflow.python.ops import nn_ops 24from tensorflow.python.ops import rnn 25from tensorflow.python.ops import rnn_cell 26from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 27from tensorflow.python.ops import variable_scope 28from tensorflow.python.profiler import model_analyzer 29from tensorflow.python.training import gradient_descent 30from tensorflow.python.util import _pywrap_tfprof as print_mdl 31from tensorflow.python.util import compat 32 33 34def BuildSmallModel(): 35 """Build a small forward conv model.""" 36 image = array_ops.zeros([2, 6, 6, 3]) 37 _ = variable_scope.get_variable( 38 'ScalarW', [], 39 dtypes.float32, 40 initializer=init_ops.random_normal_initializer(stddev=0.001)) 41 kernel = variable_scope.get_variable( 42 'DW', [3, 3, 3, 6], 43 dtypes.float32, 44 initializer=init_ops.random_normal_initializer(stddev=0.001)) 45 x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME') 46 kernel = variable_scope.get_variable( 47 'DW2', [2, 2, 6, 12], 48 dtypes.float32, 49 initializer=init_ops.random_normal_initializer(stddev=0.001)) 50 x = nn_ops.conv2d(x, kernel, [1, 2, 2, 1], padding='SAME') 51 return x 52 53 54def BuildFullModel(): 55 """Build the full model with conv,rnn,opt.""" 56 seq = [] 57 for i in range(4): 58 with variable_scope.variable_scope('inp_%d' % i): 59 seq.append(array_ops.reshape(BuildSmallModel(), [2, 1, -1])) 60 61 cell = rnn_cell.BasicRNNCell(16) 62 out = rnn.dynamic_rnn( 63 cell, array_ops.concat(seq, axis=1), dtype=dtypes.float32)[0] 64 65 target = array_ops.ones_like(out) 66 loss = nn_ops.l2_loss(math_ops.reduce_mean(target - out)) 67 sgd_op = gradient_descent.GradientDescentOptimizer(1e-2) 68 return sgd_op.minimize(loss) 69 70 71def BuildSplittableModel(): 72 """Build a small model that can be run partially in each step.""" 73 image = array_ops.zeros([2, 6, 6, 3]) 74 75 kernel1 = variable_scope.get_variable( 76 'DW', [3, 3, 3, 6], 77 dtypes.float32, 78 initializer=init_ops.random_normal_initializer(stddev=0.001)) 79 r1 = nn_ops.conv2d(image, kernel1, [1, 2, 2, 1], padding='SAME') 80 81 kernel2 = variable_scope.get_variable( 82 'DW2', [2, 3, 3, 6], 83 dtypes.float32, 84 initializer=init_ops.random_normal_initializer(stddev=0.001)) 85 r2 = nn_ops.conv2d(image, kernel2, [1, 2, 2, 1], padding='SAME') 86 87 r3 = r1 + r2 88 return r1, r2, r3 89 90 91def SearchTFProfNode(node, name): 92 """Search a node in the tree.""" 93 if node.name == name: 94 return node 95 for c in node.children: 96 r = SearchTFProfNode(c, name) 97 if r: return r 98 return None 99 100 101@contextlib.contextmanager 102def ProfilerFromFile(profile_file): 103 """Initialize a profiler from profile file.""" 104 print_mdl.ProfilerFromFile(compat.as_bytes(profile_file)) 105 profiler = model_analyzer.Profiler.__new__(model_analyzer.Profiler) 106 yield profiler 107 print_mdl.DeleteProfiler() 108 109 110def CheckAndRemoveDoc(profile): 111 assert 'Doc:' in profile 112 start_pos = profile.find('Profile:') 113 return profile[start_pos + 9:] 114