xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/grappler/grappler_internal.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Classes and utilities that work with Graph C API for internal use.
16 // This includes functions used for optimizer registration and interfaces needed
17 // for testing.
18 
19 #ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
20 #define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
21 
22 #include <functional>
23 #include <memory>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/experimental/grappler/grappler.h"
30 #include "tensorflow/core/framework/graph.pb.h"
31 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
32 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 // Plugin initialization function that a device plugin
40 // must define.
41 typedef void (*TFInitGraphPluginFn)(TP_OptimizerRegistrationParams* const,
42                                     TF_Status* const);
43 
44 // Registers Graph optimizers.
45 Status InitGraphPlugin(void* dso_handle);
46 
47 // Allow registering a graph optimizer using a function (used for
48 // testing).
49 Status InitGraphPlugin(TFInitGraphPluginFn init_fn);
50 
51 struct GrapplerItem;
52 class Cluster;
53 
54 struct TFStatusDeleter {
operatorTFStatusDeleter55   void operator()(TF_Status* s) const { TF_DeleteStatus(s); }
56 };
57 using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
58 
59 struct TFBufferDeleter {
operatorTFBufferDeleter60   void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); }
61 };
62 using OwnedTFBuffer = std::unique_ptr<TF_Buffer, TFBufferDeleter>;
63 
64 class CGraphOptimizer : public CustomGraphOptimizer {
65  public:
CGraphOptimizer(TP_Optimizer optimizer,const char * device_type)66   explicit CGraphOptimizer(TP_Optimizer optimizer, const char* device_type)
67       : optimizer_(optimizer), device_type_(device_type) {
68     if (optimizer.create_func != nullptr) {
69       c_optimizer_ = (*optimizer_.create_func)();
70     } else {
71       c_optimizer_ = nullptr;
72     }
73   }
name()74   std::string name() const override { return "PluggableGraphOptimizer"; }
UsesFunctionLibrary()75   bool UsesFunctionLibrary() const override { return false; }
Init(const tensorflow::RewriterConfig_CustomGraphOptimizer * config)76   Status Init(
77       const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
78     return OkStatus();
79   }
80   Status Optimize(Cluster* cluster, const GrapplerItem& item,
81                   GraphDef* optimized_graph_def) override;
82 
~CGraphOptimizer()83   ~CGraphOptimizer() override {
84     if (optimizer_.destroy_func != nullptr) {
85       (*optimizer_.destroy_func)(c_optimizer_);
86     }
87   }
88 
89  private:
90   TP_Optimizer optimizer_;
91   std::string device_type_;
92   void* c_optimizer_;
93 };
94 
95 // Registration function to register a CGraphOptimizer along with plugin configs
96 // and device type.
97 void CGraphOptimizerRegister(
98     const PluginGraphOptimizerRegistry::Creator& creator,
99     const TP_OptimizerConfigs tp_configs, const char* device_type);
100 
101 }  // namespace grappler
102 }  // namespace tensorflow
103 
104 #endif  // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_
105