xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/python/transform_graph_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StatSummarizer Python wrapper."""
16
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.core.framework import graph_pb2
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_util
21from tensorflow.python.platform import test
22from tensorflow.tools.graph_transforms import TransformGraph
23
24
25class TransformGraphTest(test.TestCase):
26
27  # This test constructs a graph with a relu op that's not used by the normal
28  # inference path, and then tests that the strip_unused transform removes it as
29  # expected.
30  def testTransformGraph(self):
31    input_graph_def = graph_pb2.GraphDef()
32
33    const_op1 = input_graph_def.node.add()
34    const_op1.op = "Const"
35    const_op1.name = "const_op1"
36    const_op1.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
37        type=dtypes.float32.as_datatype_enum))
38    const_op1.attr["value"].CopyFrom(
39        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
40            [1, 2], dtypes.float32, [1, 2])))
41
42    const_op2 = input_graph_def.node.add()
43    const_op2.op = "Const"
44    const_op2.name = "const_op2"
45    const_op2.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(
46        type=dtypes.float32.as_datatype_enum))
47    const_op2.attr["value"].CopyFrom(
48        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
49            [3, 4], dtypes.float32, [1, 2])))
50
51    # Create an add that has two constants as inputs.
52    add_op = input_graph_def.node.add()
53    add_op.op = "Add"
54    add_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
55        type=dtypes.float32.as_datatype_enum))
56    add_op.name = "add_op"
57    add_op.input.extend(["const_op1", "const_op2"])
58
59    # Create a relu that reads from the add.
60    relu_op = input_graph_def.node.add()
61    relu_op.op = "Relu"
62    relu_op.attr["T"].CopyFrom(attr_value_pb2.AttrValue(
63        type=dtypes.float32.as_datatype_enum))
64    relu_op.name = "relu_op"
65    relu_op.input.extend(["add_op"])
66
67    # We're specifying that add_op is the final output, and so the relu isn't
68    # needed.
69    input_names = []
70    output_names = ["add_op"]
71    transforms = ["strip_unused_nodes"]
72    transformed_graph_def = TransformGraph(input_graph_def, input_names,
73                                           output_names, transforms)
74
75    # We expect that the relu is no longer present after running the transform.
76    for node in transformed_graph_def.node:
77      self.assertNotEqual("Relu", node.op)
78
79
80if __name__ == "__main__":
81  test.main()
82