1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/gen/common/controller.h"
16
17 #include "absl/strings/substitute.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/lib/io/path.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace tensorflow {
25 namespace generator {
26
Controller(PathConfig path_config,Env * env)27 Controller::Controller(PathConfig path_config, Env* env)
28 : env_(env), path_config_(path_config) {
29 // Load the Op and API definitions
30 InitializeOpApi();
31
32 // Convert the Op and API definitions to the internal data model
33 BuildModel();
34 }
~Controller()35 Controller::~Controller() { delete api_def_map_; }
36
WriteFile(const string & file_path,const SourceCode & code) const37 const void Controller::WriteFile(const string& file_path,
38 const SourceCode& code) const {
39 TF_CHECK_OK(WriteStringToFile(env_, file_path, code.Render())) << file_path;
40 }
41
GetModelOps() const42 const std::vector<OpSpec>& Controller::GetModelOps() const {
43 return operators_;
44 }
45
InitializeOpApi()46 void Controller::InitializeOpApi() {
47 OpRegistry::Global()->Export(false, &op_list_);
48
49 // Load matching API defs for each Op. Paths are visited in order, allowing
50 // python/api_def_Xyz.pbtxt to override base/api_def_Xyz.pbtxt, for example.
51 api_def_map_ = new ApiDefMap(op_list_);
52 for (const auto& op : op_list_.op()) {
53 for (const auto& dir : path_config_.api_dirs) {
54 const string file_name = absl::Substitute("api_def_$0.pbtxt", op.name());
55 const string file_path = io::JoinPath(dir, file_name);
56 if (env_->FileExists(file_path).ok()) {
57 TF_CHECK_OK(api_def_map_->LoadFile(env_, file_path)) << file_path;
58 } else {
59 // API defs are currently used for only optional pieces.
60 }
61 }
62 }
63
64 // Doc strings (summary, description) typically come from the API def.
65 api_def_map_->UpdateDocs();
66 }
67
BuildModel()68 void Controller::BuildModel() {
69 // Build the internal data model for the requested ops
70 for (const auto& op_name : path_config_.op_names) {
71 const OpDef* op_def = nullptr;
72 TF_CHECK_OK(OpRegistry::Global()->LookUpOpDef(op_name, &op_def));
73 CHECK(op_def != nullptr); // Crash OK
74
75 const ApiDef* api_def = api_def_map_->GetApiDef(op_name);
76 CHECK(api_def != nullptr); // Crash OK
77
78 operators_.push_back(OpSpec::Create(*op_def, *api_def));
79 }
80 }
81
82 } // namespace generator
83 } // namespace tensorflow
84