xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cost_analyzer_tool.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""A tool for cost analysis."""
16
17import argparse
18import sys
19
20from absl import app
21
22from google.protobuf import message
23from google.protobuf import text_format
24from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op  # pylint: disable=unused-import
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.protobuf import meta_graph_pb2
28from tensorflow.core.protobuf import saved_model_pb2
29from tensorflow.python.framework import importer
30from tensorflow.python.framework import ops
31from tensorflow.python.grappler import cost_analyzer
32from tensorflow.python.grappler import tf_optimizer
33from tensorflow.python.platform import gfile
34from tensorflow.python.training import saver
35
36
37def get_metagraph():
38  """Constructs and returns a MetaGraphDef from the input file."""
39  with gfile.GFile(FLAGS.input) as input_file:
40    input_data = input_file.read()
41    try:
42      saved_model = saved_model_pb2.SavedModel()
43      text_format.Merge(input_data, saved_model)
44      meta_graph = saved_model.meta_graphs[0]
45    except text_format.ParseError:
46      try:
47        saved_model.ParseFromString(input_data)
48        meta_graph = saved_model.meta_graphs[0]
49      except message.DecodeError:
50        try:
51          meta_graph = meta_graph_pb2.MetaGraphDef()
52          text_format.Merge(input_data, meta_graph)
53        except text_format.ParseError:
54          try:
55            meta_graph.ParseFromString(input_data)
56          except message.DecodeError:
57            try:
58              graph_def = graph_pb2.GraphDef()
59              text_format.Merge(input_data, graph_def)
60            except text_format.ParseError:
61              try:
62                graph_def.ParseFromString(input_data)
63              except message.DecodeError:
64                raise ValueError(f"Invalid input file: {FLAGS.input}.")
65            importer.import_graph_def(graph_def, name="")
66            graph = ops.get_default_graph()
67            meta_graph = saver.export_meta_graph(
68                graph_def=graph.as_graph_def(), graph=graph)
69  if FLAGS.fetch is not None:
70    fetch_collection = meta_graph_pb2.CollectionDef()
71    for fetch in FLAGS.fetch.split(","):
72      fetch_collection.node_list.value.append(fetch)
73    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
74  return meta_graph
75
76
77def main(_):
78  metagraph = get_metagraph()
79  config = config_pb2.ConfigProto()
80  if FLAGS.rewriter_config is not None:
81    text_format.Merge(FLAGS.rewriter_config,
82                      config.graph_options.rewrite_options)
83  optimized_graph = tf_optimizer.OptimizeGraph(config, metagraph)
84  metagraph.graph_def.CopyFrom(optimized_graph)
85
86  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report,
87                                            FLAGS.verbose)
88  print(report)
89  if FLAGS.memory_report:
90    report = cost_analyzer.GenerateMemoryReport(metagraph)
91    print(report)
92
93
94if __name__ == "__main__":
95  parser = argparse.ArgumentParser()
96  parser.add_argument(
97      "--input",
98      type=str,
99      default=None,
100      help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in "
101      "either binary or text format.")
102  parser.add_argument(
103      "--fetch",
104      type=str,
105      default=None,
106      help="The names of the fetch node delimited by comma.")
107  parser.add_argument(
108      "--rewriter_config",
109      type=str,
110      default=None,
111      help="Configuration for the grappler optimizers, described as a "
112      "RewriterConfig protocol buffer. Usage example 1: "
113      "--rewriter_config='optimize_tensor_layout: true "
114      "disable_model_pruning: true'. Usage example 2: "
115      "--rewriter_config='optimizers: \"constfold\" optimizers: \"layout\"'")
116  parser.add_argument(
117      "--per_node_report",
118      action="store_true",
119      help="Generate per-node report. By default the report contains stats "
120      "aggregated on a per op type basis, per_node_report adds results "
121      "for each individual node to the report.")
122  parser.add_argument(
123      "--memory_report",
124      action="store_true",
125      help="Generate memory usage report.")
126  parser.add_argument(
127      "--verbose",
128      action="store_true",
129      help="Generate verbose reports. By default, succinct reports are used.")
130  FLAGS, unparsed = parser.parse_known_args()
131  app.run(main=main, argv=[sys.argv[0]] + unparsed)
132