1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <vector>
18 
19 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
20 #include "tensorflow/lite/toco/model.h"
21 #include "tensorflow/lite/toco/tooling_util.h"
22 
23 namespace toco {
24 
Run(Model * model,std::size_t op_index,bool * modified)25 ::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
26                                                      std::size_t op_index,
27                                                      bool* modified) {
28   *modified = false;
29   // Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
30   // https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
31   // This transform attempts to identify the len(params) > 1 case and collapse
32   // it to the len(params) = 1 case by concatenating the original params and
33   // reversing the partitioning.
34   //
35   // If len(params) to the tf.nn.embedding_lookup == 1, the whole op becomes
36   // simply a gather:
37   // https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/python/ops/embedding_ops.py#L150
38   //
39   // Notes on this implementation:
40   // - only supports partition_strategy='mod'
41   //
42   // A rough graph of a partitioned embedding_lookup looks like:
43   //   (ids)--+-->FloorDiv--+-->DynamicPartition-->[[Gather]]--\
44   //          \-->FloorMod--/                                  |
45   //                 V                                         |
46   //   Range-->DynamicPartition-------->DynamicStitch<---------/
47   //  (const)                                V
48   //                                     (embeddings)
49 
50   // First look for the final DynamicStitch.
51   auto op_it = model->operators.begin() + op_index;
52   if (op_it->get()->type != OperatorType::kDynamicStitch) {
53     return ::tensorflow::OkStatus();
54   }
55   auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
56 
57   // Split up the DynamicStitch inputs into the indices and data.
58   std::vector<std::string> stitch_indices_inputs;
59   std::vector<std::string> stitch_data_inputs;
60   stitch_indices_inputs.reserve(stitch_op->num_partitions);
61   for (int i = 0; i < stitch_op->num_partitions; ++i) {
62     stitch_indices_inputs.push_back(stitch_op->inputs[i]);
63   }
64   for (int i = stitch_op->num_partitions; i < stitch_op->num_partitions * 2;
65        ++i) {
66     stitch_data_inputs.push_back(stitch_op->inputs[i]);
67   }
68 
69   // Validate all indices come from the same DynamicPartition.
70   DynamicPartitionOperator* indices_partition_op = nullptr;
71   for (const std::string& indices_partition_output_name :
72        stitch_indices_inputs) {
73     auto* op = GetOpWithOutput(*model, indices_partition_output_name);
74     CHECK(op) << "Source of " << indices_partition_output_name << " not found";
75     if (op->type != OperatorType::kDynamicPartition) {
76       AddMessageF(
77           "Skipping because indices input %s into "
78           "%s is unexpected",
79           LogName(*op), LogName(*stitch_op));
80       return ::tensorflow::OkStatus();
81     }
82     if (!indices_partition_op) {
83       indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
84     } else {
85       // Ensure this is the same op as previous ones.
86       if (op != indices_partition_op) {
87         AddMessageF(
88             "Skipping because indices input %s into "
89             "%s is from a different source op than others",
90             LogName(*op), LogName(*stitch_op));
91         return ::tensorflow::OkStatus();
92       }
93     }
94   }
95   CHECK(indices_partition_op) << "No indices inputs";
96 
97   // The data for the indices must be a constant range of the array shape.
98   if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
99     AddMessageF("Skipping because indices partition data is non-constant");
100     return ::tensorflow::OkStatus();
101   }
102   auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
103   if (indices_data_array.data_type == ArrayDataType::kNone) {
104     // Yield until data types are propagated.
105     return ::tensorflow::OkStatus();
106   }
107   CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
108       << "Indices partition inputs must be int32";
109   const auto& indices_data_buffer =
110       indices_data_array.GetBuffer<ArrayDataType::kInt32>().data;
111   for (size_t i = 0; i < indices_data_buffer.size(); ++i) {
112     CHECK_EQ(indices_data_buffer[i], i) << "Indices range must be identity";
113   }
114 
115   // Find all of the gathers used for the data inputs.
116   std::vector<GatherOperator*> gather_ops;
117   for (const std::string& gather_output_name : stitch_data_inputs) {
118     auto* op = GetOpWithOutput(*model, gather_output_name);
119     CHECK(op) << "Source of " << gather_output_name << " not found";
120     if (op->type != OperatorType::kGather) {
121       AddMessageF(
122           "Skipping because data input %s into %s "
123           "is unexpected",
124           LogName(*op), LogName(*stitch_op));
125       return ::tensorflow::OkStatus();
126     }
127     gather_ops.push_back(static_cast<GatherOperator*>(op));
128   }
129 
130   // Validate all gathers come from the same DynamicPartition.
131   DynamicPartitionOperator* data_partition_op = nullptr;
132   for (auto* gather_op : gather_ops) {
133     auto* op = GetOpWithOutput(*model, gather_op->inputs[1]);
134     CHECK(op) << "Source of " << gather_op->inputs[1] << " not found";
135     if (op->type != OperatorType::kDynamicPartition) {
136       AddMessageF(
137           "Skipping because data input %s into "
138           "%s is unexpected",
139           LogName(*op), LogName(*gather_op));
140       return ::tensorflow::OkStatus();
141     }
142     if (!data_partition_op) {
143       data_partition_op = static_cast<DynamicPartitionOperator*>(op);
144     } else {
145       // Ensure this is the same op as previous ones.
146       if (op != data_partition_op) {
147         AddMessageF(
148             "Skipping because data input %s into "
149             "%s is from a different source op than others",
150             LogName(*op), LogName(*gather_op));
151         return ::tensorflow::OkStatus();
152       }
153     }
154   }
155   CHECK(data_partition_op) << "No data inputs";
156 
157   // Validate the partition ops have the same sizes.
158   CHECK_EQ(indices_partition_op->num_partitions,
159            data_partition_op->num_partitions)
160       << "Indices and data partition ops have differing dimensions";
161   int num_partitions = indices_partition_op->num_partitions;
162 
163   // Partition strategy of 'mod' gives us a FloorMod and FloorDiv.
164   // The gather partition uses the FloorDiv as the data and FloorMod as the
165   // partitions and the indices use the FloorMod as their partitions.
166   Operator* div_op = GetOpWithOutput(*model, data_partition_op->inputs[0]);
167   Operator* mod_op = GetOpWithOutput(*model, data_partition_op->inputs[1]);
168   CHECK(div_op && div_op->type == OperatorType::kFloorDiv)
169       << "Unsupported partition strategy";
170   CHECK(mod_op && mod_op->type == OperatorType::kFloorMod)
171       << "Unsupported partition strategy";
172   CHECK_EQ(mod_op, GetOpWithOutput(*model, indices_partition_op->inputs[1]))
173       << "Indices and data partition ops require the same partition strategy "
174          "and inputs";
175 
176   // Glob together all of the gather data. This is not yet in the correct order.
177   auto* gather_params_concat_op = new ConcatenationOperator;
178   for (const auto& gather_op : gather_ops) {
179     gather_params_concat_op->inputs.push_back(gather_op->inputs[0]);
180   }
181   gather_params_concat_op->outputs.push_back(
182       AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_unpartitioned"));
183   op_it = model->operators.emplace(op_it, gather_params_concat_op) + 1;
184   model->GetOrCreateArray(gather_params_concat_op->outputs[0]);
185 
186   // Permute the gather params to undo the partitioning that was originally
187   // done.
188   auto* gather_params_permute_op = new GatherOperator;
189   gather_params_permute_op->inputs.push_back(
190       gather_params_concat_op->outputs[0]);
191   gather_params_permute_op->inputs.push_back(
192       AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted/perm"));
193   gather_params_permute_op->outputs.push_back(
194       AvailableArrayName(*model, gather_ops[0]->inputs[0] + "_permuted"));
195   gather_params_permute_op->axis = {0};
196   op_it = model->operators.emplace(op_it, gather_params_permute_op) + 1;
197   model->GetOrCreateArray(gather_params_permute_op->outputs[0]);
198   const auto& partition_array = model->GetArray(gather_ops[0]->inputs[0]);
199   const auto& partition_array_dims = partition_array.shape().dims();
200   gather_params_permute_op->input_rank =
201       partition_array.shape().dimensions_count();
202   auto& perm_array =
203       model->GetOrCreateArray(gather_params_permute_op->inputs[1]);
204   perm_array.data_type = ArrayDataType::kInt32;
205   perm_array.mutable_shape()->ReplaceDims(
206       {num_partitions * partition_array_dims[0]});
207   auto& perm_data = perm_array.GetMutableBuffer<ArrayDataType::kInt32>().data;
208   perm_data.resize(RequiredBufferSizeForShape(perm_array.shape()));
209   // NOTE: this is what relies on the partition_strategy.
210   for (int i = 0; i < num_partitions * partition_array_dims[0]; ++i) {
211     int p = i % num_partitions;
212     perm_data[i] = p * partition_array_dims[0] + i / num_partitions;
213   }
214 
215   // Insert the new unpartitioned gather op.
216   auto* merged_gather_op = new GatherOperator;
217   merged_gather_op->inputs = {gather_params_permute_op->outputs[0],
218                               mod_op->inputs[0]};
219   merged_gather_op->outputs = {stitch_op->outputs[0]};
220   merged_gather_op->input_rank = partition_array.shape().dimensions_count();
221   merged_gather_op->axis = {0};
222   model->operators.emplace(op_it, merged_gather_op);
223 
224   AddMessageF(
225       "Replacing suspected partitioned tf.nn.embedding_lookup (starting at %s "
226       "+ %s and ending at %s) with a single unpartitioned gather %s",
227       LogName(*div_op), LogName(*mod_op), LogName(*stitch_op),
228       LogName(*merged_gather_op));
229 
230   // Ensure the stitch output array is dead, as we don't want whatever was in it
231   // previously now that we've redefined it. It'll be recreated when needed.
232   model->EraseArray(merged_gather_op->outputs[0]);
233   model->GetOrCreateArray(merged_gather_op->outputs[0]);
234 
235   // Erase all the original ops.
236   DeleteOpAndArrays(model, div_op);
237   DeleteOpAndArrays(model, mod_op);
238   for (auto* gather_op : gather_ops) {
239     DeleteOpAndArrays(model, gather_op);
240   }
241   DeleteOpAndArrays(model, indices_partition_op);
242   DeleteOpAndArrays(model, data_partition_op);
243   DeleteOpAndArrays(model, stitch_op);
244   *modified = true;
245   return ::tensorflow::OkStatus();
246 }
247 
248 }  // namespace toco
249