xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/grappler/grappler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core graph optimizer base classes in terms of
16 // the C API defined in grappler.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "TP_Something".
19 
20 #include "tensorflow/c/experimental/grappler/grappler.h"
21 
22 #include <memory>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/c/c_api_internal.h"
28 #include "tensorflow/c/experimental/grappler/grappler_internal.h"
29 #include "tensorflow/c/tf_buffer_internal.h"
30 #include "tensorflow/c/tf_status_helper.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/grappler/costs/graph_properties.h"
33 #include "tensorflow/core/grappler/grappler_item.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/status.h"
37 
38 namespace {
39 
40 #define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME)    \
41   do {                                                                    \
42     if (STRUCT_OBJ.struct_size == 0) {                                    \
43       return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION,   \
44                                 "struct_size field in " #STRUCT_NAME      \
45                                 " must be set to " #SIZE_VALUE_NAME "."); \
46     }                                                                     \
47   } while (0)
48 
49 #define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME)                  \
50   do {                                                                  \
51     if (STRUCT_OBJ.NAME == 0) {                                         \
52       return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \
53                                 "'" #NAME "' field in " #STRUCT_NAME    \
54                                 " must be set.");                       \
55     }                                                                   \
56   } while (0)
57 
ValidateTPOptimizerRegistrationParams(const TP_OptimizerRegistrationParams & params)58 tensorflow::Status ValidateTPOptimizerRegistrationParams(
59     const TP_OptimizerRegistrationParams& params) {
60   VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params,
61                        TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE);
62   VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type);
63   return ::tensorflow::OkStatus();
64 }
65 
ValidateTPOptimizer(const TP_Optimizer & optimizer)66 tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) {
67   VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE);
68   VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func);
69   return ::tensorflow::OkStatus();
70 }
71 
ValidateTPOptimizerConfigs(const TP_OptimizerConfigs & configs)72 tensorflow::Status ValidateTPOptimizerConfigs(
73     const TP_OptimizerConfigs& configs) {
74   VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs,
75                        TP_OPTIMIZER_CONFIGS_STRUCT_SIZE);
76   return ::tensorflow::OkStatus();
77 }
78 
79 #undef VALIDATE_MEMBER
80 #undef VALIDATE_STRUCT_SIZE
81 }  // namespace
82 
83 namespace tensorflow {
84 namespace grappler {
85 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph_def)86 Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
87                                  GraphDef* optimized_graph_def) {
88   OwnedTFStatus c_status(TF_NewStatus());
89   OwnedTFBuffer graph_buf(TF_NewBuffer());
90   OwnedTFBuffer optimized_graph_buf(TF_NewBuffer());
91   TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get()));
92 
93   optimizer_.optimize_func(c_optimizer_, graph_buf.get(),
94                            reinterpret_cast<const TF_GrapplerItem*>(&item),
95                            optimized_graph_buf.get(), c_status.get());
96   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
97   TF_RETURN_IF_ERROR(
98       BufferToMessage(optimized_graph_buf.get(), optimized_graph_def));
99 
100   return OkStatus();
101 }
102 
103 #define CONFIG_TOGGLE(optimizer)                             \
104   if (tp_configs.optimizer == TF_TriState_Off)               \
105     configs.toggle_config[#optimizer] = RewriterConfig::OFF; \
106   else                                                       \
107     configs.toggle_config[#optimizer] = RewriterConfig::ON;
108 
CGraphOptimizerRegister(const PluginGraphOptimizerRegistry::Creator & creator,const TP_OptimizerConfigs tp_configs,const char * device_type)109 void CGraphOptimizerRegister(
110     const PluginGraphOptimizerRegistry::Creator& creator,
111     const TP_OptimizerConfigs tp_configs, const char* device_type) {
112   ConfigList configs;
113   // disable_model_pruning is turned off by default.
114   if (tp_configs.disable_model_pruning == TF_TriState_On)
115     configs.disable_model_pruning = true;
116   else
117     configs.disable_model_pruning = false;
118   // The other configs are turned on by default.
119   CONFIG_TOGGLE(implementation_selector);
120   CONFIG_TOGGLE(function_optimization);
121   CONFIG_TOGGLE(common_subgraph_elimination);
122   CONFIG_TOGGLE(arithmetic_optimization);
123   CONFIG_TOGGLE(debug_stripper);
124   CONFIG_TOGGLE(constant_folding);
125   CONFIG_TOGGLE(shape_optimization);
126   CONFIG_TOGGLE(auto_mixed_precision);
127   CONFIG_TOGGLE(auto_mixed_precision_onednn_bfloat16);
128   CONFIG_TOGGLE(auto_mixed_precision_mkl);
129   CONFIG_TOGGLE(pin_to_host_optimization);
130   CONFIG_TOGGLE(layout_optimizer);
131   CONFIG_TOGGLE(remapping);
132   CONFIG_TOGGLE(loop_optimization);
133   CONFIG_TOGGLE(dependency_optimization);
134   CONFIG_TOGGLE(auto_parallel);
135   CONFIG_TOGGLE(memory_optimization);
136   CONFIG_TOGGLE(scoped_allocator_optimization);
137   PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
138       creator, device_type, configs);
139 }
140 
141 #undef CONFIG_TOGGLE
142 
InitGraphPlugin(void * dso_handle)143 tensorflow::Status InitGraphPlugin(void* dso_handle) {
144   tensorflow::Env* env = tensorflow::Env::Default();
145 
146   // Step 1: Load symbol for `TF_InitPlugin`
147   void* dso_symbol;
148   TF_RETURN_IF_ERROR(
149       env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol));
150 
151   // Step 2: Call `TF_InitPlugin`
152   auto init_fn = reinterpret_cast<TFInitGraphPluginFn>(dso_symbol);
153   return InitGraphPlugin(init_fn);
154 }
155 
InitGraphPlugin(TFInitGraphPluginFn init_fn)156 tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) {
157   TP_OptimizerRegistrationParams params{
158       TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE};
159   TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE};
160   TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE};
161   params.major_version = GO_MAJOR;
162   params.minor_version = GO_MINOR;
163   params.patch_version = GO_PATCH;
164   params.optimizer = &optimizer;
165   params.optimizer_configs = &optimizer_configs;
166 
167   OwnedTFStatus c_status(TF_NewStatus());
168   init_fn(&params, c_status.get());
169   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
170   TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params));
171   TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer));
172   TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs));
173 
174   CGraphOptimizerRegister(
175       [=]() { return new CGraphOptimizer(optimizer, params.device_type); },
176       optimizer_configs, params.device_type);
177 
178   return OkStatus();
179 }
180 
181 }  // namespace grappler
182 }  // namespace tensorflow
183 
TF_GetNodesToPreserveListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)184 void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values,
185                                    size_t* storage_size, TF_Status* status) {
186   TF_SetStatus(status, TF_OK, "");
187   const std::unordered_set<std::string>& nodes =
188       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
189           ->NodesToPreserve();
190   *num_values = nodes.size();
191   *storage_size = 0;
192   for (const std::string& str : nodes) {
193     *storage_size += str.size();
194   }
195 }
196 
TF_GetNodesToPreserveList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)197 void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values,
198                                size_t* lengths, int num_values, void* storage,
199                                size_t storage_size, TF_Status* status) {
200   TF_SetStatus(status, TF_OK, "");
201   const std::unordered_set<std::string>& nodes =
202       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)
203           ->NodesToPreserve();
204   char* p = static_cast<char*>(storage);
205 
206   int index = 0;
207   for (const std::string& s : nodes) {
208     if (index >= num_values) break;
209     values[index] = p;
210     lengths[index] = s.size();
211     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
212       status->status = tensorflow::errors::InvalidArgument(
213           "Not enough storage to hold the requested list of nodes");
214       return;
215     }
216     memcpy(values[index], s.data(), s.size());
217     p += s.size();
218     index++;
219   }
220 }
221 
TF_GetFetchNodesListSize(const TF_GrapplerItem * item,int * num_values,size_t * storage_size,TF_Status * status)222 void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values,
223                               size_t* storage_size, TF_Status* status) {
224   TF_SetStatus(status, TF_OK, "");
225   const std::vector<std::string>& nodes =
226       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
227   *num_values = nodes.size();
228   *storage_size = 0;
229   for (const std::string& str : nodes) {
230     *storage_size += str.size();
231   }
232 }
233 
TF_GetFetchNodesList(const TF_GrapplerItem * item,char ** values,size_t * lengths,int num_values,void * storage,size_t storage_size,TF_Status * status)234 void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values,
235                           size_t* lengths, int num_values, void* storage,
236                           size_t storage_size, TF_Status* status) {
237   TF_SetStatus(status, TF_OK, "");
238   const std::vector<std::string>& nodes =
239       reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)->fetch;
240 
241   const int len = std::min(num_values, static_cast<int>(nodes.size()));
242   char* p = static_cast<char*>(storage);
243   for (int index = 0; index < len; ++index) {
244     const std::string& s = nodes[index];
245     values[index] = p;
246     lengths[index] = s.size();
247     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
248       status->status = tensorflow::errors::InvalidArgument(
249           "Not enough storage to hold the requested list of nodes");
250       return;
251     }
252     memcpy(values[index], s.data(), s.size());
253     p += s.size();
254   }
255 }
256 
TF_NewGraphProperties(const TF_GrapplerItem * item)257 TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item) {
258   return reinterpret_cast<TF_GraphProperties*>(
259       new tensorflow::grappler::GraphProperties(
260           *reinterpret_cast<const tensorflow::grappler::GrapplerItem*>(item)));
261 }
262 
TF_DeleteGraphProperties(TF_GraphProperties * graph_properties)263 void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties) {
264   if (graph_properties == nullptr) return;
265   delete reinterpret_cast<tensorflow::grappler::GraphProperties*>(
266       graph_properties);
267 }
268 
TF_InferStatically(TF_GraphProperties * graph_properties,TF_Bool assume_valid_feeds,TF_Bool aggressive_shape_inference,TF_Bool include_input_tensor_values,TF_Bool include_output_tensor_values,TF_Status * status)269 void TF_InferStatically(TF_GraphProperties* graph_properties,
270                         TF_Bool assume_valid_feeds,
271                         TF_Bool aggressive_shape_inference,
272                         TF_Bool include_input_tensor_values,
273                         TF_Bool include_output_tensor_values,
274                         TF_Status* status) {
275   TF_SetStatus(status, TF_OK, "");
276   tensorflow::Status s =
277       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
278           ->InferStatically(assume_valid_feeds, aggressive_shape_inference,
279                             include_input_tensor_values,
280                             include_output_tensor_values);
281   if (!s.ok()) {
282     ::tensorflow::Set_TF_Status_from_Status(status, s);
283   }
284 }
285 
TF_GetInputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)286 void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties,
287                                    const char* name, int* num_values,
288                                    TF_Status* status) {
289   TF_SetStatus(status, TF_OK, "");
290   *num_values =
291       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
292           ->GetInputProperties(name)
293           .size();
294 }
295 
TF_GetOutputPropertiesListSize(TF_GraphProperties * graph_properties,const char * name,int * num_values,TF_Status * status)296 void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties,
297                                     const char* name, int* num_values,
298                                     TF_Status* status) {
299   TF_SetStatus(status, TF_OK, "");
300   *num_values =
301       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
302           ->GetOutputProperties(name)
303           .size();
304 }
305 
TF_GetInputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)306 void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties,
307                                const char* name, TF_Buffer** properties,
308                                int num_values, TF_Status* status) {
309   TF_SetStatus(status, TF_OK, "");
310   const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
311       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
312           ->GetInputProperties(name);
313   const int len =
314       std::min(num_values, static_cast<int>(tensor_properties.size()));
315   for (int i = 0; i < len; ++i) {
316     tensorflow::Status s =
317         tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
318     if (!s.ok()) {
319       ::tensorflow::Set_TF_Status_from_Status(status, s);
320       return;
321     }
322   }
323 }
324 
TF_GetOutputPropertiesList(TF_GraphProperties * graph_properties,const char * name,TF_Buffer ** properties,int num_values,TF_Status * status)325 void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties,
326                                 const char* name, TF_Buffer** properties,
327                                 int num_values, TF_Status* status) {
328   TF_SetStatus(status, TF_OK, "");
329   const std::vector<tensorflow::OpInfo::TensorProperties>& tensor_properties =
330       reinterpret_cast<tensorflow::grappler::GraphProperties*>(graph_properties)
331           ->GetOutputProperties(name);
332   const int len =
333       std::min(num_values, static_cast<int>(tensor_properties.size()));
334   for (int i = 0; i < len; ++i) {
335     tensorflow::Status s =
336         tensorflow::MessageToBuffer(tensor_properties[i], properties[i]);
337     if (!s.ok()) {
338       ::tensorflow::Set_TF_Status_from_Status(status, s);
339       return;
340     }
341   }
342 }
343 
TF_NewFunctionLibraryDefinition(const TF_Buffer * graph_buf,TF_Status * status)344 TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition(
345     const TF_Buffer* graph_buf, TF_Status* status) {
346   TF_SetStatus(status, TF_OK, "");
347   tensorflow::GraphDef graph_def;
348   tensorflow::Status s = tensorflow::BufferToMessage(graph_buf, &graph_def);
349   if (!s.ok()) {
350     ::tensorflow::Set_TF_Status_from_Status(status, s);
351     return nullptr;
352   }
353   return reinterpret_cast<TF_FunctionLibraryDefinition*>(
354       new tensorflow::FunctionLibraryDefinition(
355           tensorflow::OpRegistry::Global(), graph_def.library()));
356 }
357 
TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition * fn_lib)358 void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib) {
359   if (fn_lib == nullptr) return;
360   delete reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib);
361 }
362 
TF_LookUpOpDef(TF_FunctionLibraryDefinition * fn_lib,const char * name,TF_Buffer * buf,TF_Status * status)363 void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name,
364                     TF_Buffer* buf, TF_Status* status) {
365   TF_SetStatus(status, TF_OK, "");
366   const tensorflow::OpDef* op_def_ptr = nullptr;
367   tensorflow::Status s =
368       reinterpret_cast<tensorflow::FunctionLibraryDefinition*>(fn_lib)
369           ->LookUpOpDef(name, &op_def_ptr);
370   if (!s.ok()) {
371     ::tensorflow::Set_TF_Status_from_Status(status, s);
372     return;
373   }
374 
375   s = tensorflow::MessageToBuffer(*op_def_ptr, buf);
376   if (!s.ok()) {
377     ::tensorflow::Set_TF_Status_from_Status(status, s);
378     return;
379   }
380 }
381