xref: /aosp_15_r20/external/flatbuffers/grpc/src/compiler/java_generator.cc (revision 890232f25432b36107d06881e0a25aaa6b473652)
1 /*
2  * Copyright 2016 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/compiler/java_generator.h"
18 
19 #include <algorithm>
20 #include <iostream>
21 #include <iterator>
22 #include <map>
23 #include <utility>
24 #include <vector>
25 
26 #include "flatbuffers/util.h"
27 #define to_string flatbuffers::NumToString
28 
29 // Stringify helpers used solely to cast GRPC_VERSION
30 #ifndef STR
31 #  define STR(s) #  s
32 #endif
33 
34 #ifndef XSTR
35 #  define XSTR(s) STR(s)
36 #endif
37 
38 typedef grpc_generator::Printer Printer;
39 typedef std::map<grpc::string, grpc::string> VARS;
40 typedef grpc_generator::Service ServiceDescriptor;
41 typedef grpc_generator::CommentHolder
42     DescriptorType;  // base class of all 'descriptors'
43 typedef grpc_generator::Method MethodDescriptor;
44 
45 namespace grpc_java_generator {
46 typedef std::string string;
47 namespace {
48 // Generates imports for the service
GenerateImports(grpc_generator::File * file,grpc_generator::Printer * printer,VARS & vars)49 static void GenerateImports(grpc_generator::File *file,
50                      grpc_generator::Printer *printer, VARS &vars) {
51   vars["filename"] = file->filename();
52   printer->Print(vars,
53                  "//Generated by flatc compiler (version $flatc_version$)\n");
54   printer->Print("//If you make any local changes, they will be lost\n");
55   printer->Print(vars, "//source: $filename$.fbs\n\n");
56   printer->Print(vars, "package $Package$;\n\n");
57   vars["Package"] = vars["Package"] + ".";
58   if (!file->additional_headers().empty()) {
59     printer->Print(file->additional_headers().c_str());
60     printer->Print("\n\n");
61   }
62 }
63 
64 // Adjust a method name prefix identifier to follow the JavaBean spec:
65 //   - decapitalize the first letter
66 //   - remove embedded underscores & capitalize the following letter
MixedLower(const string & word)67 static string MixedLower(const string &word) {
68   string w;
69   w += static_cast<string::value_type>(tolower(word[0]));
70   bool after_underscore = false;
71   for (size_t i = 1; i < word.length(); ++i) {
72     if (word[i] == '_') {
73       after_underscore = true;
74     } else {
75       w += after_underscore ? static_cast<string::value_type>(toupper(word[i]))
76                             : word[i];
77       after_underscore = false;
78     }
79   }
80   return w;
81 }
82 
83 // Converts to the identifier to the ALL_UPPER_CASE format.
84 //   - An underscore is inserted where a lower case letter is followed by an
85 //     upper case letter.
86 //   - All letters are converted to upper case
ToAllUpperCase(const string & word)87 static string ToAllUpperCase(const string &word) {
88   string w;
89   for (size_t i = 0; i < word.length(); ++i) {
90     w += static_cast<string::value_type>(toupper(word[i]));
91     if ((i < word.length() - 1) && islower(word[i]) && isupper(word[i + 1])) {
92       w += '_';
93     }
94   }
95   return w;
96 }
97 
LowerMethodName(const MethodDescriptor * method)98 static inline string LowerMethodName(const MethodDescriptor *method) {
99   return MixedLower(method->name());
100 }
101 
MethodPropertiesFieldName(const MethodDescriptor * method)102 static inline string MethodPropertiesFieldName(const MethodDescriptor *method) {
103   return "METHOD_" + ToAllUpperCase(method->name());
104 }
105 
MethodPropertiesGetterName(const MethodDescriptor * method)106 static inline string MethodPropertiesGetterName(
107     const MethodDescriptor *method) {
108   return MixedLower("get_" + method->name() + "_method");
109 }
110 
MethodIdFieldName(const MethodDescriptor * method)111 static inline string MethodIdFieldName(const MethodDescriptor *method) {
112   return "METHODID_" + ToAllUpperCase(method->name());
113 }
114 
JavaClassName(VARS & vars,const string & name)115 static inline string JavaClassName(VARS &vars, const string &name) {
116   // string name = google::protobuf::compiler::java::ClassName(desc);
117   return vars["Package"] + name;
118 }
119 
ServiceClassName(const string & service_name)120 static inline string ServiceClassName(const string &service_name) {
121   return service_name + "Grpc";
122 }
123 
124 // TODO(nmittler): Remove once protobuf includes javadoc methods in
125 // distribution.
126 template<typename ITR>
GrpcSplitStringToIteratorUsing(const string & full,const char * delim,ITR & result)127 static void GrpcSplitStringToIteratorUsing(const string &full,
128                                            const char *delim, ITR &result) {
129   // Optimize the common case where delim is a single character.
130   if (delim[0] != '\0' && delim[1] == '\0') {
131     char c = delim[0];
132     const char *p = full.data();
133     const char *end = p + full.size();
134     while (p != end) {
135       if (*p == c) {
136         ++p;
137       } else {
138         const char *start = p;
139         while (++p != end && *p != c)
140           ;
141         *result++ = string(start, p - start);
142       }
143     }
144     return;
145   }
146 
147   string::size_type begin_index, end_index;
148   begin_index = full.find_first_not_of(delim);
149   while (begin_index != string::npos) {
150     end_index = full.find_first_of(delim, begin_index);
151     if (end_index == string::npos) {
152       *result++ = full.substr(begin_index);
153       return;
154     }
155     *result++ = full.substr(begin_index, (end_index - begin_index));
156     begin_index = full.find_first_not_of(delim, end_index);
157   }
158 }
159 
GrpcSplitStringUsing(const string & full,const char * delim,std::vector<string> * result)160 static void GrpcSplitStringUsing(const string &full, const char *delim,
161                                  std::vector<string> *result) {
162   std::back_insert_iterator<std::vector<string>> it(*result);
163   GrpcSplitStringToIteratorUsing(full, delim, it);
164 }
165 
GrpcSplit(const string & full,const char * delim)166 static std::vector<string> GrpcSplit(const string &full, const char *delim) {
167   std::vector<string> result;
168   GrpcSplitStringUsing(full, delim, &result);
169   return result;
170 }
171 
172 // TODO(nmittler): Remove once protobuf includes javadoc methods in
173 // distribution.
GrpcEscapeJavadoc(const string & input)174 static string GrpcEscapeJavadoc(const string &input) {
175   string result;
176   result.reserve(input.size() * 2);
177 
178   char prev = '*';
179 
180   for (string::size_type i = 0; i < input.size(); i++) {
181     char c = input[i];
182     switch (c) {
183       case '*':
184         // Avoid "/*".
185         if (prev == '/') {
186           result.append("&#42;");
187         } else {
188           result.push_back(c);
189         }
190         break;
191       case '/':
192         // Avoid "*/".
193         if (prev == '*') {
194           result.append("&#47;");
195         } else {
196           result.push_back(c);
197         }
198         break;
199       case '@':
200         // '@' starts javadoc tags including the @deprecated tag, which will
201         // cause a compile-time error if inserted before a declaration that
202         // does not have a corresponding @Deprecated annotation.
203         result.append("&#64;");
204         break;
205       case '<':
206         // Avoid interpretation as HTML.
207         result.append("&lt;");
208         break;
209       case '>':
210         // Avoid interpretation as HTML.
211         result.append("&gt;");
212         break;
213       case '&':
214         // Avoid interpretation as HTML.
215         result.append("&amp;");
216         break;
217       case '\\':
218         // Java interprets Unicode escape sequences anywhere!
219         result.append("&#92;");
220         break;
221       default: result.push_back(c); break;
222     }
223 
224     prev = c;
225   }
226 
227   return result;
228 }
229 
GrpcGetDocLines(const string & comments)230 static std::vector<string> GrpcGetDocLines(const string &comments) {
231   if (!comments.empty()) {
232     // TODO(kenton):  Ideally we should parse the comment text as Markdown and
233     //   write it back as HTML, but this requires a Markdown parser.  For now
234     //   we just use <pre> to get fixed-width text formatting.
235 
236     // If the comment itself contains block comment start or end markers,
237     // HTML-escape them so that they don't accidentally close the doc comment.
238     string escapedComments = GrpcEscapeJavadoc(comments);
239 
240     std::vector<string> lines = GrpcSplit(escapedComments, "\n");
241     while (!lines.empty() && lines.back().empty()) { lines.pop_back(); }
242     return lines;
243   }
244   return std::vector<string>();
245 }
246 
GrpcGetDocLinesForDescriptor(const DescriptorType * descriptor)247 static std::vector<string> GrpcGetDocLinesForDescriptor(
248     const DescriptorType *descriptor) {
249   return descriptor->GetAllComments();
250   // return GrpcGetDocLines(descriptor->GetLeadingComments("///"));
251 }
252 
GrpcWriteDocCommentBody(Printer * printer,VARS & vars,const std::vector<string> & lines,bool surroundWithPreTag)253 static void GrpcWriteDocCommentBody(Printer *printer, VARS &vars,
254                                     const std::vector<string> &lines,
255                                     bool surroundWithPreTag) {
256   if (!lines.empty()) {
257     if (surroundWithPreTag) { printer->Print(" * <pre>\n"); }
258 
259     for (size_t i = 0; i < lines.size(); i++) {
260       // Most lines should start with a space.  Watch out for lines that start
261       // with a /, since putting that right after the leading asterisk will
262       // close the comment.
263       vars["line"] = lines[i];
264       if (!lines[i].empty() && lines[i][0] == '/') {
265         printer->Print(vars, " * $line$\n");
266       } else {
267         printer->Print(vars, " *$line$\n");
268       }
269     }
270 
271     if (surroundWithPreTag) { printer->Print(" * </pre>\n"); }
272   }
273 }
274 
GrpcWriteDocComment(Printer * printer,VARS & vars,const string & comments)275 static void GrpcWriteDocComment(Printer *printer, VARS &vars,
276                                 const string &comments) {
277   printer->Print("/**\n");
278   std::vector<string> lines = GrpcGetDocLines(comments);
279   GrpcWriteDocCommentBody(printer, vars, lines, false);
280   printer->Print(" */\n");
281 }
282 
GrpcWriteServiceDocComment(Printer * printer,VARS & vars,const ServiceDescriptor * service)283 static void GrpcWriteServiceDocComment(Printer *printer, VARS &vars,
284                                        const ServiceDescriptor *service) {
285   printer->Print("/**\n");
286   std::vector<string> lines = GrpcGetDocLinesForDescriptor(service);
287   GrpcWriteDocCommentBody(printer, vars, lines, true);
288   printer->Print(" */\n");
289 }
290 
GrpcWriteMethodDocComment(Printer * printer,VARS & vars,const MethodDescriptor * method)291 static void GrpcWriteMethodDocComment(Printer *printer, VARS &vars,
292                                const MethodDescriptor *method) {
293   printer->Print("/**\n");
294   std::vector<string> lines = GrpcGetDocLinesForDescriptor(method);
295   GrpcWriteDocCommentBody(printer, vars, lines, true);
296   printer->Print(" */\n");
297 }
298 
299 // outputs static singleton extractor for type stored in "extr_type" and
300 // "extr_type_name" vars
PrintTypeExtractor(Printer * p,VARS & vars)301 static void PrintTypeExtractor(Printer *p, VARS &vars) {
302   p->Print(vars,
303            "private static volatile FlatbuffersUtils.FBExtactor<$extr_type$> "
304            "extractorOf$extr_type_name$;\n"
305            "private static FlatbuffersUtils.FBExtactor<$extr_type$> "
306            "getExtractorOf$extr_type_name$() {\n"
307            "    if (extractorOf$extr_type_name$ != null) return "
308            "extractorOf$extr_type_name$;\n"
309            "    synchronized ($service_class_name$.class) {\n"
310            "        if (extractorOf$extr_type_name$ != null) return "
311            "extractorOf$extr_type_name$;\n"
312            "        extractorOf$extr_type_name$ = new "
313            "FlatbuffersUtils.FBExtactor<$extr_type$>() {\n"
314            "            public $extr_type$ extract (ByteBuffer buffer) {\n"
315            "                return "
316            "$extr_type$.getRootAs$extr_type_name$(buffer);\n"
317            "            }\n"
318            "        };\n"
319            "        return extractorOf$extr_type_name$;\n"
320            "    }\n"
321            "}\n\n");
322 }
PrintMethodFields(Printer * p,VARS & vars,const ServiceDescriptor * service)323 static void PrintMethodFields(Printer *p, VARS &vars,
324                               const ServiceDescriptor *service) {
325   p->Print("// Static method descriptors that strictly reflect the proto.\n");
326   vars["service_name"] = service->name();
327 
328   // set of names of rpc input- and output- types that were already encountered.
329   // this is needed to avoid duplicating type extractor since it's possible that
330   // the same type is used as an input or output type of more than a single RPC
331   // method
332   std::set<std::string> encounteredTypes;
333 
334   for (int i = 0; i < service->method_count(); ++i) {
335     auto method = service->method(i);
336     vars["arg_in_id"] = to_string(2L * i);  // trying to make msvc 10 happy
337     vars["arg_out_id"] = to_string(2L * i + 1);
338     vars["method_name"] = method->name();
339     vars["input_type_name"] = method->get_input_type_name();
340     vars["output_type_name"] = method->get_output_type_name();
341     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
342     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
343     vars["method_field_name"] = MethodPropertiesFieldName(method.get());
344     vars["method_new_field_name"] = MethodPropertiesGetterName(method.get());
345     vars["method_method_name"] = MethodPropertiesGetterName(method.get());
346     bool client_streaming =
347         method->ClientStreaming() || method->BidiStreaming();
348     bool server_streaming =
349         method->ServerStreaming() || method->BidiStreaming();
350     if (client_streaming) {
351       if (server_streaming) {
352         vars["method_type"] = "BIDI_STREAMING";
353       } else {
354         vars["method_type"] = "CLIENT_STREAMING";
355       }
356     } else {
357       if (server_streaming) {
358         vars["method_type"] = "SERVER_STREAMING";
359       } else {
360         vars["method_type"] = "UNARY";
361       }
362     }
363 
364     p->Print(
365         vars,
366         "@$ExperimentalApi$(\"https://github.com/grpc/grpc-java/issues/"
367         "1901\")\n"
368         "@$Deprecated$ // Use {@link #$method_method_name$()} instead. \n"
369         "public static final $MethodDescriptor$<$input_type$,\n"
370         "    $output_type$> $method_field_name$ = $method_method_name$();\n"
371         "\n"
372         "private static volatile $MethodDescriptor$<$input_type$,\n"
373         "    $output_type$> $method_new_field_name$;\n"
374         "\n");
375 
376     if (encounteredTypes.insert(vars["input_type_name"]).second) {
377       vars["extr_type"] = vars["input_type"];
378       vars["extr_type_name"] = vars["input_type_name"];
379       PrintTypeExtractor(p, vars);
380     }
381 
382     if (encounteredTypes.insert(vars["output_type_name"]).second) {
383       vars["extr_type"] = vars["output_type"];
384       vars["extr_type_name"] = vars["output_type_name"];
385       PrintTypeExtractor(p, vars);
386     }
387 
388     p->Print(
389         vars,
390         "@$ExperimentalApi$(\"https://github.com/grpc/grpc-java/issues/"
391         "1901\")\n"
392         "public static $MethodDescriptor$<$input_type$,\n"
393         "    $output_type$> $method_method_name$() {\n"
394         "  $MethodDescriptor$<$input_type$, $output_type$> "
395         "$method_new_field_name$;\n"
396         "  if (($method_new_field_name$ = "
397         "$service_class_name$.$method_new_field_name$) == null) {\n"
398         "    synchronized ($service_class_name$.class) {\n"
399         "      if (($method_new_field_name$ = "
400         "$service_class_name$.$method_new_field_name$) == null) {\n"
401         "        $service_class_name$.$method_new_field_name$ = "
402         "$method_new_field_name$ = \n"
403         "            $MethodDescriptor$.<$input_type$, "
404         "$output_type$>newBuilder()\n"
405         "            .setType($MethodType$.$method_type$)\n"
406         "            .setFullMethodName(generateFullMethodName(\n"
407         "                \"$Package$$service_name$\", \"$method_name$\"))\n"
408         "            .setSampledToLocalTracing(true)\n"
409         "            .setRequestMarshaller(FlatbuffersUtils.marshaller(\n"
410         "                $input_type$.class, "
411         "getExtractorOf$input_type_name$()))\n"
412         "            .setResponseMarshaller(FlatbuffersUtils.marshaller(\n"
413         "                $output_type$.class, "
414         "getExtractorOf$output_type_name$()))\n");
415 
416     //            vars["proto_method_descriptor_supplier"] = service->name() +
417     //            "MethodDescriptorSupplier";
418     p->Print(vars, "                .setSchemaDescriptor(null)\n");
419     //"                .setSchemaDescriptor(new
420     //$proto_method_descriptor_supplier$(\"$method_name$\"))\n");
421 
422     p->Print(vars, "                .build();\n");
423     p->Print(vars,
424              "        }\n"
425              "      }\n"
426              "   }\n"
427              "   return $method_new_field_name$;\n"
428              "}\n");
429 
430     p->Print("\n");
431   }
432 }
433 enum StubType {
434   ASYNC_INTERFACE = 0,
435   BLOCKING_CLIENT_INTERFACE = 1,
436   FUTURE_CLIENT_INTERFACE = 2,
437   BLOCKING_SERVER_INTERFACE = 3,
438   ASYNC_CLIENT_IMPL = 4,
439   BLOCKING_CLIENT_IMPL = 5,
440   FUTURE_CLIENT_IMPL = 6,
441   ABSTRACT_CLASS = 7,
442 };
443 
444 enum CallType { ASYNC_CALL = 0, BLOCKING_CALL = 1, FUTURE_CALL = 2 };
445 
446 static void PrintBindServiceMethodBody(Printer *p, VARS &vars,
447                                        const ServiceDescriptor *service);
448 
449 // Prints a client interface or implementation class, or a server interface.
PrintStub(Printer * p,VARS & vars,const ServiceDescriptor * service,StubType type)450 static void PrintStub(Printer *p, VARS &vars, const ServiceDescriptor *service,
451                       StubType type) {
452   const string service_name = service->name();
453   vars["service_name"] = service_name;
454   vars["abstract_name"] = service_name + "ImplBase";
455   string stub_name = service_name;
456   string client_name = service_name;
457   CallType call_type = ASYNC_CALL;
458   bool impl_base = false;
459   bool interface = false;
460   switch (type) {
461     case ABSTRACT_CLASS:
462       call_type = ASYNC_CALL;
463       impl_base = true;
464       break;
465     case ASYNC_CLIENT_IMPL:
466       call_type = ASYNC_CALL;
467       stub_name += "Stub";
468       break;
469     case BLOCKING_CLIENT_INTERFACE:
470       interface = true;
471       FLATBUFFERS_FALLTHROUGH();  // fall thru
472     case BLOCKING_CLIENT_IMPL:
473       call_type = BLOCKING_CALL;
474       stub_name += "BlockingStub";
475       client_name += "BlockingClient";
476       break;
477     case FUTURE_CLIENT_INTERFACE:
478       interface = true;
479       FLATBUFFERS_FALLTHROUGH();  // fall thru
480     case FUTURE_CLIENT_IMPL:
481       call_type = FUTURE_CALL;
482       stub_name += "FutureStub";
483       client_name += "FutureClient";
484       break;
485     case ASYNC_INTERFACE:
486       call_type = ASYNC_CALL;
487       interface = true;
488       break;
489     default:
490       GRPC_CODEGEN_FAIL << "Cannot determine class name for StubType: " << type;
491   }
492   vars["stub_name"] = stub_name;
493   vars["client_name"] = client_name;
494 
495   // Class head
496   if (!interface) { GrpcWriteServiceDocComment(p, vars, service); }
497   if (impl_base) {
498     p->Print(vars,
499              "public static abstract class $abstract_name$ implements "
500              "$BindableService$ {\n");
501   } else {
502     p->Print(vars,
503              "public static final class $stub_name$ extends "
504              "$AbstractStub$<$stub_name$> {\n");
505   }
506   p->Indent();
507 
508   // Constructor and build() method
509   if (!impl_base && !interface) {
510     p->Print(vars, "private $stub_name$($Channel$ channel) {\n");
511     p->Indent();
512     p->Print("super(channel);\n");
513     p->Outdent();
514     p->Print("}\n\n");
515     p->Print(vars,
516              "private $stub_name$($Channel$ channel,\n"
517              "    $CallOptions$ callOptions) {\n");
518     p->Indent();
519     p->Print("super(channel, callOptions);\n");
520     p->Outdent();
521     p->Print("}\n\n");
522     p->Print(vars,
523              "@$Override$\n"
524              "protected $stub_name$ build($Channel$ channel,\n"
525              "    $CallOptions$ callOptions) {\n");
526     p->Indent();
527     p->Print(vars, "return new $stub_name$(channel, callOptions);\n");
528     p->Outdent();
529     p->Print("}\n");
530   }
531 
532   // RPC methods
533   for (int i = 0; i < service->method_count(); ++i) {
534     auto method = service->method(i);
535     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
536     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
537     vars["lower_method_name"] = LowerMethodName(&*method);
538     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
539     bool client_streaming =
540         method->ClientStreaming() || method->BidiStreaming();
541     bool server_streaming =
542         method->ServerStreaming() || method->BidiStreaming();
543 
544     if (call_type == BLOCKING_CALL && client_streaming) {
545       // Blocking client interface with client streaming is not available
546       continue;
547     }
548 
549     if (call_type == FUTURE_CALL && (client_streaming || server_streaming)) {
550       // Future interface doesn't support streaming.
551       continue;
552     }
553 
554     // Method signature
555     p->Print("\n");
556     // TODO(nmittler): Replace with WriteMethodDocComment once included by the
557     // protobuf distro.
558     if (!interface) { GrpcWriteMethodDocComment(p, vars, &*method); }
559     p->Print("public ");
560     switch (call_type) {
561       case BLOCKING_CALL:
562         GRPC_CODEGEN_CHECK(!client_streaming)
563             << "Blocking client interface with client streaming is unavailable";
564         if (server_streaming) {
565           // Server streaming
566           p->Print(vars,
567                    "$Iterator$<$output_type$> $lower_method_name$(\n"
568                    "    $input_type$ request)");
569         } else {
570           // Simple RPC
571           p->Print(vars,
572                    "$output_type$ $lower_method_name$($input_type$ request)");
573         }
574         break;
575       case ASYNC_CALL:
576         if (client_streaming) {
577           // Bidirectional streaming or client streaming
578           p->Print(vars,
579                    "$StreamObserver$<$input_type$> $lower_method_name$(\n"
580                    "    $StreamObserver$<$output_type$> responseObserver)");
581         } else {
582           // Server streaming or simple RPC
583           p->Print(vars,
584                    "void $lower_method_name$($input_type$ request,\n"
585                    "    $StreamObserver$<$output_type$> responseObserver)");
586         }
587         break;
588       case FUTURE_CALL:
589         GRPC_CODEGEN_CHECK(!client_streaming && !server_streaming)
590             << "Future interface doesn't support streaming. "
591             << "client_streaming=" << client_streaming << ", "
592             << "server_streaming=" << server_streaming;
593         p->Print(vars,
594                  "$ListenableFuture$<$output_type$> $lower_method_name$(\n"
595                  "    $input_type$ request)");
596         break;
597     }
598 
599     if (interface) {
600       p->Print(";\n");
601       continue;
602     }
603     // Method body.
604     p->Print(" {\n");
605     p->Indent();
606     if (impl_base) {
607       switch (call_type) {
608           // NB: Skipping validation of service methods. If something is wrong,
609           // we wouldn't get to this point as compiler would return errors when
610           // generating service interface.
611         case ASYNC_CALL:
612           if (client_streaming) {
613             p->Print(vars,
614                      "return "
615                      "asyncUnimplementedStreamingCall($method_method_name$(), "
616                      "responseObserver);\n");
617           } else {
618             p->Print(vars,
619                      "asyncUnimplementedUnaryCall($method_method_name$(), "
620                      "responseObserver);\n");
621           }
622           break;
623         default: break;
624       }
625     } else if (!interface) {
626       switch (call_type) {
627         case BLOCKING_CALL:
628           GRPC_CODEGEN_CHECK(!client_streaming)
629               << "Blocking client streaming interface is not available";
630           if (server_streaming) {
631             vars["calls_method"] = "blockingServerStreamingCall";
632             vars["params"] = "request";
633           } else {
634             vars["calls_method"] = "blockingUnaryCall";
635             vars["params"] = "request";
636           }
637           p->Print(vars,
638                    "return $calls_method$(\n"
639                    "    getChannel(), $method_method_name$(), "
640                    "getCallOptions(), $params$);\n");
641           break;
642         case ASYNC_CALL:
643           if (server_streaming) {
644             if (client_streaming) {
645               vars["calls_method"] = "asyncBidiStreamingCall";
646               vars["params"] = "responseObserver";
647             } else {
648               vars["calls_method"] = "asyncServerStreamingCall";
649               vars["params"] = "request, responseObserver";
650             }
651           } else {
652             if (client_streaming) {
653               vars["calls_method"] = "asyncClientStreamingCall";
654               vars["params"] = "responseObserver";
655             } else {
656               vars["calls_method"] = "asyncUnaryCall";
657               vars["params"] = "request, responseObserver";
658             }
659           }
660           vars["last_line_prefix"] = client_streaming ? "return " : "";
661           p->Print(vars,
662                    "$last_line_prefix$$calls_method$(\n"
663                    "    getChannel().newCall($method_method_name$(), "
664                    "getCallOptions()), $params$);\n");
665           break;
666         case FUTURE_CALL:
667           GRPC_CODEGEN_CHECK(!client_streaming && !server_streaming)
668               << "Future interface doesn't support streaming. "
669               << "client_streaming=" << client_streaming << ", "
670               << "server_streaming=" << server_streaming;
671           vars["calls_method"] = "futureUnaryCall";
672           p->Print(vars,
673                    "return $calls_method$(\n"
674                    "    getChannel().newCall($method_method_name$(), "
675                    "getCallOptions()), request);\n");
676           break;
677       }
678     }
679     p->Outdent();
680     p->Print("}\n");
681   }
682 
683   if (impl_base) {
684     p->Print("\n");
685     p->Print(
686         vars,
687         "@$Override$ public final $ServerServiceDefinition$ bindService() {\n");
688     vars["instance"] = "this";
689     PrintBindServiceMethodBody(p, vars, service);
690     p->Print("}\n");
691   }
692 
693   p->Outdent();
694   p->Print("}\n\n");
695 }
696 
CompareMethodClientStreaming(const std::unique_ptr<const grpc_generator::Method> & method1,const std::unique_ptr<const grpc_generator::Method> & method2)697 static bool CompareMethodClientStreaming(
698     const std::unique_ptr<const grpc_generator::Method> &method1,
699     const std::unique_ptr<const grpc_generator::Method> &method2) {
700   return method1->ClientStreaming() < method2->ClientStreaming();
701 }
702 
703 // Place all method invocations into a single class to reduce memory footprint
704 // on Android.
PrintMethodHandlerClass(Printer * p,VARS & vars,const ServiceDescriptor * service)705 static void PrintMethodHandlerClass(Printer *p, VARS &vars,
706                                     const ServiceDescriptor *service) {
707   // Sort method ids based on ClientStreaming() so switch tables are compact.
708   std::vector<std::unique_ptr<const grpc_generator::Method>> sorted_methods(
709       service->method_count());
710   for (int i = 0; i < service->method_count(); ++i) {
711     sorted_methods[i] = service->method(i);
712   }
713   stable_sort(sorted_methods.begin(), sorted_methods.end(),
714               CompareMethodClientStreaming);
715   for (size_t i = 0; i < sorted_methods.size(); i++) {
716     auto &method = sorted_methods[i];
717     vars["method_id"] = to_string(i);
718     vars["method_id_name"] = MethodIdFieldName(&*method);
719     p->Print(vars,
720              "private static final int $method_id_name$ = $method_id$;\n");
721   }
722   p->Print("\n");
723   vars["service_name"] = service->name() + "ImplBase";
724   p->Print(vars,
725            "private static final class MethodHandlers<Req, Resp> implements\n"
726            "    io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,\n"
727            "    io.grpc.stub.ServerCalls.ServerStreamingMethod<Req, Resp>,\n"
728            "    io.grpc.stub.ServerCalls.ClientStreamingMethod<Req, Resp>,\n"
729            "    io.grpc.stub.ServerCalls.BidiStreamingMethod<Req, Resp> {\n"
730            "  private final $service_name$ serviceImpl;\n"
731            "  private final int methodId;\n"
732            "\n"
733            "  MethodHandlers($service_name$ serviceImpl, int methodId) {\n"
734            "    this.serviceImpl = serviceImpl;\n"
735            "    this.methodId = methodId;\n"
736            "  }\n\n");
737   p->Indent();
738   p->Print(vars,
739            "@$Override$\n"
740            "@java.lang.SuppressWarnings(\"unchecked\")\n"
741            "public void invoke(Req request, $StreamObserver$<Resp> "
742            "responseObserver) {\n"
743            "  switch (methodId) {\n");
744   p->Indent();
745   p->Indent();
746 
747   for (int i = 0; i < service->method_count(); ++i) {
748     auto method = service->method(i);
749     if (method->ClientStreaming() || method->BidiStreaming()) { continue; }
750     vars["method_id_name"] = MethodIdFieldName(&*method);
751     vars["lower_method_name"] = LowerMethodName(&*method);
752     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
753     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
754     p->Print(vars,
755              "case $method_id_name$:\n"
756              "  serviceImpl.$lower_method_name$(($input_type$) request,\n"
757              "      ($StreamObserver$<$output_type$>) responseObserver);\n"
758              "  break;\n");
759   }
760   p->Print(
761       "default:\n"
762       "  throw new AssertionError();\n");
763 
764   p->Outdent();
765   p->Outdent();
766   p->Print(
767       "  }\n"
768       "}\n\n");
769 
770   p->Print(vars,
771            "@$Override$\n"
772            "@java.lang.SuppressWarnings(\"unchecked\")\n"
773            "public $StreamObserver$<Req> invoke(\n"
774            "    $StreamObserver$<Resp> responseObserver) {\n"
775            "  switch (methodId) {\n");
776   p->Indent();
777   p->Indent();
778 
779   for (int i = 0; i < service->method_count(); ++i) {
780     auto method = service->method(i);
781     if (!(method->ClientStreaming() || method->BidiStreaming())) { continue; }
782     vars["method_id_name"] = MethodIdFieldName(&*method);
783     vars["lower_method_name"] = LowerMethodName(&*method);
784     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
785     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
786     p->Print(
787         vars,
788         "case $method_id_name$:\n"
789         "  return ($StreamObserver$<Req>) serviceImpl.$lower_method_name$(\n"
790         "      ($StreamObserver$<$output_type$>) responseObserver);\n");
791   }
792   p->Print(
793       "default:\n"
794       "  throw new AssertionError();\n");
795 
796   p->Outdent();
797   p->Outdent();
798   p->Print(
799       "  }\n"
800       "}\n");
801 
802   p->Outdent();
803   p->Print("}\n\n");
804 }
805 
PrintGetServiceDescriptorMethod(Printer * p,VARS & vars,const ServiceDescriptor * service)806 static void PrintGetServiceDescriptorMethod(Printer *p, VARS &vars,
807                                             const ServiceDescriptor *service) {
808   vars["service_name"] = service->name();
809   //        vars["proto_base_descriptor_supplier"] = service->name() +
810   //        "BaseDescriptorSupplier"; vars["proto_file_descriptor_supplier"] =
811   //        service->name() + "FileDescriptorSupplier";
812   //        vars["proto_method_descriptor_supplier"] = service->name() +
813   //        "MethodDescriptorSupplier"; vars["proto_class_name"] =
814   //        google::protobuf::compiler::java::ClassName(service->file());
815   //        p->Print(
816   //                 vars,
817   //                 "private static abstract class
818   //                 $proto_base_descriptor_supplier$\n" "    implements
819   //                 $ProtoFileDescriptorSupplier$,
820   //                 $ProtoServiceDescriptorSupplier$ {\n" "
821   //                 $proto_base_descriptor_supplier$() {}\n"
822   //                 "\n"
823   //                 "  @$Override$\n"
824   //                 "  public com.google.protobuf.Descriptors.FileDescriptor
825   //                 getFileDescriptor() {\n" "    return
826   //                 $proto_class_name$.getDescriptor();\n" "  }\n"
827   //                 "\n"
828   //                 "  @$Override$\n"
829   //                 "  public com.google.protobuf.Descriptors.ServiceDescriptor
830   //                 getServiceDescriptor() {\n" "    return
831   //                 getFileDescriptor().findServiceByName(\"$service_name$\");\n"
832   //                 "  }\n"
833   //                 "}\n"
834   //                 "\n"
835   //                 "private static final class
836   //                 $proto_file_descriptor_supplier$\n" "    extends
837   //                 $proto_base_descriptor_supplier$ {\n" "
838   //                 $proto_file_descriptor_supplier$() {}\n"
839   //                 "}\n"
840   //                 "\n"
841   //                 "private static final class
842   //                 $proto_method_descriptor_supplier$\n" "    extends
843   //                 $proto_base_descriptor_supplier$\n" "    implements
844   //                 $ProtoMethodDescriptorSupplier$ {\n" "  private final
845   //                 String methodName;\n"
846   //                 "\n"
847   //                 "  $proto_method_descriptor_supplier$(String methodName)
848   //                 {\n" "    this.methodName = methodName;\n" "  }\n"
849   //                 "\n"
850   //                 "  @$Override$\n"
851   //                 "  public com.google.protobuf.Descriptors.MethodDescriptor
852   //                 getMethodDescriptor() {\n" "    return
853   //                 getServiceDescriptor().findMethodByName(methodName);\n" "
854   //                 }\n"
855   //                 "}\n\n");
856 
857   p->Print(
858       vars,
859       "private static volatile $ServiceDescriptor$ serviceDescriptor;\n\n");
860 
861   p->Print(vars,
862            "public static $ServiceDescriptor$ getServiceDescriptor() {\n");
863   p->Indent();
864   p->Print(vars, "$ServiceDescriptor$ result = serviceDescriptor;\n");
865   p->Print("if (result == null) {\n");
866   p->Indent();
867   p->Print(vars, "synchronized ($service_class_name$.class) {\n");
868   p->Indent();
869   p->Print("result = serviceDescriptor;\n");
870   p->Print("if (result == null) {\n");
871   p->Indent();
872 
873   p->Print(vars,
874            "serviceDescriptor = result = "
875            "$ServiceDescriptor$.newBuilder(SERVICE_NAME)");
876   p->Indent();
877   p->Indent();
878   p->Print(vars, "\n.setSchemaDescriptor(null)");
879   for (int i = 0; i < service->method_count(); ++i) {
880     auto method = service->method(i);
881     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
882     p->Print(vars, "\n.addMethod($method_method_name$())");
883   }
884   p->Print("\n.build();\n");
885   p->Outdent();
886   p->Outdent();
887 
888   p->Outdent();
889   p->Print("}\n");
890   p->Outdent();
891   p->Print("}\n");
892   p->Outdent();
893   p->Print("}\n");
894   p->Print("return result;\n");
895   p->Outdent();
896   p->Print("}\n");
897 }
898 
PrintBindServiceMethodBody(Printer * p,VARS & vars,const ServiceDescriptor * service)899 static void PrintBindServiceMethodBody(Printer *p, VARS &vars,
900                                        const ServiceDescriptor *service) {
901   vars["service_name"] = service->name();
902   p->Indent();
903   p->Print(vars,
904            "return "
905            "$ServerServiceDefinition$.builder(getServiceDescriptor())\n");
906   p->Indent();
907   p->Indent();
908   for (int i = 0; i < service->method_count(); ++i) {
909     auto method = service->method(i);
910     vars["lower_method_name"] = LowerMethodName(&*method);
911     vars["method_method_name"] = MethodPropertiesGetterName(&*method);
912     vars["input_type"] = JavaClassName(vars, method->get_input_type_name());
913     vars["output_type"] = JavaClassName(vars, method->get_output_type_name());
914     vars["method_id_name"] = MethodIdFieldName(&*method);
915     bool client_streaming =
916         method->ClientStreaming() || method->BidiStreaming();
917     bool server_streaming =
918         method->ServerStreaming() || method->BidiStreaming();
919     if (client_streaming) {
920       if (server_streaming) {
921         vars["calls_method"] = "asyncBidiStreamingCall";
922       } else {
923         vars["calls_method"] = "asyncClientStreamingCall";
924       }
925     } else {
926       if (server_streaming) {
927         vars["calls_method"] = "asyncServerStreamingCall";
928       } else {
929         vars["calls_method"] = "asyncUnaryCall";
930       }
931     }
932     p->Print(vars, ".addMethod(\n");
933     p->Indent();
934     p->Print(vars,
935              "$method_method_name$(),\n"
936              "$calls_method$(\n");
937     p->Indent();
938     p->Print(vars,
939              "new MethodHandlers<\n"
940              "  $input_type$,\n"
941              "  $output_type$>(\n"
942              "    $instance$, $method_id_name$)))\n");
943     p->Outdent();
944     p->Outdent();
945   }
946   p->Print(".build();\n");
947   p->Outdent();
948   p->Outdent();
949   p->Outdent();
950 }
951 
PrintService(Printer * p,VARS & vars,const ServiceDescriptor * service,bool disable_version)952 static void PrintService(Printer *p, VARS &vars,
953                          const ServiceDescriptor *service,
954                          bool disable_version) {
955   vars["service_name"] = service->name();
956   vars["service_class_name"] = ServiceClassName(service->name());
957   vars["grpc_version"] = "";
958 #ifdef GRPC_VERSION
959   if (!disable_version) {
960     vars["grpc_version"] = " (version " XSTR(GRPC_VERSION) ")";
961   }
962 #else
963   (void)disable_version;
964 #endif
965   // TODO(nmittler): Replace with WriteServiceDocComment once included by
966   // protobuf distro.
967   GrpcWriteServiceDocComment(p, vars, service);
968   p->Print(vars,
969            "@$Generated$(\n"
970            "    value = \"by gRPC proto compiler$grpc_version$\",\n"
971            "    comments = \"Source: $file_name$.fbs\")\n"
972            "public final class $service_class_name$ {\n\n");
973   p->Indent();
974   p->Print(vars, "private $service_class_name$() {}\n\n");
975 
976   p->Print(vars,
977            "public static final String SERVICE_NAME = "
978            "\"$Package$$service_name$\";\n\n");
979 
980   PrintMethodFields(p, vars, service);
981 
982   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
983   // distro.
984   GrpcWriteDocComment(
985       p, vars,
986       " Creates a new async stub that supports all call types for the service");
987   p->Print(vars,
988            "public static $service_name$Stub newStub($Channel$ channel) {\n");
989   p->Indent();
990   p->Print(vars, "return new $service_name$Stub(channel);\n");
991   p->Outdent();
992   p->Print("}\n\n");
993 
994   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
995   // distro.
996   GrpcWriteDocComment(
997       p, vars,
998       " Creates a new blocking-style stub that supports unary and streaming "
999       "output calls on the service");
1000   p->Print(vars,
1001            "public static $service_name$BlockingStub newBlockingStub(\n"
1002            "    $Channel$ channel) {\n");
1003   p->Indent();
1004   p->Print(vars, "return new $service_name$BlockingStub(channel);\n");
1005   p->Outdent();
1006   p->Print("}\n\n");
1007 
1008   // TODO(nmittler): Replace with WriteDocComment once included by protobuf
1009   // distro.
1010   GrpcWriteDocComment(
1011       p, vars,
1012       " Creates a new ListenableFuture-style stub that supports unary calls "
1013       "on the service");
1014   p->Print(vars,
1015            "public static $service_name$FutureStub newFutureStub(\n"
1016            "    $Channel$ channel) {\n");
1017   p->Indent();
1018   p->Print(vars, "return new $service_name$FutureStub(channel);\n");
1019   p->Outdent();
1020   p->Print("}\n\n");
1021 
1022   PrintStub(p, vars, service, ABSTRACT_CLASS);
1023   PrintStub(p, vars, service, ASYNC_CLIENT_IMPL);
1024   PrintStub(p, vars, service, BLOCKING_CLIENT_IMPL);
1025   PrintStub(p, vars, service, FUTURE_CLIENT_IMPL);
1026 
1027   PrintMethodHandlerClass(p, vars, service);
1028   PrintGetServiceDescriptorMethod(p, vars, service);
1029   p->Outdent();
1030   p->Print("}\n");
1031 }
1032 
PrintStaticImports(Printer * p)1033 static void PrintStaticImports(Printer *p) {
1034   p->Print(
1035       "import java.nio.ByteBuffer;\n"
1036       "import static "
1037       "io.grpc.MethodDescriptor.generateFullMethodName;\n"
1038       "import static "
1039       "io.grpc.stub.ClientCalls.asyncBidiStreamingCall;\n"
1040       "import static "
1041       "io.grpc.stub.ClientCalls.asyncClientStreamingCall;\n"
1042       "import static "
1043       "io.grpc.stub.ClientCalls.asyncServerStreamingCall;\n"
1044       "import static "
1045       "io.grpc.stub.ClientCalls.asyncUnaryCall;\n"
1046       "import static "
1047       "io.grpc.stub.ClientCalls.blockingServerStreamingCall;\n"
1048       "import static "
1049       "io.grpc.stub.ClientCalls.blockingUnaryCall;\n"
1050       "import static "
1051       "io.grpc.stub.ClientCalls.futureUnaryCall;\n"
1052       "import static "
1053       "io.grpc.stub.ServerCalls.asyncBidiStreamingCall;\n"
1054       "import static "
1055       "io.grpc.stub.ServerCalls.asyncClientStreamingCall;\n"
1056       "import static "
1057       "io.grpc.stub.ServerCalls.asyncServerStreamingCall;\n"
1058       "import static "
1059       "io.grpc.stub.ServerCalls.asyncUnaryCall;\n"
1060       "import static "
1061       "io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall;\n"
1062       "import static "
1063       "io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall;\n\n");
1064 }
1065 
GenerateService(const grpc_generator::Service * service,grpc_generator::Printer * printer,VARS & vars,bool disable_version)1066 static void GenerateService(const grpc_generator::Service *service,
1067                      grpc_generator::Printer *printer, VARS &vars,
1068                      bool disable_version) {
1069   // All non-generated classes must be referred by fully qualified names to
1070   // avoid collision with generated classes.
1071   vars["String"] = "java.lang.String";
1072   vars["Deprecated"] = "java.lang.Deprecated";
1073   vars["Override"] = "java.lang.Override";
1074   vars["Channel"] = "io.grpc.Channel";
1075   vars["CallOptions"] = "io.grpc.CallOptions";
1076   vars["MethodType"] = "io.grpc.MethodDescriptor.MethodType";
1077   vars["ServerMethodDefinition"] = "io.grpc.ServerMethodDefinition";
1078   vars["BindableService"] = "io.grpc.BindableService";
1079   vars["ServerServiceDefinition"] = "io.grpc.ServerServiceDefinition";
1080   vars["ServiceDescriptor"] = "io.grpc.ServiceDescriptor";
1081   vars["ProtoFileDescriptorSupplier"] =
1082       "io.grpc.protobuf.ProtoFileDescriptorSupplier";
1083   vars["ProtoServiceDescriptorSupplier"] =
1084       "io.grpc.protobuf.ProtoServiceDescriptorSupplier";
1085   vars["ProtoMethodDescriptorSupplier"] =
1086       "io.grpc.protobuf.ProtoMethodDescriptorSupplier";
1087   vars["AbstractStub"] = "io.grpc.stub.AbstractStub";
1088   vars["MethodDescriptor"] = "io.grpc.MethodDescriptor";
1089   vars["NanoUtils"] = "io.grpc.protobuf.nano.NanoUtils";
1090   vars["StreamObserver"] = "io.grpc.stub.StreamObserver";
1091   vars["Iterator"] = "java.util.Iterator";
1092   vars["Generated"] = "javax.annotation.Generated";
1093   vars["ListenableFuture"] =
1094       "com.google.common.util.concurrent.ListenableFuture";
1095   vars["ExperimentalApi"] = "io.grpc.ExperimentalApi";
1096 
1097   PrintStaticImports(printer);
1098 
1099   PrintService(printer, vars, service, disable_version);
1100 }
1101 } // namespace
1102 
GenerateServiceSource(grpc_generator::File * file,const grpc_generator::Service * service,grpc_java_generator::Parameters * parameters)1103 grpc::string GenerateServiceSource(
1104     grpc_generator::File *file, const grpc_generator::Service *service,
1105     grpc_java_generator::Parameters *parameters) {
1106   grpc::string out;
1107   auto printer = file->CreatePrinter(&out);
1108   VARS vars;
1109   vars["flatc_version"] = grpc::string(
1110       FLATBUFFERS_STRING(FLATBUFFERS_VERSION_MAJOR) "." FLATBUFFERS_STRING(
1111           FLATBUFFERS_VERSION_MINOR) "." FLATBUFFERS_STRING(FLATBUFFERS_VERSION_REVISION));
1112 
1113   vars["file_name"] = file->filename();
1114 
1115   if (!parameters->package_name.empty()) {
1116     vars["Package"] = parameters->package_name;  // ServiceJavaPackage(service);
1117   }
1118   GenerateImports(file, &*printer, vars);
1119   GenerateService(service, &*printer, vars, false);
1120   return out;
1121 }
1122 
1123 }  // namespace grpc_java_generator
1124