xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/item.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A python interface for Grappler items."""
16
17from tensorflow.core.grappler.costs import op_performance_data_pb2
18from tensorflow.core.protobuf import meta_graph_pb2
19from tensorflow.python.grappler import _pywrap_tf_item as tf_item
20
21
22class Item(object):
23  """GrapplerItem."""
24
25  def __init__(self,
26               metagraph,
27               ignore_colocation=True,
28               ignore_user_placement=False):
29    """Creates an Item.
30
31    Args:
32      metagraph: a TensorFlow metagraph.
33      ignore_colocation: if set, the tool will ignore all the colocation
34        constraints generated by TensorFlow.
35      ignore_user_placement: if set, all the placement annotations annotated in
36        the metagraph will be ignored.
37    Raises:
38      ValueError: the metagraph is incomplete or invalid.
39    """
40    self._metagraph = metagraph
41    self._item_graph = meta_graph_pb2.MetaGraphDef()
42    self._item_graph.CopyFrom(metagraph)
43    self._ignore_colocation = ignore_colocation
44    self._ignore_user_placement = ignore_user_placement
45    self._tf_item = None
46    self._BuildTFItem()
47
48  def IdentifyImportantOps(self, sort_topologically=False):
49    return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
50
51  def GetOpProperties(self):
52    """Get Op properties."""
53    props = tf_item.TF_GetOpProperties(self.tf_item)
54    properties = {}
55    for key, values in props.items():
56      prop = []
57      for value in values:
58        # TODO(petebu): Make this conversion to a dictionary be done in the C++
59        # wrapper for performance.
60        prop.append(
61            op_performance_data_pb2.OpInfo.TensorProperties.FromString(value))
62      properties[key] = prop
63    return properties
64
65  def GetColocationGroups(self):
66    """Return a list of hard colocation constraints.
67
68    All the nodes in a colocation tuple must be placed on the same device for
69    the model to work.
70
71    Returns:
72      A list of colocation tuples.
73    """
74    return tf_item.TF_GetColocationGroups(self.tf_item)
75
76  @property
77  def metagraph(self):
78    return self._metagraph
79
80  @property
81  def tf_item(self):
82    if self._item_graph != self._metagraph:
83      self._BuildTFItem()
84      self._item_graph.CopyFrom(self._metagraph)
85    return self._tf_item
86
87  def _BuildTFItem(self):
88    self._tf_item = tf_item.TF_NewItem(self._metagraph.SerializeToString(),
89                                       self._ignore_colocation,
90                                       self._ignore_user_placement)
91