xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
16 
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
27 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
28 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30 
31 namespace toco {
32 
33 using tensorflow::GraphDef;
34 using tensorflow::NodeDef;
35 
AddNodeToGraph(const NodeDef & node,const std::vector<std::string> & cluster_names,GraphDef * graph)36 void AddNodeToGraph(const NodeDef& node,
37                     const std::vector<std::string>& cluster_names,
38                     GraphDef* graph) {
39   NodeDef* new_node = graph->add_node();
40   new_node->set_op(node.op());
41   new_node->set_name(node.name());
42   new_node->set_device(node.device());
43   // If the inputs are coming from a node which belongs to another cluster, then
44   // those inputs are renamed to the source cluster name. Otherwise the original
45   // input name is used.
46   for (const std::string& node_input : node.input()) {
47     bool input_from_cluster = false;
48     for (const std::string& cluster_name : cluster_names) {
49       if (StrContains(node_input, cluster_name) &&
50           !StrContains(node.name(), cluster_name)) {
51         new_node->add_input(cluster_name);
52         input_from_cluster = true;
53         break;
54       }
55     }
56     if (!input_from_cluster) {
57       new_node->add_input(node_input);
58     }
59   }
60   for (const auto& attr : node.attr()) {
61     (*new_node->mutable_attr())[attr.first] = attr.second;
62   }
63 }
64 
FindCluster(const ClusterFactoryInterface & cluster_factory,const GraphDef & graph_def,std::unordered_map<std::string,bool> * is_node_in_cluster,std::vector<std::unique_ptr<Cluster>> * clusters)65 bool FindCluster(const ClusterFactoryInterface& cluster_factory,
66                  const GraphDef& graph_def,
67                  std::unordered_map<std::string, bool>* is_node_in_cluster,
68                  std::vector<std::unique_ptr<Cluster>>* clusters) {
69   for (const NodeDef& node : graph_def.node()) {
70     // If the node is not assigned to any cluster, then we check if it belong to
71     // the cluster_factory.
72     bool node_in_cluster = (*is_node_in_cluster)[node.name()];
73     if (!node_in_cluster) {
74       std::unique_ptr<Cluster> cluster =
75           cluster_factory.CreateCluster(node, graph_def);
76       if (cluster) {
77         // Label all the nodes in is_node_in_cluster which are in this cluster
78         // as belonged to this cluster.
79         for (const NodeDef* cluster_node : cluster->GetNodes()) {
80           (*is_node_in_cluster)[cluster_node->name()] = true;
81         }
82         clusters->push_back(std::move(cluster));
83       }
84     }
85   }
86   return (!clusters->empty());
87 }
88 
MaybeResolveClusters(const GraphDef & graph_def,const std::vector<ClusterFactoryInterface * > & cluster_factories)89 std::unique_ptr<GraphDef> MaybeResolveClusters(
90     const GraphDef& graph_def,
91     const std::vector<ClusterFactoryInterface*>& cluster_factories) {
92   std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
93   // The structure to keep track of which cluster each node is assigned to, and
94   // to initialize them to all un-assigned,
95   std::unordered_map<std::string, bool> is_node_in_cluster;
96   for (const NodeDef& node : graph_def.node()) {
97     is_node_in_cluster[node.name()] = false;
98   }
99 
100   std::vector<std::string> cluster_names;
101   std::vector<std::unique_ptr<Cluster>> all_clusters;
102   // Find the clusters for all available cluster factories.
103   for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
104     std::vector<std::unique_ptr<Cluster>> clusters;
105     if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
106                     &clusters)) {
107       for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
108         cluster_names.push_back((*itr)->GetName());
109         (*itr)->CreateNodes();
110         all_clusters.push_back(std::move(*itr));
111       }
112     }
113   }
114 
115   for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
116     for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
117          cluster->GetNewNodes()) {
118       // Add it to the output GraphDef.
119       AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
120     }
121   }
122 
123   // Add any node which is not part of a cluster.
124   for (const NodeDef& node : graph_def.node()) {
125     bool node_in_cluster = is_node_in_cluster[node.name()];
126     if (!node_in_cluster) {
127       AddNodeToGraph(node, cluster_names, pruned_graph.get());
128     }
129   }
130 
131   if (pruned_graph->node_size() == 0) {
132     return nullptr;
133   } else {
134     return pruned_graph;
135   }
136 }
137 
MaybeReplaceCompositeSubgraph(const GraphDef & tf_graph)138 std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
139     const GraphDef& tf_graph) {
140   SvdfClusterFactory svdf_cluster_factory;
141 
142   std::vector<ClusterFactoryInterface*> cluster_factories;
143   cluster_factories.push_back(&svdf_cluster_factory);
144 
145   std::unique_ptr<GraphDef> pruned_graph =
146       MaybeResolveClusters(tf_graph, cluster_factories);
147 
148   // Copy function definitions
149   if (pruned_graph) {
150     *(pruned_graph->mutable_library()) = tf_graph.library();
151   }
152   return pruned_graph;
153 }
154 
155 }  // end namespace toco
156