1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A python interface for Grappler items.""" 16 17from tensorflow.core.grappler.costs import op_performance_data_pb2 18from tensorflow.core.protobuf import meta_graph_pb2 19from tensorflow.python.grappler import _pywrap_tf_item as tf_item 20 21 22class Item(object): 23 """GrapplerItem.""" 24 25 def __init__(self, 26 metagraph, 27 ignore_colocation=True, 28 ignore_user_placement=False): 29 """Creates an Item. 30 31 Args: 32 metagraph: a TensorFlow metagraph. 33 ignore_colocation: if set, the tool will ignore all the colocation 34 constraints generated by TensorFlow. 35 ignore_user_placement: if set, all the placement annotations annotated in 36 the metagraph will be ignored. 37 Raises: 38 ValueError: the metagraph is incomplete or invalid. 39 """ 40 self._metagraph = metagraph 41 self._item_graph = meta_graph_pb2.MetaGraphDef() 42 self._item_graph.CopyFrom(metagraph) 43 self._ignore_colocation = ignore_colocation 44 self._ignore_user_placement = ignore_user_placement 45 self._tf_item = None 46 self._BuildTFItem() 47 48 def IdentifyImportantOps(self, sort_topologically=False): 49 return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically) 50 51 def GetOpProperties(self): 52 """Get Op properties.""" 53 props = tf_item.TF_GetOpProperties(self.tf_item) 54 properties = {} 55 for key, values in props.items(): 56 prop = [] 57 for value in values: 58 # TODO(petebu): Make this conversion to a dictionary be done in the C++ 59 # wrapper for performance. 60 prop.append( 61 op_performance_data_pb2.OpInfo.TensorProperties.FromString(value)) 62 properties[key] = prop 63 return properties 64 65 def GetColocationGroups(self): 66 """Return a list of hard colocation constraints. 67 68 All the nodes in a colocation tuple must be placed on the same device for 69 the model to work. 70 71 Returns: 72 A list of colocation tuples. 73 """ 74 return tf_item.TF_GetColocationGroups(self.tf_item) 75 76 @property 77 def metagraph(self): 78 return self._metagraph 79 80 @property 81 def tf_item(self): 82 if self._item_graph != self._metagraph: 83 self._BuildTFItem() 84 self._item_graph.CopyFrom(self._metagraph) 85 return self._tf_item 86 87 def _BuildTFItem(self): 88 self._tf_item = tf_item.TF_NewItem(self._metagraph.SerializeToString(), 89 self._ignore_colocation, 90 self._ignore_user_placement) 91