xref: /aosp_15_r20/external/flatbuffers/grpc/src/compiler/go_generator.cc (revision 890232f25432b36107d06881e0a25aaa6b473652)
1 #include "src/compiler/go_generator.h"
2 
3 #include <cctype>
4 #include <map>
5 #include <sstream>
6 
as_string(T x)7 template<class T> grpc::string as_string(T x) {
8   std::ostringstream out;
9   out << x;
10   return out.str();
11 }
12 
ClientOnlyStreaming(const grpc_generator::Method * method)13 inline bool ClientOnlyStreaming(const grpc_generator::Method *method) {
14   return method->ClientStreaming() && !method->ServerStreaming();
15 }
16 
ServerOnlyStreaming(const grpc_generator::Method * method)17 inline bool ServerOnlyStreaming(const grpc_generator::Method *method) {
18   return !method->ClientStreaming() && method->ServerStreaming();
19 }
20 
21 namespace grpc_go_generator {
22 namespace {
23 
24 // Returns string with first letter to lowerCase
unexportName(grpc::string s)25 static grpc::string unexportName(grpc::string s) {
26   if (s.empty()) return s;
27   s[0] = static_cast<char>(std::tolower(s[0]));
28   return s;
29 }
30 
31 // Returns string with first letter to uppercase
exportName(grpc::string s)32 static grpc::string exportName(grpc::string s) {
33   if (s.empty()) return s;
34   s[0] = static_cast<char>(std::toupper(s[0]));
35   return s;
36 }
37 
GenerateError(grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars,const bool multiple_return=true)38 static void GenerateError(grpc_generator::Printer *printer,
39                    std::map<grpc::string, grpc::string> vars,
40                    const bool multiple_return = true) {
41   printer->Print(vars, "if $Error_Check$ {\n");
42   printer->Indent();
43   vars["Return"] = multiple_return ? "nil, err" : "err";
44   printer->Print(vars, "return $Return$\n");
45   printer->Outdent();
46   printer->Print("}\n");
47 }
48 
49 // Generates imports for the service
GenerateImports(grpc_generator::File * file,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)50 static void GenerateImports(grpc_generator::File *file,
51                      grpc_generator::Printer *printer,
52                      std::map<grpc::string, grpc::string> vars) {
53   vars["filename"] = file->filename();
54   printer->Print("//Generated by gRPC Go plugin\n");
55   printer->Print("//If you make any local changes, they will be lost\n");
56   printer->Print(vars, "//source: $filename$\n\n");
57   printer->Print(vars, "package $Package$\n\n");
58   printer->Print("import (\n");
59   printer->Indent();
60   printer->Print(vars, "$context$ \"context\"\n");
61   printer->Print("flatbuffers \"github.com/google/flatbuffers/go\"\n");
62   printer->Print(vars, "$grpc$ \"google.golang.org/grpc\"\n");
63   printer->Print("\"google.golang.org/grpc/codes\"\n");
64   printer->Print("\"google.golang.org/grpc/status\"\n");
65   printer->Outdent();
66   printer->Print(")\n\n");
67 }
68 
69 // Generates Server method signature source
GenerateServerMethodSignature(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)70 static void GenerateServerMethodSignature(const grpc_generator::Method *method,
71                                    grpc_generator::Printer *printer,
72                                    std::map<grpc::string, grpc::string> vars) {
73   vars["Method"] = exportName(method->name());
74   vars["Request"] = method->get_input_type_name();
75   vars["Response"] = (vars["CustomMethodIO"] == "")
76                          ? method->get_output_type_name()
77                          : vars["CustomMethodIO"];
78   if (method->NoStreaming()) {
79     printer->Print(
80         vars,
81         "$Method$($context$.Context, *$Request$) (*$Response$, error)$Ending$");
82   } else if (ServerOnlyStreaming(method)) {
83     printer->Print(
84         vars, "$Method$(*$Request$, $Service$_$Method$Server) error$Ending$");
85   } else {
86     printer->Print(vars, "$Method$($Service$_$Method$Server) error$Ending$");
87   }
88 }
89 
GenerateServerMethod(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)90 static void GenerateServerMethod(const grpc_generator::Method *method,
91                           grpc_generator::Printer *printer,
92                           std::map<grpc::string, grpc::string> vars) {
93   vars["Method"] = exportName(method->name());
94   vars["Request"] = method->get_input_type_name();
95   vars["Response"] = (vars["CustomMethodIO"] == "")
96                          ? method->get_output_type_name()
97                          : vars["CustomMethodIO"];
98   vars["FullMethodName"] =
99       "/" + vars["ServicePrefix"] + vars["Service"] + "/" + vars["Method"];
100   vars["Handler"] = "_" + vars["Service"] + "_" + vars["Method"] + "_Handler";
101   if (method->NoStreaming()) {
102     printer->Print(
103         vars,
104         "func $Handler$(srv interface{}, ctx $context$.Context,\n\tdec "
105         "func(interface{}) error, interceptor $grpc$.UnaryServerInterceptor) "
106         "(interface{}, error) {\n");
107     printer->Indent();
108     printer->Print(vars, "in := new($Request$)\n");
109     vars["Error_Check"] = "err := dec(in); err != nil";
110     GenerateError(printer, vars);
111     printer->Print("if interceptor == nil {\n");
112     printer->Indent();
113     printer->Print(vars, "return srv.($Service$Server).$Method$(ctx, in)\n");
114     printer->Outdent();
115     printer->Print("}\n");
116     printer->Print(vars, "info := &$grpc$.UnaryServerInfo{\n");
117     printer->Indent();
118     printer->Print("Server:     srv,\n");
119     printer->Print(vars, "FullMethod: \"$FullMethodName$\",\n");
120     printer->Outdent();
121     printer->Print("}\n");
122     printer->Outdent();
123     printer->Print("\n");
124     printer->Indent();
125     printer->Print(vars,
126                    "handler := func(ctx $context$.Context, req interface{}) "
127                    "(interface{}, error) {\n");
128     printer->Indent();
129     printer->Print(
130         vars, "return srv.($Service$Server).$Method$(ctx, req.(*$Request$))\n");
131     printer->Outdent();
132     printer->Print("}\n");
133     printer->Print("return interceptor(ctx, in, info, handler)\n");
134     printer->Outdent();
135     printer->Print("}\n");
136     return;
137   }
138   vars["StreamType"] = vars["ServiceUnexported"] + vars["Method"] + "Server";
139   printer->Print(
140       vars,
141       "func $Handler$(srv interface{}, stream $grpc$.ServerStream) error {\n");
142   printer->Indent();
143   if (ServerOnlyStreaming(method)) {
144     printer->Print(vars, "m := new($Request$)\n");
145     vars["Error_Check"] = "err := stream.RecvMsg(m); err != nil";
146     GenerateError(printer, vars, false);
147     printer->Print(
148         vars,
149         "return srv.($Service$Server).$Method$(m, &$StreamType${stream})\n");
150   } else {
151     printer->Print(
152         vars, "return srv.($Service$Server).$Method$(&$StreamType${stream})\n");
153   }
154   printer->Outdent();
155   printer->Print("}\n\n");
156 
157   bool genSend = method->BidiStreaming() || ServerOnlyStreaming(method);
158   bool genRecv = method->BidiStreaming() || ClientOnlyStreaming(method);
159   bool genSendAndClose = ClientOnlyStreaming(method);
160 
161   printer->Print(vars, "type $Service$_$Method$Server interface {\n");
162   printer->Indent();
163   if (genSend) { printer->Print(vars, "Send(*$Response$) error\n"); }
164   if (genRecv) { printer->Print(vars, "Recv() (*$Request$, error)\n"); }
165   if (genSendAndClose) {
166     printer->Print(vars, "SendAndClose(*$Response$) error\n");
167   }
168   printer->Print(vars, "$grpc$.ServerStream\n");
169   printer->Outdent();
170   printer->Print("}\n\n");
171 
172   printer->Print(vars, "type $StreamType$ struct {\n");
173   printer->Indent();
174   printer->Print(vars, "$grpc$.ServerStream\n");
175   printer->Outdent();
176   printer->Print("}\n\n");
177 
178   if (genSend) {
179     printer->Print(vars,
180                    "func (x *$StreamType$) Send(m *$Response$) error {\n");
181     printer->Indent();
182     printer->Print("return x.ServerStream.SendMsg(m)\n");
183     printer->Outdent();
184     printer->Print("}\n\n");
185   }
186   if (genRecv) {
187     printer->Print(vars,
188                    "func (x *$StreamType$) Recv() (*$Request$, error) {\n");
189     printer->Indent();
190     printer->Print(vars, "m := new($Request$)\n");
191     vars["Error_Check"] = "err := x.ServerStream.RecvMsg(m); err != nil";
192     GenerateError(printer, vars);
193     printer->Print("return m, nil\n");
194     printer->Outdent();
195     printer->Print("}\n\n");
196   }
197   if (genSendAndClose) {
198     printer->Print(
199         vars, "func (x *$StreamType$) SendAndClose(m *$Response$) error {\n");
200     printer->Indent();
201     printer->Print("return x.ServerStream.SendMsg(m)\n");
202     printer->Outdent();
203     printer->Print("}\n\n");
204   }
205 }
206 
207 // Generates Client method signature source
GenerateClientMethodSignature(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)208 static void GenerateClientMethodSignature(const grpc_generator::Method *method,
209                                    grpc_generator::Printer *printer,
210                                    std::map<grpc::string, grpc::string> vars) {
211   vars["Method"] = exportName(method->name());
212   vars["Request"] =
213       ", in *" + ((vars["CustomMethodIO"] == "") ? method->get_input_type_name()
214                                                  : vars["CustomMethodIO"]);
215   if (ClientOnlyStreaming(method) || method->BidiStreaming()) {
216     vars["Request"] = "";
217   }
218   vars["Response"] = "*" + method->get_output_type_name();
219   if (ClientOnlyStreaming(method) || method->BidiStreaming() ||
220       ServerOnlyStreaming(method)) {
221     vars["Response"] = vars["Service"] + "_" + vars["Method"] + "Client";
222   }
223   printer->Print(vars,
224                  "$Method$(ctx $context$.Context$Request$,\n\topts "
225                  "...$grpc$.CallOption) ($Response$, error)$Ending$");
226 }
227 
228 // Generates Client method source
GenerateClientMethod(const grpc_generator::Method * method,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)229 static void GenerateClientMethod(const grpc_generator::Method *method,
230                           grpc_generator::Printer *printer,
231                           std::map<grpc::string, grpc::string> vars) {
232   printer->Print(vars, "func (c *$ServiceUnexported$Client) ");
233   vars["Ending"] = " {\n";
234   GenerateClientMethodSignature(method, printer, vars);
235   printer->Indent();
236   vars["Method"] = exportName(method->name());
237   vars["Request"] = (vars["CustomMethodIO"] == "")
238                         ? method->get_input_type_name()
239                         : vars["CustomMethodIO"];
240   vars["Response"] = method->get_output_type_name();
241   vars["FullMethodName"] =
242       "/" + vars["ServicePrefix"] + vars["Service"] + "/" + vars["Method"];
243   if (method->NoStreaming()) {
244     printer->Print(vars, "out := new($Response$)\n");
245     printer->Print(
246         vars,
247         "err := c.cc.Invoke(ctx, \"$FullMethodName$\", in, out, opts...)\n");
248     vars["Error_Check"] = "err != nil";
249     GenerateError(printer, vars);
250     printer->Print("return out, nil\n");
251     printer->Outdent();
252     printer->Print("}\n\n");
253     return;
254   }
255   vars["StreamType"] = vars["ServiceUnexported"] + vars["Method"] + "Client";
256   printer->Print(vars,
257                  "stream, err := c.cc.NewStream(ctx, &$MethodDesc$, "
258                  "\"$FullMethodName$\", opts...)\n");
259   vars["Error_Check"] = "err != nil";
260   GenerateError(printer, vars);
261 
262   printer->Print(vars, "x := &$StreamType${stream}\n");
263   if (ServerOnlyStreaming(method)) {
264     vars["Error_Check"] = "err := x.ClientStream.SendMsg(in); err != nil";
265     GenerateError(printer, vars);
266     vars["Error_Check"] = "err := x.ClientStream.CloseSend(); err != nil";
267     GenerateError(printer, vars);
268   }
269   printer->Print("return x, nil\n");
270   printer->Outdent();
271   printer->Print("}\n\n");
272 
273   bool genSend = method->BidiStreaming() || ClientOnlyStreaming(method);
274   bool genRecv = method->BidiStreaming() || ServerOnlyStreaming(method);
275   bool genCloseAndRecv = ClientOnlyStreaming(method);
276 
277   // Stream interface
278   printer->Print(vars, "type $Service$_$Method$Client interface {\n");
279   printer->Indent();
280   if (genSend) { printer->Print(vars, "Send(*$Request$) error\n"); }
281   if (genRecv) { printer->Print(vars, "Recv() (*$Response$, error)\n"); }
282   if (genCloseAndRecv) {
283     printer->Print(vars, "CloseAndRecv() (*$Response$, error)\n");
284   }
285   printer->Print(vars, "$grpc$.ClientStream\n");
286   printer->Outdent();
287   printer->Print("}\n\n");
288 
289   // Stream Client
290   printer->Print(vars, "type $StreamType$ struct {\n");
291   printer->Indent();
292   printer->Print(vars, "$grpc$.ClientStream\n");
293   printer->Outdent();
294   printer->Print("}\n\n");
295 
296   if (genSend) {
297     printer->Print(vars, "func (x *$StreamType$) Send(m *$Request$) error {\n");
298     printer->Indent();
299     printer->Print("return x.ClientStream.SendMsg(m)\n");
300     printer->Outdent();
301     printer->Print("}\n\n");
302   }
303 
304   if (genRecv) {
305     printer->Print(vars,
306                    "func (x *$StreamType$) Recv() (*$Response$, error) {\n");
307     printer->Indent();
308     printer->Print(vars, "m := new($Response$)\n");
309     vars["Error_Check"] = "err := x.ClientStream.RecvMsg(m); err != nil";
310     GenerateError(printer, vars);
311     printer->Print("return m, nil\n");
312     printer->Outdent();
313     printer->Print("}\n\n");
314   }
315 
316   if (genCloseAndRecv) {
317     printer->Print(
318         vars, "func (x *$StreamType$) CloseAndRecv() (*$Response$, error) {\n");
319     printer->Indent();
320     vars["Error_Check"] = "err := x.ClientStream.CloseSend(); err != nil";
321     GenerateError(printer, vars);
322     printer->Print(vars, "m := new($Response$)\n");
323     vars["Error_Check"] = "err := x.ClientStream.RecvMsg(m); err != nil";
324     GenerateError(printer, vars);
325     printer->Print("return m, nil\n");
326     printer->Outdent();
327     printer->Print("}\n\n");
328   }
329 }
330 
331 // Generates client API for the service
GenerateService(const grpc_generator::Service * service,grpc_generator::Printer * printer,std::map<grpc::string,grpc::string> vars)332 void GenerateService(const grpc_generator::Service *service,
333                      grpc_generator::Printer *printer,
334                      std::map<grpc::string, grpc::string> vars) {
335   vars["Service"] = exportName(service->name());
336   // Client Interface
337   printer->Print(vars, "// Client API for $Service$ service\n");
338   printer->Print(vars, "type $Service$Client interface {\n");
339   printer->Indent();
340   vars["Ending"] = "\n";
341   for (int i = 0; i < service->method_count(); i++) {
342     GenerateClientMethodSignature(service->method(i).get(), printer, vars);
343   }
344   printer->Outdent();
345   printer->Print("}\n\n");
346 
347   // Client structure
348   vars["ServiceUnexported"] = unexportName(vars["Service"]);
349   printer->Print(vars, "type $ServiceUnexported$Client struct {\n");
350   printer->Indent();
351   printer->Print(vars, "cc $grpc$.ClientConnInterface\n");
352   printer->Outdent();
353   printer->Print("}\n\n");
354 
355   // NewClient
356   printer->Print(vars,
357                  "func New$Service$Client(cc $grpc$.ClientConnInterface) "
358                  "$Service$Client {\n");
359   printer->Indent();
360   printer->Print(vars, "return &$ServiceUnexported$Client{cc}");
361   printer->Outdent();
362   printer->Print("\n}\n\n");
363 
364   int unary_methods = 0, streaming_methods = 0;
365   vars["ServiceDesc"] = "_" + vars["Service"] + "_serviceDesc";
366   for (int i = 0; i < service->method_count(); i++) {
367     auto method = service->method(i);
368     if (method->NoStreaming()) {
369       vars["MethodDesc"] =
370           vars["ServiceDesc"] + ".Method[" + as_string(unary_methods) + "]";
371       unary_methods++;
372     } else {
373       vars["MethodDesc"] = vars["ServiceDesc"] + ".Streams[" +
374                            as_string(streaming_methods) + "]";
375       streaming_methods++;
376     }
377     GenerateClientMethod(method.get(), printer, vars);
378   }
379 
380   // Server Interface
381   printer->Print(vars, "// Server API for $Service$ service\n");
382   printer->Print(vars, "type $Service$Server interface {\n");
383   printer->Indent();
384   vars["Ending"] = "\n";
385   for (int i = 0; i < service->method_count(); i++) {
386     GenerateServerMethodSignature(service->method(i).get(), printer, vars);
387   }
388   printer->Print(vars, "mustEmbedUnimplemented$Service$Server()\n");
389   printer->Outdent();
390   printer->Print("}\n\n");
391 
392   printer->Print(vars, "type Unimplemented$Service$Server struct {\n");
393   printer->Print("}\n\n");
394 
395   vars["Ending"] = " {\n";
396   for (int i = 0; i < service->method_count(); i++) {
397     auto method = service->method(i);
398     vars["Method"] = exportName(method->name());
399     vars["Nil"] = method->NoStreaming() ? "nil, " : "";
400     printer->Print(vars, "func (Unimplemented$Service$Server) ");
401     GenerateServerMethodSignature(method.get(), printer, vars);
402     printer->Indent();
403     printer->Print(vars,
404                    "return $Nil$status.Errorf(codes.Unimplemented, \"method "
405                    "$Method$ not implemented\")\n");
406     printer->Outdent();
407     printer->Print("}\n");
408     printer->Print("\n");
409   }
410 
411   printer->Print(vars,
412                  "func (Unimplemented$Service$Server) "
413                  "mustEmbedUnimplemented$Service$Server() {}");
414   printer->Print("\n\n");
415 
416   printer->Print(vars, "type Unsafe$Service$Server interface {\n");
417   printer->Indent();
418   printer->Print(vars, "mustEmbedUnimplemented$Service$Server()\n");
419   printer->Outdent();
420   printer->Print("}\n\n");
421   // Server registration.
422   printer->Print(vars,
423                  "func Register$Service$Server(s $grpc$.ServiceRegistrar, srv "
424                  "$Service$Server) {\n");
425   printer->Indent();
426   printer->Print(vars, "s.RegisterService(&$ServiceDesc$, srv)\n");
427   printer->Outdent();
428   printer->Print("}\n\n");
429 
430   for (int i = 0; i < service->method_count(); i++) {
431     GenerateServerMethod(service->method(i).get(), printer, vars);
432   }
433 
434   // Service Descriptor
435   printer->Print(vars, "var $ServiceDesc$ = $grpc$.ServiceDesc{\n");
436   printer->Indent();
437   printer->Print(vars, "ServiceName: \"$ServicePrefix$$Service$\",\n");
438   printer->Print(vars, "HandlerType: (*$Service$Server)(nil),\n");
439   printer->Print(vars, "Methods: []$grpc$.MethodDesc{\n");
440   printer->Indent();
441   for (int i = 0; i < service->method_count(); i++) {
442     auto method = service->method(i);
443     vars["Method"] = exportName(method->name());
444     vars["Handler"] = "_" + vars["Service"] + "_" + vars["Method"] + "_Handler";
445     if (method->NoStreaming()) {
446       printer->Print("{\n");
447       printer->Indent();
448       printer->Print(vars, "MethodName: \"$Method$\",\n");
449       printer->Print(vars, "Handler:    $Handler$,\n");
450       printer->Outdent();
451       printer->Print("},\n");
452     }
453   }
454   printer->Outdent();
455   printer->Print("},\n");
456   printer->Print(vars, "Streams: []$grpc$.StreamDesc{\n");
457   printer->Indent();
458   for (int i = 0; i < service->method_count(); i++) {
459     auto method = service->method(i);
460     vars["Method"] = exportName(method->name());
461     vars["Handler"] = "_" + vars["Service"] + "_" + vars["Method"] + "_Handler";
462     if (!method->NoStreaming()) {
463       printer->Print("{\n");
464       printer->Indent();
465       printer->Print(vars, "StreamName:    \"$Method$\",\n");
466       printer->Print(vars, "Handler:       $Handler$,\n");
467       if (ClientOnlyStreaming(method.get())) {
468         printer->Print("ClientStreams: true,\n");
469       } else if (ServerOnlyStreaming(method.get())) {
470         printer->Print("ServerStreams: true,\n");
471       } else {
472         printer->Print("ServerStreams: true,\n");
473         printer->Print("ClientStreams: true,\n");
474       }
475       printer->Outdent();
476       printer->Print("},\n");
477     }
478   }
479   printer->Outdent();
480   printer->Print("},\n");
481   printer->Outdent();
482   printer->Print("}\n");
483 }
484 }  // namespace
485 
486 // Returns source for the service
GenerateServiceSource(grpc_generator::File * file,const grpc_generator::Service * service,grpc_go_generator::Parameters * parameters)487 grpc::string GenerateServiceSource(grpc_generator::File *file,
488                                    const grpc_generator::Service *service,
489                                    grpc_go_generator::Parameters *parameters) {
490   grpc::string out;
491   auto p = file->CreatePrinter(&out, '\t');
492   p->SetIndentationSize(1);
493   auto printer = p.get();
494   std::map<grpc::string, grpc::string> vars;
495   vars["Package"] = parameters->package_name;
496   vars["ServicePrefix"] = parameters->service_prefix;
497   if (!parameters->service_prefix.empty()) vars["ServicePrefix"].append(".");
498   vars["grpc"] = "grpc";
499   vars["context"] = "context";
500   GenerateImports(file, printer, vars);
501   if (parameters->custom_method_io_type != "") {
502     vars["CustomMethodIO"] = parameters->custom_method_io_type;
503   }
504   GenerateService(service, printer, vars);
505   return out;
506 }
507 }  // Namespace grpc_go_generator
508