1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
16
17 #include <string>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/node_def.pb.h"
26 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
27 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
28 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30
31 namespace toco {
32
33 using tensorflow::GraphDef;
34 using tensorflow::NodeDef;
35
AddNodeToGraph(const NodeDef & node,const std::vector<std::string> & cluster_names,GraphDef * graph)36 void AddNodeToGraph(const NodeDef& node,
37 const std::vector<std::string>& cluster_names,
38 GraphDef* graph) {
39 NodeDef* new_node = graph->add_node();
40 new_node->set_op(node.op());
41 new_node->set_name(node.name());
42 new_node->set_device(node.device());
43 // If the inputs are coming from a node which belongs to another cluster, then
44 // those inputs are renamed to the source cluster name. Otherwise the original
45 // input name is used.
46 for (const std::string& node_input : node.input()) {
47 bool input_from_cluster = false;
48 for (const std::string& cluster_name : cluster_names) {
49 if (StrContains(node_input, cluster_name) &&
50 !StrContains(node.name(), cluster_name)) {
51 new_node->add_input(cluster_name);
52 input_from_cluster = true;
53 break;
54 }
55 }
56 if (!input_from_cluster) {
57 new_node->add_input(node_input);
58 }
59 }
60 for (const auto& attr : node.attr()) {
61 (*new_node->mutable_attr())[attr.first] = attr.second;
62 }
63 }
64
FindCluster(const ClusterFactoryInterface & cluster_factory,const GraphDef & graph_def,std::unordered_map<std::string,bool> * is_node_in_cluster,std::vector<std::unique_ptr<Cluster>> * clusters)65 bool FindCluster(const ClusterFactoryInterface& cluster_factory,
66 const GraphDef& graph_def,
67 std::unordered_map<std::string, bool>* is_node_in_cluster,
68 std::vector<std::unique_ptr<Cluster>>* clusters) {
69 for (const NodeDef& node : graph_def.node()) {
70 // If the node is not assigned to any cluster, then we check if it belong to
71 // the cluster_factory.
72 bool node_in_cluster = (*is_node_in_cluster)[node.name()];
73 if (!node_in_cluster) {
74 std::unique_ptr<Cluster> cluster =
75 cluster_factory.CreateCluster(node, graph_def);
76 if (cluster) {
77 // Label all the nodes in is_node_in_cluster which are in this cluster
78 // as belonged to this cluster.
79 for (const NodeDef* cluster_node : cluster->GetNodes()) {
80 (*is_node_in_cluster)[cluster_node->name()] = true;
81 }
82 clusters->push_back(std::move(cluster));
83 }
84 }
85 }
86 return (!clusters->empty());
87 }
88
MaybeResolveClusters(const GraphDef & graph_def,const std::vector<ClusterFactoryInterface * > & cluster_factories)89 std::unique_ptr<GraphDef> MaybeResolveClusters(
90 const GraphDef& graph_def,
91 const std::vector<ClusterFactoryInterface*>& cluster_factories) {
92 std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
93 // The structure to keep track of which cluster each node is assigned to, and
94 // to initialize them to all un-assigned,
95 std::unordered_map<std::string, bool> is_node_in_cluster;
96 for (const NodeDef& node : graph_def.node()) {
97 is_node_in_cluster[node.name()] = false;
98 }
99
100 std::vector<std::string> cluster_names;
101 std::vector<std::unique_ptr<Cluster>> all_clusters;
102 // Find the clusters for all available cluster factories.
103 for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
104 std::vector<std::unique_ptr<Cluster>> clusters;
105 if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
106 &clusters)) {
107 for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
108 cluster_names.push_back((*itr)->GetName());
109 (*itr)->CreateNodes();
110 all_clusters.push_back(std::move(*itr));
111 }
112 }
113 }
114
115 for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
116 for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
117 cluster->GetNewNodes()) {
118 // Add it to the output GraphDef.
119 AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
120 }
121 }
122
123 // Add any node which is not part of a cluster.
124 for (const NodeDef& node : graph_def.node()) {
125 bool node_in_cluster = is_node_in_cluster[node.name()];
126 if (!node_in_cluster) {
127 AddNodeToGraph(node, cluster_names, pruned_graph.get());
128 }
129 }
130
131 if (pruned_graph->node_size() == 0) {
132 return nullptr;
133 } else {
134 return pruned_graph;
135 }
136 }
137
MaybeReplaceCompositeSubgraph(const GraphDef & tf_graph)138 std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
139 const GraphDef& tf_graph) {
140 SvdfClusterFactory svdf_cluster_factory;
141
142 std::vector<ClusterFactoryInterface*> cluster_factories;
143 cluster_factories.push_back(&svdf_cluster_factory);
144
145 std::unique_ptr<GraphDef> pruned_graph =
146 MaybeResolveClusters(tf_graph, cluster_factories);
147
148 // Copy function definitions
149 if (pruned_graph) {
150 *(pruned_graph->mutable_library()) = tf_graph.library();
151 }
152 return pruned_graph;
153 }
154
155 } // end namespace toco
156