1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Provides a proper python API for the symbols exported through swig.""" 16 17from tensorflow.python.grappler import _pywrap_cost_analyzer as tf_wrap 18from tensorflow.python.grappler import cluster as gcluster 19from tensorflow.python.grappler import item as gitem 20 21 22def GenerateCostReport(metagraph, 23 per_node_report=False, 24 verbose=False, 25 cluster=None): 26 """Analyze the cost of each TensorFlow op and node in the provided metagraph. 27 28 Args: 29 metagraph: A TensorFlow MetaGraphDef. 30 per_node_report: by default the report contains stats aggregated on a per op 31 type basis, setting per_node_report to True adds results for each 32 individual node to the report. 33 verbose: Prints out the entire operation proto instead of a summary table. 34 cluster: Analyze the costs using the specified cluster, or the local machine 35 if no cluster was specified. 36 37 Returns: 38 A string of cost report. 39 """ 40 if cluster is None: 41 cluster = gcluster.Cluster(disable_detailed_stats=False) 42 43 return tf_wrap.GenerateCostReport(metagraph.SerializeToString(), 44 per_node_report, verbose, 45 cluster.tf_cluster) 46 47 48def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None): 49 """Analyze the peak memory usage for the provided metagraph. 50 51 Args: 52 metagraph: A TensorFlow MetaGraphDef. 53 detailed_report: print the live tensors in addition to the peak memory 54 usage. 55 cluster: Analyze the memory using the specified cluster, or the local 56 machine if no cluster was specified. 57 58 Returns: 59 A string with the formatted memory usage. 60 """ 61 if cluster is None: 62 cluster = gcluster.Cluster( 63 disable_detailed_stats=True, disable_timeline=True) 64 65 item = gitem.Item(metagraph) 66 peak_usage = cluster.DeterminePeakMemoryUsage(item) 67 report = "" 68 for device, snapshot in peak_usage.items(): 69 peak_usage = snapshot[0] 70 report += "Peak usage for device " + device + ": " + str( 71 peak_usage) + " bytes\n" 72 if detailed_report: 73 live_tensors = snapshot[1] 74 for tensor in live_tensors: 75 op_name = tensor[0] 76 output_id = tensor[1] 77 mem_used = tensor[2] 78 report += " " + str(op_name) + ":" + str(output_id) + " uses " + str( 79 mem_used) + " bytes\n" 80 81 return report 82