1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for StatSummarizer Python wrapper.""" 16 17from tensorflow.core.framework import attr_value_pb2 18from tensorflow.core.framework import graph_pb2 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import tensor_util 21from tensorflow.python.platform import test 22from tensorflow.tools.graph_transforms import TransformGraph 23 24 25class TransformGraphTest(test.TestCase): 26 27 # This test constructs a graph with a relu op that's not used by the normal 28 # inference path, and then tests that the strip_unused transform removes it as 29 # expected. 30 def testTransformGraph(self): 31 input_graph_def = graph_pb2.GraphDef() 32 33 const_op1 = input_graph_def.node.add() 34 const_op1.op = "Const" 35 const_op1.name = "const_op1" 36 const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( 37 type=dtypes.float32.as_datatype_enum)) 38 const_op1.attr["value"].CopyFrom( 39 attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( 40 [1, 2], dtypes.float32, [1, 2]))) 41 42 const_op2 = input_graph_def.node.add() 43 const_op2.op = "Const" 44 const_op2.name = "const_op2" 45 const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue( 46 type=dtypes.float32.as_datatype_enum)) 47 const_op2.attr["value"].CopyFrom( 48 attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( 49 [3, 4], dtypes.float32, [1, 2]))) 50 51 # Create an add that has two constants as inputs. 52 add_op = input_graph_def.node.add() 53 add_op.op = "Add" 54 add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( 55 type=dtypes.float32.as_datatype_enum)) 56 add_op.name = "add_op" 57 add_op.input.extend(["const_op1", "const_op2"]) 58 59 # Create a relu that reads from the add. 60 relu_op = input_graph_def.node.add() 61 relu_op.op = "Relu" 62 relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue( 63 type=dtypes.float32.as_datatype_enum)) 64 relu_op.name = "relu_op" 65 relu_op.input.extend(["add_op"]) 66 67 # We're specifying that add_op is the final output, and so the relu isn't 68 # needed. 69 input_names = [] 70 output_names = ["add_op"] 71 transforms = ["strip_unused_nodes"] 72 transformed_graph_def = TransformGraph(input_graph_def, input_names, 73 output_names, transforms) 74 75 # We expect that the relu is no longer present after running the transform. 76 for node in transformed_graph_def.node: 77 self.assertNotEqual("Relu", node.op) 78 79 80if __name__ == "__main__": 81 test.main() 82