xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cluster.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A python interface for Grappler clusters."""
16
17import contextlib
18
19from tensorflow.core.framework import step_stats_pb2
20from tensorflow.core.grappler.costs import op_performance_data_pb2
21from tensorflow.core.protobuf import device_properties_pb2
22from tensorflow.python.grappler import _pywrap_tf_cluster as tf_cluster
23
24
25class Cluster(object):
26  """Grappler Clusters."""
27
28  def __init__(self,
29               allow_soft_placement=True,
30               disable_detailed_stats=True,
31               disable_timeline=True,
32               devices=None):
33    """Creates a Cluster.
34
35    Args:
36      allow_soft_placement: If True, TF will automatically fix illegal
37        placements instead of erroring out if the placement isn't legal.
38      disable_detailed_stats: If True, detailed statistics will not be
39        available.
40      disable_timeline: If True, the timeline information will not be reported.
41      devices: A list of devices of type device_properties_pb2.NamedDevice.
42        If None, a device list will be created based on the spec of
43        the local machine.
44    """
45    self._tf_cluster = None
46    self._generate_timeline = not disable_timeline
47
48    if devices is None:
49      self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement,
50                                                  disable_detailed_stats)
51    else:
52      devices_serialized = [device.SerializeToString() for device in devices]
53      self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized)
54
55  def Shutdown(self):
56    if self._tf_cluster is not None:
57      tf_cluster.TF_ShutdownCluster(self._tf_cluster)
58      self._tf_cluster = None
59
60  def __del__(self):
61    self.Shutdown()
62
63  @property
64  def tf_cluster(self):
65    return self._tf_cluster
66
67  def ListDevices(self):
68    """Returns a list of available hardware devices."""
69    if self._tf_cluster is None:
70      return []
71    return [device_properties_pb2.NamedDevice.FromString(device)
72            for device in tf_cluster.TF_ListDevices(self._tf_cluster)]
73
74  def ListAvailableOps(self):
75    """Returns a list of all available operations (sorted alphabetically)."""
76    return tf_cluster.TF_ListAvailableOps()
77
78  def GetSupportedDevices(self, item):
79    return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item)
80
81  def EstimatePerformance(self, device):
82    return tf_cluster.TF_EstimatePerformance(device.SerializeToString())
83
84  def MeasureCosts(self, item):
85    """Returns the cost of running the specified item.
86
87    Args:
88      item: The item for which to measure the costs.
89    Returns: The triplet op_perfs, runtime, step_stats.
90    """
91    op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
92        item.tf_item, self._tf_cluster, self._generate_timeline)
93
94    op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
95                for op_perf_bytes in op_perf_bytes_list]
96    return (op_perfs, run_time,
97            step_stats_pb2.StepStats.FromString(step_stats_bytes))
98
99  def DeterminePeakMemoryUsage(self, item):
100    """Returns a snapshot of the peak memory usage.
101
102    Args:
103      item: The item for which to measure the costs.
104    Returns: A hashtable indexed by device name.
105    """
106    return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item,
107                                                  self._tf_cluster)
108
109
110@contextlib.contextmanager
111def Provision(allow_soft_placement=True,
112              disable_detailed_stats=True,
113              disable_timeline=True,
114              devices=None):
115  cluster = Cluster(allow_soft_placement, disable_detailed_stats,
116                    disable_timeline, devices)
117  yield cluster
118  cluster.Shutdown()
119