1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core graph optimizer base classes in terms of
16 // the C API defined in grappler.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "TP_Something".
19
20 #include "tensorflow/c/experimental/grappler/grappler.h"
21
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
29 #include "tensorflow/c/tf_buffer_internal.h"
30 #include "tensorflow/c/tf_status_helper.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/grappler/costs/graph_properties.h"
33 #include "tensorflow/core/grappler/grappler_item.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/status.h"
37
38 namespace {
39
40 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
41 do { \
42 if (STRUCT_OBJ.struct_size == 0) { \
43 return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
44 "struct_size field in " #STRUCT_NAME \
45 " must be set to " #SIZE_VALUE_NAME "."); \
46 } \
47 } while (0)
48
49 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
50 do { \
51 if (STRUCT_OBJ.NAME == 0) { \
52 return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
53 "'" #NAME "' field in " #STRUCT_NAME \
54 " must be set."); \
55 } \
56 } while (0)
57
ValidateTPOptimizerRegistrationParams(const TP_OptimizerRegistrationParams & params)58 tensorflow::Status ValidateTPOptimizerRegistrationParams(
59 const TP_OptimizerRegistrationParams& params) {
60 VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
61 TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
62 VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
63 return ::tensorflow::OkStatus();
64 }
65
ValidateTPOptimizer(const TP_Optimizer & optimizer)66 tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
67 VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
68 VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
69 return ::tensorflow::OkStatus();
70 }
71
ValidateTPOptimizerConfigs(const TP_OptimizerConfigs & configs)72 tensorflow::Status ValidateTPOptimizerConfigs(
73 const TP_OptimizerConfigs& configs) {
74 VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
75 TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
76 return ::tensorflow::OkStatus();
77 }
78
79 #undef VALIDATE_MEMBER
80 #undef VALIDATE_STRUCT_SIZE
81 } // namespace
82
83 namespace tensorflow {
84 namespace grappler {
85
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph_def)86 Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
87 GraphDef* optimized_graph_def) {
88 OwnedTFStatus c_status(TF_NewStatus());
89 OwnedTFBuffer graph_buf(TF_NewBuffer());
90 OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
91 TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
92
93 optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
94 reinterpret_cast<const TF_GrapplerItem*>(&item),
95 optimized_graph_buf.get(), c_status.get());
96 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
97 TF_RETURN_IF_ERROR(
98 BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
99
100 return OkStatus();
101 }
102
103 #define CONFIG_TOGGLE(optimizer) \
104 if (tp_configs.optimizer == TF_TriState_Off) \
105 configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
106 else \
107 configs.toggle_config[#optimizer] = RewriterConfig::ON;
108
CGraphOptimizerRegister(const PluginGraphOptimizerRegistry::Creator & creator,const TP_OptimizerConfigs tp_configs,const char * device_type)109 void CGraphOptimizerRegister(
110 const PluginGraphOptimizerRegistry::Creator& creator,
111 const TP_OptimizerConfigs tp_configs, const char* device_type) {
112 ConfigList configs;
113 // disable_model_pruning is turned off by default.
114 if (tp_configs.disable_model_pruning == TF_TriState_On)
115 configs.disable_model_pruning = true;
116 else
117 configs.disable_model_pruning = false;
118 // The other configs are turned on by default.
119 CONFIG_TOGGLE(implementation_selector);
120 CONFIG_TOGGLE(function_optimization);
121 CONFIG_TOGGLE(common_subgraph_elimination);
122 CONFIG_TOGGLE(arithmetic_optimization);
123 CONFIG_TOGGLE(debug_stripper);
124 CONFIG_TOGGLE(constant_folding);
125 CONFIG_TOGGLE(shape_optimization);
126 CONFIG_TOGGLE(auto_mixed_precision);
127 CONFIG_TOGGLE(auto_mixed_precision_onednn_bfloat16);
128 CONFIG_TOGGLE(auto_mixed_precision_mkl);
129 CONFIG_TOGGLE(pin_to_host_optimization);
130 CONFIG_TOGGLE(layout_optimizer);
131 CONFIG_TOGGLE(remapping);
132 CONFIG_TOGGLE(loop_optimization);
133 CONFIG_TOGGLE(dependency_optimization);
134 CONFIG_TOGGLE(auto_parallel);
135 CONFIG_TOGGLE(memory_optimization);
136 CONFIG_TOGGLE(scoped_allocator_optimization);
137 PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
138 creator, device_type, configs);
139 }
140
141 #undef CONFIG_TOGGLE
142
InitGraphPlugin(void * dso_handle)143 tensorflow::Status InitGraphPlugin(void* dso_handle) {
144 tensorflow::Env* env = tensorflow::Env::Default();
145
146 // Step 1: Load symbol for `TF_InitPlugin`
147 void* dso_symbol;
148 TF_RETURN_IF_ERROR(
149 env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
150
151 // Step 2: Call `TF_InitPlugin`
152 auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
153 return InitGraphPlugin(init_fn);
154 }
155
InitGraphPlugin(TFInitGraphPluginFn init_fn)156 tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
157 TP_OptimizerRegistrationParams params{
158 TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
159 TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
160 TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
161 params.major_version = GO_MAJOR;
162 params.minor_version = GO_MINOR;
163 params.patch_version = GO_PATCH;
164 params.optimizer = &optimizer;
165 params.optimizer_configs = &optimizer_configs;
166
167 OwnedTFStatus c_status(TF_NewStatus());
168 init_fn(¶ms, c_status.get());
169 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
170 TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
171 TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
172 TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
173
174 CGraphOptimizerRegister(
175 [=]() { return new CGraphOptimizer(optimizer, params.device_type); },
176 optimizer_configs, params.device_type);
177
178 return OkStatus();
179 }
180
181 } // namespace grappler
182 } // namespace tensorflow
183
TF_GetNodesToPreserveListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)184 void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values,
185 size_t* storage_size, TF_Status* status) {
186 TF_SetStatus(status, TF_OK, "");
187 const std::unordered_set<std::string>& nodes =
188 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
189 ->NodesToPreserve();
190 *num_values = nodes.size();
191 *storage_size = 0;
192 for (const std::string& str : nodes) {
193 *storage_size += str.size();
194 }
195 }
196
TF_GetNodesToPreserveList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)197 void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values,
198 size_t* lengths, int num_values, void* storage,
199 size_t storage_size, TF_Status* status) {
200 TF_SetStatus(status, TF_OK, "");
201 const std::unordered_set<std::string>& nodes =
202 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
203 ->NodesToPreserve();
204 char* p = static_cast<char*>(storage);
205
206 int index = 0;
207 for (const std::string& s : nodes) {
208 if (index >= num_values) break;
209 values[index] = p;
210 lengths[index] = s.size();
211 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
212 status->status = tensorflow::errors::InvalidArgument(
213 "Not enough storage to hold the requested list of nodes");
214 return;
215 }
216 memcpy(values[index], s.data(), s.size());
217 p += s.size();
218 index++;
219 }
220 }
221
TF_GetFetchNodesListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)222 void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values,
223 size_t* storage_size, TF_Status* status) {
224 TF_SetStatus(status, TF_OK, "");
225 const std::vector<std::string>& nodes =
226 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
227 *num_values = nodes.size();
228 *storage_size = 0;
229 for (const std::string& str : nodes) {
230 *storage_size += str.size();
231 }
232 }
233
TF_GetFetchNodesList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)234 void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values,
235 size_t* lengths, int num_values, void* storage,
236 size_t storage_size, TF_Status* status) {
237 TF_SetStatus(status, TF_OK, "");
238 const std::vector<std::string>& nodes =
239 reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
240
241 const int len = std::min(num_values, static_cast<int>(nodes.size()));
242 char* p = static_cast<char*>(storage);
243 for (int index = 0; index < len; ++index) {
244 const std::string& s = nodes[index];
245 values[index] = p;
246 lengths[index] = s.size();
247 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
248 status->status = tensorflow::errors::InvalidArgument(
249 "Not enough storage to hold the requested list of nodes");
250 return;
251 }
252 memcpy(values[index], s.data(), s.size());
253 p += s.size();
254 }
255 }
256
TF_NewGraphProperties(const TF_GrapplerItem * item)257 TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item) {
258 return reinterpret_cast<TF_GraphProperties*>(
259 new tensorflow::grappler::GraphProperties(
260 *reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)));
261 }
262
TF_DeleteGraphProperties(TF_GraphProperties * graph_properties)263 void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties) {
264 if (graph_properties == nullptr) return;
265 delete reinterpret_cast<tensorflow::grappler::GraphProperties*>(
266 graph_properties);
267 }
268
TF_InferStatically(TF_GraphProperties * graph_properties,TF_Bool assume_valid_feeds,TF_Bool aggressive_shape_inference,TF_Bool include_input_tensor_values,TF_Bool include_output_tensor_values,TF_Status * status)269 void TF_InferStatically(TF_GraphProperties* graph_properties,
270 TF_Bool assume_valid_feeds,
271 TF_Bool aggressive_shape_inference,
272 TF_Bool include_input_tensor_values,
273 TF_Bool include_output_tensor_values,
274 TF_Status* status) {
275 TF_SetStatus(status, TF_OK, "");
276 tensorflow::Status s =
277 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
278 ->InferStatically(assume_valid_feeds, aggressive_shape_inference,
279 include_input_tensor_values,
280 include_output_tensor_values);
281 if (!s.ok()) {
282 ::tensorflow::Set_TF_Status_from_Status(status, s);
283 }
284 }
285
TF_GetInputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)286 void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties,
287 const char* name, int* num_values,
288 TF_Status* status) {
289 TF_SetStatus(status, TF_OK, "");
290 *num_values =
291 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
292 ->GetInputProperties(name)
293 .size();
294 }
295
TF_GetOutputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)296 void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties,
297 const char* name, int* num_values,
298 TF_Status* status) {
299 TF_SetStatus(status, TF_OK, "");
300 *num_values =
301 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
302 ->GetOutputProperties(name)
303 .size();
304 }
305
TF_GetInputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)306 void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties,
307 const char* name, TF_Buffer** properties,
308 int num_values, TF_Status* status) {
309 TF_SetStatus(status, TF_OK, "");
310 const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
311 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
312 ->GetInputProperties(name);
313 const int len =
314 std::min(num_values, static_cast<int>(tensor_properties.size()));
315 for (int i = 0; i < len; ++i) {
316 tensorflow::Status s =
317 tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
318 if (!s.ok()) {
319 ::tensorflow::Set_TF_Status_from_Status(status, s);
320 return;
321 }
322 }
323 }
324
TF_GetOutputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)325 void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties,
326 const char* name, TF_Buffer** properties,
327 int num_values, TF_Status* status) {
328 TF_SetStatus(status, TF_OK, "");
329 const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
330 reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
331 ->GetOutputProperties(name);
332 const int len =
333 std::min(num_values, static_cast<int>(tensor_properties.size()));
334 for (int i = 0; i < len; ++i) {
335 tensorflow::Status s =
336 tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
337 if (!s.ok()) {
338 ::tensorflow::Set_TF_Status_from_Status(status, s);
339 return;
340 }
341 }
342 }
343
TF_NewFunctionLibraryDefinition(const TF_Buffer * graph_buf,TF_Status * status)344 TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition(
345 const TF_Buffer* graph_buf, TF_Status* status) {
346 TF_SetStatus(status, TF_OK, "");
347 tensorflow::GraphDef graph_def;
348 tensorflow::Status s = tensorflow::BufferToMessage(graph_buf, &graph_def);
349 if (!s.ok()) {
350 ::tensorflow::Set_TF_Status_from_Status(status, s);
351 return nullptr;
352 }
353 return reinterpret_cast<TF_FunctionLibraryDefinition*>(
354 new tensorflow::FunctionLibraryDefinition(
355 tensorflow::OpRegistry::Global(), graph_def.library()));
356 }
357
TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition * fn_lib)358 void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib) {
359 if (fn_lib == nullptr) return;
360 delete reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib);
361 }
362
TF_LookUpOpDef(TF_FunctionLibraryDefinition * fn_lib,const char * name,TF_Buffer * buf,TF_Status * status)363 void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name,
364 TF_Buffer* buf, TF_Status* status) {
365 TF_SetStatus(status, TF_OK, "");
366 const tensorflow::OpDef* op_def_ptr = nullptr;
367 tensorflow::Status s =
368 reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib)
369 ->LookUpOpDef(name, &op_def_ptr);
370 if (!s.ok()) {
371 ::tensorflow::Set_TF_Status_from_Status(status, s);
372 return;
373 }
374
375 s = tensorflow::MessageToBuffer(*op_def_ptr, buf);
376 if (!s.ok()) {
377 ::tensorflow::Set_TF_Status_from_Status(status, s);
378 return;
379 }
380 }
381