1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A python interface for Grappler clusters.""" 16 17import contextlib 18 19from tensorflow.core.framework import step_stats_pb2 20from tensorflow.core.grappler.costs import op_performance_data_pb2 21from tensorflow.core.protobuf import device_properties_pb2 22from tensorflow.python.grappler import _pywrap_tf_cluster as tf_cluster 23 24 25class Cluster(object): 26 """Grappler Clusters.""" 27 28 def __init__(self, 29 allow_soft_placement=True, 30 disable_detailed_stats=True, 31 disable_timeline=True, 32 devices=None): 33 """Creates a Cluster. 34 35 Args: 36 allow_soft_placement: If True, TF will automatically fix illegal 37 placements instead of erroring out if the placement isn't legal. 38 disable_detailed_stats: If True, detailed statistics will not be 39 available. 40 disable_timeline: If True, the timeline information will not be reported. 41 devices: A list of devices of type device_properties_pb2.NamedDevice. 42 If None, a device list will be created based on the spec of 43 the local machine. 44 """ 45 self._tf_cluster = None 46 self._generate_timeline = not disable_timeline 47 48 if devices is None: 49 self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement, 50 disable_detailed_stats) 51 else: 52 devices_serialized = [device.SerializeToString() for device in devices] 53 self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized) 54 55 def Shutdown(self): 56 if self._tf_cluster is not None: 57 tf_cluster.TF_ShutdownCluster(self._tf_cluster) 58 self._tf_cluster = None 59 60 def __del__(self): 61 self.Shutdown() 62 63 @property 64 def tf_cluster(self): 65 return self._tf_cluster 66 67 def ListDevices(self): 68 """Returns a list of available hardware devices.""" 69 if self._tf_cluster is None: 70 return [] 71 return [device_properties_pb2.NamedDevice.FromString(device) 72 for device in tf_cluster.TF_ListDevices(self._tf_cluster)] 73 74 def ListAvailableOps(self): 75 """Returns a list of all available operations (sorted alphabetically).""" 76 return tf_cluster.TF_ListAvailableOps() 77 78 def GetSupportedDevices(self, item): 79 return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item) 80 81 def EstimatePerformance(self, device): 82 return tf_cluster.TF_EstimatePerformance(device.SerializeToString()) 83 84 def MeasureCosts(self, item): 85 """Returns the cost of running the specified item. 86 87 Args: 88 item: The item for which to measure the costs. 89 Returns: The triplet op_perfs, runtime, step_stats. 90 """ 91 op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts( 92 item.tf_item, self._tf_cluster, self._generate_timeline) 93 94 op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes) 95 for op_perf_bytes in op_perf_bytes_list] 96 return (op_perfs, run_time, 97 step_stats_pb2.StepStats.FromString(step_stats_bytes)) 98 99 def DeterminePeakMemoryUsage(self, item): 100 """Returns a snapshot of the peak memory usage. 101 102 Args: 103 item: The item for which to measure the costs. 104 Returns: A hashtable indexed by device name. 105 """ 106 return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item, 107 self._tf_cluster) 108 109 110@contextlib.contextmanager 111def Provision(allow_soft_placement=True, 112 disable_detailed_stats=True, 113 disable_timeline=True, 114 devices=None): 115 cluster = Cluster(allow_soft_placement, disable_detailed_stats, 116 disable_timeline, devices) 117 yield cluster 118 cluster.Shutdown() 119