xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/dataset.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/metrics.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/tensor_util.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
36 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
37 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
38 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
39 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
40 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
41 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
42 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
43 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
44 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
46 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
47 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
48 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
49 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
50 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
51 #include "tensorflow/core/grappler/optimizers/remapper.h"
52 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
53 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
54 #include "tensorflow/core/grappler/utils/canonicalizer.h"
55 #include "tensorflow/core/grappler/utils/colocation.h"
56 #include "tensorflow/core/grappler/utils/functions.h"
57 #include "tensorflow/core/grappler/utils/topological_sort.h"
58 #include "tensorflow/core/grappler/utils/tpu.h"
59 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/lib/gtl/map_util.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/util/dump_graph.h"
64 #include "tensorflow/core/util/ptr_util.h"
65 #include "tensorflow/core/util/util.h"
66 #include "tensorflow/core/util/xla_config_registry.h"
67 
68 // #TODO(b/200087693): LLVM does not build on Fuchsia.
69 #ifndef __Fuchsia__
70 #include "tensorflow/core/grappler/optimizers/tfg_optimizer_hook.h"
71 #include "tensorflow/core/grappler/optimizers/tfg_passes_builder.h"
72 #endif
73 
74 namespace tensorflow {
75 namespace grappler {
76 
77 namespace {
78 
79 constexpr int kDefaultNumberOfIterations = 2;
80 constexpr int kDefaultMinGraphNodes = 4;
81 constexpr char kGrapplerCategory[] = "Grappler";
82 
NumEdges(const GraphDef & graph)83 int64_t NumEdges(const GraphDef& graph) {
84   int64_t num_edges = 0;
85   for (const auto& node : graph.node()) {
86     num_edges += node.input_size();
87   }
88   return num_edges;
89 }
90 
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)91 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
92   return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
93                          after.node_size() - before.node_size(), "), ",
94                          NumEdges(after), " edges (",
95                          NumEdges(after) - NumEdges(before), ")");
96 }
97 
NumIterations(const RewriterConfig & cfg)98 int NumIterations(const RewriterConfig& cfg) {
99   return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
100              ? kDefaultNumberOfIterations
101              : cfg.meta_optimizer_iterations();
102 }
103 
104 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)105 bool IsRunOnceOptimizer(const string& name) {
106   return name == "layout" || name == "memory_optimizer" ||
107          name == "loop_optimizer" ||
108          absl::StartsWith(name, "auto_mixed_precision");
109 }
110 
111 // Creates a function library stub from a real function library: copy only
112 // signatures and attributes of all the function defined in fdef_lib. This stub
113 // can be swapped with real function library in a graph, before passing it to
114 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)115 FunctionDefLibrary GetFunctionDefLibraryStub(
116     const FunctionDefLibrary& fdef_lib) {
117   FunctionDefLibrary stub;
118   for (const FunctionDef& fn : fdef_lib.function()) {
119     FunctionDef* fn_stub = stub.mutable_function()->Add();
120     *(fn_stub->mutable_signature()) = fn.signature();
121     *(fn_stub->mutable_attr()) = fn.attr();
122     *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
123     *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
124   }
125   *stub.mutable_gradient() = fdef_lib.gradient();
126   return stub;
127 }
128 
DeadlineMicroSeconds(const RewriterConfig & cfg)129 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
130   if (cfg.meta_optimizer_timeout_ms() <= 0) return 0;  // no deadline
131   return Env::Default()->NowMicros() + cfg.meta_optimizer_timeout_ms() * 1000;
132 }
133 
134 // A helper function to decide whether to enable the automatic mixed precision
135 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)136 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
137   if (opt_level == RewriterConfig::ON ||
138       opt_level == RewriterConfig::AGGRESSIVE) {
139     return true;
140   } else if (opt_level == RewriterConfig::EXPERIMENTAL_MLIR ||
141              opt_level == RewriterConfig::EXPERIMENTAL_BOTH) {
142     VLOG(2) << "auto_mixed_precision is not implemented in TFG yet";
143   }
144   return false;
145 }
146 
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)147 bool IsXlaGlobalJitOn(
148     const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
149   xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
150       xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
151   // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
152   // graphs. This is a conservative approach that turns off the memory optimizer
153   // when we are sure that all graphs will be processed by XLA JIT.
154   return xla_global_jit_level.single_gpu >= OptimizerOptions::ON_1 &&
155          xla_global_jit_level.general >= OptimizerOptions::ON_1;
156 }
157 
158 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,bool xla_auto_clustering_on)159 bool MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,
160                             bool xla_auto_clustering_on) {
161   // Disable the default memory optimizer when XLA JIT is ON as it hurts the
162   // XLA JIT performance. The (current) XLA clustering can result in loss of
163   // concurrency between kernel compute and memory copies. As such, it usually
164   // loses the concurrency needed to hide the latencies of the inserted swap-ins
165   // and swap-outs and incurs great performance overhead. Remove this check when
166   // the XLA JIT can better deal with the concurrency.
167   if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
168       xla_auto_clustering_on) {
169     return false;
170   }
171 
172   return mem_opt_type != RewriterConfig::NO_MEM_OPT;
173 }
174 
GetGraphDevice(const GraphDef & g_def,std::set<std::string> * devices)175 Status GetGraphDevice(const GraphDef& g_def, std::set<std::string>* devices) {
176   for (auto& node : g_def.node()) {
177     DeviceNameUtils::ParsedName parsed_name;
178     if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
179       return errors::InvalidArgument("Unable to parse ", node.device(),
180                                      " as a device name");
181     }
182     devices->insert(parsed_name.type);
183   }
184   return OkStatus();
185 }
186 
187 }  // namespace
188 
189 #define MK_OPT(NAME, CONFIG, VALUE)                                    \
190   if (optimizer == NAME) {                                             \
191     if (plugin_configs.toggle_config[CONFIG] != RewriterConfig::OFF) { \
192       return std::unique_ptr<GraphOptimizer>(VALUE);                   \
193     }                                                                  \
194   }
195 
LowerControlFlow() const196 bool MetaOptimizer::LowerControlFlow() const {
197   if (config_proto_.experimental().executor_type() ==
198       "SINGLE_THREADED_EXECUTOR")
199     return false;
200 
201   if (config_proto_.experimental().use_tfrt()) return false;
202 
203   return true;
204 }
205 
MakeNewOptimizer(const string & optimizer,const std::set<string> & device_types) const206 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
207     const string& optimizer, const std::set<string>& device_types) const {
208   ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
209       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
210   if (optimizer == "pruning" && !plugin_configs.disable_model_pruning)
211     return std::unique_ptr<GraphOptimizer>(new ModelPruner());
212   MK_OPT("function", "function_optimization",
213          new FunctionOptimizer(cfg_.function_optimization(),
214                                /*lower_control_flow=*/LowerControlFlow()));
215   MK_OPT("constfold", "constant_folding",
216          new ConstantFolding(
217              cpu_device_,
218              cfg_.experimental_disable_compressed_tensor_optimization(),
219              !cfg_.experimental_disable_folding_quantization_emulation()));
220   MK_OPT("shape", "shape_optimization", new ShapeOptimizer());
221   MK_OPT("remap", "remapping",
222          new Remapper(cfg_.remapping(), cfg_.cpu_layout_conversion(),
223                       xla_auto_clustering_on_));
224   MK_OPT("layout", "layout_optimizer",
225          new GenericLayoutOptimizer(
226              /*optimization level*/ cfg_.layout_optimizer(),
227              /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
228   MK_OPT("auto_mixed_precision", "auto_mixed_precision",
229          new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
230 #ifdef INTEL_MKL
231   if (IsMKLEnabled()) {
232     MK_OPT("auto_mixed_precision_mkl", "auto_mixed_precision_mkl",
233            new AutoMixedPrecision(AutoMixedPrecisionMode::BF16));
234     MK_OPT("auto_mixed_precision_onednn_bfloat16",
235            "auto_mixed_precision_onednn_bfloat16",
236            new AutoMixedPrecision(AutoMixedPrecisionMode::BF16));
237   }
238 #endif
239   MK_OPT("auto_mixed_precision_cpu", "auto_mixed_precision_cpu",
240          new AutoMixedPrecision(AutoMixedPrecisionMode::CPU));
241   MK_OPT("memory", "memory_optimization",
242          new MemoryOptimizer(RewriterConfig::MANUAL));
243   MK_OPT("common_subgraph_elimination", "common_subgraph_elimination",
244          new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
245   MK_OPT("arithmetic", "arithmetic_optimization",
246          new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
247   MK_OPT("autoparallel", "auto_parallel",
248          new AutoParallel(cfg_.auto_parallel().num_replicas()));
249   MK_OPT("loop", "loop_optimization",
250          new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
251   MK_OPT("dependency", "dependency_optimization",
252          new DependencyOptimizer(cfg_.dependency_optimization()));
253   MK_OPT("debug_stripper", "debug_stripper", new DebugStripper());
254   MK_OPT("scoped_allocator", "scoped_allocator_optimization",
255          new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
256                                       cfg_.scoped_allocator_opts()));
257   MK_OPT("pin_to_host", "pin_to_host_optimization",
258          new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
259 
260   return std::unique_ptr<GraphOptimizer>();
261 }
262 
263 #undef MK_OPT
264 
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)265 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
266     : cpu_device_(cpu_device),
267       config_proto_(cfg),
268       cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
269   DCHECK(cpu_device_ == nullptr ||
270          cpu_device_->attributes().device_type() == "CPU");
271   auto global_jit_level =
272       cfg.graph_options().optimizer_options().global_jit_level();
273   xla_auto_clustering_on_ = IsXlaGlobalJitOn(global_jit_level);
274 }
275 
InitializeOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const276 Status MetaOptimizer::InitializeOptimizers(
277     const std::set<string>& device_types,
278     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
279   if (cfg_.disable_meta_optimizer()) {
280     return OkStatus();
281   }
282 
283   ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
284       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
285   if (!cfg_.disable_model_pruning() && !plugin_configs.disable_model_pruning) {
286     optimizers->push_back(MakeUnique<ModelPruner>());
287   }
288 
289   // #TODO(b/200087693): LLVM does not build on Fuchsia.
290 #ifndef __Fuchsia__
291   // Hooks the MLIR optimizer, it won't run any optimizations right now. This
292   // optimizer instance runs on functions one at a time; don't use any threads.
293   optimizers->push_back(MakeUnique<mlir::tfg::TFGGrapplerOptimizer>(
294       mlir::tfg::DefaultGrapplerPipeline));
295 #endif
296 
297 // A set of macro utilities which check if the toggle of an optimization.
298 // Support both user and plugin configurations.
299 #define USER_IS_ON(CFG) cfg_.CFG() == RewriterConfig::ON
300 #define USER_IS_EXPERIMENTAL_MLIR(CFG) \
301   cfg_.CFG() == RewriterConfig::EXPERIMENTAL_MLIR
302 #define USER_IS_EXPERIMENTAL_BOTH(CFG) \
303   cfg_.CFG() == RewriterConfig::EXPERIMENTAL_BOTH
304 #define USER_NOT_OFF(CFG) cfg_.CFG() != RewriterConfig::OFF
305 #define PLUGIN_IS_ON(CFG) \
306   plugin_configs.toggle_config[#CFG] == RewriterConfig::ON
307 #define PLUGIN_IS_EXPERIMENTAL_MLIR(CFG) \
308   plugin_configs.toggle_config[#CFG] == RewriterConfig::EXPERIMENTAL_MLIR
309 #define PLUGIN_IS_EXPERIMENTAL_BOTH(CFG) \
310   plugin_configs.toggle_config[#CFG] == RewriterConfig::EXPERIMENTAL_BOTH
311 #define PLUGIN_NOT_OFF(CFG) \
312   plugin_configs.toggle_config[#CFG] != RewriterConfig::OFF
313 #define BOTH_ARE_ON(CFG) (USER_IS_ON(CFG) && PLUGIN_IS_ON(CFG))
314 #define BOTH_NOT_OFF(CFG) (USER_NOT_OFF(CFG) && PLUGIN_NOT_OFF(CFG))
315 #define BOTH_ARE_EXPERIMENTAL_MLIR(CFG) \
316   (USER_IS_EXPERIMENTAL_MLIR(CFG) && PLUGIN_IS_EXPERIMENTAL_MLIR(CFG))
317 #define BOTH_ARE_EXPERIMENTAL_BOTH(CFG) \
318   (USER_IS_EXPERIMENTAL_BOTH(CFG) && PLUGIN_IS_EXPERIMENTAL_BOTH(CFG))
319   if (BOTH_NOT_OFF(implementation_selector)) {
320     if (USER_IS_EXPERIMENTAL_MLIR(implementation_selector) ||
321         USER_IS_EXPERIMENTAL_BOTH(implementation_selector))
322       VLOG(2) << "implementation_selector is not implemented in TFG yet";
323     else
324       optimizers->push_back(MakeUnique<ImplementationSelector>());
325   }
326   if (BOTH_NOT_OFF(function_optimization)) {
327     if (USER_IS_EXPERIMENTAL_MLIR(function_optimization) ||
328         USER_IS_EXPERIMENTAL_BOTH(function_optimization)) {
329       VLOG(2) << "function_optimization is not implemented in TFG yet";
330     } else {
331       optimizers->push_back(MakeUnique<FunctionOptimizer>(
332           cfg_.function_optimization(),
333           /*lower_control_flow=*/LowerControlFlow()));
334     }
335   }
336   if (BOTH_NOT_OFF(common_subgraph_elimination) &&
337       BOTH_NOT_OFF(arithmetic_optimization)) {
338     if (USER_IS_EXPERIMENTAL_MLIR(common_subgraph_elimination) ||
339         USER_IS_EXPERIMENTAL_BOTH(common_subgraph_elimination)) {
340       VLOG(2) << "common_subgraph_elimination is not implemented in TFG yet";
341     } else {
342       optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
343           cfg_.common_subgraph_elimination()));
344     }
345   }
346   if (BOTH_ARE_ON(debug_stripper))
347     optimizers->push_back(MakeUnique<DebugStripper>());
348   else if (BOTH_ARE_EXPERIMENTAL_MLIR(debug_stripper) ||
349            BOTH_ARE_EXPERIMENTAL_BOTH(debug_stripper))
350     VLOG(2) << "debug_stripper is not implemented in TFG yet";
351   if (BOTH_NOT_OFF(constant_folding)) {
352     if (USER_IS_EXPERIMENTAL_MLIR(constant_folding) ||
353         USER_IS_EXPERIMENTAL_BOTH(constant_folding)) {
354       VLOG(2) << "constant_folding is not implemented in TFG yet";
355     } else {
356       optimizers->push_back(MakeUnique<ConstantFolding>(
357           cfg_.constant_folding(), cpu_device_,
358           cfg_.experimental_disable_compressed_tensor_optimization(),
359           !cfg_.experimental_disable_folding_quantization_emulation()));
360     }
361   }
362   if (BOTH_NOT_OFF(shape_optimization)) {
363     if (USER_IS_EXPERIMENTAL_MLIR(shape_optimization) ||
364         USER_IS_EXPERIMENTAL_BOTH(shape_optimization))
365       VLOG(2) << "shape_optimization is not implemented in TFG yet";
366     else
367       optimizers->push_back(MakeUnique<ShapeOptimizer>());
368   }
369   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) &&
370       AutoMixedPrecisionEnabled(
371           plugin_configs.toggle_config["auto_mixed_precision"])) {
372     optimizers->push_back(
373         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
374   }
375 #ifdef INTEL_MKL
376   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16()) &&
377       AutoMixedPrecisionEnabled(
378           plugin_configs
379               .toggle_config["auto_mixed_precision_onednn_bfloat16"]) &&
380       IsMKLEnabled()) {
381     optimizers->push_back(
382         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::BF16));
383   }
384   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl()) &&
385       AutoMixedPrecisionEnabled(
386           plugin_configs.toggle_config["auto_mixed_precision_mkl"]) &&
387       IsMKLEnabled()) {
388     LOG_FIRST_N(WARNING, 1)
389         << "NOTE: auto_mixed_precision_mkl is deprecated."
390            " Please use auto_mixed_precision_onednn_bfloat16 instead";
391     optimizers->push_back(
392         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::BF16));
393   }
394 #endif
395   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_cpu()) &&
396       AutoMixedPrecisionEnabled(
397           plugin_configs.toggle_config["auto_mixed_precision_cpu"])) {
398     optimizers->push_back(
399         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CPU));
400   }
401   if (BOTH_ARE_ON(pin_to_host_optimization))
402     optimizers->push_back(MakeUnique<PinToHostOptimizer>());
403   else if (BOTH_ARE_EXPERIMENTAL_MLIR(pin_to_host_optimization) ||
404            BOTH_ARE_EXPERIMENTAL_BOTH(pin_to_host_optimization))
405     VLOG(2) << "pin_to_host_optimization is not implemented in TFG yet";
406   if (BOTH_NOT_OFF(arithmetic_optimization)) {
407     if (USER_IS_EXPERIMENTAL_MLIR(arithmetic_optimization) ||
408         USER_IS_EXPERIMENTAL_BOTH(arithmetic_optimization)) {
409       VLOG(2) << "arithmetic_optimization is not implemented in TFG yet";
410     } else {
411       optimizers->push_back(
412           MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
413     }
414   }
415   if (BOTH_NOT_OFF(layout_optimizer)) {
416     if (USER_IS_EXPERIMENTAL_MLIR(layout_optimizer) ||
417         USER_IS_EXPERIMENTAL_BOTH(layout_optimizer)) {
418       VLOG(2) << "layout_optimizer is not implemented in TFG yet";
419     } else {
420       optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
421           /*optimization level*/ cfg_.layout_optimizer(),
422           /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
423     }
424   }
425   if (BOTH_NOT_OFF(remapping)) {
426     bool enable_mlir_pass = USER_IS_EXPERIMENTAL_MLIR(remapping) ||
427                             USER_IS_EXPERIMENTAL_BOTH(remapping);
428     bool enable_grappler_pass =
429         !enable_mlir_pass || USER_IS_EXPERIMENTAL_BOTH(remapping);
430     if (enable_mlir_pass) {
431 // #TODO(b/200087693): LLVM does not build on Fuchsia.
432 #ifndef __Fuchsia__
433       optimizers->push_back(MakeUnique<mlir::tfg::TFGGrapplerOptimizer>(
434           mlir::tfg::RemapperPassBuilder));
435 #else
436       VLOG(2) << "mlir Remapper pass is not supported on Fuchsia";
437 #endif
438     }
439     if (enable_grappler_pass) {
440       optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping(),
441                                                  cfg_.cpu_layout_conversion(),
442                                                  xla_auto_clustering_on_));
443     }
444   }
445   if (BOTH_NOT_OFF(loop_optimization)) {
446     if (USER_IS_EXPERIMENTAL_MLIR(loop_optimization) ||
447         USER_IS_EXPERIMENTAL_BOTH(loop_optimization)) {
448       VLOG(2) << "loop_optimization is not implemented in TFG yet";
449     } else {
450       optimizers->push_back(
451           MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
452     }
453   }
454   if (BOTH_NOT_OFF(dependency_optimization)) {
455     if (USER_IS_EXPERIMENTAL_MLIR(dependency_optimization) ||
456         USER_IS_EXPERIMENTAL_BOTH(dependency_optimization)) {
457       VLOG(2) << "dependency_optimization is not implemented in TFG yet";
458     } else {
459       optimizers->push_back(
460           MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
461     }
462   }
463   if (MemoryOptimizerEnabled(cfg_.memory_optimization(),
464                              xla_auto_clustering_on_) &&
465       PLUGIN_NOT_OFF(memory_optimization)) {
466     if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
467       optimizers->push_back(
468           // Use the default target node name prefix "gradients/"
469           MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
470     } else {
471       optimizers->push_back(MakeUnique<MemoryOptimizer>(
472           cfg_.memory_optimization(),
473           cfg_.memory_optimizer_target_node_name_scope()));
474     }
475   }
476   if (cfg_.auto_parallel().enable() && PLUGIN_IS_ON(auto_parallel)) {
477     optimizers->push_back(
478         MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
479   }
480 
481 #ifndef ENABLE_MKL
482   if (BOTH_ARE_ON(scoped_allocator_optimization)) {
483     optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
484         cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
485   } else if (BOTH_ARE_EXPERIMENTAL_MLIR(scoped_allocator_optimization) ||
486              BOTH_ARE_EXPERIMENTAL_BOTH(scoped_allocator_optimization)) {
487     VLOG(2) << "scoped_allocator_optimization is not implemented in TFG yet";
488   }
489 #endif
490 
491 #undef USER_IS_ON
492 #undef USER_IS_EXPERIMENTAL_MLIR
493 #undef USER_IS_EXPERIMENTAL_BOTH
494 #undef USER_NOT_OFF
495 #undef PLUGIN_IS_ON
496 #undef PLUGIN_NOT_OFF
497 #undef BOTH_ARE_ON
498 #undef BOTH_NOT_OFF
499   return InitializeCustomGraphOptimizers(device_types, std::set<string>(),
500                                          optimizers);
501 }
502 
InitializeOptimizersByName(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const503 Status MetaOptimizer::InitializeOptimizersByName(
504     const std::set<string>& device_types,
505     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
506   std::set<string> initialized_custom_optimizers;
507   for (const string& optimizer_name : cfg_.optimizers()) {
508     auto optimizer = MakeNewOptimizer(optimizer_name, device_types);
509     if (optimizer) {
510       VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
511       optimizers->push_back(std::move(optimizer));
512       continue;
513     }
514 
515     auto custom_optimizer =
516         CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
517 
518     if (custom_optimizer) {
519       VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
520       TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
521           config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
522       optimizers->push_back(std::move(custom_optimizer));
523       initialized_custom_optimizers.insert(optimizer_name);
524     } else {
525       VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
526     }
527   }
528   return InitializeCustomGraphOptimizers(
529       device_types, initialized_custom_optimizers, optimizers);
530 }
531 
InitializeCustomGraphOptimizers(const std::set<string> & device_types,const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const532 Status MetaOptimizer::InitializeCustomGraphOptimizers(
533     const std::set<string>& device_types,
534     const std::set<string>& pre_initialized_optimizers,
535     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
536   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
537     if (pre_initialized_optimizers.find(optimizer_config.name()) !=
538         pre_initialized_optimizers.end()) {
539       continue;
540     }
541 
542     auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
543         optimizer_config.name());
544 
545     if (custom_optimizer) {
546       VLOG(2) << "Registered custom configurable graph optimizer: "
547               << optimizer_config.name();
548       TF_RETURN_IF_ERROR(
549           custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
550       optimizers->push_back(std::move(custom_optimizer));
551     } else {
552       // If there are no custom optimizers with given name, try to initialize a
553       // default optimizer. This way, custom configurable optimizers can be
554       // mixed with default optimizers in any order.
555       auto optimizer = MakeNewOptimizer(optimizer_config.name(), device_types);
556       if (optimizer) {
557         VLOG(2) << "Registered default graph optimizer: "
558                 << optimizer_config.name();
559         optimizers->push_back(std::move(optimizer));
560         continue;
561       }
562       VLOG(2) << "Can't register an optimizer by name: "
563               << optimizer_config.name();
564     }
565   }
566   return InitializePluginGraphOptimizers(device_types, optimizers);
567 }
568 
InitializePluginGraphOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const569 Status MetaOptimizer::InitializePluginGraphOptimizers(
570     const std::set<string>& device_types,
571     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
572   if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return OkStatus();
573   auto plugin_optimizers =
574       PluginGraphOptimizerRegistry::CreateOptimizers(device_types);
575   for (auto& plugin_optimizer : plugin_optimizers) {
576     optimizers->push_back(std::move(plugin_optimizer));
577   }
578   return OkStatus();
579 }
580 
581 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const582 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
583   for (const auto& config : cfg_.custom_optimizers()) {
584     if (config.name() == name) {
585       return &config;
586     }
587   }
588   return nullptr;
589 }
590 
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const591 void MetaOptimizer::InitializeVerifiers(
592     std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
593     std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
594     const {
595   if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
596       VerifierConfig::ON) {
597     inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
598   }
599   if (cfg_.post_optimization_verifier_config().structure_verifier() ==
600       VerifierConfig::ON) {
601     post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
602   }
603 }
604 
PrintUserAndPluginConfigs(const std::set<string> & device_types) const605 void MetaOptimizer::PrintUserAndPluginConfigs(
606     const std::set<string>& device_types) const {
607   if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return;
608   ConfigList plugin_cfg = PluginGraphOptimizerRegistry::GetPluginConfigs(
609       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
610   PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(device_types);
611 
612   ConfigList user_cfg;
613   // Print user's and plugin's configs.
614   if (cfg_.optimizers().empty()) {
615     if (cfg_.disable_meta_optimizer()) {
616       return;
617     }
618     user_cfg.disable_model_pruning = cfg_.disable_model_pruning();
619 #define PRINT_CFG(CFG) user_cfg.toggle_config[#CFG] = cfg_.CFG();
620     PRINT_CFG(implementation_selector)
621     PRINT_CFG(function_optimization)
622     PRINT_CFG(common_subgraph_elimination)
623     PRINT_CFG(arithmetic_optimization)
624     PRINT_CFG(debug_stripper)
625     PRINT_CFG(constant_folding)
626     PRINT_CFG(shape_optimization)
627     PRINT_CFG(pin_to_host_optimization)
628     PRINT_CFG(layout_optimizer)
629     PRINT_CFG(remapping)
630     PRINT_CFG(loop_optimization)
631     PRINT_CFG(dependency_optimization)
632     PRINT_CFG(scoped_allocator_optimization)
633 #undef PRINT_CFG
634     user_cfg.toggle_config["auto_mixed_precision"] =
635         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())
636             ? RewriterConfig::ON
637             : RewriterConfig::OFF;
638     user_cfg.toggle_config["auto_mixed_precision_onednn_bfloat16"] =
639         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16())
640             ? RewriterConfig::ON
641             : RewriterConfig::OFF;
642     user_cfg.toggle_config["auto_mixed_precision_mkl"] =
643         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())
644             ? RewriterConfig::ON
645             : RewriterConfig::OFF;
646     user_cfg.toggle_config["auto_mixed_precision_cpu"] =
647         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_cpu())
648             ? RewriterConfig::ON
649             : RewriterConfig::OFF;
650     user_cfg.toggle_config["memory_optimization"] =
651         MemoryOptimizerEnabled(cfg_.memory_optimization(),
652                                config_proto_.graph_options()
653                                    .optimizer_options()
654                                    .global_jit_level())
655             ? RewriterConfig::ON
656             : RewriterConfig::OFF;
657     user_cfg.toggle_config["auto_parallel"] = cfg_.auto_parallel().enable()
658                                                   ? RewriterConfig::ON
659                                                   : RewriterConfig::OFF;
660   } else {
661     for (const string& optimizer_name : cfg_.optimizers()) {
662       if (optimizer_name == "pruning") user_cfg.disable_model_pruning = true;
663 
664 #define PRINT_CFG(NAME, CONFIG) \
665   if (optimizer_name == NAME)   \
666     user_cfg.toggle_config[CONFIG] = RewriterConfig::ON;
667 
668       PRINT_CFG("implementation_selector", "implementation_selector")
669       PRINT_CFG("function", "function_optimization")
670       PRINT_CFG("common_subgraph_elimination", "common_subgraph_elimination")
671       PRINT_CFG("arithmetic", "arithmetic_optimization")
672       PRINT_CFG("debug_stripper", "debug_stripper")
673       PRINT_CFG("constfold", "constant_folding")
674       PRINT_CFG("shape", "shape_optimization")
675       PRINT_CFG("auto_mixed_precision", "auto_mixed_precision")
676       PRINT_CFG("auto_mixed_precision_onednn_bfloat16",
677                 "auto_mixed_precision_onednn_bfloat16")
678       PRINT_CFG("auto_mixed_precision_mkl", "auto_mixed_precision_mkl")
679       PRINT_CFG("auto_mixed_precision_cpu", "auto_mixed_precision_cpu")
680       PRINT_CFG("pin_to_host", "pin_to_host_optimization")
681       PRINT_CFG("layout", "layout_optimizer")
682       PRINT_CFG("remap", "remapping")
683       PRINT_CFG("loop", "loop_optimization")
684       PRINT_CFG("dependency", "dependency_optimization")
685       PRINT_CFG("memory", "memory_optimization")
686       PRINT_CFG("autoparallel", "auto_parallel")
687       PRINT_CFG("scoped_allocator", "scoped_allocator_optimization")
688 #undef PRINT_CFG
689     }
690   }
691 
692   // Print logs only when plugin config has conflict with user config.
693   if (!PluginGraphOptimizerRegistry::IsConfigsConflict(user_cfg, plugin_cfg))
694     return;
695 
696   ConfigList final_cfg = user_cfg;
697   // If plugin turns on `disable_model_pruning`, then `disable_model_pruning`
698   // should be true;
699   if (plugin_cfg.disable_model_pruning == true)
700     final_cfg.disable_model_pruning = true;
701   // If plugin turns off a certain optimizer, then the optimizer should be
702   // turned off;
703   for (auto& pair : plugin_cfg.toggle_config) {
704     if (plugin_cfg.toggle_config[pair.first] == RewriterConfig::OFF)
705       final_cfg.toggle_config[pair.first] = RewriterConfig::OFF;
706   }
707 
708   string logs =
709       "\nConfig of optimizers\t\tUser's config\tPlugin's config\tFinal "
710       "config(User & Plugin)\n";
711   strings::StrAppend(&logs, "disable_model_pruning\t\t",
712                      user_cfg.disable_model_pruning, "\t\t",
713                      plugin_cfg.disable_model_pruning, "\t\t",
714                      final_cfg.disable_model_pruning, "\n");
715   for (auto& pair : user_cfg.toggle_config) {
716     if (pair.first == "debug_stripper" ||
717         pair.first == "auto_mixed_precision" ||
718         pair.first == "auto_mixed_precision_onednn_bfloat16" ||
719         pair.first == "auto_mixed_precision_mkl" ||
720         pair.first == "auto_mixed_precision_cpu" ||
721         pair.first == "pin_to_host_optimization" ||
722         pair.first == "scoped_allocator_optimization") {
723       // These optimizers are turned off by default.
724       // TODO(penporn): Remove the hard-coded length and change it to max length
725       // of all option strings.
726       strings::StrAppend(
727           &logs, pair.first, string(40 - pair.first.size(), ' '),
728           (pair.second == RewriterConfig::ON), "\t\t",
729           (plugin_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\t\t",
730           (final_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\n");
731     } else {
732       // These optimizers are turned on by default.
733       // TODO(penporn): Remove the hard-coded length and change it to max length
734       // of all option strings.
735       strings::StrAppend(
736           &logs, pair.first, string(40 - pair.first.size(), ' '),
737           (pair.second != RewriterConfig::OFF), "\t\t",
738           (plugin_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\t\t",
739           (final_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\n");
740     }
741   }
742   LOG(WARNING) << "User's config has been changed based on plugin's config.";
743   LOG(WARNING) << logs;
744 }
745 
OptimizeGraph(const std::vector<std::unique_ptr<GraphOptimizer>> & optimizers,Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)746 Status MetaOptimizer::OptimizeGraph(
747     const std::vector<std::unique_ptr<GraphOptimizer>>& optimizers,
748     Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) {
749   int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
750                                                     : cfg_.min_graph_nodes();
751   if (item.graph.node_size() < min_graph_nodes) {
752     VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
753             << " nodes.";
754     *optimized_graph = item.graph;
755     return OkStatus();
756   }
757 
758   tensorflow::metrics::ScopedCounter<2> timings(
759       tensorflow::metrics::GetGraphOptimizationCounter(),
760       {kGrapplerCategory, "OptimizeMainGraph"});
761 
762   // Initialize the configured verifiers.
763   std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
764   std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
765   InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
766   if (inter_optimizer_verifiers.empty()) {
767     VLOG(2) << "No inter optimizer verifiers have been configured";
768   } else {
769     VLOG(2) << inter_optimizer_verifiers.size()
770             << " inter optimizer verifiers have been configured";
771   }
772   if (post_optimization_verifiers.empty()) {
773     VLOG(2) << "No post optimization verifiers have been configured";
774   } else {
775     VLOG(2) << post_optimization_verifiers.size()
776             << " post optimization verifiers have been configured";
777   }
778 
779   VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
780           << " num_optimizers=" << optimizers.size()
781           << ", num nodes = " << item.graph.node_size();
782 
783   if (optimizers.empty()) {
784     VLOG(3) << "Skipping graph optimization, no optimizers registered";
785     *optimized_graph = item.graph;
786     return OkStatus();
787   }
788 
789   // Invariant: optimized_graph contains the most recently optimized version of
790   // the graph.
791   auto original_producer = item.graph.versions().producer();
792   *optimized_graph = std::move(item.graph);
793 
794   GraphOptimizationResult optimization_result(item.id);
795 #ifndef ENABLE_MKL
796   GraphOptimizer* sa_optimizer = nullptr;
797 #endif
798 
799   // Constants in the graph are normally compressed after model_pruner.
800   // Do it here if model pruner is disabled.
801   if (cfg_.disable_model_pruning()) {
802     CompressConstants(optimized_graph);
803   }
804 
805   for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
806     // Don't bother optimizing further if the graph is already tiny.
807     if (optimized_graph->node_size() < min_graph_nodes) {
808       VLOG(3) << "Stopping after iteration " << iteration
809               << ", graph is tiny (#nodes = " << optimized_graph->node_size()
810               << "  < " << min_graph_nodes << ")";
811       break;
812     }
813 
814     VLOG(4) << "Starting optimization iteration " << iteration;
815     if (VLOG_IS_ON(4)) {
816       DumpGraphDefToFile(
817           strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
818                           reinterpret_cast<uintptr_t>(optimized_graph)),
819           *optimized_graph);
820     }
821 
822     for (const auto& optimizer : optimizers) {
823       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
824       // Some optimizers can run only once.
825       if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
826 #ifndef ENABLE_MKL
827       // Some must run only on the last iteration.
828       if (optimizer->name() == "scoped_allocator_optimizer") {
829         if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
830         continue;
831       }
832 #endif
833 
834       TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
835                                       optimized_graph, &optimization_result));
836 
837       if (iteration == 0 && optimizer->name() == "model_pruner") {
838         CompressConstants(optimized_graph);
839       }
840 
841       if (VLOG_IS_ON(4)) {
842         DumpGraphDefToFile(
843             strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
844                             optimizer->name(), "_",
845                             reinterpret_cast<uintptr_t>(optimized_graph)),
846             *optimized_graph);
847       }
848       for (const auto& verifier : inter_optimizer_verifiers) {
849         // TODO(ashwinm): Need to enforce verification_deadline.
850         TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
851       }
852     }
853     if (VLOG_IS_ON(4)) {
854       DumpGraphDefToFile(
855           strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
856                           reinterpret_cast<uintptr_t>(optimized_graph)),
857           *optimized_graph);
858     }
859     // TODO(ashwinm): Need to enforce verification_deadline.
860     for (const auto& verifier : post_optimization_verifiers) {
861       TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
862     }
863   }
864 #ifndef ENABLE_MKL
865   // ScopedAllocatorOptimizer must run last.
866   if (sa_optimizer != nullptr) {
867     TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
868                                     optimized_graph, &optimization_result));
869     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
870   }
871 #endif
872 
873   bool is_optimized = std::find_if(optimization_result.results.begin(),
874                                    optimization_result.results.end(),
875                                    [](const OptimizerResult& result) {
876                                      return result.status.ok();
877                                    }) != optimization_result.results.end();
878 
879   // Record graph optimization result.
880   optimization_results_.push_back(optimization_result);
881 
882   if (is_optimized) {
883     TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
884     ReassignColocation(optimized_graph);
885     // Make sure that the optimizers preserved the graph version.
886     DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
887   }
888 
889   return OkStatus();
890 }
891 
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)892 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
893                                     GraphDef* optimized_graph) {
894   std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
895   std::set<std::string> device_types;
896   TF_RETURN_IF_ERROR(GetGraphDevice(item.graph, &device_types));
897   if (cfg_.optimizers().empty()) {
898     TF_RETURN_IF_ERROR(InitializeOptimizers(device_types, &optimizers));
899   } else {
900     TF_RETURN_IF_ERROR(InitializeOptimizersByName(device_types, &optimizers));
901   }
902   PrintUserAndPluginConfigs(device_types);
903 
904   return OptimizeGraph(std::move(optimizers), cluster, std::move(item),
905                        optimized_graph);
906 }
907 
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)908 Status MetaOptimizer::RunOptimizer(
909     GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
910     GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
911   // If optimizer doesn't need a function library, we will replace it with a
912   // stub before running optimization, and will put it back at the end.
913   std::unique_ptr<FunctionDefLibrary> optimized_graph_function_library;
914   const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
915 
916   // Replace function library in optimized graph with a stub.
917   if (!is_function_library_aware) {
918     VLOG(3) << "Replace function library with a stub for " << optimizer->name();
919     optimized_graph_function_library =
920         absl::WrapUnique(optimized_graph->release_library());
921     *optimized_graph->mutable_library() =
922         GetFunctionDefLibraryStub(*optimized_graph_function_library);
923   }
924 
925   // This swaps the current optimized_graph into optimized item and
926   // resets optimized_graph to an empty graph.
927   optimized_item->graph = std::move(*optimized_graph);
928   *optimized_graph = GraphDef();
929   optimizer->set_deadline_usec(this->deadline_usec());
930   tensorflow::metrics::ScopedCounter<2> timings(
931       tensorflow::metrics::GetGraphOptimizationCounter(),
932       {kGrapplerCategory, optimizer->name()});
933   Status status =
934       optimizer->Optimize(cluster, *optimized_item, optimized_graph);
935   auto duration_ms = timings.DurationMicroSec().value() / 1000.0f;
936   timings.ReportAndStop();
937 
938   string message;
939   if (!status.ok()) {
940     *optimized_graph = std::move(optimized_item->graph);
941     if (errors::IsAborted(status)) {
942       // By convention we (ab-)use the Aborted error code to signal that the
943       // optimizer returned without performing any changes to the graph.
944       message = strings::StrCat(optimizer->name(),
945                                 " did nothing. time = ", duration_ms, "ms.");
946       // Swallow the non-critical error.
947       status = OkStatus();
948     } else if (errors::IsDeadlineExceeded(status)) {
949       message =
950           strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
951       LOG(WARNING) << optimizer->name() << " failed: " << message;
952     } else {
953       message = status.ToString();
954       LOG(ERROR) << optimizer->name() << " failed: " << message;
955     }
956   } else {
957     message = strings::StrCat(
958         PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
959         ", time = ", duration_ms, "ms.");
960     VLOG(1) << optimizer->name() << ": " << message;
961   }
962 
963   // Swap function library back into the main graph.
964   if (!is_function_library_aware) {
965     optimized_graph->set_allocated_library(
966         optimized_graph_function_library.release());
967   }
968 
969   OptimizerResult optimizer_result{optimizer->name(), message, status};
970   optimization_result->results.push_back(optimizer_result);
971 
972   if (!status.ok()) {
973     if (cfg_.fail_on_optimizer_errors()) return status;
974 
975     // Non-aborted failures in the TFG optimizer are always fatal.
976     if (absl::StartsWith(optimizer->name(), "tfg_optimizer")) return status;
977   }
978 
979   return OkStatus();
980 }
981 
982 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)983 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
984                           FunctionDefLibrary& fdef_lib) {
985   // Collect functions that need the attribute in this set.
986   absl::flat_hash_set<std::string> tf_data_functions;
987   std::function<void(const std::string&)> collect_tf_data_functions_dfs =
988       [&](const std::string& func_name) -> void {
989     const FunctionDef* func_def = flib.Find(func_name);
990     // Skip functions that are not reachable from the optimized graph.
991     if (func_def == nullptr) return;
992 
993     // Return if we already found and added this function.
994     if (tf_data_functions.contains(func_name)) return;
995 
996     // We only get here if the function is (directly or indirectly) called from
997     // a tf.data function, so add it to the set.
998     tf_data_functions.insert(func_name);
999 
1000     // Proceed with DFS for functions called from current function.
1001     for (const NodeDef& node : func_def->node_def()) {
1002       if (flib.Contains(node.op())) {
1003         // This is a function call node.
1004         collect_tf_data_functions_dfs(node.op());
1005       }
1006       // Check if there are functions in attributes.
1007       for (const auto& attr : node.attr()) {
1008         const AttrValue& attr_value = attr.second;
1009         if (attr_value.has_func()) {
1010           collect_tf_data_functions_dfs(attr_value.func().name());
1011         }
1012         if (attr_value.has_list()) {
1013           for (const auto& func : attr_value.list().func()) {
1014             collect_tf_data_functions_dfs(func.name());
1015           }
1016         }
1017       }
1018     }
1019   };
1020   // Perform DFS for all tf.data functions in `fdef_lib`.
1021   for (const auto& func_def : fdef_lib.function()) {
1022     const std::string& func_name = func_def.signature().name();
1023     if (data::IsTFDataFunction(func_def))
1024       collect_tf_data_functions_dfs(func_name);
1025   }
1026   // Set attribute for tf.data functions. We cannot do this in the DFS directly
1027   // because `FunctionLibraryDefinition` does not seem to provide mutable access
1028   // to a `FunctionDef`.
1029   for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
1030     const std::string& func_name = func_def.signature().name();
1031     if (tf_data_functions.contains(func_name) &&
1032         !data::IsTFDataFunction(func_def)) {
1033       VLOG(2) << "Marking " << func_name << " as tf.data function";
1034       (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
1035     }
1036   }
1037 }
1038 
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)1039 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
1040                                           GraphDef* optimized_graph) {
1041   tensorflow::metrics::ScopedCounter<2> timings(
1042       tensorflow::metrics::GetGraphOptimizationCounter(),
1043       {kGrapplerCategory, "*"});
1044 
1045   VLOG(1) << "Starting optimization for grappler item: " << item.id;
1046   optimization_results_.clear();
1047 
1048   // Constructs a FunctionLibraryDefinition with functions that are reachable
1049   // from the nodes of the graph.
1050   const auto minimized_flib =
1051       [](const GraphDef& graph) -> FunctionLibraryDefinition {
1052     return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
1053         .ReachableDefinitions(graph);
1054   };
1055 
1056   // 0. Original graph might contain a huge function library, that is mostly
1057   // unused. This library copied over by each individual Grappler optimizer,
1058   // which adds a huge overhead. Before starting optimization passes we just
1059   // remove all the unreachable functions.
1060   // TODO(ezhulenev): Construct reachable function library definition directly
1061   // from the proto without constructing temporary FunctionLibraryDefinition.
1062   int old_library_size = item.graph.library().function_size();
1063   *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
1064   int new_library_size = item.graph.library().function_size();
1065 
1066   VLOG(1) << absl::Substitute(
1067       "Deleted $0 unreachable functions from the graph (library size = $1)",
1068       old_library_size - new_library_size, new_library_size);
1069 
1070   // Save a few small fields from item before we move it.
1071   bool optimize_function_library =
1072       item.optimization_options().optimize_function_library;
1073   const auto producer = item.graph.versions().producer();
1074 
1075   // 1. Optimize main graph
1076   TF_RETURN_IF_ERROR(
1077       OptimizeGraph(cluster, GrapplerItem(item), optimized_graph));
1078   VLOG(1) << "Optimized main graph.";
1079   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1080 
1081   // 2. Optimize functions reachable from the optimized graph.
1082   FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
1083   using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
1084 
1085   // Find functions for which we might need to compute a gradient at runtime.
1086   absl::flat_hash_set<string> differentiable_functions;
1087 
1088   const auto find_differentiable_functions =
1089       [&](const NodeDefs& nodes) -> void {
1090     for (const NodeDef& node : nodes) {
1091       if (IsSymbolicGradient(node)) {
1092         const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
1093         if (f_attr) differentiable_functions.insert(f_attr->func().name());
1094       }
1095     }
1096   };
1097 
1098   // SymbolicGradient nodes inside the main graph.
1099   find_differentiable_functions(optimized_graph->node());
1100   // SymbolicGradient nodes inside the function library.
1101   for (const FunctionDef& function : optimized_graph->library().function()) {
1102     find_differentiable_functions(function.node_def());
1103   }
1104 
1105   // Find functions that will be compiled by XLA later
1106   // We do it by looking for XlaLaunch ops that call functions,
1107   // then depth first search down those functions to find transitive functions.
1108   // Grappler rewrites can potentially add nodes that are
1109   // not supported by XLA, so we choose to skip such functions when we optimize
1110   // the function library.
1111   absl::flat_hash_set<string> xla_compiled_functions;
1112   std::function<void(const string&)> find_all_functions;
1113   find_all_functions = [&](const string& func) -> void {
1114     // Ignore call cycles in the graph
1115     if (xla_compiled_functions.contains(func)) return;
1116     // Find func in the flib
1117     const FunctionDef* func_def = flib.Find(func);
1118     CHECK(func_def) << "not found: " << func;
1119     // Mark function to be ignored by grappler
1120     xla_compiled_functions.insert(func);
1121     // Depth first search through the func for transitively called funcs
1122     for (const NodeDef& node : func_def->node_def()) {
1123       for (const auto& attr : node.attr()) {
1124         const AttrValue& attr_value = attr.second;
1125         if (attr_value.has_func()) {
1126           find_all_functions(attr_value.func().name());
1127         }
1128       }
1129     }
1130   };
1131 
1132   auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
1133     NameAttrList function;
1134     for (const NodeDef& node : nodes) {
1135       // Look only for XlaLaunch nodes that call a function
1136       if (!IsXlaLaunch(node)) continue;
1137       if (!GetNodeAttr(node, "function", &function).ok()) continue;
1138       // Find all transitively called functions
1139       find_all_functions(function.name());
1140     }
1141   };
1142 
1143   // XlaLaunch ops inside the main graph ...
1144   find_xla_compiled_functions(optimized_graph->node());
1145   // ... and inside the function library.
1146   for (const FunctionDef& function : optimized_graph->library().function()) {
1147     find_xla_compiled_functions(function.node_def());
1148   }
1149   // Propagate `_tf_data_function` attributes from functions to their callees.
1150   PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
1151 
1152   // True if this is a TPU graph using the old bridge.
1153   bool is_tpu_graph = IsLegacyTPUBridgeGraphDef(*optimized_graph);
1154 
1155   // Optimize each function only once.
1156   absl::flat_hash_set<string> optimized_funcs;
1157   while (optimize_function_library) {
1158     optimize_function_library = false;
1159 
1160     int function_idx = 0;
1161     for (const FunctionDef& func : optimized_graph->library().function()) {
1162       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1163 
1164       const string& func_name = func.signature().name();
1165 
1166       // Skip functions that are not reachable from the optimized graph.
1167       if (!flib.Contains(func_name)) continue;
1168       // Skip already optimized functions.
1169       if (optimized_funcs.contains(func_name)) continue;
1170       // Skip functions that will be compiled by XLA.
1171       if (xla_compiled_functions.contains(func_name)) continue;
1172 
1173       // Skip parametrized functions (function type or body is defined only at
1174       // function call time by caller node attributes).
1175       // They should be specialized to their instantiation type parameters by
1176       // the function optimizer, before we can optimize function body.
1177       if (IsParametrized(func)) continue;
1178 
1179       // Skip tf.data functions as they are optimized by tf.data meta optimizer
1180       // and in function instantiation.
1181       if (data::IsTFDataFunction(func)) continue;
1182 
1183       VLOG(3) << "Optimize function: function=" << func_name << " ["
1184               << function_idx++ << " of "
1185               << optimized_graph->library().function_size() << "]";
1186 
1187       // Function optimization might specialize nested function calls, so we
1188       // have to reset the flag and do at least one more pass over the library.
1189       optimize_function_library = true;
1190       optimized_funcs.insert(func_name);
1191 
1192       // Make a GrapplerItem from a FunctionDef.
1193       GrapplerFunctionItem func_item;
1194       TF_RETURN_IF_ERROR(
1195           MakeGrapplerFunctionItem(func, flib, producer, &func_item));
1196 
1197       // If we need to compute the gradient of optimized function at runtime, we
1198       // can't perform non-differentiable rewrites.
1199       func_item.optimization_options().allow_non_differentiable_rewrites =
1200           !differentiable_functions.contains(func_name);
1201 
1202       // Device set available to the function is defined only by the runtime,
1203       // when we instantiate and execute the function. We can't use all devices
1204       // available to the main graph, because after partitioning the function
1205       // call node might execute on a remote worker.
1206       if (!func_item.devices().empty()) {
1207         return errors::Internal("GrapplerFunctionItem devices must be empty.");
1208       }
1209 
1210       // We are not allowed to prune certain types of ops from the graph
1211       // instantiated by the function definition, because we must guarantee
1212       // function execution semantics wrt side effects (see
1213       // function_optimizer.cc).
1214       func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
1215           false;
1216 
1217       // Optimize function body graph.
1218       GraphDef optimized_func_graph;
1219       if (is_tpu_graph) {
1220         // Skip optimizing functions if this is a TPU graph. Currently, Grappler
1221         // passes do not handle TPU functions correctly in a variety of ways
1222         // (Note that due to the pre-placement TPU graph rewriting passes, the
1223         // TPU-related ops are encapsulated away into functions). For example,
1224         // TPU graphs contain TPUReplicateMetadata node that carries relevant
1225         // TPU metadata and Grappler passes could prune that away. Grappler
1226         // passes could also cause issues around shape inference. Since the
1227         // desired and existing behavior is to not optimize TPU functions with
1228         // Grappler, this check preserves that. The only exception is
1229         // implementation selector what is required to swap in some TPU specific
1230         // lowering code and is verified the work correctly on TPUs.
1231         ImplementationSelector implementation_selector;
1232 
1233         // Implementation selector needs to have access to valid function
1234         // signature and attributes, and it doesn't need actual function body.
1235         std::unique_ptr<FunctionDefLibrary> func_item_function_library(
1236             func_item.graph.release_library());
1237         *func_item.graph.mutable_library() =
1238             GetFunctionDefLibraryStub(*func_item_function_library);
1239 
1240         TF_RETURN_IF_ERROR(implementation_selector.Optimize(
1241             cluster, func_item, &optimized_func_graph));
1242       } else {
1243         GrapplerFunctionItem func_item_copy = func_item;
1244         TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
1245                                          &optimized_func_graph));
1246       }
1247 
1248       // Function body optimization might have created new specialized
1249       // functions for each instantiation context. Add them to the library.
1250       for (const FunctionDef& func_def :
1251            optimized_func_graph.library().function()) {
1252         if (flib.Find(func_def.signature().name()) == nullptr) {
1253           TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
1254         }
1255       }
1256 
1257       // Convert optimized graph back to FunctionDef.
1258       FunctionDef optimized_func;
1259       func_item.SwapFunctionBody(std::move(optimized_func_graph));
1260       TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
1261 
1262       // Replace optimized function with a new FunctionDef.
1263       TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
1264     }
1265 
1266     // If optimized at least one function, update the graph library.
1267     if (optimize_function_library) {
1268       *optimized_graph->mutable_library() = flib.ToProto();
1269     }
1270   }
1271 
1272   // Run module-level TFG optimizations at the end of the meta-optimizer.
1273   // TODO(jeffniu): None of the TFG optimizations are meant to create new
1274   // opportunities for other optimizers; they could, but it's unclear whether
1275   // re-running all the other optimizers is worthwhile.
1276 #ifndef __Fuchsia__
1277   {
1278     // Create a Grappler optimization pipeline with only the TFG optimizer.
1279     std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
1280     optimizers.push_back(std::make_unique<mlir::tfg::TFGGrapplerOptimizer>(
1281         // For module-level optimizations, use multithreading to process
1282         // functions in parallel.
1283         [&](mlir::PassManager& manager) {
1284           mlir::tfg::DefaultModuleGrapplerPipeline(manager, cfg_);
1285         },
1286         /*num_tfg_threads=*/4));
1287     // Wrap the optimized GraphDef in a new GrapplerItem with copied
1288     // configuration options from the provided item.
1289     GrapplerItem tfg_item = item.WithGraph(std::move(*optimized_graph));
1290     // Invoke the optimizers.
1291     *optimized_graph = GraphDef();
1292     TF_RETURN_IF_ERROR(OptimizeGraph(optimizers, cluster, std::move(tfg_item),
1293                                      optimized_graph));
1294   }
1295 #endif
1296 
1297   VLOG(1) << "Optimized " << optimized_funcs.size()
1298           << " functions: " << absl::StrJoin(optimized_funcs, ", ");
1299   VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
1300   if (VLOG_IS_ON(1)) {
1301     DumpGraphDefToFile(
1302         strings::StrCat("after_MetaOptimizer_",
1303                         reinterpret_cast<uintptr_t>(optimized_graph)),
1304         *optimized_graph);
1305   }
1306 
1307   return OkStatus();
1308 }
1309 
GetResultString() const1310 string MetaOptimizer::GetResultString() const {
1311   std::string result_string;
1312   for (const GraphOptimizationResult& graph_result : optimization_results_) {
1313     absl::StrAppend(&result_string,
1314                     "Optimization results for grappler item: ", graph_result.id,
1315                     "\n");
1316     for (const OptimizerResult& result : graph_result.results) {
1317       absl::StrAppend(&result_string, "  ", result.optimizer_name, ": ",
1318                       result.message, "\n");
1319     }
1320   }
1321   return result_string;
1322 }
1323 
PrintResult()1324 void MetaOptimizer::PrintResult() { VLOG(1) << GetResultString(); }
1325 
MetaOptimizerEnabled(const ConfigProto & cfg)1326 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
1327   const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
1328   if (rewrite_cfg.disable_meta_optimizer()) {
1329     return false;
1330   }
1331   return !rewrite_cfg.disable_model_pruning() ||
1332          rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
1333          rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
1334          rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
1335          rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
1336          rewrite_cfg.remapping() != RewriterConfig::OFF ||
1337          rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
1338          rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
1339          rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
1340          rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
1341          rewrite_cfg.auto_parallel().enable() ||
1342          rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
1343          rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
1344 #ifndef ENABLE_MKL
1345          rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
1346 #endif
1347          rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
1348          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
1349          AutoMixedPrecisionEnabled(
1350              rewrite_cfg.auto_mixed_precision_onednn_bfloat16()) ||
1351          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
1352          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_cpu()) ||
1353          !rewrite_cfg.optimizers().empty() ||
1354          !rewrite_cfg.custom_optimizers().empty();
1355 }
1356 
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)1357 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
1358                         DeviceBase* cpu_device, Cluster* cluster,
1359                         GraphDef* optimized_graph) {
1360   MetaOptimizer optimizer(cpu_device, cfg);
1361   optimizer.set_deadline_usec(
1362       DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
1363   return optimizer.OptimizeConsumeItem(cluster, std::move(item),
1364                                        optimized_graph);
1365 }
1366 
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)1367 Status OptimizeGraph(
1368     std::vector<string> ret_node_names, std::vector<string> keep_node_names,
1369     FunctionLibraryDefinition* flib, const DeviceSet& device_set,
1370     Device* cpu_device, const ConfigProto& config_proto,
1371     const string& grappler_item_id,
1372     const GrapplerItem::OptimizationOptions& optimization_options,
1373     std::unique_ptr<tensorflow::Graph>* g) {
1374   if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
1375     return OkStatus();
1376   }
1377 
1378   tensorflow::grappler::GrapplerItem item;
1379   item.id = grappler_item_id;
1380   item.optimization_options() = optimization_options;
1381 
1382   // Add all available devices so that inlined function can be placed.
1383   for (const Device* d : device_set.devices()) {
1384     Status added_device = item.AddDevice(d->name());
1385     if (!added_device.ok()) VLOG(3) << added_device.error_message();
1386   }
1387   VLOG(3) << "Grappler available devices: "
1388           << absl::StrJoin(item.devices(), ", ");
1389 
1390   // Add fetches so that the graph can be pruned.
1391   item.fetch.swap(ret_node_names);
1392 
1393   // Add noes that can't be removed from the graph.
1394   item.keep_ops = std::move(keep_node_names);
1395 
1396   (*g)->ToGraphDef(&item.graph);
1397 
1398   if (flib) {
1399     *item.graph.mutable_library() = flib->ToProto();
1400   }
1401 
1402   tensorflow::GraphDef out_graph;
1403   tensorflow::grappler::VirtualCluster cluster(&device_set);
1404   // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1405   // proto (which also contain the OptimizerOptions).
1406   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1407       std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1408 
1409   std::unique_ptr<tensorflow::Graph> optimized_graph(
1410       new tensorflow::Graph(OpRegistry::Global()));
1411 
1412   // Copy optimized functions back to the overlay lib.
1413   if (flib) {
1414     for (const FunctionDef& fdef : out_graph.library().function()) {
1415       const string& func_name = fdef.signature().name();
1416       if (flib->Contains(func_name)) {
1417         StackTracesMap stack_traces = flib->GetStackTraces(func_name);
1418         TF_RETURN_IF_ERROR(
1419             flib->ReplaceFunction(func_name, fdef, stack_traces));
1420       } else {
1421         TF_RETURN_IF_ERROR(
1422             flib->AddFunctionDef(fdef, flib->GetStackTraces(func_name)));
1423       }
1424     }
1425   }
1426 
1427   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1428       GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1429 
1430   // The graph conversion sets the requested device names but not the
1431   // assigned device names. However, since at this point the graph is
1432   // placed TF expects an assigned device name for every node. Therefore
1433   // we copy the requested device into the assigned device field.
1434   for (Node* node : optimized_graph->nodes()) {
1435     if (node->IsOp() && node->assigned_device_name().empty()) {
1436       if (node->requested_device().empty()) {
1437         return errors::Internal(
1438             "Either placer did not place the node or Grappler did not "
1439             "copy the assigned device. Contact Grappler team since latter "
1440             "is more likely. Node=",
1441             node->name(),
1442             " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1443       }
1444       node->set_assigned_device_name(node->requested_device());
1445     }
1446   }
1447 
1448   *g = std::move(optimized_graph);
1449   return OkStatus();
1450 }
1451 
1452 }  // namespace grappler
1453 }  // namespace tensorflow
1454