1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""A tool for cost analysis.""" 16 17import argparse 18import sys 19 20from absl import app 21 22from google.protobuf import message 23from google.protobuf import text_format 24from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.core.protobuf import meta_graph_pb2 28from tensorflow.core.protobuf import saved_model_pb2 29from tensorflow.python.framework import importer 30from tensorflow.python.framework import ops 31from tensorflow.python.grappler import cost_analyzer 32from tensorflow.python.grappler import tf_optimizer 33from tensorflow.python.platform import gfile 34from tensorflow.python.training import saver 35 36 37def get_metagraph(): 38 """Constructs and returns a MetaGraphDef from the input file.""" 39 with gfile.GFile(FLAGS.input) as input_file: 40 input_data = input_file.read() 41 try: 42 saved_model = saved_model_pb2.SavedModel() 43 text_format.Merge(input_data, saved_model) 44 meta_graph = saved_model.meta_graphs[0] 45 except text_format.ParseError: 46 try: 47 saved_model.ParseFromString(input_data) 48 meta_graph = saved_model.meta_graphs[0] 49 except message.DecodeError: 50 try: 51 meta_graph = meta_graph_pb2.MetaGraphDef() 52 text_format.Merge(input_data, meta_graph) 53 except text_format.ParseError: 54 try: 55 meta_graph.ParseFromString(input_data) 56 except message.DecodeError: 57 try: 58 graph_def = graph_pb2.GraphDef() 59 text_format.Merge(input_data, graph_def) 60 except text_format.ParseError: 61 try: 62 graph_def.ParseFromString(input_data) 63 except message.DecodeError: 64 raise ValueError(f"Invalid input file: {FLAGS.input}.") 65 importer.import_graph_def(graph_def, name="") 66 graph = ops.get_default_graph() 67 meta_graph = saver.export_meta_graph( 68 graph_def=graph.as_graph_def(), graph=graph) 69 if FLAGS.fetch is not None: 70 fetch_collection = meta_graph_pb2.CollectionDef() 71 for fetch in FLAGS.fetch.split(","): 72 fetch_collection.node_list.value.append(fetch) 73 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 74 return meta_graph 75 76 77def main(_): 78 metagraph = get_metagraph() 79 config = config_pb2.ConfigProto() 80 if FLAGS.rewriter_config is not None: 81 text_format.Merge(FLAGS.rewriter_config, 82 config.graph_options.rewrite_options) 83 optimized_graph = tf_optimizer.OptimizeGraph(config, metagraph) 84 metagraph.graph_def.CopyFrom(optimized_graph) 85 86 report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report, 87 FLAGS.verbose) 88 print(report) 89 if FLAGS.memory_report: 90 report = cost_analyzer.GenerateMemoryReport(metagraph) 91 print(report) 92 93 94if __name__ == "__main__": 95 parser = argparse.ArgumentParser() 96 parser.add_argument( 97 "--input", 98 type=str, 99 default=None, 100 help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in " 101 "either binary or text format.") 102 parser.add_argument( 103 "--fetch", 104 type=str, 105 default=None, 106 help="The names of the fetch node delimited by comma.") 107 parser.add_argument( 108 "--rewriter_config", 109 type=str, 110 default=None, 111 help="Configuration for the grappler optimizers, described as a " 112 "RewriterConfig protocol buffer. Usage example 1: " 113 "--rewriter_config='optimize_tensor_layout: true " 114 "disable_model_pruning: true'. Usage example 2: " 115 "--rewriter_config='optimizers: \"constfold\" optimizers: \"layout\"'") 116 parser.add_argument( 117 "--per_node_report", 118 action="store_true", 119 help="Generate per-node report. By default the report contains stats " 120 "aggregated on a per op type basis, per_node_report adds results " 121 "for each individual node to the report.") 122 parser.add_argument( 123 "--memory_report", 124 action="store_true", 125 help="Generate memory usage report.") 126 parser.add_argument( 127 "--verbose", 128 action="store_true", 129 help="Generate verbose reports. By default, succinct reports are used.") 130 FLAGS, unparsed = parser.parse_known_args() 131 app.run(main=main, argv=[sys.argv[0]] + unparsed) 132