xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/sparsify_gather.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <memory>
18 #include <unordered_map>
19 
20 #include "tensorflow/c/checkpoint_reader.h"
21 #include "tensorflow/core/common_runtime/graph_constructor.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/graph/node_builder.h"
24 #include "tensorflow/core/graph/subgraph.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/init_main.h"
27 #include "tensorflow/core/public/session.h"
28 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
29 #include "tensorflow/tools/graph_transforms/transform_utils.h"
30 
31 namespace tensorflow {
32 using str_util::Split;
33 using str_util::StringReplace;
34 using strings::StrCat;
35 
36 namespace graph_transforms {
37 
38 // Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
39 // non-zero tensor content.
SparsifyWeights(const Tensor & tensor,Tensor * indices_tensor,Tensor * values_tensor)40 Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
41                        Tensor* values_tensor) {
42   if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
43     return tensorflow::errors::FailedPrecondition(
44         "Transform only applicable to subgraph with 'Const' with "
45         "tensor of shape [N, 1]. But instead get shape ",
46         tensor.shape().DebugString(), ".");
47   }
48 
49   auto flat = tensor.flat<float>();
50   std::vector<int64_t> indices;
51   std::vector<float> values;
52 
53   for (int64_t i = 0; i < flat.size(); i++) {
54     float val = flat(i);
55     if (std::abs(val) >= 1.0e-5) {
56       indices.push_back(i);
57       values.push_back(val);
58     }
59   }
60 
61   // During model initialization, InitializeTableOp makes use of
62   // KeyValueTensorIterator, which does not accept empty keys or values.
63   // Consequently, adding a dummy pair of indices and values as a walkaround.
64   if (indices.empty() || values.empty()) {
65     indices.push_back(0);
66     values.push_back(0);
67   }
68   *indices_tensor = Tensor(DataTypeToEnum<int64_t>::value,
69                            {static_cast<int64_t>(indices.size())});
70   std::copy_n(indices.begin(), indices.size(),
71               indices_tensor->flat<int64_t>().data());
72 
73   *values_tensor = Tensor(DataTypeToEnum<float>::value,
74                           {static_cast<int64_t>(values.size())});
75   std::copy_n(values.begin(), values.size(),
76               values_tensor->flat<float>().data());
77 
78   return OkStatus();
79 }
80 
CreateConstNode(const Tensor & tensor,const string & name,NodeDef * node_def)81 void CreateConstNode(const Tensor& tensor, const string& name,
82                      NodeDef* node_def) {
83   node_def->set_op("Const");
84   node_def->set_name(name);
85   SetNodeTensorAttr<float>("value", tensor, node_def);
86 }
87 
GetMonolithicTensorKey(const string & tensor_slice_name)88 string GetMonolithicTensorKey(const string& tensor_slice_name) {
89   std::vector<string> names = Split(tensor_slice_name, "/");
90   if (absl::StartsWith(names[names.size() - 1], "part_")) {
91     CHECK_GE(names.size(), 2);
92     names.pop_back();
93   }
94   return absl::StrJoin(names, "/");
95 }
96 
ObtainTensorSlice(const GraphDef & input_graph_def,const string & target_name,string * shape_slice_string)97 Status ObtainTensorSlice(const GraphDef& input_graph_def,
98                          const string& target_name,
99                          string* shape_slice_string) {
100   string restore_node_name;
101   for (const auto& node : input_graph_def.node()) {
102     std::vector<string> node_name_parts = Split(node.name(), "/");
103     if (node_name_parts.size() == 2 &&
104         absl::StartsWith(node_name_parts[0], "save") &&
105         absl::StartsWith(node_name_parts[1], "Assign") &&
106         node.input(0) == target_name) {
107       restore_node_name = node.input(1);
108       break;
109     }
110   }
111 
112   std::vector<string> restore_node_parts = Split(restore_node_name, ":");
113   CHECK_LE(restore_node_parts.size(), 2);
114   string tensor_names_node;
115   string shape_and_slices_node;
116   for (const auto& node : input_graph_def.node()) {
117     if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
118       tensor_names_node = node.input(1);
119       shape_and_slices_node = node.input(2);
120       break;
121     }
122   }
123 
124   int offset = -1;
125   for (const auto& node : input_graph_def.node()) {
126     if (node.name() == tensor_names_node) {
127       Tensor tensor_names_tensor;
128       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
129       const auto& tensor_names_value = tensor_names_tensor.flat<tstring>();
130       for (int i = 0; i < tensor_names_value.size(); i++) {
131         if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
132           offset = i;
133           break;
134         }
135       }
136     }
137   }
138   if (offset == -1) {
139     return errors::Internal("Unable to find RestoreV2 entry for variable: ",
140                             target_name);
141   }
142   for (const auto& node : input_graph_def.node()) {
143     if (node.name() == shape_and_slices_node) {
144       Tensor shape_and_slices_tensor;
145       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
146       const auto& shape_and_slices_value =
147           shape_and_slices_tensor.flat<tstring>();
148       *shape_slice_string = shape_and_slices_value(offset);
149       return OkStatus();
150     }
151   }
152   return errors::Internal("Unable to find slice for variable: ", target_name);
153 }
154 
ReadTensorFromCheckpoint(const string & tensor_name,const std::unique_ptr<BundleReader> & ckpt_reader,const string & shape_and_slice,Tensor * tensor)155 Status ReadTensorFromCheckpoint(
156     const string& tensor_name, const std::unique_ptr<BundleReader>& ckpt_reader,
157     const string& shape_and_slice, Tensor* tensor) {
158   if (ckpt_reader) {
159     TensorShape parsed_full_shape;
160     TensorSlice parsed_slice;
161     TensorShape parsed_slice_shape;
162 
163     bool get_slice = false;
164     if (!shape_and_slice.empty()) {
165       TF_RETURN_IF_ERROR(
166           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
167                                          &parsed_slice, &parsed_slice_shape));
168       get_slice = (parsed_full_shape != parsed_slice_shape);
169     }
170     if (get_slice) {
171       TF_RETURN_IF_ERROR(ckpt_reader->LookupSlice(
172           GetMonolithicTensorKey(tensor_name), parsed_slice, tensor));
173     } else {
174       TF_RETURN_IF_ERROR(
175           ckpt_reader->Lookup(GetMonolithicTensorKey(tensor_name), tensor));
176     }
177     return OkStatus();
178   }
179   return errors::Internal("Checkpoint reader was not initialized. ");
180 }
181 
InitializeCheckpointReader(const TransformFuncContext & context,std::unique_ptr<BundleReader> * ckpt_reader)182 Status InitializeCheckpointReader(const TransformFuncContext& context,
183                                   std::unique_ptr<BundleReader>* ckpt_reader) {
184   if (context.params.count("input_checkpoint")) {
185     const string input_checkpoint = context.params.at("input_checkpoint")[0];
186     ckpt_reader->reset(new BundleReader(Env::Default(), input_checkpoint));
187     TF_RETURN_IF_ERROR((*ckpt_reader)->status());
188   }
189   return OkStatus();
190 }
191 
ObtainVariableInfo(const GraphDef & input_graph_def,std::unique_ptr<std::unordered_map<string,string>> * shapes_and_slices)192 Status ObtainVariableInfo(
193     const GraphDef& input_graph_def,
194     std::unique_ptr<std::unordered_map<string, string> >* shapes_and_slices) {
195   shapes_and_slices->reset(new std::unordered_map<string, string>());
196   for (const auto& node : input_graph_def.node()) {
197     if ((node.op() == "Variable") || (node.op() == "VariableV2")) {
198       string s;
199       TF_RETURN_IF_ERROR(ObtainTensorSlice(input_graph_def, node.name(), &s));
200       (**shapes_and_slices)[node.name()] = s;
201     }
202   }
203   return OkStatus();
204 }
205 
RemoveInputAtIndex(NodeDef * n,int index)206 Status RemoveInputAtIndex(NodeDef* n, int index) {
207   for (int i = index; i < n->input_size() - 1; i++) {
208     n->mutable_input()->SwapElements(i, i + 1);
209   }
210   n->mutable_input()->RemoveLast();
211   return OkStatus();
212 }
213 
RemoveNodeAtIndex(GraphDef * g,int index)214 Status RemoveNodeAtIndex(GraphDef* g, int index) {
215   for (int i = index; i < g->node_size() - 1; i++) {
216     g->mutable_node()->SwapElements(i, i + 1);
217   }
218   g->mutable_node()->RemoveLast();
219   return OkStatus();
220 }
221 
SparsifyGatherInternal(const GraphDef & input_graph_def,const std::unique_ptr<std::unordered_map<string,string>> & shapes_and_slices,const TransformFuncContext & context,const OpTypePattern & pattern,const std::unique_ptr<BundleReader> & ckpt_reader,GraphDef * output_graph_def)222 Status SparsifyGatherInternal(
223     const GraphDef& input_graph_def,
224     const std::unique_ptr<std::unordered_map<string, string> >&
225         shapes_and_slices,
226     const TransformFuncContext& context, const OpTypePattern& pattern,
227     const std::unique_ptr<BundleReader>& ckpt_reader,
228     GraphDef* output_graph_def) {
229   string group_init_node = "group_deps";
230   if (context.params.count("group_init_node")) {
231     group_init_node = context.params.at("group_init_node")[0];
232   }
233   GraphDef current_graph_def = input_graph_def;
234   bool any_match_found = false;
235 
236   // Populate references.
237   std::unordered_map<string, int> refs;
238   for (const auto& node : current_graph_def.node()) {
239     for (const auto& input : node.input()) {
240       auto parsed_input = StringReplace(input, "^", "", true);
241       refs[parsed_input] += 1;
242     }
243   }
244 
245   // The subgraphs may have overlapping components, therefore GraphMatcher
246   // doesn't return all subgraphs in one round -- this has to be multi-round
247   // update.
248   do {
249     any_match_found = false;
250     GraphDef replaced_graph_def = current_graph_def;
251     std::vector<string> init_table_node_names;
252     std::vector<string> removed_node_names;
253 
254     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
255         current_graph_def, pattern,
256         [&ckpt_reader, &any_match_found, &init_table_node_names,
257          &shapes_and_slices, &removed_node_names,
258          &refs](const NodeMatch& match, const std::set<string>& input_nodes,
259                 const std::set<string>& output_nodes,
260                 std::vector<NodeDef>* new_nodes) {
261           any_match_found = true;
262 
263           // The captured subgraph should be of the following pattern:
264           // Const --> Identity --> Gather --> ...
265           //                          ^
266           //                          |
267           //                        (ids)
268           //
269           // After transform, it becomes:
270           //                   --> NoOp(group_deps)
271           //                   |
272           // Const --> InitializeTable --> HashTable
273           //                   ^              |
274           //                   |              |
275           // Const -------------              |
276           //                                  v
277           //               (ids) ---> LookupTableFind <--- Const(default)
278           //                                  |
279           //                                  v
280           //                                 ...
281 
282           // clang-format off
283           // For each subgraph, do the following
284           // 1. Sparsify the `Const`, creating two `Const`, for hashtable
285           // key/val.
286           // 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
287           // 3. Create a `HashTable` op connecting to `InitializeTable` op.
288           // 4. Replace the `Gather` with a `LookupTableFind` op.
289           // 5. Connect the `LookupTableFind` with
290           //    a. `HashTable`
291           //    b. `Gather`'s ids input
292           //    c. a `default_val` arg, valued at 0
293           // clang-format on
294           const NodeDef& gather_node = match.node;
295 
296           // GatherV2 adds an "axis" parameter. sparsify_gather only supports
297           // axis 0 gathers.
298           if (gather_node.op() == "GatherV2") {
299             // Per the OpTypePattern, the 3rd input to Gather must be a Const.
300             const NodeDef& axis_node = match.inputs[2].node;
301 
302             Tensor axis_t;
303             TF_RETURN_IF_ERROR(GetNodeAttr(axis_node, "value", &axis_t));
304             int64_t axis = 0;
305             if (axis_t.dtype() == DT_INT32) {
306               axis = axis_t.scalar<int32>()();
307             } else if (axis_t.dtype() == DT_INT64) {
308               axis = axis_t.scalar<int64_t>()();
309             } else {
310               return tensorflow::errors::FailedPrecondition(
311                   "Gather axis was not int32 or int64.");
312             }
313 
314             if (axis != 0) {
315               return tensorflow::errors::FailedPrecondition(
316                   "Transform only applicable to subgraph with GatherV2 over "
317                   "axis 0. Found axis ",
318                   axis, ".");
319             }
320           }
321 
322           const NodeDef& weights_node = match.inputs[0].inputs[0].node;
323 
324           DataType data_type;
325           TF_RETURN_IF_ERROR(GetNodeAttr(weights_node, "dtype", &data_type));
326           if (data_type != DT_FLOAT) {
327             return tensorflow::errors::FailedPrecondition(
328                 "Transform only applicable to subgraph with 'Const',"
329                 "'Variable', or 'VariableV2' of dtype "
330                 "'DT_FLOAT'. Found '" +
331                     weights_node.op() + "' with name '",
332                 weights_node.name(), "' and dtype '", data_type, "'.");
333           }
334 
335           Tensor weight;
336           if (weights_node.op() == "Const") {
337             weight = GetNodeTensorAttr(weights_node, "value");
338           } else {
339             TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
340                 weights_node.name(), ckpt_reader,
341                 (*shapes_and_slices)[weights_node.name()], &weight));
342           }
343           // Add both weight and identity node names.
344           removed_node_names.push_back(weights_node.name());
345           removed_node_names.push_back(match.inputs[0].node.name());
346           for (auto input_node : match.inputs[0].node.input()) {
347             auto parsed_input = StringReplace(input_node, "^", "", true);
348             refs[parsed_input]--;
349           }
350           Tensor indices_tensor;
351           Tensor values_tensor;
352           TF_RETURN_IF_ERROR(
353               SparsifyWeights(weight, &indices_tensor, &values_tensor));
354 
355           // indices and values of sparsified `Const`
356           DataType key_dtype = DT_INT64;
357           NodeDef indices_node;
358           CreateConstNode(indices_tensor,
359                           StrCat(weights_node.name(), "/indices"),
360                           &indices_node);
361           SetNodeAttr("dtype", key_dtype, &indices_node);
362 
363           NodeDef values_node;
364           CreateConstNode(values_tensor, StrCat(weights_node.name(), "/values"),
365                           &values_node);
366           SetNodeAttr("dtype", data_type, &values_node);
367 
368           // HashTable node
369           NodeDef hashtable_node;
370           hashtable_node.set_op("HashTable");
371           hashtable_node.set_name(StrCat(weights_node.name(), "/HashTable"));
372           SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
373           SetNodeAttr("value_dtype", data_type, &hashtable_node);
374 
375           // InitializeTable node
376           NodeDef init_table_node;
377           init_table_node.set_op("InitializeTable");
378           init_table_node.set_name(
379               StrCat(weights_node.name(), "/InitializeTable"));
380           SetNodeAttr("Tkey", key_dtype, &init_table_node);
381           SetNodeAttr("Tval", data_type, &init_table_node);
382           init_table_node_names.push_back(init_table_node.name());
383 
384           // LookupTableFind node
385           NodeDef lookup_node;
386           lookup_node.set_op("LookupTableFind");
387           lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
388           SetNodeAttr("Tin", key_dtype, &lookup_node);
389           SetNodeAttr("Tout", data_type, &lookup_node);
390 
391           // Default return value of hashtable lookup
392           Tensor zero_tensor(data_type, TensorShape({}));
393           zero_tensor.flat<float>()(0) = 0.0;
394           NodeDef default_value_node;
395           CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
396                           &default_value_node);
397           SetNodeAttr("dtype", data_type, &default_value_node);
398 
399           // ExpandDims argument
400           Tensor dim_idx(DT_INT32, TensorShape({}));
401           dim_idx.flat<int32>()(0) = -1;
402           NodeDef dim_idx_node;
403           dim_idx_node.set_op("Const");
404           dim_idx_node.set_name(
405               StrCat(gather_node.name(), "/ExpandDims/Const"));
406           SetNodeAttr("value", dim_idx, &dim_idx_node);
407           SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
408 
409           // ExpandDims node
410           NodeDef expand_dims_node;
411           expand_dims_node.set_op("ExpandDims");
412           // Reuse gather_node's name so not to change dependent's inputs
413           expand_dims_node.set_name(gather_node.name());
414           SetNodeAttr("T", data_type, &expand_dims_node);
415 
416           // Connect nodes
417           AddNodeInput(hashtable_node.name(), &init_table_node);
418           refs[hashtable_node.name()]++;
419           AddNodeInput(indices_node.name(), &init_table_node);
420           refs[indices_node.name()]++;
421           AddNodeInput(values_node.name(), &init_table_node);
422           refs[values_node.name()]++;
423 
424           AddNodeInput(hashtable_node.name(), &lookup_node);
425           refs[hashtable_node.name()]++;
426           AddNodeInput(gather_node.input(1), &lookup_node);
427           refs[gather_node.input(1)]++;
428           AddNodeInput(default_value_node.name(), &lookup_node);
429           refs[default_value_node.name()]++;
430 
431           AddNodeInput(lookup_node.name(), &expand_dims_node);
432           refs[lookup_node.name()]++;
433           AddNodeInput(dim_idx_node.name(), &expand_dims_node);
434           refs[dim_idx_node.name()]++;
435 
436           // Copy 'ids' input of original 'Gather'
437           new_nodes->push_back(match.inputs[1].node);
438           new_nodes->push_back(indices_node);
439           new_nodes->push_back(values_node);
440           new_nodes->push_back(hashtable_node);
441           new_nodes->push_back(init_table_node);
442           new_nodes->push_back(lookup_node);
443           new_nodes->push_back(default_value_node);
444           new_nodes->push_back(dim_idx_node);
445           new_nodes->push_back(expand_dims_node);
446 
447           return OkStatus();
448         },
449         {true}, &replaced_graph_def));
450 
451     NodeDef* init_op = nullptr;
452     for (int i = 0; i < replaced_graph_def.node_size(); i++) {
453       if (replaced_graph_def.node(i).name() == group_init_node &&
454           replaced_graph_def.node(i).op() == "NoOp") {
455         init_op = replaced_graph_def.mutable_node(i);
456         break;
457       }
458     }
459     if (!init_op) {
460       // Init node
461       init_op = replaced_graph_def.mutable_node()->Add();
462       init_op->set_op("NoOp");
463       init_op->set_name(group_init_node);
464     }
465     for (const string& name : init_table_node_names) {
466       // Add control dependence from init_table_node to group_deps_node
467       AddNodeInput(StrCat("^", name), init_op);
468       refs[name]++;
469     }
470 
471     // Erase inputs and outputs as they are not considered for deletion.
472     for (const auto& output : context.output_names) {
473       refs.erase(output);
474     }
475 
476     for (const auto& input : context.input_names) {
477       refs.erase(input);
478     }
479 
480     // Add nodes with a reference count of 0 for deletion.
481     for (const auto& entry : refs) {
482       if (entry.second == 0) {
483         removed_node_names.push_back(entry.first);
484       }
485     }
486 
487     while (!removed_node_names.empty()) {
488       auto name = removed_node_names.back();
489       removed_node_names.pop_back();
490 
491       int i = 0;
492       while (i < replaced_graph_def.node_size()) {
493         // Revisit this to see if we can safely remove RestoreV2 nodes.
494         if ((replaced_graph_def.node(i).name() == name) &&
495             (replaced_graph_def.node(i).op() != "RestoreV2")) {
496           for (const auto& input : replaced_graph_def.node(i).input()) {
497             auto parsed_input = StringReplace(input, "^", "", true);
498             refs[parsed_input] -= 1;
499             if (refs[parsed_input] == 0) {
500               removed_node_names.push_back(parsed_input);
501             }
502           }
503           TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
504           continue;
505         }
506         int j = 0;
507         bool deleted_inputs = false;
508         while (j < replaced_graph_def.node(i).input_size()) {
509           if (replaced_graph_def.node(i).input(j) == name ||
510               replaced_graph_def.node(i).input(j) == ("^" + name)) {
511             TF_RETURN_IF_ERROR(
512                 RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
513             deleted_inputs = true;
514             continue;
515           }
516           j++;
517         }
518         if (deleted_inputs) {
519           if (replaced_graph_def.node(i).op() == "ConcatV2") {
520             if (replaced_graph_def.node(i).input_size() > 2) {
521               SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
522                           replaced_graph_def.mutable_node(i));
523             } else if (replaced_graph_def.node(i).input_size() == 2) {
524               if (refs[replaced_graph_def.node(i).input(1)] != 1) {
525                 return errors::Internal(
526                     "Expect axis tensor of ConcatV2 node to only be referenced "
527                     "once.");
528               }
529               refs[replaced_graph_def.node(i).input(1)] -= 1;
530               removed_node_names.push_back(replaced_graph_def.node(i).input(1));
531               replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
532               replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
533               replaced_graph_def.mutable_node(i)->set_op("Identity");
534             } else {
535               return errors::Internal(
536                   "ConcatV2 should have at least two elements");
537             }
538           }
539           if ((replaced_graph_def.node(i).op() == "Assign" ||
540                replaced_graph_def.node(i).op() == "Reshape" ||
541                replaced_graph_def.node(i).op() == "Equal" ||
542                replaced_graph_def.node(i).op() == "Mean" ||
543                replaced_graph_def.node(i).op() == "ScalarSummary") &&
544               replaced_graph_def.node(i).input_size() == 1) {
545             removed_node_names.push_back(replaced_graph_def.node(i).name());
546           }
547           if (!replaced_graph_def.node(i).input_size()) {
548             removed_node_names.push_back(replaced_graph_def.node(i).name());
549           }
550         }
551         i++;
552       }
553     }
554     current_graph_def = replaced_graph_def;
555   } while (any_match_found);
556   *output_graph_def = current_graph_def;
557   return OkStatus();
558 }
559 
SparsifyGather(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)560 Status SparsifyGather(const GraphDef& input_graph_def,
561                       const TransformFuncContext& context,
562                       GraphDef* output_graph_def) {
563   // clang-format off
564   const OpTypePattern gather_pattern =
565     {"Gather",
566      {
567        {"Identity",
568         {
569           {"Const|Variable|VariableV2"}
570         }
571        },
572        {"*"},
573      }
574     };
575   const OpTypePattern gather_v2_pattern =
576     {"GatherV2",
577       {
578         {"Identity",
579           {
580             {"Const|Variable|VariableV2"}
581           }
582         },
583         {"*"},
584         // GatherV2's axis must be constant.
585         {"Const"},
586       }
587     };
588   // clang-format on
589 
590   GraphDef cleaned_input_graph_def;
591   RemoveAttributes(input_graph_def, {"_output_shapes"},
592                    &cleaned_input_graph_def);
593 
594   GraphDef temp_output;
595 
596   std::unique_ptr<BundleReader> ckpt_reader;
597   TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
598 
599   std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
600   TF_RETURN_IF_ERROR(
601       ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
602 
603   TF_RETURN_IF_ERROR(SparsifyGatherInternal(
604       cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
605       ckpt_reader, &temp_output));
606 
607   TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
608                                             context, gather_v2_pattern,
609                                             ckpt_reader, output_graph_def));
610 
611   return OkStatus();
612 }
613 
614 REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
615 
616 }  // namespace graph_transforms
617 }  // namespace tensorflow
618