xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/rewrite_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/rewrite_utils.h"
16 
17 #include "tensorflow/core/platform/refcount.h"
18 
19 // On mobile we do not provide this functionality because not all of its
20 // dependencies are available there.
21 #if !defined(IS_MOBILE_PLATFORM)
22 
23 #include <algorithm>
24 #include <functional>
25 #include <map>
26 #include <memory>
27 #include <string>
28 #include <unordered_map>
29 #include <utility>
30 #include <vector>
31 
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/substitute.h"
35 #include "tensorflow/core/common_runtime/graph_constructor.h"
36 #include "tensorflow/core/common_runtime/graph_runner.h"
37 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
38 #include "tensorflow/core/data/dataset_utils.h"
39 #include "tensorflow/core/data/hash_utils.h"
40 #include "tensorflow/core/data/serialization_utils.h"
41 #include "tensorflow/core/framework/dataset.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/function.pb.h"
44 #include "tensorflow/core/framework/graph.pb.h"
45 #include "tensorflow/core/framework/metrics.h"
46 #include "tensorflow/core/framework/node_def.pb.h"
47 #include "tensorflow/core/framework/op.h"
48 #include "tensorflow/core/framework/op_def_util.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/graph/graph.h"
52 #include "tensorflow/core/graph/graph_def_builder.h"
53 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
54 #include "tensorflow/core/grappler/graph_view.h"
55 #include "tensorflow/core/grappler/grappler_item.h"
56 #include "tensorflow/core/grappler/grappler_item_builder.h"
57 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
58 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
59 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
60 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
61 #include "tensorflow/core/lib/hash/hash.h"
62 #include "tensorflow/core/lib/strings/proto_serialization.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/status.h"
65 #include "tensorflow/core/platform/statusor.h"
66 #include "tensorflow/core/platform/tstring.h"
67 #include "tensorflow/core/protobuf/config.pb.h"
68 #include "tensorflow/core/protobuf/device_properties.pb.h"
69 #include "tensorflow/core/protobuf/meta_graph.pb.h"
70 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
71 
72 namespace tensorflow {
73 namespace data {
74 namespace {
75 
76 constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
77 constexpr char kOptimizers[] = "optimizers";
78 constexpr char kOptimizerConfigs[] = "optimizer_configs";
79 
AddFakeSinks(FunctionDef * function_def)80 void AddFakeSinks(FunctionDef* function_def) {
81   int counter = 0;
82   for (const auto& output : function_def->signature().output_arg()) {
83     NodeDef* node = function_def->add_node_def();
84     tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
85         strings::StrCat("FakeSink", counter++), function_def, node);
86     node->set_op("Identity");
87     node->add_input(function_def->ret().at(output.name()));
88     (*node->mutable_attr())["T"].set_type(output.type());
89 
90     (*function_def->mutable_ret())[output.name()] =
91         strings::StrCat(node->name(), ":output:0");
92   }
93 }
94 
RemoveFakeSinks(FunctionDef * function_def)95 void RemoveFakeSinks(FunctionDef* function_def) {
96   // Map from identity node names to their input tensor strings
97   std::map<std::string, std::string> identity_map;
98   for (const auto& node : function_def->node_def()) {
99     if (node.op() == "Identity" && node.input_size() == 1) {
100       identity_map[node.name()] = node.input(0);
101     }
102   }
103   for (const auto& output_arg : function_def->signature().output_arg()) {
104     const std::string& tensor = function_def->ret().at(output_arg.name());
105     const std::string& output_node = tensor.substr(0, tensor.find(':'));
106     if (identity_map.find(output_node) != identity_map.end()) {
107       (*function_def->mutable_ret())[output_arg.name()] =
108           identity_map.at(output_node);
109     }
110   }
111 }
112 
ApplyRewrites(OpKernelContext * ctx,const std::function<RewriterConfig (void)> config_factory,GraphDef * graph_def,string * dataset_node)113 Status ApplyRewrites(OpKernelContext* ctx,
114                      const std::function<RewriterConfig(void)> config_factory,
115                      GraphDef* graph_def, string* dataset_node) {
116   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
117       GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true);
118   std::unordered_map<std::string, tensorflow::DeviceProperties> device_map;
119   tensorflow::grappler::VirtualCluster cluster(device_map);
120 
121   // Run data optimizer using grappler's meta optimizer.
122   tensorflow::ConfigProto config;
123   *config.mutable_graph_options()->mutable_rewrite_options() = config_factory();
124   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
125       std::move(*grappler_item), config, ctx->device(), &cluster, graph_def));
126 
127   // Remove fake sinks after optimizations are done.
128   //
129   // TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
130   // to be optimizable, we will no longer need this.
131   for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
132     RemoveFakeSinks(&function_def);
133   }
134 
135   return OkStatus();
136 }
137 }  // anonymous namespace
138 
CreateRewriterConfig(const absl::flat_hash_set<tstring> & optimizations,const absl::flat_hash_set<tstring> & optimizations_configs)139 RewriterConfig CreateRewriterConfig(
140     const absl::flat_hash_set<tstring>& optimizations,
141     const absl::flat_hash_set<tstring>& optimizations_configs) {
142   RewriterConfig rewriter_config;
143   rewriter_config.add_optimizers(kOptimizerName);
144   rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
145   rewriter_config.set_fail_on_optimizer_errors(true);
146   auto custom_optimizer = rewriter_config.add_custom_optimizers();
147   custom_optimizer->set_name(kOptimizerName);
148   auto* custom_optimizations_list =
149       (*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
150   const auto& registered_optimizers =
151       grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
152   for (const auto& optimization : optimizations) {
153     if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
154                   optimization) != registered_optimizers.end()) {
155       custom_optimizations_list->add_s(optimization.data(),
156                                        optimization.size());
157     } else {
158       VLOG(1) << "Optimization " << optimization << " is not registered.";
159     }
160   }
161   auto* config_list =
162       (*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
163           .mutable_list();
164   for (const auto& config : optimizations_configs) {
165     config_list->add_s(config.data(), config.size());
166   }
167   return rewriter_config;
168 }
169 
RewriteDataset(OpKernelContext * ctx,const DatasetBase * input,std::function<RewriterConfig (void)> config_factory,bool record_fingerprint,core::RefCountPtr<DatasetBase> * rewritten_input)170 Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
171                       std::function<RewriterConfig(void)> config_factory,
172                       bool record_fingerprint,
173                       core::RefCountPtr<DatasetBase>* rewritten_input) {
174   std::vector<std::pair<string, Tensor>> input_list;
175   GraphDef graph_def;
176   string output_node;
177   TF_RETURN_IF_ERROR(
178       AsGraphDefForRewrite(ctx, input, &input_list, &graph_def, &output_node));
179 
180   VLOG(3) << "Before graph rewrites: " << graph_def.DebugString();
181   TF_RETURN_IF_ERROR(
182       ApplyRewrites(ctx, config_factory, &graph_def, &output_node));
183   VLOG(3) << "After graph rewrites: " << graph_def.DebugString();
184 
185   // Instantiate the optimized input pipeline by running the optimized graph
186   // using the optimized function library.
187   FunctionLibraryRuntime* flr = nullptr;
188   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr = nullptr;
189   std::unique_ptr<FunctionLibraryDefinition> lib_def = nullptr;
190   TF_RETURN_IF_ERROR(
191       ctx->function_library()->Clone(&lib_def, &pflr, &flr, true));
192 
193   // Some functions may have been modified without having their names changed
194   // (for example, nested dataset graphs from FlatMap or Interleave).
195   TF_RETURN_IF_ERROR(AddToFunctionLibrary(lib_def.get(), graph_def.library()));
196 
197   Graph graph(OpRegistry::Global());
198   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
199   std::vector<Tensor> outputs;
200   GraphRunner graph_runner(flr->device());
201 
202   TF_RETURN_IF_ERROR(
203       graph_runner.Run(&graph, flr, input_list, {output_node}, &outputs));
204   DatasetBase* rewritten_dataset;
205   TF_RETURN_IF_ERROR(
206       GetDatasetFromVariantTensor(outputs[0], &rewritten_dataset));
207   rewritten_dataset->Ref();
208   rewritten_input->reset(rewritten_dataset);
209 
210   if (record_fingerprint) {
211     (*ctx->runner())([graph_def = std::move(graph_def),
212                       lib_def = lib_def.release(),
213                       input_list = std::move(input_list),
214                       output_node = std::move(output_node)]() {
215       std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
216       const NodeDef* node_def = nullptr;
217       for (const auto& node : graph_def.node()) {
218         if (node.name() == output_node) {
219           node_def = &node;
220           break;
221         }
222       }
223       if (node_def == nullptr) {
224         VLOG(3) << "Failed to find node: " << output_node;
225         return;
226       }
227       uint64 hash = 0;
228       Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
229       if (!s.ok()) {
230         VLOG(3) << "Failed to hash graph: " << s.ToString();
231         return;
232       }
233       for (const auto& pair : input_list) {
234         hash = Hash64CombineUnordered(hash, Hash64(pair.first));
235         uint64 tensor_hash = 0;
236         Status s = HashTensor(pair.second, &tensor_hash);
237         if (s.ok()) {
238           hash = Hash64CombineUnordered(hash, tensor_hash);
239         } else {
240           VLOG(3) << "Failed to hash tensor: " << s.ToString();
241         }
242       }
243       string graph_hash =
244           strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
245       metrics::RecordTFDataFingerprint(graph_hash);
246     });
247   }
248 
249   return OkStatus();
250 }
251 
GetGrapplerItem(GraphDef * graph_def,std::string * dataset_node,bool add_fake_sinks)252 std::unique_ptr<tensorflow::grappler::GrapplerItem> GetGrapplerItem(
253     GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks) {
254   // Add an identity node as the fetch node, otherwise we might get 'placeholder
255   // is both fed and fetched' errors in some cases when using input list with
256   // placeholder dataset nodes.
257   NodeDef* node = graph_def->mutable_node()->Add();
258   tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
259                                                             node);
260   node->set_op("Identity");
261   node->add_input(*dataset_node);
262   (*node->mutable_attr())["T"].set_type(DT_VARIANT);
263   *dataset_node = node->name();
264 
265   if (add_fake_sinks) {
266     // Add fake sink node to graph and functions to allow rewriting the actual
267     // sink nodes.
268     //
269     // TODO(b/118820916): When MetaOptimizer adds provisions for function
270     // retvals to be optimizable, we will no longer need this.
271     for (auto& function_def :
272          *graph_def->mutable_library()->mutable_function()) {
273       AddFakeSinks(&function_def);
274     }
275   }
276 
277   // Create metagraph.
278   MetaGraphDef meta_graph_def;
279   (*meta_graph_def.mutable_graph_def()) = *graph_def;
280 
281   // Grappler determines fetch ops from collection 'train_op'.
282   CollectionDef collection_def;
283   auto node_list = collection_def.mutable_node_list();
284   node_list->add_value(*dataset_node);
285   (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
286 
287   // Create Grappler item.
288   tensorflow::grappler::ItemConfig item_config;
289   item_config.apply_optimizations = true;
290   std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
291       tensorflow::grappler::GrapplerItemFromMetaGraphDef(
292           "graph", meta_graph_def, item_config);
293   // Grappler should not optimize function library of tf.data graphs. The
294   // tf.data meta optimizer takes care of optimizing tf.data functions.
295   grappler_item->optimization_options().optimize_function_library = false;
296   return grappler_item;
297 }
298 
SelectOptimizations(const absl::flat_hash_set<string> & experiments,const absl::flat_hash_set<tstring> & optimizations_enabled,const absl::flat_hash_set<tstring> & optimizations_disabled,const absl::flat_hash_set<tstring> & optimizations_default)299 absl::flat_hash_set<tstring> SelectOptimizations(
300     const absl::flat_hash_set<string>& experiments,
301     const absl::flat_hash_set<tstring>& optimizations_enabled,
302     const absl::flat_hash_set<tstring>& optimizations_disabled,
303     const absl::flat_hash_set<tstring>& optimizations_default) {
304   absl::flat_hash_set<tstring> optimizations;
305 
306   // Add the enabled optimizations.
307   optimizations.insert(optimizations_enabled.begin(),
308                        optimizations_enabled.end());
309 
310   // Add all default optimization that are not disabled.
311   for (const auto& optimization : optimizations_default) {
312     if (!optimizations_disabled.contains(optimization)) {
313       optimizations.insert(optimization);
314     }
315   }
316 
317   // Add experiments that correspond to an optimization unless the optimization
318   // is disabled.
319   const auto& registered_optimizers =
320       grappler::CustomGraphOptimizerRegistry::GetRegisteredOptimizers();
321   for (const auto& experiment : experiments) {
322     if (std::find(registered_optimizers.begin(), registered_optimizers.end(),
323                   experiment) != registered_optimizers.end() &&
324         !optimizations_disabled.contains(experiment)) {
325       optimizations.insert(experiment);
326     }
327   }
328 
329   return optimizations;
330 }
331 
GetDatasetNode(const GraphDef & graph_def)332 StatusOr<std::string> GetDatasetNode(const GraphDef& graph_def) {
333   // Symbolic `_Retval` node indicates which node corresponds to the dataset.
334   for (const auto& node : graph_def.node()) {
335     if (node.op() == "_Retval") {
336       return node.input(0);
337     }
338   }
339   return errors::NotFound(
340       absl::Substitute("Dataset node for graph is not found:\n$0",
341                        graph_def.ShortDebugString()));
342 }
343 
GetDatasetNodeDef(const GraphDef & graph_def)344 StatusOr<NodeDef> GetDatasetNodeDef(const GraphDef& graph_def) {
345   TF_ASSIGN_OR_RETURN(std::string dataset_node_name, GetDatasetNode(graph_def));
346   for (const auto& node : graph_def.node()) {
347     if (node.name() == dataset_node_name) {
348       return node;
349     }
350   }
351   return errors::NotFound(
352       absl::Substitute("Dataset node for graph is not found:\n$0",
353                        graph_def.ShortDebugString()));
354 }
355 
356 }  // namespace data
357 }  // namespace tensorflow
358 #endif  // !IS_MOBILE_PLATFORM
359