xref: /aosp_15_r20/external/flatbuffers/grpc/src/compiler/swift_generator.cc (revision 890232f25432b36107d06881e0a25aaa6b473652)
1 /*
2  * Copyright 2020 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * NOTE: The following implementation is a translation for the Swift-grpc
19  * generator since flatbuffers doesnt allow plugins for now. if an issue arises
20  * please open an issue in the flatbuffers repository. This file should always
21  * be maintained according to the Swift-grpc repository
22  */
23 #include <map>
24 #include <sstream>
25 
26 #include "flatbuffers/util.h"
27 #include "src/compiler/schema_interface.h"
28 #include "src/compiler/swift_generator.h"
29 
30 namespace grpc_swift_generator {
31 namespace {
32 
WrapInNameSpace(const std::vector<std::string> & components,const grpc::string & name)33 static std::string WrapInNameSpace(const std::vector<std::string> &components,
34                             const grpc::string &name) {
35   std::string qualified_name;
36   for (auto it = components.begin(); it != components.end(); ++it)
37     qualified_name += *it + "_";
38   return qualified_name + name;
39 }
40 
GenerateMessage(const std::vector<std::string> & components,const grpc::string & name)41 static grpc::string GenerateMessage(const std::vector<std::string> &components,
42                              const grpc::string &name) {
43   return "Message<" + WrapInNameSpace(components, name) + ">";
44 }
45 
46 // MARK: - Client
47 
GenerateClientFuncName(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)48 static void GenerateClientFuncName(const grpc_generator::Method *method,
49                             grpc_generator::Printer *printer,
50                             std::map<grpc::string, grpc::string> *dictonary) {
51   auto vars = *dictonary;
52   if (method->NoStreaming()) {
53     printer->Print(vars,
54                    "  $GenAccess$func $MethodName$(\n"
55                    "    _ request: $Input$\n"
56                    "    , callOptions: CallOptions?$isNil$\n"
57                    "  ) -> UnaryCall<$Input$, $Output$>");
58     return;
59   }
60 
61   if (method->ServerStreaming()) {
62     printer->Print(vars,
63                    "  $GenAccess$func $MethodName$(\n"
64                    "    _ request: $Input$\n"
65                    "    , callOptions: CallOptions?$isNil$,\n"
66                    "    handler: @escaping ($Output$) -> Void\n"
67                    "  ) -> ServerStreamingCall<$Input$, $Output$>");
68     return;
69   }
70 
71   if (method->ClientStreaming()) {
72     printer->Print(vars,
73                    "  $GenAccess$func $MethodName$(\n"
74                    "    callOptions: CallOptions?$isNil$\n"
75                    "  ) -> ClientStreamingCall<$Input$, $Output$>");
76     return;
77   }
78 
79   printer->Print(vars,
80                  "  $GenAccess$func $MethodName$(\n"
81                  "    callOptions: CallOptions?$isNil$,\n"
82                  "    handler: @escaping ($Output$ ) -> Void\n"
83                  "  ) -> BidirectionalStreamingCall<$Input$, $Output$>");
84 }
85 
GenerateClientFuncBody(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)86 static void GenerateClientFuncBody(const grpc_generator::Method *method,
87                             grpc_generator::Printer *printer,
88                             std::map<grpc::string, grpc::string> *dictonary) {
89   auto vars = *dictonary;
90   vars["Interceptor"] =
91       "interceptors: self.interceptors?.make$MethodName$Interceptors() ?? []";
92   if (method->NoStreaming()) {
93     printer->Print(
94         vars,
95         "    return self.makeUnaryCall(\n"
96         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
97         "      request: request,\n"
98         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
99         "      $Interceptor$\n"
100         "    )\n");
101     return;
102   }
103 
104   if (method->ServerStreaming()) {
105     printer->Print(
106         vars,
107         "    return self.makeServerStreamingCall(\n"
108         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
109         "      request: request,\n"
110         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
111         "      $Interceptor$,\n"
112         "      handler: handler\n"
113         "    )\n");
114     return;
115   }
116 
117   if (method->ClientStreaming()) {
118     printer->Print(
119         vars,
120         "    return self.makeClientStreamingCall(\n"
121         "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
122         "      callOptions: callOptions ?? self.defaultCallOptions,\n"
123         "      $Interceptor$\n"
124         "    )\n");
125     return;
126   }
127   printer->Print(vars,
128                  "    return self.makeBidirectionalStreamingCall(\n"
129                  "      path: \"/$PATH$$ServiceName$/$MethodName$\",\n"
130                  "      callOptions: callOptions ?? self.defaultCallOptions,\n"
131                  "      $Interceptor$,\n"
132                  "      handler: handler\n"
133                  "    )\n");
134 }
135 
GenerateClientProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)136 void GenerateClientProtocol(const grpc_generator::Service *service,
137                             grpc_generator::Printer *printer,
138                             std::map<grpc::string, grpc::string> *dictonary) {
139   auto vars = *dictonary;
140   printer->Print(
141       vars,
142       "$ACCESS$ protocol $ServiceQualifiedName$ClientProtocol: GRPCClient {");
143   printer->Print("\n\n");
144   printer->Print("  var serviceName: String { get }");
145   printer->Print("\n\n");
146   printer->Print(
147       vars,
148       "  var interceptors: "
149       "$ServiceQualifiedName$ClientInterceptorFactoryProtocol? { get }");
150   printer->Print("\n\n");
151 
152   vars["GenAccess"] = "";
153   for (auto it = 0; it < service->method_count(); it++) {
154     auto method = service->method(it);
155     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
156                                     method->get_input_type_name());
157     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
158                                      method->get_output_type_name());
159     vars["MethodName"] = method->name();
160     vars["isNil"] = "";
161     GenerateClientFuncName(method.get(), &*printer, &vars);
162     printer->Print("\n\n");
163   }
164   printer->Print("}\n\n");
165 
166   printer->Print(vars, "extension $ServiceQualifiedName$ClientProtocol {");
167   printer->Print("\n\n");
168   printer->Print(vars,
169                  "  $ACCESS$ var serviceName: String { "
170                  "\"$PATH$$ServiceName$\" }\n");
171 
172   vars["GenAccess"] = service->is_internal() ? "internal " : "public ";
173   for (auto it = 0; it < service->method_count(); it++) {
174     auto method = service->method(it);
175     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
176                                     method->get_input_type_name());
177     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
178                                      method->get_output_type_name());
179     vars["MethodName"] = method->name();
180     vars["isNil"] = " = nil";
181     printer->Print("\n");
182     GenerateClientFuncName(method.get(), &*printer, &vars);
183     printer->Print(" {\n");
184     GenerateClientFuncBody(method.get(), &*printer, &vars);
185     printer->Print("  }\n");
186   }
187   printer->Print("}\n\n");
188 
189   printer->Print(vars,
190                  "$ACCESS$ protocol "
191                  "$ServiceQualifiedName$ClientInterceptorFactoryProtocol {\n");
192 
193   for (auto it = 0; it < service->method_count(); it++) {
194     auto method = service->method(it);
195     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
196                                     method->get_input_type_name());
197     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
198                                      method->get_output_type_name());
199     vars["MethodName"] = method->name();
200     printer->Print(
201         vars,
202         "  /// - Returns: Interceptors to use when invoking '$MethodName$'.\n");
203     printer->Print(vars,
204                    "  func make$MethodName$Interceptors() -> "
205                    "[ClientInterceptor<$Input$, $Output$>]\n\n");
206   }
207   printer->Print("}\n\n");
208 }
209 
GenerateClientClass(grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)210 void GenerateClientClass(grpc_generator::Printer *printer,
211                          std::map<grpc::string, grpc::string> *dictonary) {
212   auto vars = *dictonary;
213   printer->Print(vars,
214                  "$ACCESS$ final class $ServiceQualifiedName$ServiceClient: "
215                  "$ServiceQualifiedName$ClientProtocol {\n");
216   printer->Print(vars, "  $ACCESS$ let channel: GRPCChannel\n");
217   printer->Print(vars, "  $ACCESS$ var defaultCallOptions: CallOptions\n");
218   printer->Print(vars,
219                  "  $ACCESS$ var interceptors: "
220                  "$ServiceQualifiedName$ClientInterceptorFactoryProtocol?\n");
221   printer->Print("\n");
222   printer->Print(
223       vars,
224       "  $ACCESS$ init(\n"
225       "    channel: GRPCChannel,\n"
226       "    defaultCallOptions: CallOptions = CallOptions(),\n"
227       "    interceptors: "
228       "$ServiceQualifiedName$ClientInterceptorFactoryProtocol? = nil\n"
229       "  ) {\n");
230   printer->Print("    self.channel = channel\n");
231   printer->Print("    self.defaultCallOptions = defaultCallOptions\n");
232   printer->Print("    self.interceptors = interceptors\n");
233   printer->Print("  }");
234   printer->Print("\n");
235   printer->Print("}\n");
236 }
237 
238 // MARK: - Server
239 
GenerateServerFuncName(const grpc_generator::Method * method)240 grpc::string GenerateServerFuncName(const grpc_generator::Method *method) {
241   if (method->NoStreaming()) {
242     return "func $MethodName$(request: $Input$"
243            ", context: StatusOnlyCallContext) -> EventLoopFuture<$Output$>";
244   }
245 
246   if (method->ClientStreaming()) {
247     return "func $MethodName$(context: UnaryResponseCallContext<$Output$>) -> "
248            "EventLoopFuture<(StreamEvent<$Input$"
249            ">) -> Void>";
250   }
251 
252   if (method->ServerStreaming()) {
253     return "func $MethodName$(request: $Input$"
254            ", context: StreamingResponseCallContext<$Output$>) -> "
255            "EventLoopFuture<GRPCStatus>";
256   }
257   return "func $MethodName$(context: StreamingResponseCallContext<$Output$>) "
258          "-> EventLoopFuture<(StreamEvent<$Input$>) -> Void>";
259 }
260 
GenerateServerExtensionBody(const grpc_generator::Method * method)261 grpc::string GenerateServerExtensionBody(const grpc_generator::Method *method) {
262   grpc::string start = "    case \"$MethodName$\":\n    ";
263   grpc::string interceptors =
264       "      interceptors: self.interceptors?.make$MethodName$Interceptors() "
265       "?? [],\n";
266   if (method->NoStreaming()) {
267     return start +
268            "return UnaryServerHandler(\n"
269            "      context: context,\n"
270            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
271            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
272            interceptors +
273            "      userFunction: self.$MethodName$(request:context:))\n";
274   }
275   if (method->ServerStreaming()) {
276     return start +
277            "return ServerStreamingServerHandler(\n"
278            "      context: context,\n"
279            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
280            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
281            interceptors +
282            "      userFunction: self.$MethodName$(request:context:))\n";
283   }
284   if (method->ClientStreaming()) {
285     return start +
286            "return ClientStreamingServerHandler(\n"
287            "      context: context,\n"
288            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
289            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
290            interceptors +
291            "      observerFactory: self.$MethodName$(context:))\n";
292   }
293   if (method->BidiStreaming()) {
294     return start +
295            "return BidirectionalStreamingServerHandler(\n"
296            "      context: context,\n"
297            "      requestDeserializer: GRPCPayloadDeserializer<$Input$>(),\n"
298            "      responseSerializer: GRPCPayloadSerializer<$Output$>(),\n" +
299            interceptors +
300            "      observerFactory: self.$MethodName$(context:))\n";
301   }
302   return "";
303 }
304 
GenerateServerProtocol(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> * dictonary)305 void GenerateServerProtocol(const grpc_generator::Service *service,
306                             grpc_generator::Printer *printer,
307                             std::map<grpc::string, grpc::string> *dictonary) {
308   auto vars = *dictonary;
309   printer->Print(vars,
310                  "$ACCESS$ protocol $ServiceQualifiedName$Provider: "
311                  "CallHandlerProvider {\n");
312   printer->Print(
313       vars,
314       "  var interceptors: "
315       "$ServiceQualifiedName$ServerInterceptorFactoryProtocol? { get }\n");
316   for (auto it = 0; it < service->method_count(); it++) {
317     auto method = service->method(it);
318     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
319                                     method->get_input_type_name());
320     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
321                                      method->get_output_type_name());
322     vars["MethodName"] = method->name();
323     printer->Print("  ");
324     auto func = GenerateServerFuncName(method.get());
325     printer->Print(vars, func.c_str());
326     printer->Print("\n");
327   }
328   printer->Print("}\n\n");
329 
330   printer->Print(vars, "$ACCESS$ extension $ServiceQualifiedName$Provider {\n");
331   printer->Print("\n");
332   printer->Print(vars,
333                  "  var serviceName: Substring { return "
334                  "\"$PATH$$ServiceName$\" }\n");
335   printer->Print("\n");
336   printer->Print(
337       "  func handle(method name: Substring, context: "
338       "CallHandlerContext) -> GRPCServerHandlerProtocol? {\n");
339   printer->Print("    switch name {\n");
340   for (auto it = 0; it < service->method_count(); it++) {
341     auto method = service->method(it);
342     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
343                                     method->get_input_type_name());
344     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
345                                      method->get_output_type_name());
346     vars["MethodName"] = method->name();
347     auto body = GenerateServerExtensionBody(method.get());
348     printer->Print(vars, body.c_str());
349     printer->Print("\n");
350   }
351   printer->Print("    default: return nil;\n");
352   printer->Print("    }\n");
353   printer->Print("  }\n\n");
354   printer->Print("}\n\n");
355 
356   printer->Print(vars,
357                  "$ACCESS$ protocol "
358                  "$ServiceQualifiedName$ServerInterceptorFactoryProtocol {\n");
359   for (auto it = 0; it < service->method_count(); it++) {
360     auto method = service->method(it);
361     vars["Input"] = GenerateMessage(method->get_input_namespace_parts(),
362                                     method->get_input_type_name());
363     vars["Output"] = GenerateMessage(method->get_output_namespace_parts(),
364                                      method->get_output_type_name());
365     vars["MethodName"] = method->name();
366     printer->Print(
367         vars,
368         "  /// - Returns: Interceptors to use when handling '$MethodName$'.\n"
369         "  ///   Defaults to calling `self.makeInterceptors()`.\n");
370     printer->Print(vars,
371                    "  func make$MethodName$Interceptors() -> "
372                    "[ServerInterceptor<$Input$, $Output$>]\n\n");
373   }
374   printer->Print("}");
375 }
376 } // namespace
377 
Generate(grpc_generator::File * file,const grpc_generator::Service * service)378 grpc::string Generate(grpc_generator::File *file,
379                       const grpc_generator::Service *service) {
380   grpc::string output;
381   std::map<grpc::string, grpc::string> vars;
382   vars["PATH"] = file->package();
383   if (!file->package().empty()) { vars["PATH"].append("."); }
384   vars["ServiceQualifiedName"] =
385       WrapInNameSpace(service->namespace_parts(), service->name());
386   vars["ServiceName"] = service->name();
387   vars["ACCESS"] = service->is_internal() ? "internal" : "public";
388   auto printer = file->CreatePrinter(&output);
389   printer->Print(
390       vars,
391       "/// Usage: instantiate $ServiceQualifiedName$ServiceClient, then call "
392       "methods of this protocol to make API calls.\n");
393   GenerateClientProtocol(service, &*printer, &vars);
394   GenerateClientClass(&*printer, &vars);
395   printer->Print("\n");
396   GenerateServerProtocol(service, &*printer, &vars);
397   return output;
398 }
399 
GenerateHeader()400 grpc::string GenerateHeader() {
401   grpc::string code;
402   code +=
403       "/// The following code is generated by the Flatbuffers library which "
404       "might not be in sync with grpc-swift\n";
405   code +=
406       "/// in case of an issue please open github issue, though it would be "
407       "maintained\n";
408   code += "\n";
409   code += "// swiftlint:disable all\n";
410   code += "// swiftformat:disable all\n";
411   code += "\n";
412   code += "import Foundation\n";
413   code += "import GRPC\n";
414   code += "import NIO\n";
415   code += "import NIOHTTP1\n";
416   code += "import FlatBuffers\n";
417   code += "\n";
418   code +=
419       "public protocol GRPCFlatBufPayload: GRPCPayload, FlatBufferGRPCMessage "
420       "{}\n";
421 
422   code += "public extension GRPCFlatBufPayload {\n";
423   code += "  init(serializedByteBuffer: inout NIO.ByteBuffer) throws {\n";
424   code +=
425       "    self.init(byteBuffer: FlatBuffers.ByteBuffer(contiguousBytes: "
426       "serializedByteBuffer.readableBytesView, count: "
427       "serializedByteBuffer.readableBytes))\n";
428   code += "  }\n";
429 
430   code += "  func serialize(into buffer: inout NIO.ByteBuffer) throws {\n";
431   code +=
432       "    let buf = UnsafeRawBufferPointer(start: self.rawPointer, count: "
433       "Int(self.size))\n";
434   code += "    buffer.writeBytes(buf)\n";
435   code += "  }\n";
436   code += "}\n";
437   code += "extension Message: GRPCFlatBufPayload {}\n";
438   return code;
439 }
440 }  // namespace grpc_swift_generator
441