1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
16
17 #include <ctype.h>
18 #include <stddef.h>
19
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "google/protobuf/map.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/lite/toco/model.h"
34 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
35 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
36 #include "tensorflow/lite/toco/toco_port.h"
37 #include "tensorflow/lite/toco/tooling_util.h"
38
39 using tensorflow::GraphDef;
40 using tensorflow::NodeDef;
41
42 namespace toco {
43
44 namespace {
45
46 // Receives a vector of cluster nodes and returns only those which are array
47 // partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
48 // Since these nodes are connected to a Concatenate node, it makes sure the
49 // axis value input of the Concatenate operator is 0.
FilterPartitionedConstNodes(const std::string & const_pattern,const std::vector<const NodeDef * > & cluster_nodes,std::vector<const NodeDef * > * const_node_parts)50 void FilterPartitionedConstNodes(
51 const std::string& const_pattern,
52 const std::vector<const NodeDef*>& cluster_nodes,
53 std::vector<const NodeDef*>* const_node_parts) {
54 for (const NodeDef* node : cluster_nodes) {
55 std::string node_name_to_upper = node->name();
56 std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
57 node_name_to_upper.begin(), ::toupper);
58 if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
59 if (StrContains(node_name_to_upper, "/PART_")) {
60 const_node_parts->push_back(node);
61 } else if (StrContains(node->name(), "AXIS") &&
62 StrContains(node->name(), "CONCAT")) {
63 // For now only supporting Concatenate on Axix 0
64 const auto& value_attr = node->attr().at("value");
65 const tensorflow::TensorProto& tensor = value_attr.tensor();
66 CHECK_EQ(tensor.int_val(0), 0);
67 }
68 }
69 }
70 std::sort(const_node_parts->begin(), const_node_parts->end(),
71 [](const NodeDef* a, const NodeDef* b) {
72 return (a->name().compare(b->name()) < 0 &&
73 (a->name().size() < b->name().size()));
74 });
75 }
76
77 } // namespace
78
79 // SvdfCluster methods
80
InferFilterRank()81 int SvdfCluster::InferFilterRank() {
82 for (const NodeDef* node : nodes_) {
83 if (StrContains(node->name(), "Reshape/shape")) {
84 const auto& value_attr = node->attr().at("value");
85 const tensorflow::TensorProto& tensor = value_attr.tensor();
86 std::vector<int32> shape_values(
87 tensor.tensor_content().size() / sizeof(int), 0);
88 port::CopyToBuffer(tensor.tensor_content(),
89 reinterpret_cast<char*>(shape_values.data()));
90 CHECK_EQ(shape_values.size(), 3);
91 // shape_value array is arranged as:
92 // [num_units, rank, -1]
93 CHECK_EQ(shape_values[2], -1);
94 return shape_values[1];
95 }
96 }
97 return -1;
98 }
99
CreateNodes()100 void SvdfCluster::CreateNodes() {
101 for (const std::string& const_pattern : const_node_patterns_) {
102 CreateConstNode(const_pattern);
103 }
104 std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
105 svdf_node->set_op("Svdf");
106 svdf_node->set_name(name_);
107 svdf_node->set_device(device_);
108
109 // Add the main input.
110 svdf_node->add_input(inputs_[0]);
111
112 // Add the rest of the inputs to Svdf cell: weights and bias.
113 CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
114 std::string* weights_feature_input = svdf_node->add_input();
115 std::string* weights_time_input = svdf_node->add_input();
116 std::string* bias_input;
117 if (new_nodes_.size() == 3) {
118 bias_input = svdf_node->add_input();
119 }
120 for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
121 const std::string node_name = node->name();
122 if (StrContains(node_name, "SVDF_weights_feature")) {
123 *weights_feature_input = node_name;
124 } else if (StrContains(node_name, "SVDF_weights_time")) {
125 *weights_time_input = node_name;
126 } else if (StrContains(node_name, "SVDF_bias")) {
127 CHECK(bias_input) << "Bias input cannot be provided when there are only "
128 "two Const input nodes!";
129 *bias_input = node_name;
130 } else {
131 // Unexpected input for Svdf op.
132 LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
133 "weights_feature, weights_time and bias.";
134 }
135 }
136 const int rank = InferFilterRank();
137 CHECK_GT(rank, 0);
138
139 // Add Svdf activation and rank.
140 std::string activation_function =
141 StrContains(outputs_[0], "Relu") ? "Relu" : "None";
142 (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
143 (*svdf_node->mutable_attr())["Rank"].set_i(rank);
144
145 // Finally add it to the list of the newly created nodes.
146 new_nodes_.push_back(std::move(svdf_node));
147 }
148
CreateConstNode(const std::string & const_pattern)149 void SvdfCluster::CreateConstNode(const std::string& const_pattern) {
150 // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
151 std::vector<const NodeDef*> const_node_parts;
152 FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
153
154 if (const_node_parts.empty()) return;
155
156 bool transpose_tensor_value =
157 StrContains(const_pattern, "SVDF_weights_feature");
158
159 // Merge them if necessary.
160 std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
161 MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
162 new_nodes_.push_back(std::move(merged_node));
163 }
164
MaybeMergeConstNodes(const std::vector<const NodeDef * > & const_node_parts,bool transpose_tensor_value,const std::unique_ptr<tensorflow::NodeDef> & merged_node)165 void SvdfCluster::MaybeMergeConstNodes(
166 const std::vector<const NodeDef*>& const_node_parts,
167 bool transpose_tensor_value,
168 const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
169 merged_node->set_name(const_node_parts[0]->name());
170 merged_node->set_op("Const");
171 merged_node->set_device(const_node_parts[0]->device());
172 (*merged_node->mutable_attr())["dtype"].set_type(
173 const_node_parts[0]->attr().at("dtype").type());
174
175 // Figuring out Value attribute for the merged node.
176 // Assuming the partitioning is done on Axis 0.
177 // The attributes which are inferred:
178 // * Shape and dimensions
179 // * Float content values
180
181 // Inferring shape and dimension
182 int dim0_size = 0;
183 int dim1_size = 1;
184 tensorflow::TensorProto* allocated_tensor =
185 (*merged_node->mutable_attr())["value"].mutable_tensor();
186 tensorflow::TensorShapeProto* allocated_tensor_shape =
187 allocated_tensor->mutable_tensor_shape();
188 auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
189 int allocated_content_flat_size = 0;
190 for (size_t i = 0; i < const_node_parts.size(); i++) {
191 const auto& value_attr = const_node_parts[i]->attr().at("value");
192 const tensorflow::TensorProto& tensor = value_attr.tensor();
193 if (i == 0) {
194 allocated_tensor->set_dtype(tensor.dtype());
195 } else {
196 CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
197 }
198 allocated_content_flat_size += tensor.tensor_content().size();
199 CHECK(tensor.has_tensor_shape());
200 const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
201 dim0_size += shape.dim(0).size();
202 for (int d = 1; d < shape.dim_size(); d++) {
203 if (i == 0) {
204 allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
205 allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
206 dim1_size *= shape.dim(d).size();
207 } else {
208 CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
209 CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
210 }
211 }
212 }
213
214 // Copying the float content from each array partition.
215 std::unique_ptr<char[]> allocated_content(
216 new char[allocated_content_flat_size]);
217 char* content_ptr = allocated_content.get();
218 for (size_t i = 0; i < const_node_parts.size(); i++) {
219 const auto& value_attr = const_node_parts[i]->attr().at("value");
220 const tensorflow::TensorProto& tensor = value_attr.tensor();
221 port::CopyToBuffer(tensor.tensor_content(), content_ptr);
222 content_ptr += tensor.tensor_content().size();
223 }
224
225 // Transpose the tensor if needed.
226 if (transpose_tensor_value) {
227 // We use dimension 0 to show the row size for the tensor.
228 // We use multiplication of the rest of dimension size to for the col size
229 // of the tensor.
230 std::unique_ptr<float[]> transposed_tensor(
231 new float[dim0_size * dim1_size]);
232 Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
233 dim0_size, dim1_size, transposed_tensor.get());
234 allocated_tensor_shape->clear_dim();
235 allocated_tensor_shape->add_dim()->set_size(dim1_size);
236 allocated_tensor_shape->add_dim()->set_size(dim0_size);
237
238 // Set the tensor attributes.
239 allocated_tensor->set_tensor_content(
240 std::string(reinterpret_cast<const char*>(transposed_tensor.get()),
241 allocated_content_flat_size));
242 } else {
243 tensor_shape_dim0->set_size(dim0_size);
244
245 // Set the tensor attributes.
246 allocated_tensor->set_tensor_content(
247 std::string(reinterpret_cast<const char*>(allocated_content.get()),
248 allocated_content_flat_size));
249 }
250 }
251
252 // SvdfClusterFactory methods
253
CreateCluster(const NodeDef & node,const GraphDef & graph_def) const254 std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
255 const NodeDef& node, const GraphDef& graph_def) const {
256 std::vector<std::string> node_patterns = {"SVDF_weights_feature",
257 "SVDF_weights_time", "SVDF_bias"};
258
259 std::string node_name_to_upper = node.name();
260 std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
261 node_name_to_upper.begin(), ::toupper);
262 std::unique_ptr<SvdfCluster> cluster = nullptr;
263 if (node_name_to_upper.find("SVDF", 0) != std::string::npos) {
264 size_t weights_pos = node.name().find(node_patterns[0]);
265 if (weights_pos != std::string::npos) {
266 // Assuming the node name has a pattern like:
267 // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
268 // CELLNAME as the cluster name.
269 size_t cell_pos = node.name().rfind('/', weights_pos - 2) + 1;
270 std::string cell_name =
271 node.name().substr(cell_pos, weights_pos - cell_pos - 1);
272 cluster = std::make_unique<SvdfCluster>();
273 cluster->SetName(cell_name);
274 cluster->SetDevice(node.device());
275 cluster->SetGraphDefInfo(&graph_def);
276 CHECK(cluster->FindClusterInputsAndOutputs());
277
278 for (const std::string& const_pattern : node_patterns) {
279 cluster->AddConstNodePattern(const_pattern);
280 }
281 }
282 }
283 return std::move(cluster);
284 }
285
286 } // end namespace toco
287