1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Code generation for the `call_method!` series of proc-macros. This module is used by creating a
16 //! [`MethodCall`] instance and calling [`MethodCall::generate`]. The generated code will vary
17 //! for each macro based on the contained [`Receiver`] value. If there is a detected issue with the
18 //! [`MethodCall`] provided, then `generate` will return a [`CodegenError`] with information on
19 //! what is wrong. This error can be converted into a [`syn::Error`] for reporting any failures to
20 //! the user.
21 
22 use proc_macro2::{Span, TokenStream};
23 use quote::{quote, quote_spanned, TokenStreamExt};
24 use syn::{parse_quote, spanned::Spanned, Expr, Ident, ItemStatic, LitStr, Type};
25 
26 use crate::type_parser::{JavaType, MethodSig, NonArray, Primitive, ReturnType};
27 
28 /// The errors that can be encountered during codegen. Used in [`CodegenError`].
29 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
30 pub enum ErrorKind {
31     InvalidArgsLength { expected: usize, found: usize },
32     ConstructorRetValShouldBeVoid,
33     InvalidTypeSignature,
34 }
35 
36 /// An error encountered during codegen with span information. Can be converted to [`syn::Error`].
37 /// using [`From`].
38 #[derive(Clone, Debug)]
39 pub struct CodegenError(pub Span, pub ErrorKind);
40 
41 impl From<CodegenError> for syn::Error {
from(CodegenError(span, kind): CodegenError) -> syn::Error42     fn from(CodegenError(span, kind): CodegenError) -> syn::Error {
43         use ErrorKind::*;
44         match kind {
45             InvalidArgsLength { expected, found } => {
46                 syn::Error::new(span, format!("The number of args does not match the type signature provided: expected={expected}, found={found}"))
47             }
48             ConstructorRetValShouldBeVoid => {
49                 syn::Error::new(span, "Return type should be `void` (`V`) for constructor methods")
50             }
51             InvalidTypeSignature => syn::Error::new(span, "Failed to parse type signature"),
52         }
53     }
54 }
55 
56 /// Codegen can fail with [`CodegenError`].
57 pub type CodegenResult<T> = Result<T, CodegenError>;
58 
59 /// Describes a method that will be generated. Create one with [`MethodCall::new`] and generate
60 /// code with [`MethodCall::generate`]. This should be given AST nodes from the macro input so that
61 /// errors are associated properly to the input span.
62 pub struct MethodCall {
63     env: Expr,
64     method: MethodInfo,
65     receiver: Receiver,
66     arg_exprs: Vec<Expr>,
67 }
68 
69 impl MethodCall {
70     /// Create a new MethodCall instance
new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec<Expr>) -> Self71     pub fn new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec<Expr>) -> Self {
72         Self {
73             env,
74             method,
75             receiver,
76             arg_exprs,
77         }
78     }
79 
80     /// Generate code to call the described method.
generate(&self) -> CodegenResult<TokenStream>81     pub fn generate(&self) -> CodegenResult<TokenStream> {
82         // Needs to be threaded manually to other methods since self-referential structs can't
83         // exist.
84         let sig = self.method.sig()?;
85 
86         let args = self.generate_args(&sig)?;
87 
88         let method_call = self
89             .receiver
90             .generate_call(&self.env, &self.method, &sig, &args)?;
91 
92         // Wrap the generated code in a closure so that we can access the outer scope but any
93         // variables we define aren't accessible by the outer scope. There is small hygiene issue
94         // where the arg exprs have our `env` variable in scope. If this becomes an issue we can
95         // refactor these exprs to be passed as closure parameters instead.
96         Ok(quote! {
97             (|| {
98                 #method_call
99             })()
100         })
101     }
102 
103     /// Generate the `&[jni::jvalue]` arguments slice that will be passed to the `jni` method call.
104     /// This validates the argument count and types.
generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult<Expr>105     fn generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult<Expr> {
106         // Safety: must check that arg count matches the signature
107         if self.arg_exprs.len() != sig.args.len() {
108             return Err(CodegenError(
109                 self.method.sig.span(),
110                 ErrorKind::InvalidArgsLength {
111                     expected: sig.args.len(),
112                     found: self.arg_exprs.len(),
113                 },
114             ));
115         }
116 
117         // Create each `jvalue` expression
118         let type_expr_pairs = core::iter::zip(sig.args.iter().copied(), self.arg_exprs.iter());
119         let jvalues = type_expr_pairs.map(|(ty, expr)| generate_jvalue(ty, expr));
120 
121         // Put the `jvalue` expressions in a slice.
122         Ok(parse_quote! {
123             &[#(#jvalues),*]
124         })
125     }
126 }
127 
128 /// The receiver of the method call and the type of the method.
129 pub enum Receiver {
130     /// A constructor.
131     Constructor,
132     /// A static method.
133     Static,
134     /// An instance method. The `Expr` here is the `this` object.
135     Instance(Expr),
136 }
137 
138 impl Receiver {
139     /// Generate the code that performs the JNI call.
generate_call( &self, env: &Expr, method_info: &MethodInfo, sig: &MethodSig<'_>, args: &Expr, ) -> CodegenResult<TokenStream>140     fn generate_call(
141         &self,
142         env: &Expr,
143         method_info: &MethodInfo,
144         sig: &MethodSig<'_>,
145         args: &Expr,
146     ) -> CodegenResult<TokenStream> {
147         // Constructors are void methods. Validate this fact.
148         if matches!(*self, Receiver::Constructor) && !sig.ret.is_void() {
149             return Err(CodegenError(
150                 method_info.sig.span(),
151                 ErrorKind::ConstructorRetValShouldBeVoid,
152             ));
153         }
154 
155         // The static item containing the `pourover::[Static]MethodDesc`.
156         let method_desc = self.generate_method_desc(method_info);
157 
158         // The `jni::signature::ReturnType` that the `jni` crate uses to perform the correct native
159         // call.
160         let return_type = return_type_from_sig(sig.ret);
161 
162         // A conversion expression to convert from `jni::object::JValueOwned` to the actual return
163         // type. We have this information from the parsed method signature whereas the `jni` crate
164         // only knows this at runtime.
165         let conversion = return_value_conversion_from_sig(sig.ret);
166 
167         // This preamble is used to evaluate all the client-provided expressions outside of the
168         // `unsafe` block. This is the same for all receiver kinds.
169         let mut method_call = quote! {
170             #method_desc
171 
172             let env: &mut ::jni::JNIEnv = #env;
173             let method_id = ::jni::descriptors::Desc::lookup(&METHOD_DESC, env)?;
174             let args: &[::jni::sys::jvalue] = #args;
175         };
176 
177         // Generate the unsafe JNI call.
178         //
179         // Safety: `args` contains the arguments to this method. The type signature of this
180         // method is `#sig`.
181         //
182         // `args` must adhere to the following:
183         //  - `args.len()` must match the number of arguments given in the type signature.
184         //  - The union value of each arg in `args` must match the type specified in the type
185         //    signature.
186         //
187         // These conditions are upheld by this proc macro and a compile error will be caused if
188         // they are broken. No user-provided code is executed within the `unsafe` block.
189         method_call.append_all(match self {
190             Self::Constructor => quote! {
191                 unsafe {
192                     env.new_object_unchecked(
193                         METHOD_DESC.cls(),
194                         method_id,
195                         args,
196                     )
197                 }
198             },
199             Self::Static => quote! {
200                 unsafe {
201                     env.call_static_method_unchecked(
202                         METHOD_DESC.cls(),
203                         method_id,
204                         #return_type,
205                         args,
206                     )
207                 }#conversion
208             },
209             Self::Instance(this) => quote! {
210                 let this_obj: &JObject = #this;
211                 unsafe {
212                     env.call_method_unchecked(
213                         this_obj,
214                         method_id,
215                         #return_type,
216                         args,
217                     )
218                 }#conversion
219             },
220         });
221 
222         Ok(method_call)
223     }
224 
225     fn generate_method_desc(&self, MethodInfo { cls, name, sig, .. }: &MethodInfo) -> ItemStatic {
226         match self {
227             Self::Constructor => parse_quote! {
228                 static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).constructor(#sig);
229             },
230             Self::Static => parse_quote! {
231                 static METHOD_DESC: ::pourover::desc::StaticMethodDesc = (#cls).static_method(#name, #sig);
232             },
233             Self::Instance(_) => parse_quote! {
234                 static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).method(#name, #sig);
235             },
236         }
237     }
238 }
239 
240 /// Information about the method being called
241 pub struct MethodInfo {
242     cls: Expr,
243     name: Expr,
244     sig: LitStr,
245     /// Derived from `sig.value()`. This string must be stored in the struct so that we can return
246     /// a `MethodSig` instance that references it from `MethodInfo::sig()`.
247     sig_str: String,
248 }
249 
250 impl MethodInfo {
new(cls: Expr, name: Expr, sig: LitStr) -> Self251     pub fn new(cls: Expr, name: Expr, sig: LitStr) -> Self {
252         let sig_str = sig.value();
253         Self {
254             cls,
255             name,
256             sig,
257             sig_str,
258         }
259     }
260 
261     /// Parse the type signature from `sig`. Will return a [`CodegenError`] if the signature cannot
262     /// be parsed.
sig(&self) -> CodegenResult<MethodSig<'_>>263     fn sig(&self) -> CodegenResult<MethodSig<'_>> {
264         MethodSig::try_from_str(&self.sig_str)
265             .ok_or_else(|| CodegenError(self.sig.span(), ErrorKind::InvalidTypeSignature))
266     }
267 }
268 
269 /// Generate a `jni::sys::jvalue` instance given a Java type and a Rust value.
270 ///
271 /// Safety: The generated `jvalue` must match the given type `ty`.
generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream272 fn generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream {
273     // The `jvalue` field to inhabit
274     let union_field: Ident;
275     // The expected input type
276     let type_name: Type;
277     // Whether we need to call `JObject::as_raw()` on the input type
278     let needs_as_raw: bool;
279 
280     // Fill the above values based the type signature.
281     match ty {
282         JavaType::Array { depth, ty } => {
283             union_field = parse_quote![l];
284             if let NonArray::Primitive(p) = ty {
285                 if depth.get() == 1 {
286                     let prim_type = prim_to_sys_type(p);
287                     type_name = parse_quote![::jni::objects::JPrimitiveArray<'_, #prim_type>]
288                 } else {
289                     type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
290                 }
291             } else {
292                 type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
293             }
294             needs_as_raw = true;
295         }
296         JavaType::NonArray(NonArray::Object { cls }) => {
297             union_field = parse_quote![l];
298             type_name = match cls {
299                 "java/lang/String" => parse_quote![&::jni::objects::JString<'_>],
300                 "java/util/List" => parse_quote![&::jni::objects::JList<'_>],
301                 "java/util/Map" => parse_quote![&::jni::objects::JMap<'_>],
302                 _ => parse_quote![&::jni::objects::JObject<'_>],
303             };
304             needs_as_raw = true;
305         }
306         JavaType::NonArray(NonArray::Primitive(p)) => {
307             union_field = prim_to_union_field(p);
308             type_name = prim_to_sys_type(p);
309             needs_as_raw = false;
310         }
311     }
312 
313     // The as_raw() tokens if required.
314     let as_raw = if needs_as_raw {
315         quote! { .as_raw() }
316     } else {
317         quote![]
318     };
319 
320     // Create the `jvalue` expression. This uses `identity` to produce nice type error messages.
321     quote_spanned! { expr.span() =>
322         ::jni::sys::jvalue {
323             #union_field: ::core::convert::identity::<#type_name>(#expr) #as_raw
324         }
325     }
326 }
327 
328 /// Get a `::jni::signature::ReturnType` expression from a [`crate::type_parser::ReturnType`]. This
329 /// value is passed to the `jni` crate so that it knows which JNI method to call.
return_type_from_sig(ret: ReturnType<'_>) -> Expr330 fn return_type_from_sig(ret: ReturnType<'_>) -> Expr {
331     let prim_type = |prim| parse_quote![::jni::signature::ReturnType::Primitive(::jni::signature::Primitive::#prim)];
332 
333     use crate::type_parser::{JavaType::*, NonArray::*, Primitive::*};
334 
335     match ret {
336         ReturnType::Void => prim_type(quote![Void]),
337         ReturnType::Returns(NonArray(Primitive(Boolean))) => prim_type(quote![Boolean]),
338         ReturnType::Returns(NonArray(Primitive(Byte))) => prim_type(quote![Byte]),
339         ReturnType::Returns(NonArray(Primitive(Char))) => prim_type(quote![Char]),
340         ReturnType::Returns(NonArray(Primitive(Double))) => prim_type(quote![Double]),
341         ReturnType::Returns(NonArray(Primitive(Float))) => prim_type(quote![Float]),
342         ReturnType::Returns(NonArray(Primitive(Int))) => prim_type(quote![Int]),
343         ReturnType::Returns(NonArray(Primitive(Long))) => prim_type(quote![Long]),
344         ReturnType::Returns(NonArray(Primitive(Short))) => prim_type(quote![Short]),
345         ReturnType::Returns(NonArray(Object { .. })) => {
346             parse_quote![::jni::signature::ReturnType::Object]
347         }
348         ReturnType::Returns(Array { .. }) => parse_quote![::jni::signature::ReturnType::Array],
349     }
350 }
351 
352 /// A postfix call on a `jni::objects::JValueOwned` instance to convert it to the type specified by
353 /// `ret`. Since we have this information from the type signature we can  perform this conversion
354 /// in the macro.
return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream355 fn return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream {
356     use crate::type_parser::{JavaType::*, NonArray::*};
357 
358     match ret {
359         ReturnType::Void => quote! { .and_then(::jni::objects::JValueOwned::v) },
360         ReturnType::Returns(NonArray(Primitive(p))) => {
361             let prim = prim_to_union_field(p);
362             quote! { .and_then(::jni::objects::JValueOwned::#prim) }
363         }
364         ReturnType::Returns(NonArray(Object { cls })) => {
365             let mut conversion = quote! { .and_then(::jni::objects::JValueOwned::l) };
366             match cls {
367                 "java/lang/String" => {
368                     conversion.append_all(quote! { .map(::jni::objects::JString::from) });
369                 }
370                 "java/util/List" => {
371                     conversion.append_all(quote! { .map(::jni::objects::JList::from) });
372                 }
373                 "java/util/Map" => {
374                     conversion.append_all(quote! { .map(::jni::objects::JMap::from) });
375                 }
376                 _ => {
377                     // Already a JObject, so we are good here
378                 }
379             }
380             conversion
381         }
382         ReturnType::Returns(Array {
383             depth,
384             ty: Primitive(p),
385         }) if depth.get() == 1 => {
386             let sys_type = prim_to_sys_type(p);
387             quote! {
388                 .and_then(::jni::objects::JValueOwned::l)
389                 .map(::jni::objects::JPrimitiveArray::<#sys_type>::from)
390             }
391         }
392         ReturnType::Returns(Array { .. }) => quote! {
393             .and_then(::jni::objects::JValueOwned::l)
394             .map(::jni::objects::JObjectArray::from)
395         },
396     }
397 }
398 
399 /// From a [`Primitive`], this gets the `jni::sys::jvalue` union field name for that type. This is
400 /// also the `jni::objects::JValueGen` getter name.
prim_to_union_field(p: Primitive) -> Ident401 fn prim_to_union_field(p: Primitive) -> Ident {
402     quote::format_ident!("{}", p.as_char().to_ascii_lowercase())
403 }
404 
405 /// From a [`Primitive`], this gets the matching `jvalue::sys` type.
prim_to_sys_type(p: Primitive) -> Type406 fn prim_to_sys_type(p: Primitive) -> Type {
407     match p {
408         Primitive::Boolean => parse_quote![::jni::sys::jboolean],
409         Primitive::Byte => parse_quote![::jni::sys::jbyte],
410         Primitive::Char => parse_quote![::jni::sys::jchar],
411         Primitive::Double => parse_quote![::jni::sys::jdouble],
412         Primitive::Float => parse_quote![::jni::sys::jfloat],
413         Primitive::Int => parse_quote![::jni::sys::jint],
414         Primitive::Long => parse_quote![::jni::sys::jlong],
415         Primitive::Short => parse_quote![::jni::sys::jshort],
416     }
417 }
418 
419 #[cfg(test)]
420 #[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
421 mod tests {
422     use super::*;
423     use crate::test_util::contains_ident;
424     use quote::ToTokens;
425     use syn::parse_quote;
426 
example_method_call() -> MethodCall427     fn example_method_call() -> MethodCall {
428         MethodCall::new(
429             parse_quote![&mut env],
430             MethodInfo::new(
431                 parse_quote![&FOO_CLS],
432                 parse_quote!["example"],
433                 parse_quote!["(II)I"],
434             ),
435             Receiver::Instance(parse_quote![&foo]),
436             vec![parse_quote![123], parse_quote![2 + 3]],
437         )
438     }
439 
440     #[test]
args_are_counted()441     fn args_are_counted() {
442         let mut call = example_method_call();
443         call.arg_exprs.push(parse_quote![too_many]);
444 
445         let CodegenError(_span, kind) = call.generate().unwrap_err();
446 
447         assert_eq!(
448             ErrorKind::InvalidArgsLength {
449                 expected: 2,
450                 found: 3
451             },
452             kind
453         );
454     }
455 
456     #[test]
constructor_return_type_is_void()457     fn constructor_return_type_is_void() {
458         let mut call = example_method_call();
459         call.receiver = Receiver::Constructor;
460 
461         let CodegenError(_span, kind) = call.generate().unwrap_err();
462 
463         assert_eq!(ErrorKind::ConstructorRetValShouldBeVoid, kind);
464     }
465 
466     #[test]
invalid_type_sig_is_error()467     fn invalid_type_sig_is_error() {
468         let mut call = example_method_call();
469         call.method.sig = parse_quote!["L"];
470         call.method.sig_str = call.method.sig.value();
471 
472         let CodegenError(_span, kind) = call.generate().unwrap_err();
473 
474         assert_eq!(ErrorKind::InvalidTypeSignature, kind);
475     }
476 
477     #[test]
jni_types_are_used_for_stdlib_classes_input()478     fn jni_types_are_used_for_stdlib_classes_input() {
479         let types = [
480             ("Ljava/lang/String;", "JString"),
481             ("Ljava/util/Map;", "JMap"),
482             ("Ljava/util/List;", "JList"),
483             ("[Ljava/lang/String;", "JObjectArray"),
484             ("[[I", "JObjectArray"),
485             ("[I", "JPrimitiveArray"),
486             ("Lcom/example/MyObject;", "JObject"),
487             ("Z", "jboolean"),
488             ("C", "jchar"),
489             ("B", "jbyte"),
490             ("S", "jshort"),
491             ("I", "jint"),
492             ("J", "jlong"),
493             ("F", "jfloat"),
494             ("D", "jdouble"),
495         ];
496 
497         for (desc, jni_type) in types {
498             let jt = JavaType::try_from_str(desc).unwrap();
499             let expr = parse_quote![some_value];
500 
501             let jvalue = generate_jvalue(jt, &expr);
502 
503             assert!(
504                 contains_ident(jvalue, jni_type),
505                 "desc: {desc}, jni_type: {jni_type}"
506             );
507         }
508     }
509 
510     #[test]
jni_types_are_used_for_stdlib_classes_output()511     fn jni_types_are_used_for_stdlib_classes_output() {
512         let types = [
513             ("Ljava/lang/String;", "JString"),
514             ("Ljava/util/Map;", "JMap"),
515             ("Ljava/util/List;", "JList"),
516             ("[Ljava/lang/String;", "JObjectArray"),
517             ("[[I", "JObjectArray"),
518             ("[I", "JPrimitiveArray"),
519         ];
520 
521         for (desc, jni_type) in types {
522             let rt = ReturnType::try_from_str(desc).unwrap();
523 
524             let conversion = return_value_conversion_from_sig(rt);
525 
526             assert!(
527                 contains_ident(conversion, jni_type),
528                 "desc: {desc}, jni_type: {jni_type}"
529             );
530         }
531     }
532 
533     #[test]
return_type_passed_to_jni_is_correct()534     fn return_type_passed_to_jni_is_correct() {
535         let types = [
536             ("Ljava/lang/String;", "Object"),
537             ("Ljava/util/Map;", "Object"),
538             ("Ljava/util/List;", "Object"),
539             ("[Ljava/lang/String;", "Array"),
540             ("[[I", "Array"),
541             ("[I", "Array"),
542             ("V", "Void"),
543             ("Z", "Boolean"),
544             ("C", "Char"),
545             ("B", "Byte"),
546             ("S", "Short"),
547             ("I", "Int"),
548             ("J", "Long"),
549             ("F", "Float"),
550             ("D", "Double"),
551         ];
552 
553         for (desc, return_type) in types {
554             let rt = ReturnType::try_from_str(desc).unwrap();
555 
556             let expr = return_type_from_sig(rt).into_token_stream();
557 
558             assert!(
559                 contains_ident(expr, return_type),
560                 "desc: {desc}, return_type: {return_type}"
561             );
562         }
563     }
564 
565     #[test]
method_desc_is_correct()566     fn method_desc_is_correct() {
567         let mut call = example_method_call();
568         call.method.sig = parse_quote!["(II)V"];
569         call.method.sig_str = call.method.sig.value();
570 
571         let tests = [
572             (Receiver::Constructor, "constructor"),
573             (Receiver::Static, "static_method"),
574             (Receiver::Instance(parse_quote![this_value]), "method"),
575         ];
576 
577         for (receiver, method_ident) in tests {
578             let desc = receiver.generate_method_desc(&call.method);
579             let rhs = desc.expr.into_token_stream();
580 
581             assert!(contains_ident(rhs, method_ident), "method: {method_ident}");
582         }
583     }
584 
585     #[test]
jni_call_is_correct()586     fn jni_call_is_correct() {
587         let mut call = example_method_call();
588         call.method.sig = parse_quote!["(II)V"];
589         call.method.sig_str = call.method.sig.value();
590         let sig = call.method.sig().unwrap();
591         let args = parse_quote![test_stub];
592 
593         let tests = [
594             (Receiver::Constructor, "new_object_unchecked"),
595             (Receiver::Static, "call_static_method_unchecked"),
596             (
597                 Receiver::Instance(parse_quote![this_value]),
598                 "call_method_unchecked",
599             ),
600         ];
601 
602         for (receiver, method_ident) in tests {
603             let call = receiver
604                 .generate_call(&call.env, &call.method, &sig, &args)
605                 .unwrap();
606 
607             assert!(contains_ident(call, method_ident), "method: {method_ident}");
608         }
609     }
610 }
611