1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
16
17 #include <string>
18 #include <unordered_map>
19
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace tensorflow {
23 namespace grappler {
24 namespace {
25
26 typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
27 RegistrationMap;
28 RegistrationMap* registered_optimizers = nullptr;
GetRegistrationMap()29 RegistrationMap* GetRegistrationMap() {
30 if (registered_optimizers == nullptr)
31 registered_optimizers = new RegistrationMap;
32 return registered_optimizers;
33 }
34
35 // This map is a global map for registered plugin optimizers. It contains the
36 // device_type as its key, and an optimizer creator as the value.
37 typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
38 PluginRegistrationMap;
GetPluginRegistrationMap()39 PluginRegistrationMap* GetPluginRegistrationMap() {
40 static PluginRegistrationMap* registered_plugin_optimizers =
41 new PluginRegistrationMap;
42 return registered_plugin_optimizers;
43 }
44
45 // This map is a global map for registered plugin configs. It contains the
46 // device_type as its key, and ConfigList as the value.
47 typedef std::unordered_map<string, ConfigList> PluginConfigMap;
GetPluginConfigMap()48 PluginConfigMap* GetPluginConfigMap() {
49 static PluginConfigMap* plugin_config_map = new PluginConfigMap;
50 return plugin_config_map;
51 }
52
53 // Returns plugin's default configuration for each Grappler optimizer (on/off).
54 // See tensorflow/core/protobuf/rewriter_config.proto for more details about
55 // each optimizer.
DefaultPluginConfigs()56 const ConfigList& DefaultPluginConfigs() {
57 static ConfigList* default_plugin_configs = new ConfigList(
58 /*disable_model_pruning=*/false,
59 {{"implementation_selector", RewriterConfig::ON},
60 {"function_optimization", RewriterConfig::ON},
61 {"common_subgraph_elimination", RewriterConfig::ON},
62 {"arithmetic_optimization", RewriterConfig::ON},
63 {"debug_stripper", RewriterConfig::ON},
64 {"constant_folding", RewriterConfig::ON},
65 {"shape_optimization", RewriterConfig::ON},
66 {"auto_mixed_precision", RewriterConfig::ON},
67 {"auto_mixed_precision_onednn_bfloat16", RewriterConfig::ON},
68 {"auto_mixed_precision_mkl", RewriterConfig::ON},
69 {"auto_mixed_precision_cpu", RewriterConfig::ON},
70 {"pin_to_host_optimization", RewriterConfig::ON},
71 {"layout_optimizer", RewriterConfig::ON},
72 {"remapping", RewriterConfig::ON},
73 {"loop_optimization", RewriterConfig::ON},
74 {"dependency_optimization", RewriterConfig::ON},
75 {"auto_parallel", RewriterConfig::ON},
76 {"memory_optimization", RewriterConfig::ON},
77 {"scoped_allocator_optimization", RewriterConfig::ON}});
78 return *default_plugin_configs;
79 }
80
81 } // namespace
82
83 std::unique_ptr<CustomGraphOptimizer>
CreateByNameOrNull(const string & name)84 CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
85 const auto it = GetRegistrationMap()->find(name);
86 if (it == GetRegistrationMap()->end()) return nullptr;
87 return std::unique_ptr<CustomGraphOptimizer>(it->second());
88 }
89
GetRegisteredOptimizers()90 std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
91 std::vector<string> optimizer_names;
92 optimizer_names.reserve(GetRegistrationMap()->size());
93 for (const auto& opt : *GetRegistrationMap())
94 optimizer_names.emplace_back(opt.first);
95 return optimizer_names;
96 }
97
RegisterOptimizerOrDie(const Creator & optimizer_creator,const string & name)98 void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
99 const Creator& optimizer_creator, const string& name) {
100 const auto it = GetRegistrationMap()->find(name);
101 if (it != GetRegistrationMap()->end()) {
102 LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
103 }
104 GetRegistrationMap()->insert({name, optimizer_creator});
105 }
106
107 std::vector<std::unique_ptr<CustomGraphOptimizer>>
CreateOptimizers(const std::set<string> & device_types)108 PluginGraphOptimizerRegistry::CreateOptimizers(
109 const std::set<string>& device_types) {
110 std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
111 for (auto it = GetPluginRegistrationMap()->begin();
112 it != GetPluginRegistrationMap()->end(); ++it) {
113 if (device_types.find(it->first) == device_types.end()) continue;
114 LOG(INFO) << "Plugin optimizer for device_type " << it->first
115 << " is enabled.";
116 optimizer_list.emplace_back(
117 std::unique_ptr<CustomGraphOptimizer>(it->second()));
118 }
119 return optimizer_list;
120 }
121
RegisterPluginOptimizerOrDie(const Creator & optimizer_creator,const std::string & device_type,ConfigList & configs)122 void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
123 const Creator& optimizer_creator, const std::string& device_type,
124 ConfigList& configs) {
125 auto ret = GetPluginConfigMap()->insert({device_type, configs});
126 if (!ret.second) {
127 LOG(FATAL) << "PluginGraphOptimizer with device_type " // Crash OK
128 << device_type << " is registered twice.";
129 }
130 GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
131 }
132
PrintPluginConfigsIfConflict(const std::set<string> & device_types)133 void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
134 const std::set<string>& device_types) {
135 bool init = false, conflict = false;
136 ConfigList plugin_configs;
137 // Check if plugin's configs have conflict.
138 for (const auto& device_type : device_types) {
139 const auto it = GetPluginConfigMap()->find(device_type);
140 if (it == GetPluginConfigMap()->end()) continue;
141 auto cur_plugin_configs = it->second;
142
143 if (!init) {
144 plugin_configs = cur_plugin_configs;
145 init = true;
146 } else {
147 if (!(plugin_configs == cur_plugin_configs)) {
148 conflict = true;
149 break;
150 }
151 }
152 }
153 if (!conflict) return;
154 LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
155 "regression may happen.";
156 for (const auto& device_type : device_types) {
157 const auto it = GetPluginConfigMap()->find(device_type);
158 if (it == GetPluginConfigMap()->end()) continue;
159 auto cur_plugin_configs = it->second;
160
161 // Print logs in following style:
162 // disable_model_pruning 0
163 // remapping 1
164 // ...
165 string logs = "";
166 strings::StrAppend(&logs, "disable_model_pruning\t\t",
167 cur_plugin_configs.disable_model_pruning, "\n");
168 for (auto const& pair : cur_plugin_configs.toggle_config) {
169 strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
170 (pair.second != RewriterConfig::OFF), "\n");
171 }
172 LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
173 << logs;
174 }
175 }
176
GetPluginConfigs(bool use_plugin_optimizers,const std::set<string> & device_types)177 ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
178 bool use_plugin_optimizers, const std::set<string>& device_types) {
179 if (!use_plugin_optimizers) return DefaultPluginConfigs();
180
181 ConfigList ret_plugin_configs = DefaultPluginConfigs();
182 for (const auto& device_type : device_types) {
183 const auto it = GetPluginConfigMap()->find(device_type);
184 if (it == GetPluginConfigMap()->end()) continue;
185 auto cur_plugin_configs = it->second;
186 // If any of the plugin turns on `disable_model_pruning`,
187 // then `disable_model_pruning` should be true;
188 if (cur_plugin_configs.disable_model_pruning == true)
189 ret_plugin_configs.disable_model_pruning = true;
190
191 // If any of the plugin turns off a certain optimizer,
192 // then the optimizer should be turned off;
193 for (auto& pair : cur_plugin_configs.toggle_config) {
194 if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
195 ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
196 }
197 }
198
199 return ret_plugin_configs;
200 }
201
IsConfigsConflict(ConfigList & user_config,ConfigList & plugin_config)202 bool PluginGraphOptimizerRegistry::IsConfigsConflict(
203 ConfigList& user_config, ConfigList& plugin_config) {
204 if (plugin_config == DefaultPluginConfigs()) return false;
205 if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
206 return true;
207 // Returns true if user_config is turned on but plugin_config is turned off.
208 for (auto& pair : user_config.toggle_config) {
209 if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
210 (plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
211 return true;
212 }
213 return false;
214 }
215
216 } // end namespace grappler
217 } // end namespace tensorflow
218