1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include <vector>
21 
22 #include <google/protobuf/descriptor.h>
23 #include <google/protobuf/compiler/plugin.h>
24 #include <google/protobuf/compiler/code_generator.h>
25 #include <google/protobuf/io/printer.h>
26 #include <google/protobuf/io/zero_copy_stream.h>
27 
28 #include "absl/strings/strip.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_join.h"
31 #include "absl/strings/str_split.h"
32 #include "nugget/protobuf/options.pb.h"
33 
34 using ::google::protobuf::FileDescriptor;
35 using ::google::protobuf::MethodDescriptor;
36 using ::google::protobuf::ServiceDescriptor;
37 using ::google::protobuf::compiler::CodeGenerator;
38 using ::google::protobuf::compiler::OutputDirectory;
39 using ::google::protobuf::io::Printer;
40 using ::google::protobuf::io::ZeroCopyOutputStream;
41 
42 using ::nugget::protobuf::app_id;
43 using ::nugget::protobuf::request_buffer_size;
44 using ::nugget::protobuf::response_buffer_size;
45 
46 namespace {
47 
validateServiceOptions(const ServiceDescriptor & service)48 std::string validateServiceOptions(const ServiceDescriptor& service) {
49     if (!service.options().HasExtension(app_id)) {
50         return "nugget.protobuf.app_id is not defined for service " + service.name();
51     }
52     if (!service.options().HasExtension(request_buffer_size)) {
53         return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
54     }
55     if (!service.options().HasExtension(response_buffer_size)) {
56         return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
57     }
58     return "";
59 }
60 
61 template <typename Descriptor>
Packages(const Descriptor & descriptor)62 std::vector<std::string> Packages(const Descriptor& descriptor) {
63     std::vector<std::string> namespaces =
64         absl::StrSplit(descriptor.full_name(), '.');
65     namespaces.pop_back(); // just take the package
66     return namespaces;
67 }
68 
69 template <typename Descriptor>
FullyQualifiedIdentifier(const Descriptor & descriptor)70 std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
71     const auto namespaces = Packages(descriptor);
72     if (namespaces.empty()) {
73         return "::" + descriptor.name();
74     } else {
75         return absl::StrCat("::", absl::StrJoin(namespaces, "::"), "::", descriptor.name());
76     }
77 }
78 
79 template <typename Descriptor>
FullyQualifiedHeader(const Descriptor & descriptor)80 std::string FullyQualifiedHeader(const Descriptor& descriptor) {
81     const std::vector<std::string> packages = Packages(descriptor);
82     const std::vector<std::string> path_components =
83         absl::StrSplit(descriptor.file()->name(), '/');
84     const std::string file(path_components.back());
85     const std::string header = absl::StrCat(absl::StripSuffix(file, ".proto"), ".pb.h");
86     if (packages.empty()) {
87         return header;
88     } else {
89         return absl::StrCat(absl::StrJoin(packages, "/"), "/", header);
90     }
91 }
92 
93 template <typename Descriptor>
OpenNamespaces(Printer & printer,const Descriptor & descriptor)94 void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
95     const auto namespaces = Packages(descriptor);
96     for (const auto& ns : namespaces) {
97         std::map<std::string, std::string> namespaceVars;
98         namespaceVars["namespace"] = ns;
99         printer.Print(namespaceVars, R"(
100 namespace $namespace$ {)");
101     }
102 }
103 
104 template <typename Descriptor>
CloseNamespaces(Printer & printer,const Descriptor & descriptor)105 void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
106     const auto namespaces = Packages(descriptor);
107     for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
108         std::map<std::string, std::string> namespaceVars;
109         namespaceVars["namespace"] = *it;
110         printer.Print(namespaceVars, R"(
111 } // namespace $namespace$)");
112     }
113 }
114 
ForEachMethod(const ServiceDescriptor & service,std::function<void (std::map<std::string,std::string>)> handler)115 void ForEachMethod(const ServiceDescriptor& service,
116                    std::function<void(std::map<std::string, std::string>)> handler) {
117     for (int i = 0; i < service.method_count(); ++i) {
118         const MethodDescriptor& method = *service.method(i);
119         std::map<std::string, std::string> vars;
120         vars["method_id"] = std::to_string(i);
121         vars["method_name"] = method.name();
122         vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
123         vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
124         handler(vars);
125     }
126 }
127 
GenerateMockClient(Printer & printer,const ServiceDescriptor & service)128 void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
129     std::map<std::string, std::string> vars;
130     vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
131     vars["service_header"] = service.name() + ".client.h";
132     vars["mock_class"] = "Mock" + service.name();
133     vars["class"] = service.name();
134 
135     printer.Print(vars, R"(
136 #ifndef $include_guard$
137 #define $include_guard$
138 
139 #include <gmock/gmock.h>
140 
141 #include <$service_header$>)");
142 
143     OpenNamespaces(printer, service);
144 
145     printer.Print(vars, R"(
146 struct $mock_class$ : public I$class$ {)");
147 
148     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
149         printer.Print(methodVars, R"(
150     MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
151     });
152 
153     printer.Print(vars, R"(
154 };)");
155 
156     CloseNamespaces(printer, service);
157 
158     printer.Print(vars, R"(
159 #endif)");
160 }
161 
GenerateClientHeader(Printer & printer,const ServiceDescriptor & service)162 void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
163     std::map<std::string, std::string> vars;
164     vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
165     vars["protobuf_header"] = FullyQualifiedHeader(service);
166     vars["class"] = service.name();
167     vars["iface_class"] = "I" + service.name();
168     vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);
169 
170     printer.Print(vars, R"(
171 #ifndef $include_guard$
172 #define $include_guard$
173 
174 #include <application.h>
175 #include <nos/AppClient.h>
176 #include <nos/NuggetClientInterface.h>
177 
178 #include "$protobuf_header$")");
179 
180     OpenNamespaces(printer, service);
181 
182     // Pure virtual interface to make testing easier
183     printer.Print(vars, R"(
184 class $iface_class$ {
185 public:
186     virtual ~$iface_class$() = default;)");
187 
188     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
189         printer.Print(methodVars, R"(
190     virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
191     });
192 
193     printer.Print(vars, R"(
194 };)");
195 
196     // Implementation of the interface for Nugget
197     printer.Print(vars, R"(
198 class $class$ : public $iface_class$ {
199     ::nos::AppClient _app;
200 public:
201     $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
202     ~$class$() override = default;)");
203 
204     ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
205         printer.Print(methodVars, R"(
206     uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
207     });
208 
209     printer.Print(vars, R"(
210 };)");
211 
212     CloseNamespaces(printer, service);
213 
214     printer.Print(vars, R"(
215 #endif)");
216 }
217 
GenerateClientSource(Printer & printer,const ServiceDescriptor & service)218 void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
219     std::map<std::string, std::string> vars;
220     vars["generated_header"] = service.name() + ".client.h";
221     vars["class"] = service.name();
222 
223     const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
224     const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
225     vars["max_request_size"] = std::to_string(max_request_size);
226     vars["max_response_size"] = std::to_string(max_response_size);
227 
228     printer.Print(vars, R"(
229 #include <$generated_header$>
230 
231 #include <application.h>)");
232 
233     OpenNamespaces(printer, service);
234 
235     // Methods
236     ForEachMethod(service, [&](std::map<std::string, std::string>  methodVars) {
237         methodVars.insert(vars.begin(), vars.end());
238         printer.Print(methodVars, R"(
239 uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
240     const size_t request_size = request.ByteSizeLong();
241     if (request_size > $max_request_size$) {
242         return APP_ERROR_TOO_MUCH;
243     }
244     std::vector<uint8_t> buffer(request_size);
245     if (!request.SerializeToArray(buffer.data(), buffer.size())) {
246         return APP_ERROR_RPC;
247     }
248     std::vector<uint8_t> responseBuffer;
249     if (response != nullptr) {
250       responseBuffer.resize($max_response_size$);
251     }
252     const uint32_t appStatus = _app.Call($method_id$, buffer,
253                                          (response != nullptr) ? &responseBuffer : nullptr);
254     if (appStatus == APP_SUCCESS && response != nullptr) {
255         if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
256             return APP_ERROR_RPC;
257         }
258     }
259     return appStatus;
260 })");
261     });
262 
263     CloseNamespaces(printer, service);
264 }
265 
266 // Generator for C++ Nugget service client
267 class CppNuggetServiceClientGenerator : public CodeGenerator {
268 public:
269     CppNuggetServiceClientGenerator() = default;
270     CppNuggetServiceClientGenerator(const CppNuggetServiceClientGenerator&) = delete;
271     CppNuggetServiceClientGenerator& operator=(const CppNuggetServiceClientGenerator&) = delete;
272     ~CppNuggetServiceClientGenerator() override = default;
273 
Generate(const FileDescriptor * file,const std::string & parameter,OutputDirectory * output_directory,std::string * error) const274     bool Generate(const FileDescriptor* file,
275                   const std::string& parameter,
276                   OutputDirectory* output_directory,
277                   std::string* error) const override {
278         for (int i = 0; i < file->service_count(); ++i) {
279             const auto& service = *file->service(i);
280 
281             *error = validateServiceOptions(service);
282             if (!error->empty()) {
283                 return false;
284             }
285 
286             if (parameter == "mock") {
287                 std::unique_ptr<ZeroCopyOutputStream> output{
288                         output_directory->Open("Mock" + service.name() + ".client.h")};
289                 Printer printer(output.get(), '$');
290                 GenerateMockClient(printer, service);
291             } else if (parameter == "header") {
292                 std::unique_ptr<ZeroCopyOutputStream> output{
293                         output_directory->Open(service.name() + ".client.h")};
294                 Printer printer(output.get(), '$');
295                 GenerateClientHeader(printer, service);
296             } else if (parameter == "source") {
297                 std::unique_ptr<ZeroCopyOutputStream> output{
298                         output_directory->Open(service.name() + ".client.cpp")};
299                 Printer printer(output.get(), '$');
300                 GenerateClientSource(printer, service);
301             } else {
302                 *error = "Illegal parameter: must be mock|header|source";
303                 return false;
304             }
305         }
306 
307         return true;
308     }
309 };
310 
311 } // namespace
312 
main(int argc,char * argv[])313 int main(int argc, char* argv[]) {
314     GOOGLE_PROTOBUF_VERIFY_VERSION;
315     CppNuggetServiceClientGenerator generator;
316     return google::protobuf::compiler::PluginMain(argc, argv, &generator);
317 }
318