xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cost_analyzer.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Provides a proper python API for the symbols exported through swig."""
16
17from tensorflow.python.grappler import _pywrap_cost_analyzer as tf_wrap
18from tensorflow.python.grappler import cluster as gcluster
19from tensorflow.python.grappler import item as gitem
20
21
22def GenerateCostReport(metagraph,
23                       per_node_report=False,
24                       verbose=False,
25                       cluster=None):
26  """Analyze the cost of each TensorFlow op and node in the provided metagraph.
27
28  Args:
29    metagraph: A TensorFlow MetaGraphDef.
30    per_node_report: by default the report contains stats aggregated on a per op
31      type basis, setting per_node_report to True adds results for each
32      individual node to the report.
33    verbose: Prints out the entire operation proto instead of a summary table.
34    cluster: Analyze the costs using the specified cluster, or the local machine
35      if no cluster was specified.
36
37  Returns:
38    A string of cost report.
39  """
40  if cluster is None:
41    cluster = gcluster.Cluster(disable_detailed_stats=False)
42
43  return tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
44                                    per_node_report, verbose,
45                                    cluster.tf_cluster)
46
47
48def GenerateMemoryReport(metagraph, detailed_report=True, cluster=None):
49  """Analyze the peak memory usage for the provided metagraph.
50
51  Args:
52    metagraph: A TensorFlow MetaGraphDef.
53    detailed_report: print the live tensors in addition to the peak memory
54      usage.
55    cluster: Analyze the memory using the specified cluster, or the local
56      machine if no cluster was specified.
57
58  Returns:
59    A string with the formatted memory usage.
60  """
61  if cluster is None:
62    cluster = gcluster.Cluster(
63        disable_detailed_stats=True, disable_timeline=True)
64
65  item = gitem.Item(metagraph)
66  peak_usage = cluster.DeterminePeakMemoryUsage(item)
67  report = ""
68  for device, snapshot in peak_usage.items():
69    peak_usage = snapshot[0]
70    report += "Peak usage for device " + device + ": " + str(
71        peak_usage) + " bytes\n"
72    if detailed_report:
73      live_tensors = snapshot[1]
74      for tensor in live_tensors:
75        op_name = tensor[0]
76        output_id = tensor[1]
77        mem_used = tensor[2]
78        report += "  " + str(op_name) + ":" + str(output_id) + " uses " + str(
79            mem_used) + " bytes\n"
80
81  return report
82