1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include <vector>
21
22 #include <google/protobuf/descriptor.h>
23 #include <google/protobuf/compiler/plugin.h>
24 #include <google/protobuf/compiler/code_generator.h>
25 #include <google/protobuf/io/printer.h>
26 #include <google/protobuf/io/zero_copy_stream.h>
27
28 #include "absl/strings/strip.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "nugget/protobuf/options.pb.h"
33
34 using ::google::protobuf::FileDescriptor;
35 using ::google::protobuf::MethodDescriptor;
36 using ::google::protobuf::ServiceDescriptor;
37 using ::google::protobuf::compiler::CodeGenerator;
38 using ::google::protobuf::compiler::OutputDirectory;
39 using ::google::protobuf::io::Printer;
40 using ::google::protobuf::io::ZeroCopyOutputStream;
41
42 using ::nugget::protobuf::app_id;
43 using ::nugget::protobuf::request_buffer_size;
44 using ::nugget::protobuf::response_buffer_size;
45
46 namespace {
47
validateServiceOptions(const ServiceDescriptor & service)48 std::string validateServiceOptions(const ServiceDescriptor& service) {
49 if (!service.options().HasExtension(app_id)) {
50 return "nugget.protobuf.app_id is not defined for service " + service.name();
51 }
52 if (!service.options().HasExtension(request_buffer_size)) {
53 return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
54 }
55 if (!service.options().HasExtension(response_buffer_size)) {
56 return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
57 }
58 return "";
59 }
60
61 template <typename Descriptor>
Packages(const Descriptor & descriptor)62 std::vector<std::string> Packages(const Descriptor& descriptor) {
63 std::vector<std::string> namespaces =
64 absl::StrSplit(descriptor.full_name(), '.');
65 namespaces.pop_back(); // just take the package
66 return namespaces;
67 }
68
69 template <typename Descriptor>
FullyQualifiedIdentifier(const Descriptor & descriptor)70 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
71 const auto namespaces = Packages(descriptor);
72 if (namespaces.empty()) {
73 return "::" + descriptor.name();
74 } else {
75 return absl::StrCat("::", absl::StrJoin(namespaces, "::"), "::", descriptor.name());
76 }
77 }
78
79 template <typename Descriptor>
FullyQualifiedHeader(const Descriptor & descriptor)80 std::string FullyQualifiedHeader(const Descriptor& descriptor) {
81 const std::vector<std::string> packages = Packages(descriptor);
82 const std::vector<std::string> path_components =
83 absl::StrSplit(descriptor.file()->name(), '/');
84 const std::string file(path_components.back());
85 const std::string header = absl::StrCat(absl::StripSuffix(file, ".proto"), ".pb.h");
86 if (packages.empty()) {
87 return header;
88 } else {
89 return absl::StrCat(absl::StrJoin(packages, "/"), "/", header);
90 }
91 }
92
93 template <typename Descriptor>
OpenNamespaces(Printer & printer,const Descriptor & descriptor)94 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
95 const auto namespaces = Packages(descriptor);
96 for (const auto& ns : namespaces) {
97 std::map<std::string, std::string> namespaceVars;
98 namespaceVars["namespace"] = ns;
99 printer.Print(namespaceVars, R"(
100 namespace $namespace$ {)");
101 }
102 }
103
104 template <typename Descriptor>
CloseNamespaces(Printer & printer,const Descriptor & descriptor)105 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
106 const auto namespaces = Packages(descriptor);
107 for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
108 std::map<std::string, std::string> namespaceVars;
109 namespaceVars["namespace"] = *it;
110 printer.Print(namespaceVars, R"(
111 } // namespace $namespace$)");
112 }
113 }
114
ForEachMethod(const ServiceDescriptor & service,std::function<void (std::map<std::string,std::string>)> handler)115 void ForEachMethod(const ServiceDescriptor& service,
116 std::function<void(std::map<std::string, std::string>)> handler) {
117 for (int i = 0; i < service.method_count(); ++i) {
118 const MethodDescriptor& method = *service.method(i);
119 std::map<std::string, std::string> vars;
120 vars["method_id"] = std::to_string(i);
121 vars["method_name"] = method.name();
122 vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
123 vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
124 handler(vars);
125 }
126 }
127
GenerateMockClient(Printer & printer,const ServiceDescriptor & service)128 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
129 std::map<std::string, std::string> vars;
130 vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
131 vars["service_header"] = service.name() + ".client.h";
132 vars["mock_class"] = "Mock" + service.name();
133 vars["class"] = service.name();
134
135 printer.Print(vars, R"(
136 #ifndef $include_guard$
137 #define $include_guard$
138
139 #include <gmock/gmock.h>
140
141 #include <$service_header$>)");
142
143 OpenNamespaces(printer, service);
144
145 printer.Print(vars, R"(
146 struct $mock_class$ : public I$class$ {)");
147
148 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
149 printer.Print(methodVars, R"(
150 MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
151 });
152
153 printer.Print(vars, R"(
154 };)");
155
156 CloseNamespaces(printer, service);
157
158 printer.Print(vars, R"(
159 #endif)");
160 }
161
GenerateClientHeader(Printer & printer,const ServiceDescriptor & service)162 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
163 std::map<std::string, std::string> vars;
164 vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
165 vars["protobuf_header"] = FullyQualifiedHeader(service);
166 vars["class"] = service.name();
167 vars["iface_class"] = "I" + service.name();
168 vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
169
170 printer.Print(vars, R"(
171 #ifndef $include_guard$
172 #define $include_guard$
173
174 #include <application.h>
175 #include <nos/AppClient.h>
176 #include <nos/NuggetClientInterface.h>
177
178 #include "$protobuf_header$")");
179
180 OpenNamespaces(printer, service);
181
182 // Pure virtual interface to make testing easier
183 printer.Print(vars, R"(
184 class $iface_class$ {
185 public:
186 virtual ~$iface_class$() = default;)");
187
188 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
189 printer.Print(methodVars, R"(
190 virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
191 });
192
193 printer.Print(vars, R"(
194 };)");
195
196 // Implementation of the interface for Nugget
197 printer.Print(vars, R"(
198 class $class$ : public $iface_class$ {
199 ::nos::AppClient _app;
200 public:
201 $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
202 ~$class$() override = default;)");
203
204 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
205 printer.Print(methodVars, R"(
206 uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
207 });
208
209 printer.Print(vars, R"(
210 };)");
211
212 CloseNamespaces(printer, service);
213
214 printer.Print(vars, R"(
215 #endif)");
216 }
217
GenerateClientSource(Printer & printer,const ServiceDescriptor & service)218 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
219 std::map<std::string, std::string> vars;
220 vars["generated_header"] = service.name() + ".client.h";
221 vars["class"] = service.name();
222
223 const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
224 const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
225 vars["max_request_size"] = std::to_string(max_request_size);
226 vars["max_response_size"] = std::to_string(max_response_size);
227
228 printer.Print(vars, R"(
229 #include <$generated_header$>
230
231 #include <application.h>)");
232
233 OpenNamespaces(printer, service);
234
235 // Methods
236 ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
237 methodVars.insert(vars.begin(), vars.end());
238 printer.Print(methodVars, R"(
239 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
240 const size_t request_size = request.ByteSizeLong();
241 if (request_size > $max_request_size$) {
242 return APP_ERROR_TOO_MUCH;
243 }
244 std::vector<uint8_t> buffer(request_size);
245 if (!request.SerializeToArray(buffer.data(), buffer.size())) {
246 return APP_ERROR_RPC;
247 }
248 std::vector<uint8_t> responseBuffer;
249 if (response != nullptr) {
250 responseBuffer.resize($max_response_size$);
251 }
252 const uint32_t appStatus = _app.Call($method_id$, buffer,
253 (response != nullptr) ? &responseBuffer : nullptr);
254 if (appStatus == APP_SUCCESS && response != nullptr) {
255 if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
256 return APP_ERROR_RPC;
257 }
258 }
259 return appStatus;
260 })");
261 });
262
263 CloseNamespaces(printer, service);
264 }
265
266 // Generator for C++ Nugget service client
267 class CppNuggetServiceClientGenerator : public CodeGenerator {
268 public:
269 CppNuggetServiceClientGenerator() = default;
270 CppNuggetServiceClientGenerator(const CppNuggetServiceClientGenerator&) = delete;
271 CppNuggetServiceClientGenerator& operator=(const CppNuggetServiceClientGenerator&) = delete;
272 ~CppNuggetServiceClientGenerator() override = default;
273
Generate(const FileDescriptor * file,const std::string & parameter,OutputDirectory * output_directory,std::string * error) const274 bool Generate(const FileDescriptor* file,
275 const std::string& parameter,
276 OutputDirectory* output_directory,
277 std::string* error) const override {
278 for (int i = 0; i < file->service_count(); ++i) {
279 const auto& service = *file->service(i);
280
281 *error = validateServiceOptions(service);
282 if (!error->empty()) {
283 return false;
284 }
285
286 if (parameter == "mock") {
287 std::unique_ptr<ZeroCopyOutputStream> output{
288 output_directory->Open("Mock" + service.name() + ".client.h")};
289 Printer printer(output.get(), '$');
290 GenerateMockClient(printer, service);
291 } else if (parameter == "header") {
292 std::unique_ptr<ZeroCopyOutputStream> output{
293 output_directory->Open(service.name() + ".client.h")};
294 Printer printer(output.get(), '$');
295 GenerateClientHeader(printer, service);
296 } else if (parameter == "source") {
297 std::unique_ptr<ZeroCopyOutputStream> output{
298 output_directory->Open(service.name() + ".client.cpp")};
299 Printer printer(output.get(), '$');
300 GenerateClientSource(printer, service);
301 } else {
302 *error = "Illegal parameter: must be mock|header|source";
303 return false;
304 }
305 }
306
307 return true;
308 }
309 };
310
311 } // namespace
312
main(int argc,char * argv[])313 int main(int argc, char* argv[]) {
314 GOOGLE_PROTOBUF_VERIFY_VERSION;
315 CppNuggetServiceClientGenerator generator;
316 return google::protobuf::compiler::PluginMain(argc, argv, &generator);
317 }
318