1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <vector>
18
19 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
20 #include "tensorflow/lite/toco/model.h"
21 #include "tensorflow/lite/toco/tooling_util.h"
22
23 namespace toco {
24
Run(Model * model,std::size_t op_index,bool * modified)25 ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
26 std::size_t op_index,
27 bool* modified) {
28 *modified = false;
29 // Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
30 // https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
31 // This transform attempts to identify the len(params) > 1 case and collapse
32 // it to the len(params) = 1 case by concatenating the original params and
33 // reversing the partitioning.
34 //
35 // If len(params) to the tf.nn.embedding_lookup == 1, the whole op becomes
36 // simply a gather:
37 // https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/python/ops/embedding_ops.py#L150
38 //
39 // Notes on this implementation:
40 // - only supports partition_strategy='mod'
41 //
42 // A rough graph of a partitioned embedding_lookup looks like:
43 // (ids)--+-->FloorDiv--+-->DynamicPartition-->[[Gather]]--\
44 // \-->FloorMod--/ |
45 // V |
46 // Range-->DynamicPartition-------->DynamicStitch<---------/
47 // (const) V
48 // (embeddings)
49
50 // First look for the final DynamicStitch.
51 auto op_it = model->operators.begin() + op_index;
52 if (op_it->get()->type != OperatorType::kDynamicStitch) {
53 return ::tensorflow::OkStatus();
54 }
55 auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
56
57 // Split up the DynamicStitch inputs into the indices and data.
58 std::vector<std::string> stitch_indices_inputs;
59 std::vector<std::string> stitch_data_inputs;
60 stitch_indices_inputs.reserve(stitch_op->num_partitions);
61 for (int i = 0; i < stitch_op->num_partitions; ++i) {
62 stitch_indices_inputs.push_back(stitch_op->inputs[i]);
63 }
64 for (int i = stitch_op->num_partitions; i < stitch_op->num_partitions * 2;
65 ++i) {
66 stitch_data_inputs.push_back(stitch_op->inputs[i]);
67 }
68
69 // Validate all indices come from the same DynamicPartition.
70 DynamicPartitionOperator* indices_partition_op = nullptr;
71 for (const std::string& indices_partition_output_name :
72 stitch_indices_inputs) {
73 auto* op = GetOpWithOutput(*model, indices_partition_output_name);
74 CHECK(op) << "Source of " << indices_partition_output_name << " not found";
75 if (op->type != OperatorType::kDynamicPartition) {
76 AddMessageF(
77 "Skipping because indices input %s into "
78 "%s is unexpected",
79 LogName(*op), LogName(*stitch_op));
80 return ::tensorflow::OkStatus();
81 }
82 if (!indices_partition_op) {
83 indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
84 } else {
85 // Ensure this is the same op as previous ones.
86 if (op != indices_partition_op) {
87 AddMessageF(
88 "Skipping because indices input %s into "
89 "%s is from a different source op than others",
90 LogName(*op), LogName(*stitch_op));
91 return ::tensorflow::OkStatus();
92 }
93 }
94 }
95 CHECK(indices_partition_op) << "No indices inputs";
96
97 // The data for the indices must be a constant range of the array shape.
98 if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
99 AddMessageF("Skipping because indices partition data is non-constant");
100 return ::tensorflow::OkStatus();
101 }
102 auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
103 if (indices_data_array.data_type == ArrayDataType::kNone) {
104 // Yield until data types are propagated.
105 return ::tensorflow::OkStatus();
106 }
107 CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
108 << "Indices partition inputs must be int32";
109 const auto& indices_data_buffer =
110 indices_data_array.GetBuffer<ArrayDataType::kInt32>().data;
111 for (size_t i = 0; i < indices_data_buffer.size(); ++i) {
112 CHECK_EQ(indices_data_buffer[i], i) << "Indices range must be identity";
113 }
114
115 // Find all of the gathers used for the data inputs.
116 std::vector<GatherOperator*> gather_ops;
117 for (const std::string& gather_output_name : stitch_data_inputs) {
118 auto* op = GetOpWithOutput(*model, gather_output_name);
119 CHECK(op) << "Source of " << gather_output_name << " not found";
120 if (op->type != OperatorType::kGather) {
121 AddMessageF(
122 "Skipping because data input %s into %s "
123 "is unexpected",
124 LogName(*op), LogName(*stitch_op));
125 return ::tensorflow::OkStatus();
126 }
127 gather_ops.push_back(static_cast<GatherOperator*>(op));
128 }
129
130 // Validate all gathers come from the same DynamicPartition.
131 DynamicPartitionOperator* data_partition_op = nullptr;
132 for (auto* gather_op : gather_ops) {
133 auto* op = GetOpWithOutput(*model, gather_op->inputs[1]);
134 CHECK(op) << "Source of " << gather_op->inputs[1] << " not found";
135 if (op->type != OperatorType::kDynamicPartition) {
136 AddMessageF(
137 "Skipping because data input %s into "
138 "%s is unexpected",
139 LogName(*op), LogName(*gather_op));
140 return ::tensorflow::OkStatus();
141 }
142 if (!data_partition_op) {
143 data_partition_op = static_cast<DynamicPartitionOperator*>(op);
144 } else {
145 // Ensure this is the same op as previous ones.
146 if (op != data_partition_op) {
147 AddMessageF(
148 "Skipping because data input %s into "
149 "%s is from a different source op than others",
150 LogName(*op), LogName(*gather_op));
151 return ::tensorflow::OkStatus();
152 }
153 }
154 }
155 CHECK(data_partition_op) << "No data inputs";
156
157 // Validate the partition ops have the same sizes.
158 CHECK_EQ(indices_partition_op->num_partitions,
159 data_partition_op->num_partitions)
160 << "Indices and data partition ops have differing dimensions";
161 int num_partitions = indices_partition_op->num_partitions;
162
163 // Partition strategy of 'mod' gives us a FloorMod and FloorDiv.
164 // The gather partition uses the FloorDiv as the data and FloorMod as the
165 // partitions and the indices use the FloorMod as their partitions.
166 Operator* div_op = GetOpWithOutput(*model, data_partition_op->inputs[0]);
167 Operator* mod_op = GetOpWithOutput(*model, data_partition_op->inputs[1]);
168 CHECK(div_op && div_op->type == OperatorType::kFloorDiv)
169 << "Unsupported partition strategy";
170 CHECK(mod_op && mod_op->type == OperatorType::kFloorMod)
171 << "Unsupported partition strategy";
172 CHECK_EQ(mod_op, GetOpWithOutput(*model, indices_partition_op->inputs[1]))
173 << "Indices and data partition ops require the same partition strategy "
174 "and inputs";
175
176 // Glob together all of the gather data. This is not yet in the correct order.
177 auto* gather_params_concat_op = new ConcatenationOperator;
178 for (const auto& gather_op : gather_ops) {
179 gather_params_concat_op->inputs.push_back(gather_op->inputs[0]);
180 }
181 gather_params_concat_op->outputs.push_back(
182 AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_unpartitioned"));
183 op_it = model->operators.emplace(op_it, gather_params_concat_op) + 1;
184 model->GetOrCreateArray(gather_params_concat_op->outputs[0]);
185
186 // Permute the gather params to undo the partitioning that was originally
187 // done.
188 auto* gather_params_permute_op = new GatherOperator;
189 gather_params_permute_op->inputs.push_back(
190 gather_params_concat_op->outputs[0]);
191 gather_params_permute_op->inputs.push_back(
192 AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm"));
193 gather_params_permute_op->outputs.push_back(
194 AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted"));
195 gather_params_permute_op->axis = {0};
196 op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1;
197 model->GetOrCreateArray(gather_params_permute_op->outputs[0]);
198 const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]);
199 const auto& partition_array_dims = partition_array.shape().dims();
200 gather_params_permute_op->input_rank =
201 partition_array.shape().dimensions_count();
202 auto& perm_array =
203 model->GetOrCreateArray(gather_params_permute_op->inputs[1]);
204 perm_array.data_type = ArrayDataType::kInt32;
205 perm_array.mutable_shape()->ReplaceDims(
206 {num_partitions * partition_array_dims[0]});
207 auto& perm_data = perm_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
208 perm_data.resize(RequiredBufferSizeForShape(perm_array.shape()));
209 // NOTE: this is what relies on the partition_strategy.
210 for (int i = 0; i < num_partitions * partition_array_dims[0]; ++i) {
211 int p = i % num_partitions;
212 perm_data[i] = p * partition_array_dims[0] + i / num_partitions;
213 }
214
215 // Insert the new unpartitioned gather op.
216 auto* merged_gather_op = new GatherOperator;
217 merged_gather_op->inputs = {gather_params_permute_op->outputs[0],
218 mod_op->inputs[0]};
219 merged_gather_op->outputs = {stitch_op->outputs[0]};
220 merged_gather_op->input_rank = partition_array.shape().dimensions_count();
221 merged_gather_op->axis = {0};
222 model->operators.emplace(op_it, merged_gather_op);
223
224 AddMessageF(
225 "Replacing suspected partitioned tf.nn.embedding_lookup (starting at %s "
226 "+ %s and ending at %s) with a single unpartitioned gather %s",
227 LogName(*div_op), LogName(*mod_op), LogName(*stitch_op),
228 LogName(*merged_gather_op));
229
230 // Ensure the stitch output array is dead, as we don't want whatever was in it
231 // previously now that we've redefined it. It'll be recreated when needed.
232 model->EraseArray(merged_gather_op->outputs[0]);
233 model->GetOrCreateArray(merged_gather_op->outputs[0]);
234
235 // Erase all the original ops.
236 DeleteOpAndArrays(model, div_op);
237 DeleteOpAndArrays(model, mod_op);
238 for (auto* gather_op : gather_ops) {
239 DeleteOpAndArrays(model, gather_op);
240 }
241 DeleteOpAndArrays(model, indices_partition_op);
242 DeleteOpAndArrays(model, data_partition_op);
243 DeleteOpAndArrays(model, stitch_op);
244 *modified = true;
245 return ::tensorflow::OkStatus();
246 }
247
248 } // namespace toco
249