xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
16 
17 #include <ctype.h>
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "google/protobuf/map.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/lite/toco/model.h"
34 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster.h"
35 #include "tensorflow/lite/toco/tensorflow_graph_matching/cluster_utils.h"
36 #include "tensorflow/lite/toco/toco_port.h"
37 #include "tensorflow/lite/toco/tooling_util.h"
38 
39 using tensorflow::GraphDef;
40 using tensorflow::NodeDef;
41 
42 namespace toco {
43 
44 namespace {
45 
46 // Receives a vector of cluster nodes and returns only those which are array
47 // partitions (of type 'Const' and have the pattern 'part_<.*>' in their name.
48 // Since these nodes are connected to a Concatenate node, it makes sure the
49 // axis value input of the Concatenate operator is 0.
FilterPartitionedConstNodes(const std::string & const_pattern,const std::vector<const NodeDef * > & cluster_nodes,std::vector<const NodeDef * > * const_node_parts)50 void FilterPartitionedConstNodes(
51     const std::string& const_pattern,
52     const std::vector<const NodeDef*>& cluster_nodes,
53     std::vector<const NodeDef*>* const_node_parts) {
54   for (const NodeDef* node : cluster_nodes) {
55     std::string node_name_to_upper = node->name();
56     std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
57                    node_name_to_upper.begin(), ::toupper);
58     if (StrContains(node->name(), const_pattern) && node->op() == "Const") {
59       if (StrContains(node_name_to_upper, "/PART_")) {
60         const_node_parts->push_back(node);
61       } else if (StrContains(node->name(), "AXIS") &&
62                  StrContains(node->name(), "CONCAT")) {
63         // For now only supporting Concatenate on Axix 0
64         const auto& value_attr = node->attr().at("value");
65         const tensorflow::TensorProto& tensor = value_attr.tensor();
66         CHECK_EQ(tensor.int_val(0), 0);
67       }
68     }
69   }
70   std::sort(const_node_parts->begin(), const_node_parts->end(),
71             [](const NodeDef* a, const NodeDef* b) {
72               return (a->name().compare(b->name()) < 0 &&
73                       (a->name().size() < b->name().size()));
74             });
75 }
76 
77 }  // namespace
78 
79 // SvdfCluster methods
80 
InferFilterRank()81 int SvdfCluster::InferFilterRank() {
82   for (const NodeDef* node : nodes_) {
83     if (StrContains(node->name(), "Reshape/shape")) {
84       const auto& value_attr = node->attr().at("value");
85       const tensorflow::TensorProto& tensor = value_attr.tensor();
86       std::vector<int32> shape_values(
87           tensor.tensor_content().size() / sizeof(int), 0);
88       port::CopyToBuffer(tensor.tensor_content(),
89                          reinterpret_cast<char*>(shape_values.data()));
90       CHECK_EQ(shape_values.size(), 3);
91       // shape_value array is arranged as:
92       // [num_units, rank, -1]
93       CHECK_EQ(shape_values[2], -1);
94       return shape_values[1];
95     }
96   }
97   return -1;
98 }
99 
CreateNodes()100 void SvdfCluster::CreateNodes() {
101   for (const std::string& const_pattern : const_node_patterns_) {
102     CreateConstNode(const_pattern);
103   }
104   std::unique_ptr<tensorflow::NodeDef> svdf_node(new NodeDef);
105   svdf_node->set_op("Svdf");
106   svdf_node->set_name(name_);
107   svdf_node->set_device(device_);
108 
109   // Add the main input.
110   svdf_node->add_input(inputs_[0]);
111 
112   // Add the rest of the inputs to Svdf cell: weights and bias.
113   CHECK(new_nodes_.size() == 3 || new_nodes_.size() == 2);
114   std::string* weights_feature_input = svdf_node->add_input();
115   std::string* weights_time_input = svdf_node->add_input();
116   std::string* bias_input;
117   if (new_nodes_.size() == 3) {
118     bias_input = svdf_node->add_input();
119   }
120   for (const std::unique_ptr<tensorflow::NodeDef>& node : new_nodes_) {
121     const std::string node_name = node->name();
122     if (StrContains(node_name, "SVDF_weights_feature")) {
123       *weights_feature_input = node_name;
124     } else if (StrContains(node_name, "SVDF_weights_time")) {
125       *weights_time_input = node_name;
126     } else if (StrContains(node_name, "SVDF_bias")) {
127       CHECK(bias_input) << "Bias input cannot be provided when there are only "
128                            "two Const input nodes!";
129       *bias_input = node_name;
130     } else {
131       // Unexpected input for Svdf op.
132       LOG(FATAL) << "Unexpected input node for SVDF op! Accepted inputs are: "
133                     "weights_feature, weights_time and bias.";
134     }
135   }
136   const int rank = InferFilterRank();
137   CHECK_GT(rank, 0);
138 
139   // Add Svdf activation and rank.
140   std::string activation_function =
141       StrContains(outputs_[0], "Relu") ? "Relu" : "None";
142   (*svdf_node->mutable_attr())["ActivationFunction"].set_s(activation_function);
143   (*svdf_node->mutable_attr())["Rank"].set_i(rank);
144 
145   // Finally add it to the list of the newly created nodes.
146   new_nodes_.push_back(std::move(svdf_node));
147 }
148 
CreateConstNode(const std::string & const_pattern)149 void SvdfCluster::CreateConstNode(const std::string& const_pattern) {
150   // Find the nodes with pattern like: "const_pattern"/part_xxx of type Const.
151   std::vector<const NodeDef*> const_node_parts;
152   FilterPartitionedConstNodes(const_pattern, nodes_, &const_node_parts);
153 
154   if (const_node_parts.empty()) return;
155 
156   bool transpose_tensor_value =
157       StrContains(const_pattern, "SVDF_weights_feature");
158 
159   // Merge them if necessary.
160   std::unique_ptr<tensorflow::NodeDef> merged_node(new NodeDef);
161   MaybeMergeConstNodes(const_node_parts, transpose_tensor_value, merged_node);
162   new_nodes_.push_back(std::move(merged_node));
163 }
164 
MaybeMergeConstNodes(const std::vector<const NodeDef * > & const_node_parts,bool transpose_tensor_value,const std::unique_ptr<tensorflow::NodeDef> & merged_node)165 void SvdfCluster::MaybeMergeConstNodes(
166     const std::vector<const NodeDef*>& const_node_parts,
167     bool transpose_tensor_value,
168     const std::unique_ptr<tensorflow::NodeDef>& merged_node) {
169   merged_node->set_name(const_node_parts[0]->name());
170   merged_node->set_op("Const");
171   merged_node->set_device(const_node_parts[0]->device());
172   (*merged_node->mutable_attr())["dtype"].set_type(
173       const_node_parts[0]->attr().at("dtype").type());
174 
175   // Figuring out Value attribute for the merged node.
176   // Assuming the partitioning is done on Axis 0.
177   // The attributes which are inferred:
178   // * Shape and dimensions
179   // * Float content values
180 
181   // Inferring shape and dimension
182   int dim0_size = 0;
183   int dim1_size = 1;
184   tensorflow::TensorProto* allocated_tensor =
185       (*merged_node->mutable_attr())["value"].mutable_tensor();
186   tensorflow::TensorShapeProto* allocated_tensor_shape =
187       allocated_tensor->mutable_tensor_shape();
188   auto tensor_shape_dim0 = allocated_tensor_shape->add_dim();
189   int allocated_content_flat_size = 0;
190   for (size_t i = 0; i < const_node_parts.size(); i++) {
191     const auto& value_attr = const_node_parts[i]->attr().at("value");
192     const tensorflow::TensorProto& tensor = value_attr.tensor();
193     if (i == 0) {
194       allocated_tensor->set_dtype(tensor.dtype());
195     } else {
196       CHECK_EQ(allocated_tensor->dtype(), tensor.dtype());
197     }
198     allocated_content_flat_size += tensor.tensor_content().size();
199     CHECK(tensor.has_tensor_shape());
200     const tensorflow::TensorShapeProto shape = tensor.tensor_shape();
201     dim0_size += shape.dim(0).size();
202     for (int d = 1; d < shape.dim_size(); d++) {
203       if (i == 0) {
204         allocated_tensor_shape->add_dim()->set_size(shape.dim(d).size());
205         allocated_tensor_shape->set_unknown_rank(shape.unknown_rank());
206         dim1_size *= shape.dim(d).size();
207       } else {
208         CHECK_EQ(shape.dim(d).size(), allocated_tensor_shape->dim(d).size());
209         CHECK_EQ(allocated_tensor_shape->unknown_rank(), shape.unknown_rank());
210       }
211     }
212   }
213 
214   // Copying the float content from each array partition.
215   std::unique_ptr<char[]> allocated_content(
216       new char[allocated_content_flat_size]);
217   char* content_ptr = allocated_content.get();
218   for (size_t i = 0; i < const_node_parts.size(); i++) {
219     const auto& value_attr = const_node_parts[i]->attr().at("value");
220     const tensorflow::TensorProto& tensor = value_attr.tensor();
221     port::CopyToBuffer(tensor.tensor_content(), content_ptr);
222     content_ptr += tensor.tensor_content().size();
223   }
224 
225   // Transpose the tensor if needed.
226   if (transpose_tensor_value) {
227     // We use dimension 0 to show the row size for the tensor.
228     // We use multiplication of the rest of dimension size to for the col size
229     // of the tensor.
230     std::unique_ptr<float[]> transposed_tensor(
231         new float[dim0_size * dim1_size]);
232     Transpose2DTensor(reinterpret_cast<float*>(allocated_content.get()),
233                       dim0_size, dim1_size, transposed_tensor.get());
234     allocated_tensor_shape->clear_dim();
235     allocated_tensor_shape->add_dim()->set_size(dim1_size);
236     allocated_tensor_shape->add_dim()->set_size(dim0_size);
237 
238     // Set the tensor attributes.
239     allocated_tensor->set_tensor_content(
240         std::string(reinterpret_cast<const char*>(transposed_tensor.get()),
241                     allocated_content_flat_size));
242   } else {
243     tensor_shape_dim0->set_size(dim0_size);
244 
245     // Set the tensor attributes.
246     allocated_tensor->set_tensor_content(
247         std::string(reinterpret_cast<const char*>(allocated_content.get()),
248                     allocated_content_flat_size));
249   }
250 }
251 
252 // SvdfClusterFactory methods
253 
CreateCluster(const NodeDef & node,const GraphDef & graph_def) const254 std::unique_ptr<Cluster> SvdfClusterFactory::CreateCluster(
255     const NodeDef& node, const GraphDef& graph_def) const {
256   std::vector<std::string> node_patterns = {"SVDF_weights_feature",
257                                             "SVDF_weights_time", "SVDF_bias"};
258 
259   std::string node_name_to_upper = node.name();
260   std::transform(node_name_to_upper.begin(), node_name_to_upper.end(),
261                  node_name_to_upper.begin(), ::toupper);
262   std::unique_ptr<SvdfCluster> cluster = nullptr;
263   if (node_name_to_upper.find("SVDF", 0) != std::string::npos) {
264     size_t weights_pos = node.name().find(node_patterns[0]);
265     if (weights_pos != std::string::npos) {
266       // Assuming the node name has a pattern like:
267       // "SOMESTRING1/CELLNAME/SEARCH_PATTERN/SOMESTRING2", we use
268       // CELLNAME as the cluster name.
269       size_t cell_pos = node.name().rfind('/', weights_pos - 2) + 1;
270       std::string cell_name =
271           node.name().substr(cell_pos, weights_pos - cell_pos - 1);
272       cluster = std::make_unique<SvdfCluster>();
273       cluster->SetName(cell_name);
274       cluster->SetDevice(node.device());
275       cluster->SetGraphDefInfo(&graph_def);
276       CHECK(cluster->FindClusterInputsAndOutputs());
277 
278       for (const std::string& const_pattern : node_patterns) {
279         cluster->AddConstNodePattern(const_pattern);
280       }
281     }
282   }
283   return std::move(cluster);
284 }
285 
286 }  // end namespace toco
287