1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22
23 #include "absl/strings/match.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_join.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/dataset.h"
30 #include "tensorflow/core/framework/function.pb.h"
31 #include "tensorflow/core/framework/metrics.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/tensor_util.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
36 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
37 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
38 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
39 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
40 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
41 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
42 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
43 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
44 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
46 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
47 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
48 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
49 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
50 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
51 #include "tensorflow/core/grappler/optimizers/remapper.h"
52 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
53 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
54 #include "tensorflow/core/grappler/utils/canonicalizer.h"
55 #include "tensorflow/core/grappler/utils/colocation.h"
56 #include "tensorflow/core/grappler/utils/functions.h"
57 #include "tensorflow/core/grappler/utils/topological_sort.h"
58 #include "tensorflow/core/grappler/utils/tpu.h"
59 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
60 #include "tensorflow/core/lib/core/status.h"
61 #include "tensorflow/core/lib/gtl/map_util.h"
62 #include "tensorflow/core/platform/logging.h"
63 #include "tensorflow/core/util/dump_graph.h"
64 #include "tensorflow/core/util/ptr_util.h"
65 #include "tensorflow/core/util/util.h"
66 #include "tensorflow/core/util/xla_config_registry.h"
67
68 // #TODO(b/200087693): LLVM does not build on Fuchsia.
69 #ifndef __Fuchsia__
70 #include "tensorflow/core/grappler/optimizers/tfg_optimizer_hook.h"
71 #include "tensorflow/core/grappler/optimizers/tfg_passes_builder.h"
72 #endif
73
74 namespace tensorflow {
75 namespace grappler {
76
77 namespace {
78
79 constexpr int kDefaultNumberOfIterations = 2;
80 constexpr int kDefaultMinGraphNodes = 4;
81 constexpr char kGrapplerCategory[] = "Grappler";
82
NumEdges(const GraphDef & graph)83 int64_t NumEdges(const GraphDef& graph) {
84 int64_t num_edges = 0;
85 for (const auto& node : graph.node()) {
86 num_edges += node.input_size();
87 }
88 return num_edges;
89 }
90
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)91 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
92 return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
93 after.node_size() - before.node_size(), "), ",
94 NumEdges(after), " edges (",
95 NumEdges(after) - NumEdges(before), ")");
96 }
97
NumIterations(const RewriterConfig & cfg)98 int NumIterations(const RewriterConfig& cfg) {
99 return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
100 ? kDefaultNumberOfIterations
101 : cfg.meta_optimizer_iterations();
102 }
103
104 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)105 bool IsRunOnceOptimizer(const string& name) {
106 return name == "layout" || name == "memory_optimizer" ||
107 name == "loop_optimizer" ||
108 absl::StartsWith(name, "auto_mixed_precision");
109 }
110
111 // Creates a function library stub from a real function library: copy only
112 // signatures and attributes of all the function defined in fdef_lib. This stub
113 // can be swapped with real function library in a graph, before passing it to
114 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)115 FunctionDefLibrary GetFunctionDefLibraryStub(
116 const FunctionDefLibrary& fdef_lib) {
117 FunctionDefLibrary stub;
118 for (const FunctionDef& fn : fdef_lib.function()) {
119 FunctionDef* fn_stub = stub.mutable_function()->Add();
120 *(fn_stub->mutable_signature()) = fn.signature();
121 *(fn_stub->mutable_attr()) = fn.attr();
122 *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
123 *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
124 }
125 *stub.mutable_gradient() = fdef_lib.gradient();
126 return stub;
127 }
128
DeadlineMicroSeconds(const RewriterConfig & cfg)129 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
130 if (cfg.meta_optimizer_timeout_ms() <= 0) return 0; // no deadline
131 return Env::Default()->NowMicros() + cfg.meta_optimizer_timeout_ms() * 1000;
132 }
133
134 // A helper function to decide whether to enable the automatic mixed precision
135 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)136 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
137 if (opt_level == RewriterConfig::ON ||
138 opt_level == RewriterConfig::AGGRESSIVE) {
139 return true;
140 } else if (opt_level == RewriterConfig::EXPERIMENTAL_MLIR ||
141 opt_level == RewriterConfig::EXPERIMENTAL_BOTH) {
142 VLOG(2) << "auto_mixed_precision is not implemented in TFG yet";
143 }
144 return false;
145 }
146
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)147 bool IsXlaGlobalJitOn(
148 const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
149 xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
150 xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
151 // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
152 // graphs. This is a conservative approach that turns off the memory optimizer
153 // when we are sure that all graphs will be processed by XLA JIT.
154 return xla_global_jit_level.single_gpu >= OptimizerOptions::ON_1 &&
155 xla_global_jit_level.general >= OptimizerOptions::ON_1;
156 }
157
158 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,bool xla_auto_clustering_on)159 bool MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,
160 bool xla_auto_clustering_on) {
161 // Disable the default memory optimizer when XLA JIT is ON as it hurts the
162 // XLA JIT performance. The (current) XLA clustering can result in loss of
163 // concurrency between kernel compute and memory copies. As such, it usually
164 // loses the concurrency needed to hide the latencies of the inserted swap-ins
165 // and swap-outs and incurs great performance overhead. Remove this check when
166 // the XLA JIT can better deal with the concurrency.
167 if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
168 xla_auto_clustering_on) {
169 return false;
170 }
171
172 return mem_opt_type != RewriterConfig::NO_MEM_OPT;
173 }
174
GetGraphDevice(const GraphDef & g_def,std::set<std::string> * devices)175 Status GetGraphDevice(const GraphDef& g_def, std::set<std::string>* devices) {
176 for (auto& node : g_def.node()) {
177 DeviceNameUtils::ParsedName parsed_name;
178 if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
179 return errors::InvalidArgument("Unable to parse ", node.device(),
180 " as a device name");
181 }
182 devices->insert(parsed_name.type);
183 }
184 return OkStatus();
185 }
186
187 } // namespace
188
189 #define MK_OPT(NAME, CONFIG, VALUE) \
190 if (optimizer == NAME) { \
191 if (plugin_configs.toggle_config[CONFIG] != RewriterConfig::OFF) { \
192 return std::unique_ptr<GraphOptimizer>(VALUE); \
193 } \
194 }
195
LowerControlFlow() const196 bool MetaOptimizer::LowerControlFlow() const {
197 if (config_proto_.experimental().executor_type() ==
198 "SINGLE_THREADED_EXECUTOR")
199 return false;
200
201 if (config_proto_.experimental().use_tfrt()) return false;
202
203 return true;
204 }
205
MakeNewOptimizer(const string & optimizer,const std::set<string> & device_types) const206 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
207 const string& optimizer, const std::set<string>& device_types) const {
208 ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
209 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
210 if (optimizer == "pruning" && !plugin_configs.disable_model_pruning)
211 return std::unique_ptr<GraphOptimizer>(new ModelPruner());
212 MK_OPT("function", "function_optimization",
213 new FunctionOptimizer(cfg_.function_optimization(),
214 /*lower_control_flow=*/LowerControlFlow()));
215 MK_OPT("constfold", "constant_folding",
216 new ConstantFolding(
217 cpu_device_,
218 cfg_.experimental_disable_compressed_tensor_optimization(),
219 !cfg_.experimental_disable_folding_quantization_emulation()));
220 MK_OPT("shape", "shape_optimization", new ShapeOptimizer());
221 MK_OPT("remap", "remapping",
222 new Remapper(cfg_.remapping(), cfg_.cpu_layout_conversion(),
223 xla_auto_clustering_on_));
224 MK_OPT("layout", "layout_optimizer",
225 new GenericLayoutOptimizer(
226 /*optimization level*/ cfg_.layout_optimizer(),
227 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
228 MK_OPT("auto_mixed_precision", "auto_mixed_precision",
229 new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
230 #ifdef INTEL_MKL
231 if (IsMKLEnabled()) {
232 MK_OPT("auto_mixed_precision_mkl", "auto_mixed_precision_mkl",
233 new AutoMixedPrecision(AutoMixedPrecisionMode::BF16));
234 MK_OPT("auto_mixed_precision_onednn_bfloat16",
235 "auto_mixed_precision_onednn_bfloat16",
236 new AutoMixedPrecision(AutoMixedPrecisionMode::BF16));
237 }
238 #endif
239 MK_OPT("auto_mixed_precision_cpu", "auto_mixed_precision_cpu",
240 new AutoMixedPrecision(AutoMixedPrecisionMode::CPU));
241 MK_OPT("memory", "memory_optimization",
242 new MemoryOptimizer(RewriterConfig::MANUAL));
243 MK_OPT("common_subgraph_elimination", "common_subgraph_elimination",
244 new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
245 MK_OPT("arithmetic", "arithmetic_optimization",
246 new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
247 MK_OPT("autoparallel", "auto_parallel",
248 new AutoParallel(cfg_.auto_parallel().num_replicas()));
249 MK_OPT("loop", "loop_optimization",
250 new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
251 MK_OPT("dependency", "dependency_optimization",
252 new DependencyOptimizer(cfg_.dependency_optimization()));
253 MK_OPT("debug_stripper", "debug_stripper", new DebugStripper());
254 MK_OPT("scoped_allocator", "scoped_allocator_optimization",
255 new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
256 cfg_.scoped_allocator_opts()));
257 MK_OPT("pin_to_host", "pin_to_host_optimization",
258 new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
259
260 return std::unique_ptr<GraphOptimizer>();
261 }
262
263 #undef MK_OPT
264
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)265 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
266 : cpu_device_(cpu_device),
267 config_proto_(cfg),
268 cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
269 DCHECK(cpu_device_ == nullptr ||
270 cpu_device_->attributes().device_type() == "CPU");
271 auto global_jit_level =
272 cfg.graph_options().optimizer_options().global_jit_level();
273 xla_auto_clustering_on_ = IsXlaGlobalJitOn(global_jit_level);
274 }
275
InitializeOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const276 Status MetaOptimizer::InitializeOptimizers(
277 const std::set<string>& device_types,
278 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
279 if (cfg_.disable_meta_optimizer()) {
280 return OkStatus();
281 }
282
283 ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
284 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
285 if (!cfg_.disable_model_pruning() && !plugin_configs.disable_model_pruning) {
286 optimizers->push_back(MakeUnique<ModelPruner>());
287 }
288
289 // #TODO(b/200087693): LLVM does not build on Fuchsia.
290 #ifndef __Fuchsia__
291 // Hooks the MLIR optimizer, it won't run any optimizations right now. This
292 // optimizer instance runs on functions one at a time; don't use any threads.
293 optimizers->push_back(MakeUnique<mlir::tfg::TFGGrapplerOptimizer>(
294 mlir::tfg::DefaultGrapplerPipeline));
295 #endif
296
297 // A set of macro utilities which check if the toggle of an optimization.
298 // Support both user and plugin configurations.
299 #define USER_IS_ON(CFG) cfg_.CFG() == RewriterConfig::ON
300 #define USER_IS_EXPERIMENTAL_MLIR(CFG) \
301 cfg_.CFG() == RewriterConfig::EXPERIMENTAL_MLIR
302 #define USER_IS_EXPERIMENTAL_BOTH(CFG) \
303 cfg_.CFG() == RewriterConfig::EXPERIMENTAL_BOTH
304 #define USER_NOT_OFF(CFG) cfg_.CFG() != RewriterConfig::OFF
305 #define PLUGIN_IS_ON(CFG) \
306 plugin_configs.toggle_config[#CFG] == RewriterConfig::ON
307 #define PLUGIN_IS_EXPERIMENTAL_MLIR(CFG) \
308 plugin_configs.toggle_config[#CFG] == RewriterConfig::EXPERIMENTAL_MLIR
309 #define PLUGIN_IS_EXPERIMENTAL_BOTH(CFG) \
310 plugin_configs.toggle_config[#CFG] == RewriterConfig::EXPERIMENTAL_BOTH
311 #define PLUGIN_NOT_OFF(CFG) \
312 plugin_configs.toggle_config[#CFG] != RewriterConfig::OFF
313 #define BOTH_ARE_ON(CFG) (USER_IS_ON(CFG) && PLUGIN_IS_ON(CFG))
314 #define BOTH_NOT_OFF(CFG) (USER_NOT_OFF(CFG) && PLUGIN_NOT_OFF(CFG))
315 #define BOTH_ARE_EXPERIMENTAL_MLIR(CFG) \
316 (USER_IS_EXPERIMENTAL_MLIR(CFG) && PLUGIN_IS_EXPERIMENTAL_MLIR(CFG))
317 #define BOTH_ARE_EXPERIMENTAL_BOTH(CFG) \
318 (USER_IS_EXPERIMENTAL_BOTH(CFG) && PLUGIN_IS_EXPERIMENTAL_BOTH(CFG))
319 if (BOTH_NOT_OFF(implementation_selector)) {
320 if (USER_IS_EXPERIMENTAL_MLIR(implementation_selector) ||
321 USER_IS_EXPERIMENTAL_BOTH(implementation_selector))
322 VLOG(2) << "implementation_selector is not implemented in TFG yet";
323 else
324 optimizers->push_back(MakeUnique<ImplementationSelector>());
325 }
326 if (BOTH_NOT_OFF(function_optimization)) {
327 if (USER_IS_EXPERIMENTAL_MLIR(function_optimization) ||
328 USER_IS_EXPERIMENTAL_BOTH(function_optimization)) {
329 VLOG(2) << "function_optimization is not implemented in TFG yet";
330 } else {
331 optimizers->push_back(MakeUnique<FunctionOptimizer>(
332 cfg_.function_optimization(),
333 /*lower_control_flow=*/LowerControlFlow()));
334 }
335 }
336 if (BOTH_NOT_OFF(common_subgraph_elimination) &&
337 BOTH_NOT_OFF(arithmetic_optimization)) {
338 if (USER_IS_EXPERIMENTAL_MLIR(common_subgraph_elimination) ||
339 USER_IS_EXPERIMENTAL_BOTH(common_subgraph_elimination)) {
340 VLOG(2) << "common_subgraph_elimination is not implemented in TFG yet";
341 } else {
342 optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
343 cfg_.common_subgraph_elimination()));
344 }
345 }
346 if (BOTH_ARE_ON(debug_stripper))
347 optimizers->push_back(MakeUnique<DebugStripper>());
348 else if (BOTH_ARE_EXPERIMENTAL_MLIR(debug_stripper) ||
349 BOTH_ARE_EXPERIMENTAL_BOTH(debug_stripper))
350 VLOG(2) << "debug_stripper is not implemented in TFG yet";
351 if (BOTH_NOT_OFF(constant_folding)) {
352 if (USER_IS_EXPERIMENTAL_MLIR(constant_folding) ||
353 USER_IS_EXPERIMENTAL_BOTH(constant_folding)) {
354 VLOG(2) << "constant_folding is not implemented in TFG yet";
355 } else {
356 optimizers->push_back(MakeUnique<ConstantFolding>(
357 cfg_.constant_folding(), cpu_device_,
358 cfg_.experimental_disable_compressed_tensor_optimization(),
359 !cfg_.experimental_disable_folding_quantization_emulation()));
360 }
361 }
362 if (BOTH_NOT_OFF(shape_optimization)) {
363 if (USER_IS_EXPERIMENTAL_MLIR(shape_optimization) ||
364 USER_IS_EXPERIMENTAL_BOTH(shape_optimization))
365 VLOG(2) << "shape_optimization is not implemented in TFG yet";
366 else
367 optimizers->push_back(MakeUnique<ShapeOptimizer>());
368 }
369 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) &&
370 AutoMixedPrecisionEnabled(
371 plugin_configs.toggle_config["auto_mixed_precision"])) {
372 optimizers->push_back(
373 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
374 }
375 #ifdef INTEL_MKL
376 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16()) &&
377 AutoMixedPrecisionEnabled(
378 plugin_configs
379 .toggle_config["auto_mixed_precision_onednn_bfloat16"]) &&
380 IsMKLEnabled()) {
381 optimizers->push_back(
382 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::BF16));
383 }
384 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl()) &&
385 AutoMixedPrecisionEnabled(
386 plugin_configs.toggle_config["auto_mixed_precision_mkl"]) &&
387 IsMKLEnabled()) {
388 LOG_FIRST_N(WARNING, 1)
389 << "NOTE: auto_mixed_precision_mkl is deprecated."
390 " Please use auto_mixed_precision_onednn_bfloat16 instead";
391 optimizers->push_back(
392 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::BF16));
393 }
394 #endif
395 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_cpu()) &&
396 AutoMixedPrecisionEnabled(
397 plugin_configs.toggle_config["auto_mixed_precision_cpu"])) {
398 optimizers->push_back(
399 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CPU));
400 }
401 if (BOTH_ARE_ON(pin_to_host_optimization))
402 optimizers->push_back(MakeUnique<PinToHostOptimizer>());
403 else if (BOTH_ARE_EXPERIMENTAL_MLIR(pin_to_host_optimization) ||
404 BOTH_ARE_EXPERIMENTAL_BOTH(pin_to_host_optimization))
405 VLOG(2) << "pin_to_host_optimization is not implemented in TFG yet";
406 if (BOTH_NOT_OFF(arithmetic_optimization)) {
407 if (USER_IS_EXPERIMENTAL_MLIR(arithmetic_optimization) ||
408 USER_IS_EXPERIMENTAL_BOTH(arithmetic_optimization)) {
409 VLOG(2) << "arithmetic_optimization is not implemented in TFG yet";
410 } else {
411 optimizers->push_back(
412 MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
413 }
414 }
415 if (BOTH_NOT_OFF(layout_optimizer)) {
416 if (USER_IS_EXPERIMENTAL_MLIR(layout_optimizer) ||
417 USER_IS_EXPERIMENTAL_BOTH(layout_optimizer)) {
418 VLOG(2) << "layout_optimizer is not implemented in TFG yet";
419 } else {
420 optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
421 /*optimization level*/ cfg_.layout_optimizer(),
422 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
423 }
424 }
425 if (BOTH_NOT_OFF(remapping)) {
426 bool enable_mlir_pass = USER_IS_EXPERIMENTAL_MLIR(remapping) ||
427 USER_IS_EXPERIMENTAL_BOTH(remapping);
428 bool enable_grappler_pass =
429 !enable_mlir_pass || USER_IS_EXPERIMENTAL_BOTH(remapping);
430 if (enable_mlir_pass) {
431 // #TODO(b/200087693): LLVM does not build on Fuchsia.
432 #ifndef __Fuchsia__
433 optimizers->push_back(MakeUnique<mlir::tfg::TFGGrapplerOptimizer>(
434 mlir::tfg::RemapperPassBuilder));
435 #else
436 VLOG(2) << "mlir Remapper pass is not supported on Fuchsia";
437 #endif
438 }
439 if (enable_grappler_pass) {
440 optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping(),
441 cfg_.cpu_layout_conversion(),
442 xla_auto_clustering_on_));
443 }
444 }
445 if (BOTH_NOT_OFF(loop_optimization)) {
446 if (USER_IS_EXPERIMENTAL_MLIR(loop_optimization) ||
447 USER_IS_EXPERIMENTAL_BOTH(loop_optimization)) {
448 VLOG(2) << "loop_optimization is not implemented in TFG yet";
449 } else {
450 optimizers->push_back(
451 MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
452 }
453 }
454 if (BOTH_NOT_OFF(dependency_optimization)) {
455 if (USER_IS_EXPERIMENTAL_MLIR(dependency_optimization) ||
456 USER_IS_EXPERIMENTAL_BOTH(dependency_optimization)) {
457 VLOG(2) << "dependency_optimization is not implemented in TFG yet";
458 } else {
459 optimizers->push_back(
460 MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
461 }
462 }
463 if (MemoryOptimizerEnabled(cfg_.memory_optimization(),
464 xla_auto_clustering_on_) &&
465 PLUGIN_NOT_OFF(memory_optimization)) {
466 if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
467 optimizers->push_back(
468 // Use the default target node name prefix "gradients/"
469 MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
470 } else {
471 optimizers->push_back(MakeUnique<MemoryOptimizer>(
472 cfg_.memory_optimization(),
473 cfg_.memory_optimizer_target_node_name_scope()));
474 }
475 }
476 if (cfg_.auto_parallel().enable() && PLUGIN_IS_ON(auto_parallel)) {
477 optimizers->push_back(
478 MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
479 }
480
481 #ifndef ENABLE_MKL
482 if (BOTH_ARE_ON(scoped_allocator_optimization)) {
483 optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
484 cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
485 } else if (BOTH_ARE_EXPERIMENTAL_MLIR(scoped_allocator_optimization) ||
486 BOTH_ARE_EXPERIMENTAL_BOTH(scoped_allocator_optimization)) {
487 VLOG(2) << "scoped_allocator_optimization is not implemented in TFG yet";
488 }
489 #endif
490
491 #undef USER_IS_ON
492 #undef USER_IS_EXPERIMENTAL_MLIR
493 #undef USER_IS_EXPERIMENTAL_BOTH
494 #undef USER_NOT_OFF
495 #undef PLUGIN_IS_ON
496 #undef PLUGIN_NOT_OFF
497 #undef BOTH_ARE_ON
498 #undef BOTH_NOT_OFF
499 return InitializeCustomGraphOptimizers(device_types, std::set<string>(),
500 optimizers);
501 }
502
InitializeOptimizersByName(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const503 Status MetaOptimizer::InitializeOptimizersByName(
504 const std::set<string>& device_types,
505 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
506 std::set<string> initialized_custom_optimizers;
507 for (const string& optimizer_name : cfg_.optimizers()) {
508 auto optimizer = MakeNewOptimizer(optimizer_name, device_types);
509 if (optimizer) {
510 VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
511 optimizers->push_back(std::move(optimizer));
512 continue;
513 }
514
515 auto custom_optimizer =
516 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
517
518 if (custom_optimizer) {
519 VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
520 TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
521 config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
522 optimizers->push_back(std::move(custom_optimizer));
523 initialized_custom_optimizers.insert(optimizer_name);
524 } else {
525 VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
526 }
527 }
528 return InitializeCustomGraphOptimizers(
529 device_types, initialized_custom_optimizers, optimizers);
530 }
531
InitializeCustomGraphOptimizers(const std::set<string> & device_types,const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const532 Status MetaOptimizer::InitializeCustomGraphOptimizers(
533 const std::set<string>& device_types,
534 const std::set<string>& pre_initialized_optimizers,
535 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
536 for (const auto& optimizer_config : cfg_.custom_optimizers()) {
537 if (pre_initialized_optimizers.find(optimizer_config.name()) !=
538 pre_initialized_optimizers.end()) {
539 continue;
540 }
541
542 auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
543 optimizer_config.name());
544
545 if (custom_optimizer) {
546 VLOG(2) << "Registered custom configurable graph optimizer: "
547 << optimizer_config.name();
548 TF_RETURN_IF_ERROR(
549 custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
550 optimizers->push_back(std::move(custom_optimizer));
551 } else {
552 // If there are no custom optimizers with given name, try to initialize a
553 // default optimizer. This way, custom configurable optimizers can be
554 // mixed with default optimizers in any order.
555 auto optimizer = MakeNewOptimizer(optimizer_config.name(), device_types);
556 if (optimizer) {
557 VLOG(2) << "Registered default graph optimizer: "
558 << optimizer_config.name();
559 optimizers->push_back(std::move(optimizer));
560 continue;
561 }
562 VLOG(2) << "Can't register an optimizer by name: "
563 << optimizer_config.name();
564 }
565 }
566 return InitializePluginGraphOptimizers(device_types, optimizers);
567 }
568
InitializePluginGraphOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const569 Status MetaOptimizer::InitializePluginGraphOptimizers(
570 const std::set<string>& device_types,
571 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
572 if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return OkStatus();
573 auto plugin_optimizers =
574 PluginGraphOptimizerRegistry::CreateOptimizers(device_types);
575 for (auto& plugin_optimizer : plugin_optimizers) {
576 optimizers->push_back(std::move(plugin_optimizer));
577 }
578 return OkStatus();
579 }
580
581 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const582 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
583 for (const auto& config : cfg_.custom_optimizers()) {
584 if (config.name() == name) {
585 return &config;
586 }
587 }
588 return nullptr;
589 }
590
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const591 void MetaOptimizer::InitializeVerifiers(
592 std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
593 std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
594 const {
595 if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
596 VerifierConfig::ON) {
597 inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
598 }
599 if (cfg_.post_optimization_verifier_config().structure_verifier() ==
600 VerifierConfig::ON) {
601 post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
602 }
603 }
604
PrintUserAndPluginConfigs(const std::set<string> & device_types) const605 void MetaOptimizer::PrintUserAndPluginConfigs(
606 const std::set<string>& device_types) const {
607 if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return;
608 ConfigList plugin_cfg = PluginGraphOptimizerRegistry::GetPluginConfigs(
609 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
610 PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(device_types);
611
612 ConfigList user_cfg;
613 // Print user's and plugin's configs.
614 if (cfg_.optimizers().empty()) {
615 if (cfg_.disable_meta_optimizer()) {
616 return;
617 }
618 user_cfg.disable_model_pruning = cfg_.disable_model_pruning();
619 #define PRINT_CFG(CFG) user_cfg.toggle_config[#CFG] = cfg_.CFG();
620 PRINT_CFG(implementation_selector)
621 PRINT_CFG(function_optimization)
622 PRINT_CFG(common_subgraph_elimination)
623 PRINT_CFG(arithmetic_optimization)
624 PRINT_CFG(debug_stripper)
625 PRINT_CFG(constant_folding)
626 PRINT_CFG(shape_optimization)
627 PRINT_CFG(pin_to_host_optimization)
628 PRINT_CFG(layout_optimizer)
629 PRINT_CFG(remapping)
630 PRINT_CFG(loop_optimization)
631 PRINT_CFG(dependency_optimization)
632 PRINT_CFG(scoped_allocator_optimization)
633 #undef PRINT_CFG
634 user_cfg.toggle_config["auto_mixed_precision"] =
635 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())
636 ? RewriterConfig::ON
637 : RewriterConfig::OFF;
638 user_cfg.toggle_config["auto_mixed_precision_onednn_bfloat16"] =
639 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_onednn_bfloat16())
640 ? RewriterConfig::ON
641 : RewriterConfig::OFF;
642 user_cfg.toggle_config["auto_mixed_precision_mkl"] =
643 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())
644 ? RewriterConfig::ON
645 : RewriterConfig::OFF;
646 user_cfg.toggle_config["auto_mixed_precision_cpu"] =
647 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_cpu())
648 ? RewriterConfig::ON
649 : RewriterConfig::OFF;
650 user_cfg.toggle_config["memory_optimization"] =
651 MemoryOptimizerEnabled(cfg_.memory_optimization(),
652 config_proto_.graph_options()
653 .optimizer_options()
654 .global_jit_level())
655 ? RewriterConfig::ON
656 : RewriterConfig::OFF;
657 user_cfg.toggle_config["auto_parallel"] = cfg_.auto_parallel().enable()
658 ? RewriterConfig::ON
659 : RewriterConfig::OFF;
660 } else {
661 for (const string& optimizer_name : cfg_.optimizers()) {
662 if (optimizer_name == "pruning") user_cfg.disable_model_pruning = true;
663
664 #define PRINT_CFG(NAME, CONFIG) \
665 if (optimizer_name == NAME) \
666 user_cfg.toggle_config[CONFIG] = RewriterConfig::ON;
667
668 PRINT_CFG("implementation_selector", "implementation_selector")
669 PRINT_CFG("function", "function_optimization")
670 PRINT_CFG("common_subgraph_elimination", "common_subgraph_elimination")
671 PRINT_CFG("arithmetic", "arithmetic_optimization")
672 PRINT_CFG("debug_stripper", "debug_stripper")
673 PRINT_CFG("constfold", "constant_folding")
674 PRINT_CFG("shape", "shape_optimization")
675 PRINT_CFG("auto_mixed_precision", "auto_mixed_precision")
676 PRINT_CFG("auto_mixed_precision_onednn_bfloat16",
677 "auto_mixed_precision_onednn_bfloat16")
678 PRINT_CFG("auto_mixed_precision_mkl", "auto_mixed_precision_mkl")
679 PRINT_CFG("auto_mixed_precision_cpu", "auto_mixed_precision_cpu")
680 PRINT_CFG("pin_to_host", "pin_to_host_optimization")
681 PRINT_CFG("layout", "layout_optimizer")
682 PRINT_CFG("remap", "remapping")
683 PRINT_CFG("loop", "loop_optimization")
684 PRINT_CFG("dependency", "dependency_optimization")
685 PRINT_CFG("memory", "memory_optimization")
686 PRINT_CFG("autoparallel", "auto_parallel")
687 PRINT_CFG("scoped_allocator", "scoped_allocator_optimization")
688 #undef PRINT_CFG
689 }
690 }
691
692 // Print logs only when plugin config has conflict with user config.
693 if (!PluginGraphOptimizerRegistry::IsConfigsConflict(user_cfg, plugin_cfg))
694 return;
695
696 ConfigList final_cfg = user_cfg;
697 // If plugin turns on `disable_model_pruning`, then `disable_model_pruning`
698 // should be true;
699 if (plugin_cfg.disable_model_pruning == true)
700 final_cfg.disable_model_pruning = true;
701 // If plugin turns off a certain optimizer, then the optimizer should be
702 // turned off;
703 for (auto& pair : plugin_cfg.toggle_config) {
704 if (plugin_cfg.toggle_config[pair.first] == RewriterConfig::OFF)
705 final_cfg.toggle_config[pair.first] = RewriterConfig::OFF;
706 }
707
708 string logs =
709 "\nConfig of optimizers\t\tUser's config\tPlugin's config\tFinal "
710 "config(User & Plugin)\n";
711 strings::StrAppend(&logs, "disable_model_pruning\t\t",
712 user_cfg.disable_model_pruning, "\t\t",
713 plugin_cfg.disable_model_pruning, "\t\t",
714 final_cfg.disable_model_pruning, "\n");
715 for (auto& pair : user_cfg.toggle_config) {
716 if (pair.first == "debug_stripper" ||
717 pair.first == "auto_mixed_precision" ||
718 pair.first == "auto_mixed_precision_onednn_bfloat16" ||
719 pair.first == "auto_mixed_precision_mkl" ||
720 pair.first == "auto_mixed_precision_cpu" ||
721 pair.first == "pin_to_host_optimization" ||
722 pair.first == "scoped_allocator_optimization") {
723 // These optimizers are turned off by default.
724 // TODO(penporn): Remove the hard-coded length and change it to max length
725 // of all option strings.
726 strings::StrAppend(
727 &logs, pair.first, string(40 - pair.first.size(), ' '),
728 (pair.second == RewriterConfig::ON), "\t\t",
729 (plugin_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\t\t",
730 (final_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\n");
731 } else {
732 // These optimizers are turned on by default.
733 // TODO(penporn): Remove the hard-coded length and change it to max length
734 // of all option strings.
735 strings::StrAppend(
736 &logs, pair.first, string(40 - pair.first.size(), ' '),
737 (pair.second != RewriterConfig::OFF), "\t\t",
738 (plugin_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\t\t",
739 (final_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\n");
740 }
741 }
742 LOG(WARNING) << "User's config has been changed based on plugin's config.";
743 LOG(WARNING) << logs;
744 }
745
OptimizeGraph(const std::vector<std::unique_ptr<GraphOptimizer>> & optimizers,Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)746 Status MetaOptimizer::OptimizeGraph(
747 const std::vector<std::unique_ptr<GraphOptimizer>>& optimizers,
748 Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph) {
749 int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
750 : cfg_.min_graph_nodes();
751 if (item.graph.node_size() < min_graph_nodes) {
752 VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
753 << " nodes.";
754 *optimized_graph = item.graph;
755 return OkStatus();
756 }
757
758 tensorflow::metrics::ScopedCounter<2> timings(
759 tensorflow::metrics::GetGraphOptimizationCounter(),
760 {kGrapplerCategory, "OptimizeMainGraph"});
761
762 // Initialize the configured verifiers.
763 std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
764 std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
765 InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
766 if (inter_optimizer_verifiers.empty()) {
767 VLOG(2) << "No inter optimizer verifiers have been configured";
768 } else {
769 VLOG(2) << inter_optimizer_verifiers.size()
770 << " inter optimizer verifiers have been configured";
771 }
772 if (post_optimization_verifiers.empty()) {
773 VLOG(2) << "No post optimization verifiers have been configured";
774 } else {
775 VLOG(2) << post_optimization_verifiers.size()
776 << " post optimization verifiers have been configured";
777 }
778
779 VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
780 << " num_optimizers=" << optimizers.size()
781 << ", num nodes = " << item.graph.node_size();
782
783 if (optimizers.empty()) {
784 VLOG(3) << "Skipping graph optimization, no optimizers registered";
785 *optimized_graph = item.graph;
786 return OkStatus();
787 }
788
789 // Invariant: optimized_graph contains the most recently optimized version of
790 // the graph.
791 auto original_producer = item.graph.versions().producer();
792 *optimized_graph = std::move(item.graph);
793
794 GraphOptimizationResult optimization_result(item.id);
795 #ifndef ENABLE_MKL
796 GraphOptimizer* sa_optimizer = nullptr;
797 #endif
798
799 // Constants in the graph are normally compressed after model_pruner.
800 // Do it here if model pruner is disabled.
801 if (cfg_.disable_model_pruning()) {
802 CompressConstants(optimized_graph);
803 }
804
805 for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
806 // Don't bother optimizing further if the graph is already tiny.
807 if (optimized_graph->node_size() < min_graph_nodes) {
808 VLOG(3) << "Stopping after iteration " << iteration
809 << ", graph is tiny (#nodes = " << optimized_graph->node_size()
810 << " < " << min_graph_nodes << ")";
811 break;
812 }
813
814 VLOG(4) << "Starting optimization iteration " << iteration;
815 if (VLOG_IS_ON(4)) {
816 DumpGraphDefToFile(
817 strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
818 reinterpret_cast<uintptr_t>(optimized_graph)),
819 *optimized_graph);
820 }
821
822 for (const auto& optimizer : optimizers) {
823 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
824 // Some optimizers can run only once.
825 if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
826 #ifndef ENABLE_MKL
827 // Some must run only on the last iteration.
828 if (optimizer->name() == "scoped_allocator_optimizer") {
829 if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
830 continue;
831 }
832 #endif
833
834 TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
835 optimized_graph, &optimization_result));
836
837 if (iteration == 0 && optimizer->name() == "model_pruner") {
838 CompressConstants(optimized_graph);
839 }
840
841 if (VLOG_IS_ON(4)) {
842 DumpGraphDefToFile(
843 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
844 optimizer->name(), "_",
845 reinterpret_cast<uintptr_t>(optimized_graph)),
846 *optimized_graph);
847 }
848 for (const auto& verifier : inter_optimizer_verifiers) {
849 // TODO(ashwinm): Need to enforce verification_deadline.
850 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
851 }
852 }
853 if (VLOG_IS_ON(4)) {
854 DumpGraphDefToFile(
855 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
856 reinterpret_cast<uintptr_t>(optimized_graph)),
857 *optimized_graph);
858 }
859 // TODO(ashwinm): Need to enforce verification_deadline.
860 for (const auto& verifier : post_optimization_verifiers) {
861 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
862 }
863 }
864 #ifndef ENABLE_MKL
865 // ScopedAllocatorOptimizer must run last.
866 if (sa_optimizer != nullptr) {
867 TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
868 optimized_graph, &optimization_result));
869 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
870 }
871 #endif
872
873 bool is_optimized = std::find_if(optimization_result.results.begin(),
874 optimization_result.results.end(),
875 [](const OptimizerResult& result) {
876 return result.status.ok();
877 }) != optimization_result.results.end();
878
879 // Record graph optimization result.
880 optimization_results_.push_back(optimization_result);
881
882 if (is_optimized) {
883 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
884 ReassignColocation(optimized_graph);
885 // Make sure that the optimizers preserved the graph version.
886 DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
887 }
888
889 return OkStatus();
890 }
891
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)892 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
893 GraphDef* optimized_graph) {
894 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
895 std::set<std::string> device_types;
896 TF_RETURN_IF_ERROR(GetGraphDevice(item.graph, &device_types));
897 if (cfg_.optimizers().empty()) {
898 TF_RETURN_IF_ERROR(InitializeOptimizers(device_types, &optimizers));
899 } else {
900 TF_RETURN_IF_ERROR(InitializeOptimizersByName(device_types, &optimizers));
901 }
902 PrintUserAndPluginConfigs(device_types);
903
904 return OptimizeGraph(std::move(optimizers), cluster, std::move(item),
905 optimized_graph);
906 }
907
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)908 Status MetaOptimizer::RunOptimizer(
909 GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
910 GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
911 // If optimizer doesn't need a function library, we will replace it with a
912 // stub before running optimization, and will put it back at the end.
913 std::unique_ptr<FunctionDefLibrary> optimized_graph_function_library;
914 const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
915
916 // Replace function library in optimized graph with a stub.
917 if (!is_function_library_aware) {
918 VLOG(3) << "Replace function library with a stub for " << optimizer->name();
919 optimized_graph_function_library =
920 absl::WrapUnique(optimized_graph->release_library());
921 *optimized_graph->mutable_library() =
922 GetFunctionDefLibraryStub(*optimized_graph_function_library);
923 }
924
925 // This swaps the current optimized_graph into optimized item and
926 // resets optimized_graph to an empty graph.
927 optimized_item->graph = std::move(*optimized_graph);
928 *optimized_graph = GraphDef();
929 optimizer->set_deadline_usec(this->deadline_usec());
930 tensorflow::metrics::ScopedCounter<2> timings(
931 tensorflow::metrics::GetGraphOptimizationCounter(),
932 {kGrapplerCategory, optimizer->name()});
933 Status status =
934 optimizer->Optimize(cluster, *optimized_item, optimized_graph);
935 auto duration_ms = timings.DurationMicroSec().value() / 1000.0f;
936 timings.ReportAndStop();
937
938 string message;
939 if (!status.ok()) {
940 *optimized_graph = std::move(optimized_item->graph);
941 if (errors::IsAborted(status)) {
942 // By convention we (ab-)use the Aborted error code to signal that the
943 // optimizer returned without performing any changes to the graph.
944 message = strings::StrCat(optimizer->name(),
945 " did nothing. time = ", duration_ms, "ms.");
946 // Swallow the non-critical error.
947 status = OkStatus();
948 } else if (errors::IsDeadlineExceeded(status)) {
949 message =
950 strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
951 LOG(WARNING) << optimizer->name() << " failed: " << message;
952 } else {
953 message = status.ToString();
954 LOG(ERROR) << optimizer->name() << " failed: " << message;
955 }
956 } else {
957 message = strings::StrCat(
958 PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
959 ", time = ", duration_ms, "ms.");
960 VLOG(1) << optimizer->name() << ": " << message;
961 }
962
963 // Swap function library back into the main graph.
964 if (!is_function_library_aware) {
965 optimized_graph->set_allocated_library(
966 optimized_graph_function_library.release());
967 }
968
969 OptimizerResult optimizer_result{optimizer->name(), message, status};
970 optimization_result->results.push_back(optimizer_result);
971
972 if (!status.ok()) {
973 if (cfg_.fail_on_optimizer_errors()) return status;
974
975 // Non-aborted failures in the TFG optimizer are always fatal.
976 if (absl::StartsWith(optimizer->name(), "tfg_optimizer")) return status;
977 }
978
979 return OkStatus();
980 }
981
982 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)983 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
984 FunctionDefLibrary& fdef_lib) {
985 // Collect functions that need the attribute in this set.
986 absl::flat_hash_set<std::string> tf_data_functions;
987 std::function<void(const std::string&)> collect_tf_data_functions_dfs =
988 [&](const std::string& func_name) -> void {
989 const FunctionDef* func_def = flib.Find(func_name);
990 // Skip functions that are not reachable from the optimized graph.
991 if (func_def == nullptr) return;
992
993 // Return if we already found and added this function.
994 if (tf_data_functions.contains(func_name)) return;
995
996 // We only get here if the function is (directly or indirectly) called from
997 // a tf.data function, so add it to the set.
998 tf_data_functions.insert(func_name);
999
1000 // Proceed with DFS for functions called from current function.
1001 for (const NodeDef& node : func_def->node_def()) {
1002 if (flib.Contains(node.op())) {
1003 // This is a function call node.
1004 collect_tf_data_functions_dfs(node.op());
1005 }
1006 // Check if there are functions in attributes.
1007 for (const auto& attr : node.attr()) {
1008 const AttrValue& attr_value = attr.second;
1009 if (attr_value.has_func()) {
1010 collect_tf_data_functions_dfs(attr_value.func().name());
1011 }
1012 if (attr_value.has_list()) {
1013 for (const auto& func : attr_value.list().func()) {
1014 collect_tf_data_functions_dfs(func.name());
1015 }
1016 }
1017 }
1018 }
1019 };
1020 // Perform DFS for all tf.data functions in `fdef_lib`.
1021 for (const auto& func_def : fdef_lib.function()) {
1022 const std::string& func_name = func_def.signature().name();
1023 if (data::IsTFDataFunction(func_def))
1024 collect_tf_data_functions_dfs(func_name);
1025 }
1026 // Set attribute for tf.data functions. We cannot do this in the DFS directly
1027 // because `FunctionLibraryDefinition` does not seem to provide mutable access
1028 // to a `FunctionDef`.
1029 for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
1030 const std::string& func_name = func_def.signature().name();
1031 if (tf_data_functions.contains(func_name) &&
1032 !data::IsTFDataFunction(func_def)) {
1033 VLOG(2) << "Marking " << func_name << " as tf.data function";
1034 (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
1035 }
1036 }
1037 }
1038
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)1039 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
1040 GraphDef* optimized_graph) {
1041 tensorflow::metrics::ScopedCounter<2> timings(
1042 tensorflow::metrics::GetGraphOptimizationCounter(),
1043 {kGrapplerCategory, "*"});
1044
1045 VLOG(1) << "Starting optimization for grappler item: " << item.id;
1046 optimization_results_.clear();
1047
1048 // Constructs a FunctionLibraryDefinition with functions that are reachable
1049 // from the nodes of the graph.
1050 const auto minimized_flib =
1051 [](const GraphDef& graph) -> FunctionLibraryDefinition {
1052 return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
1053 .ReachableDefinitions(graph);
1054 };
1055
1056 // 0. Original graph might contain a huge function library, that is mostly
1057 // unused. This library copied over by each individual Grappler optimizer,
1058 // which adds a huge overhead. Before starting optimization passes we just
1059 // remove all the unreachable functions.
1060 // TODO(ezhulenev): Construct reachable function library definition directly
1061 // from the proto without constructing temporary FunctionLibraryDefinition.
1062 int old_library_size = item.graph.library().function_size();
1063 *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
1064 int new_library_size = item.graph.library().function_size();
1065
1066 VLOG(1) << absl::Substitute(
1067 "Deleted $0 unreachable functions from the graph (library size = $1)",
1068 old_library_size - new_library_size, new_library_size);
1069
1070 // Save a few small fields from item before we move it.
1071 bool optimize_function_library =
1072 item.optimization_options().optimize_function_library;
1073 const auto producer = item.graph.versions().producer();
1074
1075 // 1. Optimize main graph
1076 TF_RETURN_IF_ERROR(
1077 OptimizeGraph(cluster, GrapplerItem(item), optimized_graph));
1078 VLOG(1) << "Optimized main graph.";
1079 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1080
1081 // 2. Optimize functions reachable from the optimized graph.
1082 FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
1083 using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
1084
1085 // Find functions for which we might need to compute a gradient at runtime.
1086 absl::flat_hash_set<string> differentiable_functions;
1087
1088 const auto find_differentiable_functions =
1089 [&](const NodeDefs& nodes) -> void {
1090 for (const NodeDef& node : nodes) {
1091 if (IsSymbolicGradient(node)) {
1092 const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
1093 if (f_attr) differentiable_functions.insert(f_attr->func().name());
1094 }
1095 }
1096 };
1097
1098 // SymbolicGradient nodes inside the main graph.
1099 find_differentiable_functions(optimized_graph->node());
1100 // SymbolicGradient nodes inside the function library.
1101 for (const FunctionDef& function : optimized_graph->library().function()) {
1102 find_differentiable_functions(function.node_def());
1103 }
1104
1105 // Find functions that will be compiled by XLA later
1106 // We do it by looking for XlaLaunch ops that call functions,
1107 // then depth first search down those functions to find transitive functions.
1108 // Grappler rewrites can potentially add nodes that are
1109 // not supported by XLA, so we choose to skip such functions when we optimize
1110 // the function library.
1111 absl::flat_hash_set<string> xla_compiled_functions;
1112 std::function<void(const string&)> find_all_functions;
1113 find_all_functions = [&](const string& func) -> void {
1114 // Ignore call cycles in the graph
1115 if (xla_compiled_functions.contains(func)) return;
1116 // Find func in the flib
1117 const FunctionDef* func_def = flib.Find(func);
1118 CHECK(func_def) << "not found: " << func;
1119 // Mark function to be ignored by grappler
1120 xla_compiled_functions.insert(func);
1121 // Depth first search through the func for transitively called funcs
1122 for (const NodeDef& node : func_def->node_def()) {
1123 for (const auto& attr : node.attr()) {
1124 const AttrValue& attr_value = attr.second;
1125 if (attr_value.has_func()) {
1126 find_all_functions(attr_value.func().name());
1127 }
1128 }
1129 }
1130 };
1131
1132 auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
1133 NameAttrList function;
1134 for (const NodeDef& node : nodes) {
1135 // Look only for XlaLaunch nodes that call a function
1136 if (!IsXlaLaunch(node)) continue;
1137 if (!GetNodeAttr(node, "function", &function).ok()) continue;
1138 // Find all transitively called functions
1139 find_all_functions(function.name());
1140 }
1141 };
1142
1143 // XlaLaunch ops inside the main graph ...
1144 find_xla_compiled_functions(optimized_graph->node());
1145 // ... and inside the function library.
1146 for (const FunctionDef& function : optimized_graph->library().function()) {
1147 find_xla_compiled_functions(function.node_def());
1148 }
1149 // Propagate `_tf_data_function` attributes from functions to their callees.
1150 PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
1151
1152 // True if this is a TPU graph using the old bridge.
1153 bool is_tpu_graph = IsLegacyTPUBridgeGraphDef(*optimized_graph);
1154
1155 // Optimize each function only once.
1156 absl::flat_hash_set<string> optimized_funcs;
1157 while (optimize_function_library) {
1158 optimize_function_library = false;
1159
1160 int function_idx = 0;
1161 for (const FunctionDef& func : optimized_graph->library().function()) {
1162 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
1163
1164 const string& func_name = func.signature().name();
1165
1166 // Skip functions that are not reachable from the optimized graph.
1167 if (!flib.Contains(func_name)) continue;
1168 // Skip already optimized functions.
1169 if (optimized_funcs.contains(func_name)) continue;
1170 // Skip functions that will be compiled by XLA.
1171 if (xla_compiled_functions.contains(func_name)) continue;
1172
1173 // Skip parametrized functions (function type or body is defined only at
1174 // function call time by caller node attributes).
1175 // They should be specialized to their instantiation type parameters by
1176 // the function optimizer, before we can optimize function body.
1177 if (IsParametrized(func)) continue;
1178
1179 // Skip tf.data functions as they are optimized by tf.data meta optimizer
1180 // and in function instantiation.
1181 if (data::IsTFDataFunction(func)) continue;
1182
1183 VLOG(3) << "Optimize function: function=" << func_name << " ["
1184 << function_idx++ << " of "
1185 << optimized_graph->library().function_size() << "]";
1186
1187 // Function optimization might specialize nested function calls, so we
1188 // have to reset the flag and do at least one more pass over the library.
1189 optimize_function_library = true;
1190 optimized_funcs.insert(func_name);
1191
1192 // Make a GrapplerItem from a FunctionDef.
1193 GrapplerFunctionItem func_item;
1194 TF_RETURN_IF_ERROR(
1195 MakeGrapplerFunctionItem(func, flib, producer, &func_item));
1196
1197 // If we need to compute the gradient of optimized function at runtime, we
1198 // can't perform non-differentiable rewrites.
1199 func_item.optimization_options().allow_non_differentiable_rewrites =
1200 !differentiable_functions.contains(func_name);
1201
1202 // Device set available to the function is defined only by the runtime,
1203 // when we instantiate and execute the function. We can't use all devices
1204 // available to the main graph, because after partitioning the function
1205 // call node might execute on a remote worker.
1206 if (!func_item.devices().empty()) {
1207 return errors::Internal("GrapplerFunctionItem devices must be empty.");
1208 }
1209
1210 // We are not allowed to prune certain types of ops from the graph
1211 // instantiated by the function definition, because we must guarantee
1212 // function execution semantics wrt side effects (see
1213 // function_optimizer.cc).
1214 func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
1215 false;
1216
1217 // Optimize function body graph.
1218 GraphDef optimized_func_graph;
1219 if (is_tpu_graph) {
1220 // Skip optimizing functions if this is a TPU graph. Currently, Grappler
1221 // passes do not handle TPU functions correctly in a variety of ways
1222 // (Note that due to the pre-placement TPU graph rewriting passes, the
1223 // TPU-related ops are encapsulated away into functions). For example,
1224 // TPU graphs contain TPUReplicateMetadata node that carries relevant
1225 // TPU metadata and Grappler passes could prune that away. Grappler
1226 // passes could also cause issues around shape inference. Since the
1227 // desired and existing behavior is to not optimize TPU functions with
1228 // Grappler, this check preserves that. The only exception is
1229 // implementation selector what is required to swap in some TPU specific
1230 // lowering code and is verified the work correctly on TPUs.
1231 ImplementationSelector implementation_selector;
1232
1233 // Implementation selector needs to have access to valid function
1234 // signature and attributes, and it doesn't need actual function body.
1235 std::unique_ptr<FunctionDefLibrary> func_item_function_library(
1236 func_item.graph.release_library());
1237 *func_item.graph.mutable_library() =
1238 GetFunctionDefLibraryStub(*func_item_function_library);
1239
1240 TF_RETURN_IF_ERROR(implementation_selector.Optimize(
1241 cluster, func_item, &optimized_func_graph));
1242 } else {
1243 GrapplerFunctionItem func_item_copy = func_item;
1244 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
1245 &optimized_func_graph));
1246 }
1247
1248 // Function body optimization might have created new specialized
1249 // functions for each instantiation context. Add them to the library.
1250 for (const FunctionDef& func_def :
1251 optimized_func_graph.library().function()) {
1252 if (flib.Find(func_def.signature().name()) == nullptr) {
1253 TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
1254 }
1255 }
1256
1257 // Convert optimized graph back to FunctionDef.
1258 FunctionDef optimized_func;
1259 func_item.SwapFunctionBody(std::move(optimized_func_graph));
1260 TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
1261
1262 // Replace optimized function with a new FunctionDef.
1263 TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
1264 }
1265
1266 // If optimized at least one function, update the graph library.
1267 if (optimize_function_library) {
1268 *optimized_graph->mutable_library() = flib.ToProto();
1269 }
1270 }
1271
1272 // Run module-level TFG optimizations at the end of the meta-optimizer.
1273 // TODO(jeffniu): None of the TFG optimizations are meant to create new
1274 // opportunities for other optimizers; they could, but it's unclear whether
1275 // re-running all the other optimizers is worthwhile.
1276 #ifndef __Fuchsia__
1277 {
1278 // Create a Grappler optimization pipeline with only the TFG optimizer.
1279 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
1280 optimizers.push_back(std::make_unique<mlir::tfg::TFGGrapplerOptimizer>(
1281 // For module-level optimizations, use multithreading to process
1282 // functions in parallel.
1283 [&](mlir::PassManager& manager) {
1284 mlir::tfg::DefaultModuleGrapplerPipeline(manager, cfg_);
1285 },
1286 /*num_tfg_threads=*/4));
1287 // Wrap the optimized GraphDef in a new GrapplerItem with copied
1288 // configuration options from the provided item.
1289 GrapplerItem tfg_item = item.WithGraph(std::move(*optimized_graph));
1290 // Invoke the optimizers.
1291 *optimized_graph = GraphDef();
1292 TF_RETURN_IF_ERROR(OptimizeGraph(optimizers, cluster, std::move(tfg_item),
1293 optimized_graph));
1294 }
1295 #endif
1296
1297 VLOG(1) << "Optimized " << optimized_funcs.size()
1298 << " functions: " << absl::StrJoin(optimized_funcs, ", ");
1299 VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
1300 if (VLOG_IS_ON(1)) {
1301 DumpGraphDefToFile(
1302 strings::StrCat("after_MetaOptimizer_",
1303 reinterpret_cast<uintptr_t>(optimized_graph)),
1304 *optimized_graph);
1305 }
1306
1307 return OkStatus();
1308 }
1309
GetResultString() const1310 string MetaOptimizer::GetResultString() const {
1311 std::string result_string;
1312 for (const GraphOptimizationResult& graph_result : optimization_results_) {
1313 absl::StrAppend(&result_string,
1314 "Optimization results for grappler item: ", graph_result.id,
1315 "\n");
1316 for (const OptimizerResult& result : graph_result.results) {
1317 absl::StrAppend(&result_string, " ", result.optimizer_name, ": ",
1318 result.message, "\n");
1319 }
1320 }
1321 return result_string;
1322 }
1323
PrintResult()1324 void MetaOptimizer::PrintResult() { VLOG(1) << GetResultString(); }
1325
MetaOptimizerEnabled(const ConfigProto & cfg)1326 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
1327 const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
1328 if (rewrite_cfg.disable_meta_optimizer()) {
1329 return false;
1330 }
1331 return !rewrite_cfg.disable_model_pruning() ||
1332 rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
1333 rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
1334 rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
1335 rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
1336 rewrite_cfg.remapping() != RewriterConfig::OFF ||
1337 rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
1338 rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
1339 rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
1340 rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
1341 rewrite_cfg.auto_parallel().enable() ||
1342 rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
1343 rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
1344 #ifndef ENABLE_MKL
1345 rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
1346 #endif
1347 rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
1348 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
1349 AutoMixedPrecisionEnabled(
1350 rewrite_cfg.auto_mixed_precision_onednn_bfloat16()) ||
1351 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
1352 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_cpu()) ||
1353 !rewrite_cfg.optimizers().empty() ||
1354 !rewrite_cfg.custom_optimizers().empty();
1355 }
1356
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)1357 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
1358 DeviceBase* cpu_device, Cluster* cluster,
1359 GraphDef* optimized_graph) {
1360 MetaOptimizer optimizer(cpu_device, cfg);
1361 optimizer.set_deadline_usec(
1362 DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
1363 return optimizer.OptimizeConsumeItem(cluster, std::move(item),
1364 optimized_graph);
1365 }
1366
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)1367 Status OptimizeGraph(
1368 std::vector<string> ret_node_names, std::vector<string> keep_node_names,
1369 FunctionLibraryDefinition* flib, const DeviceSet& device_set,
1370 Device* cpu_device, const ConfigProto& config_proto,
1371 const string& grappler_item_id,
1372 const GrapplerItem::OptimizationOptions& optimization_options,
1373 std::unique_ptr<tensorflow::Graph>* g) {
1374 if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
1375 return OkStatus();
1376 }
1377
1378 tensorflow::grappler::GrapplerItem item;
1379 item.id = grappler_item_id;
1380 item.optimization_options() = optimization_options;
1381
1382 // Add all available devices so that inlined function can be placed.
1383 for (const Device* d : device_set.devices()) {
1384 Status added_device = item.AddDevice(d->name());
1385 if (!added_device.ok()) VLOG(3) << added_device.error_message();
1386 }
1387 VLOG(3) << "Grappler available devices: "
1388 << absl::StrJoin(item.devices(), ", ");
1389
1390 // Add fetches so that the graph can be pruned.
1391 item.fetch.swap(ret_node_names);
1392
1393 // Add noes that can't be removed from the graph.
1394 item.keep_ops = std::move(keep_node_names);
1395
1396 (*g)->ToGraphDef(&item.graph);
1397
1398 if (flib) {
1399 *item.graph.mutable_library() = flib->ToProto();
1400 }
1401
1402 tensorflow::GraphDef out_graph;
1403 tensorflow::grappler::VirtualCluster cluster(&device_set);
1404 // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1405 // proto (which also contain the OptimizerOptions).
1406 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1407 std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1408
1409 std::unique_ptr<tensorflow::Graph> optimized_graph(
1410 new tensorflow::Graph(OpRegistry::Global()));
1411
1412 // Copy optimized functions back to the overlay lib.
1413 if (flib) {
1414 for (const FunctionDef& fdef : out_graph.library().function()) {
1415 const string& func_name = fdef.signature().name();
1416 if (flib->Contains(func_name)) {
1417 StackTracesMap stack_traces = flib->GetStackTraces(func_name);
1418 TF_RETURN_IF_ERROR(
1419 flib->ReplaceFunction(func_name, fdef, stack_traces));
1420 } else {
1421 TF_RETURN_IF_ERROR(
1422 flib->AddFunctionDef(fdef, flib->GetStackTraces(func_name)));
1423 }
1424 }
1425 }
1426
1427 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1428 GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1429
1430 // The graph conversion sets the requested device names but not the
1431 // assigned device names. However, since at this point the graph is
1432 // placed TF expects an assigned device name for every node. Therefore
1433 // we copy the requested device into the assigned device field.
1434 for (Node* node : optimized_graph->nodes()) {
1435 if (node->IsOp() && node->assigned_device_name().empty()) {
1436 if (node->requested_device().empty()) {
1437 return errors::Internal(
1438 "Either placer did not place the node or Grappler did not "
1439 "copy the assigned device. Contact Grappler team since latter "
1440 "is more likely. Node=",
1441 node->name(),
1442 " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1443 }
1444 node->set_assigned_device_name(node->requested_device());
1445 }
1446 }
1447
1448 *g = std::move(optimized_graph);
1449 return OkStatus();
1450 }
1451
1452 } // namespace grappler
1453 } // namespace tensorflow
1454