1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
16 
17 #include <string>
18 #include <unordered_map>
19 
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace {
25 
26 typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
27     RegistrationMap;
28 RegistrationMap* registered_optimizers = nullptr;
GetRegistrationMap()29 RegistrationMap* GetRegistrationMap() {
30   if (registered_optimizers == nullptr)
31     registered_optimizers = new RegistrationMap;
32   return registered_optimizers;
33 }
34 
35 // This map is a global map for registered plugin optimizers. It contains the
36 // device_type as its key, and an optimizer creator as the value.
37 typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
38     PluginRegistrationMap;
GetPluginRegistrationMap()39 PluginRegistrationMap* GetPluginRegistrationMap() {
40   static PluginRegistrationMap* registered_plugin_optimizers =
41       new PluginRegistrationMap;
42   return registered_plugin_optimizers;
43 }
44 
45 // This map is a global map for registered plugin configs. It contains the
46 // device_type as its key, and ConfigList as the value.
47 typedef std::unordered_map<string, ConfigList> PluginConfigMap;
GetPluginConfigMap()48 PluginConfigMap* GetPluginConfigMap() {
49   static PluginConfigMap* plugin_config_map = new PluginConfigMap;
50   return plugin_config_map;
51 }
52 
53 // Returns plugin's default configuration for each Grappler optimizer (on/off).
54 // See tensorflow/core/protobuf/rewriter_config.proto for more details about
55 // each optimizer.
DefaultPluginConfigs()56 const ConfigList& DefaultPluginConfigs() {
57   static ConfigList* default_plugin_configs = new ConfigList(
58       /*disable_model_pruning=*/false,
59       {{"implementation_selector", RewriterConfig::ON},
60        {"function_optimization", RewriterConfig::ON},
61        {"common_subgraph_elimination", RewriterConfig::ON},
62        {"arithmetic_optimization", RewriterConfig::ON},
63        {"debug_stripper", RewriterConfig::ON},
64        {"constant_folding", RewriterConfig::ON},
65        {"shape_optimization", RewriterConfig::ON},
66        {"auto_mixed_precision", RewriterConfig::ON},
67        {"auto_mixed_precision_onednn_bfloat16", RewriterConfig::ON},
68        {"auto_mixed_precision_mkl", RewriterConfig::ON},
69        {"auto_mixed_precision_cpu", RewriterConfig::ON},
70        {"pin_to_host_optimization", RewriterConfig::ON},
71        {"layout_optimizer", RewriterConfig::ON},
72        {"remapping", RewriterConfig::ON},
73        {"loop_optimization", RewriterConfig::ON},
74        {"dependency_optimization", RewriterConfig::ON},
75        {"auto_parallel", RewriterConfig::ON},
76        {"memory_optimization", RewriterConfig::ON},
77        {"scoped_allocator_optimization", RewriterConfig::ON}});
78   return *default_plugin_configs;
79 }
80 
81 }  // namespace
82 
83 std::unique_ptr<CustomGraphOptimizer>
CreateByNameOrNull(const string & name)84 CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
85   const auto it = GetRegistrationMap()->find(name);
86   if (it == GetRegistrationMap()->end()) return nullptr;
87   return std::unique_ptr<CustomGraphOptimizer>(it->second());
88 }
89 
GetRegisteredOptimizers()90 std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
91   std::vector<string> optimizer_names;
92   optimizer_names.reserve(GetRegistrationMap()->size());
93   for (const auto& opt : *GetRegistrationMap())
94     optimizer_names.emplace_back(opt.first);
95   return optimizer_names;
96 }
97 
RegisterOptimizerOrDie(const Creator & optimizer_creator,const string & name)98 void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
99     const Creator& optimizer_creator, const string& name) {
100   const auto it = GetRegistrationMap()->find(name);
101   if (it != GetRegistrationMap()->end()) {
102     LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
103   }
104   GetRegistrationMap()->insert({name, optimizer_creator});
105 }
106 
107 std::vector<std::unique_ptr<CustomGraphOptimizer>>
CreateOptimizers(const std::set<string> & device_types)108 PluginGraphOptimizerRegistry::CreateOptimizers(
109     const std::set<string>& device_types) {
110   std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
111   for (auto it = GetPluginRegistrationMap()->begin();
112        it != GetPluginRegistrationMap()->end(); ++it) {
113     if (device_types.find(it->first) == device_types.end()) continue;
114     LOG(INFO) << "Plugin optimizer for device_type " << it->first
115               << " is enabled.";
116     optimizer_list.emplace_back(
117         std::unique_ptr<CustomGraphOptimizer>(it->second()));
118   }
119   return optimizer_list;
120 }
121 
RegisterPluginOptimizerOrDie(const Creator & optimizer_creator,const std::string & device_type,ConfigList & configs)122 void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
123     const Creator& optimizer_creator, const std::string& device_type,
124     ConfigList& configs) {
125   auto ret = GetPluginConfigMap()->insert({device_type, configs});
126   if (!ret.second) {
127     LOG(FATAL) << "PluginGraphOptimizer with device_type "  // Crash OK
128                << device_type << " is registered twice.";
129   }
130   GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
131 }
132 
PrintPluginConfigsIfConflict(const std::set<string> & device_types)133 void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
134     const std::set<string>& device_types) {
135   bool init = false, conflict = false;
136   ConfigList plugin_configs;
137   // Check if plugin's configs have conflict.
138   for (const auto& device_type : device_types) {
139     const auto it = GetPluginConfigMap()->find(device_type);
140     if (it == GetPluginConfigMap()->end()) continue;
141     auto cur_plugin_configs = it->second;
142 
143     if (!init) {
144       plugin_configs = cur_plugin_configs;
145       init = true;
146     } else {
147       if (!(plugin_configs == cur_plugin_configs)) {
148         conflict = true;
149         break;
150       }
151     }
152   }
153   if (!conflict) return;
154   LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
155                   "regression may happen.";
156   for (const auto& device_type : device_types) {
157     const auto it = GetPluginConfigMap()->find(device_type);
158     if (it == GetPluginConfigMap()->end()) continue;
159     auto cur_plugin_configs = it->second;
160 
161     // Print logs in following style:
162     // disable_model_pruning    0
163     // remapping                1
164     // ...
165     string logs = "";
166     strings::StrAppend(&logs, "disable_model_pruning\t\t",
167                        cur_plugin_configs.disable_model_pruning, "\n");
168     for (auto const& pair : cur_plugin_configs.toggle_config) {
169       strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
170                          (pair.second != RewriterConfig::OFF), "\n");
171     }
172     LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
173                  << logs;
174   }
175 }
176 
GetPluginConfigs(bool use_plugin_optimizers,const std::set<string> & device_types)177 ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
178     bool use_plugin_optimizers, const std::set<string>& device_types) {
179   if (!use_plugin_optimizers) return DefaultPluginConfigs();
180 
181   ConfigList ret_plugin_configs = DefaultPluginConfigs();
182   for (const auto& device_type : device_types) {
183     const auto it = GetPluginConfigMap()->find(device_type);
184     if (it == GetPluginConfigMap()->end()) continue;
185     auto cur_plugin_configs = it->second;
186     // If any of the plugin turns on `disable_model_pruning`,
187     // then `disable_model_pruning` should be true;
188     if (cur_plugin_configs.disable_model_pruning == true)
189       ret_plugin_configs.disable_model_pruning = true;
190 
191     // If any of the plugin turns off a certain optimizer,
192     // then the optimizer should be turned off;
193     for (auto& pair : cur_plugin_configs.toggle_config) {
194       if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
195         ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
196     }
197   }
198 
199   return ret_plugin_configs;
200 }
201 
IsConfigsConflict(ConfigList & user_config,ConfigList & plugin_config)202 bool PluginGraphOptimizerRegistry::IsConfigsConflict(
203     ConfigList& user_config, ConfigList& plugin_config) {
204   if (plugin_config == DefaultPluginConfigs()) return false;
205   if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
206     return true;
207   // Returns true if user_config is turned on but plugin_config is turned off.
208   for (auto& pair : user_config.toggle_config) {
209     if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
210         (plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
211       return true;
212   }
213   return false;
214 }
215 
216 }  // end namespace grappler
217 }  // end namespace tensorflow
218