1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Code generation for the `call_method!` series of proc-macros. This module is used by creating a
16 //! [`MethodCall`] instance and calling [`MethodCall::generate`]. The generated code will vary
17 //! for each macro based on the contained [`Receiver`] value. If there is a detected issue with the
18 //! [`MethodCall`] provided, then `generate` will return a [`CodegenError`] with information on
19 //! what is wrong. This error can be converted into a [`syn::Error`] for reporting any failures to
20 //! the user.
21
22 use proc_macro2::{Span, TokenStream};
23 use quote::{quote, quote_spanned, TokenStreamExt};
24 use syn::{parse_quote, spanned::Spanned, Expr, Ident, ItemStatic, LitStr, Type};
25
26 use crate::type_parser::{JavaType, MethodSig, NonArray, Primitive, ReturnType};
27
28 /// The errors that can be encountered during codegen. Used in [`CodegenError`].
29 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
30 pub enum ErrorKind {
31 InvalidArgsLength { expected: usize, found: usize },
32 ConstructorRetValShouldBeVoid,
33 InvalidTypeSignature,
34 }
35
36 /// An error encountered during codegen with span information. Can be converted to [`syn::Error`].
37 /// using [`From`].
38 #[derive(Clone, Debug)]
39 pub struct CodegenError(pub Span, pub ErrorKind);
40
41 impl From<CodegenError> for syn::Error {
from(CodegenError(span, kind): CodegenError) -> syn::Error42 fn from(CodegenError(span, kind): CodegenError) -> syn::Error {
43 use ErrorKind::*;
44 match kind {
45 InvalidArgsLength { expected, found } => {
46 syn::Error::new(span, format!("The number of args does not match the type signature provided: expected={expected}, found={found}"))
47 }
48 ConstructorRetValShouldBeVoid => {
49 syn::Error::new(span, "Return type should be `void` (`V`) for constructor methods")
50 }
51 InvalidTypeSignature => syn::Error::new(span, "Failed to parse type signature"),
52 }
53 }
54 }
55
56 /// Codegen can fail with [`CodegenError`].
57 pub type CodegenResult<T> = Result<T, CodegenError>;
58
59 /// Describes a method that will be generated. Create one with [`MethodCall::new`] and generate
60 /// code with [`MethodCall::generate`]. This should be given AST nodes from the macro input so that
61 /// errors are associated properly to the input span.
62 pub struct MethodCall {
63 env: Expr,
64 method: MethodInfo,
65 receiver: Receiver,
66 arg_exprs: Vec<Expr>,
67 }
68
69 impl MethodCall {
70 /// Create a new MethodCall instance
new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec<Expr>) -> Self71 pub fn new(env: Expr, method: MethodInfo, receiver: Receiver, arg_exprs: Vec<Expr>) -> Self {
72 Self {
73 env,
74 method,
75 receiver,
76 arg_exprs,
77 }
78 }
79
80 /// Generate code to call the described method.
generate(&self) -> CodegenResult<TokenStream>81 pub fn generate(&self) -> CodegenResult<TokenStream> {
82 // Needs to be threaded manually to other methods since self-referential structs can't
83 // exist.
84 let sig = self.method.sig()?;
85
86 let args = self.generate_args(&sig)?;
87
88 let method_call = self
89 .receiver
90 .generate_call(&self.env, &self.method, &sig, &args)?;
91
92 // Wrap the generated code in a closure so that we can access the outer scope but any
93 // variables we define aren't accessible by the outer scope. There is small hygiene issue
94 // where the arg exprs have our `env` variable in scope. If this becomes an issue we can
95 // refactor these exprs to be passed as closure parameters instead.
96 Ok(quote! {
97 (|| {
98 #method_call
99 })()
100 })
101 }
102
103 /// Generate the `&[jni::jvalue]` arguments slice that will be passed to the `jni` method call.
104 /// This validates the argument count and types.
generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult<Expr>105 fn generate_args(&self, sig: &MethodSig<'_>) -> CodegenResult<Expr> {
106 // Safety: must check that arg count matches the signature
107 if self.arg_exprs.len() != sig.args.len() {
108 return Err(CodegenError(
109 self.method.sig.span(),
110 ErrorKind::InvalidArgsLength {
111 expected: sig.args.len(),
112 found: self.arg_exprs.len(),
113 },
114 ));
115 }
116
117 // Create each `jvalue` expression
118 let type_expr_pairs = core::iter::zip(sig.args.iter().copied(), self.arg_exprs.iter());
119 let jvalues = type_expr_pairs.map(|(ty, expr)| generate_jvalue(ty, expr));
120
121 // Put the `jvalue` expressions in a slice.
122 Ok(parse_quote! {
123 &[#(#jvalues),*]
124 })
125 }
126 }
127
128 /// The receiver of the method call and the type of the method.
129 pub enum Receiver {
130 /// A constructor.
131 Constructor,
132 /// A static method.
133 Static,
134 /// An instance method. The `Expr` here is the `this` object.
135 Instance(Expr),
136 }
137
138 impl Receiver {
139 /// Generate the code that performs the JNI call.
generate_call( &self, env: &Expr, method_info: &MethodInfo, sig: &MethodSig<'_>, args: &Expr, ) -> CodegenResult<TokenStream>140 fn generate_call(
141 &self,
142 env: &Expr,
143 method_info: &MethodInfo,
144 sig: &MethodSig<'_>,
145 args: &Expr,
146 ) -> CodegenResult<TokenStream> {
147 // Constructors are void methods. Validate this fact.
148 if matches!(*self, Receiver::Constructor) && !sig.ret.is_void() {
149 return Err(CodegenError(
150 method_info.sig.span(),
151 ErrorKind::ConstructorRetValShouldBeVoid,
152 ));
153 }
154
155 // The static item containing the `pourover::[Static]MethodDesc`.
156 let method_desc = self.generate_method_desc(method_info);
157
158 // The `jni::signature::ReturnType` that the `jni` crate uses to perform the correct native
159 // call.
160 let return_type = return_type_from_sig(sig.ret);
161
162 // A conversion expression to convert from `jni::object::JValueOwned` to the actual return
163 // type. We have this information from the parsed method signature whereas the `jni` crate
164 // only knows this at runtime.
165 let conversion = return_value_conversion_from_sig(sig.ret);
166
167 // This preamble is used to evaluate all the client-provided expressions outside of the
168 // `unsafe` block. This is the same for all receiver kinds.
169 let mut method_call = quote! {
170 #method_desc
171
172 let env: &mut ::jni::JNIEnv = #env;
173 let method_id = ::jni::descriptors::Desc::lookup(&METHOD_DESC, env)?;
174 let args: &[::jni::sys::jvalue] = #args;
175 };
176
177 // Generate the unsafe JNI call.
178 //
179 // Safety: `args` contains the arguments to this method. The type signature of this
180 // method is `#sig`.
181 //
182 // `args` must adhere to the following:
183 // - `args.len()` must match the number of arguments given in the type signature.
184 // - The union value of each arg in `args` must match the type specified in the type
185 // signature.
186 //
187 // These conditions are upheld by this proc macro and a compile error will be caused if
188 // they are broken. No user-provided code is executed within the `unsafe` block.
189 method_call.append_all(match self {
190 Self::Constructor => quote! {
191 unsafe {
192 env.new_object_unchecked(
193 METHOD_DESC.cls(),
194 method_id,
195 args,
196 )
197 }
198 },
199 Self::Static => quote! {
200 unsafe {
201 env.call_static_method_unchecked(
202 METHOD_DESC.cls(),
203 method_id,
204 #return_type,
205 args,
206 )
207 }#conversion
208 },
209 Self::Instance(this) => quote! {
210 let this_obj: &JObject = #this;
211 unsafe {
212 env.call_method_unchecked(
213 this_obj,
214 method_id,
215 #return_type,
216 args,
217 )
218 }#conversion
219 },
220 });
221
222 Ok(method_call)
223 }
224
225 fn generate_method_desc(&self, MethodInfo { cls, name, sig, .. }: &MethodInfo) -> ItemStatic {
226 match self {
227 Self::Constructor => parse_quote! {
228 static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).constructor(#sig);
229 },
230 Self::Static => parse_quote! {
231 static METHOD_DESC: ::pourover::desc::StaticMethodDesc = (#cls).static_method(#name, #sig);
232 },
233 Self::Instance(_) => parse_quote! {
234 static METHOD_DESC: ::pourover::desc::MethodDesc = (#cls).method(#name, #sig);
235 },
236 }
237 }
238 }
239
240 /// Information about the method being called
241 pub struct MethodInfo {
242 cls: Expr,
243 name: Expr,
244 sig: LitStr,
245 /// Derived from `sig.value()`. This string must be stored in the struct so that we can return
246 /// a `MethodSig` instance that references it from `MethodInfo::sig()`.
247 sig_str: String,
248 }
249
250 impl MethodInfo {
new(cls: Expr, name: Expr, sig: LitStr) -> Self251 pub fn new(cls: Expr, name: Expr, sig: LitStr) -> Self {
252 let sig_str = sig.value();
253 Self {
254 cls,
255 name,
256 sig,
257 sig_str,
258 }
259 }
260
261 /// Parse the type signature from `sig`. Will return a [`CodegenError`] if the signature cannot
262 /// be parsed.
sig(&self) -> CodegenResult<MethodSig<'_>>263 fn sig(&self) -> CodegenResult<MethodSig<'_>> {
264 MethodSig::try_from_str(&self.sig_str)
265 .ok_or_else(|| CodegenError(self.sig.span(), ErrorKind::InvalidTypeSignature))
266 }
267 }
268
269 /// Generate a `jni::sys::jvalue` instance given a Java type and a Rust value.
270 ///
271 /// Safety: The generated `jvalue` must match the given type `ty`.
generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream272 fn generate_jvalue(ty: JavaType<'_>, expr: &Expr) -> TokenStream {
273 // The `jvalue` field to inhabit
274 let union_field: Ident;
275 // The expected input type
276 let type_name: Type;
277 // Whether we need to call `JObject::as_raw()` on the input type
278 let needs_as_raw: bool;
279
280 // Fill the above values based the type signature.
281 match ty {
282 JavaType::Array { depth, ty } => {
283 union_field = parse_quote![l];
284 if let NonArray::Primitive(p) = ty {
285 if depth.get() == 1 {
286 let prim_type = prim_to_sys_type(p);
287 type_name = parse_quote![::jni::objects::JPrimitiveArray<'_, #prim_type>]
288 } else {
289 type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
290 }
291 } else {
292 type_name = parse_quote![&::jni::objects::JObjectArray<'_>];
293 }
294 needs_as_raw = true;
295 }
296 JavaType::NonArray(NonArray::Object { cls }) => {
297 union_field = parse_quote![l];
298 type_name = match cls {
299 "java/lang/String" => parse_quote![&::jni::objects::JString<'_>],
300 "java/util/List" => parse_quote![&::jni::objects::JList<'_>],
301 "java/util/Map" => parse_quote![&::jni::objects::JMap<'_>],
302 _ => parse_quote![&::jni::objects::JObject<'_>],
303 };
304 needs_as_raw = true;
305 }
306 JavaType::NonArray(NonArray::Primitive(p)) => {
307 union_field = prim_to_union_field(p);
308 type_name = prim_to_sys_type(p);
309 needs_as_raw = false;
310 }
311 }
312
313 // The as_raw() tokens if required.
314 let as_raw = if needs_as_raw {
315 quote! { .as_raw() }
316 } else {
317 quote![]
318 };
319
320 // Create the `jvalue` expression. This uses `identity` to produce nice type error messages.
321 quote_spanned! { expr.span() =>
322 ::jni::sys::jvalue {
323 #union_field: ::core::convert::identity::<#type_name>(#expr) #as_raw
324 }
325 }
326 }
327
328 /// Get a `::jni::signature::ReturnType` expression from a [`crate::type_parser::ReturnType`]. This
329 /// value is passed to the `jni` crate so that it knows which JNI method to call.
return_type_from_sig(ret: ReturnType<'_>) -> Expr330 fn return_type_from_sig(ret: ReturnType<'_>) -> Expr {
331 let prim_type = |prim| parse_quote![::jni::signature::ReturnType::Primitive(::jni::signature::Primitive::#prim)];
332
333 use crate::type_parser::{JavaType::*, NonArray::*, Primitive::*};
334
335 match ret {
336 ReturnType::Void => prim_type(quote![Void]),
337 ReturnType::Returns(NonArray(Primitive(Boolean))) => prim_type(quote![Boolean]),
338 ReturnType::Returns(NonArray(Primitive(Byte))) => prim_type(quote![Byte]),
339 ReturnType::Returns(NonArray(Primitive(Char))) => prim_type(quote![Char]),
340 ReturnType::Returns(NonArray(Primitive(Double))) => prim_type(quote![Double]),
341 ReturnType::Returns(NonArray(Primitive(Float))) => prim_type(quote![Float]),
342 ReturnType::Returns(NonArray(Primitive(Int))) => prim_type(quote![Int]),
343 ReturnType::Returns(NonArray(Primitive(Long))) => prim_type(quote![Long]),
344 ReturnType::Returns(NonArray(Primitive(Short))) => prim_type(quote![Short]),
345 ReturnType::Returns(NonArray(Object { .. })) => {
346 parse_quote![::jni::signature::ReturnType::Object]
347 }
348 ReturnType::Returns(Array { .. }) => parse_quote![::jni::signature::ReturnType::Array],
349 }
350 }
351
352 /// A postfix call on a `jni::objects::JValueOwned` instance to convert it to the type specified by
353 /// `ret`. Since we have this information from the type signature we can perform this conversion
354 /// in the macro.
return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream355 fn return_value_conversion_from_sig(ret: ReturnType<'_>) -> TokenStream {
356 use crate::type_parser::{JavaType::*, NonArray::*};
357
358 match ret {
359 ReturnType::Void => quote! { .and_then(::jni::objects::JValueOwned::v) },
360 ReturnType::Returns(NonArray(Primitive(p))) => {
361 let prim = prim_to_union_field(p);
362 quote! { .and_then(::jni::objects::JValueOwned::#prim) }
363 }
364 ReturnType::Returns(NonArray(Object { cls })) => {
365 let mut conversion = quote! { .and_then(::jni::objects::JValueOwned::l) };
366 match cls {
367 "java/lang/String" => {
368 conversion.append_all(quote! { .map(::jni::objects::JString::from) });
369 }
370 "java/util/List" => {
371 conversion.append_all(quote! { .map(::jni::objects::JList::from) });
372 }
373 "java/util/Map" => {
374 conversion.append_all(quote! { .map(::jni::objects::JMap::from) });
375 }
376 _ => {
377 // Already a JObject, so we are good here
378 }
379 }
380 conversion
381 }
382 ReturnType::Returns(Array {
383 depth,
384 ty: Primitive(p),
385 }) if depth.get() == 1 => {
386 let sys_type = prim_to_sys_type(p);
387 quote! {
388 .and_then(::jni::objects::JValueOwned::l)
389 .map(::jni::objects::JPrimitiveArray::<#sys_type>::from)
390 }
391 }
392 ReturnType::Returns(Array { .. }) => quote! {
393 .and_then(::jni::objects::JValueOwned::l)
394 .map(::jni::objects::JObjectArray::from)
395 },
396 }
397 }
398
399 /// From a [`Primitive`], this gets the `jni::sys::jvalue` union field name for that type. This is
400 /// also the `jni::objects::JValueGen` getter name.
prim_to_union_field(p: Primitive) -> Ident401 fn prim_to_union_field(p: Primitive) -> Ident {
402 quote::format_ident!("{}", p.as_char().to_ascii_lowercase())
403 }
404
405 /// From a [`Primitive`], this gets the matching `jvalue::sys` type.
prim_to_sys_type(p: Primitive) -> Type406 fn prim_to_sys_type(p: Primitive) -> Type {
407 match p {
408 Primitive::Boolean => parse_quote![::jni::sys::jboolean],
409 Primitive::Byte => parse_quote![::jni::sys::jbyte],
410 Primitive::Char => parse_quote![::jni::sys::jchar],
411 Primitive::Double => parse_quote![::jni::sys::jdouble],
412 Primitive::Float => parse_quote![::jni::sys::jfloat],
413 Primitive::Int => parse_quote![::jni::sys::jint],
414 Primitive::Long => parse_quote![::jni::sys::jlong],
415 Primitive::Short => parse_quote![::jni::sys::jshort],
416 }
417 }
418
419 #[cfg(test)]
420 #[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
421 mod tests {
422 use super::*;
423 use crate::test_util::contains_ident;
424 use quote::ToTokens;
425 use syn::parse_quote;
426
example_method_call() -> MethodCall427 fn example_method_call() -> MethodCall {
428 MethodCall::new(
429 parse_quote![&mut env],
430 MethodInfo::new(
431 parse_quote![&FOO_CLS],
432 parse_quote!["example"],
433 parse_quote!["(II)I"],
434 ),
435 Receiver::Instance(parse_quote![&foo]),
436 vec![parse_quote![123], parse_quote![2 + 3]],
437 )
438 }
439
440 #[test]
args_are_counted()441 fn args_are_counted() {
442 let mut call = example_method_call();
443 call.arg_exprs.push(parse_quote![too_many]);
444
445 let CodegenError(_span, kind) = call.generate().unwrap_err();
446
447 assert_eq!(
448 ErrorKind::InvalidArgsLength {
449 expected: 2,
450 found: 3
451 },
452 kind
453 );
454 }
455
456 #[test]
constructor_return_type_is_void()457 fn constructor_return_type_is_void() {
458 let mut call = example_method_call();
459 call.receiver = Receiver::Constructor;
460
461 let CodegenError(_span, kind) = call.generate().unwrap_err();
462
463 assert_eq!(ErrorKind::ConstructorRetValShouldBeVoid, kind);
464 }
465
466 #[test]
invalid_type_sig_is_error()467 fn invalid_type_sig_is_error() {
468 let mut call = example_method_call();
469 call.method.sig = parse_quote!["L"];
470 call.method.sig_str = call.method.sig.value();
471
472 let CodegenError(_span, kind) = call.generate().unwrap_err();
473
474 assert_eq!(ErrorKind::InvalidTypeSignature, kind);
475 }
476
477 #[test]
jni_types_are_used_for_stdlib_classes_input()478 fn jni_types_are_used_for_stdlib_classes_input() {
479 let types = [
480 ("Ljava/lang/String;", "JString"),
481 ("Ljava/util/Map;", "JMap"),
482 ("Ljava/util/List;", "JList"),
483 ("[Ljava/lang/String;", "JObjectArray"),
484 ("[[I", "JObjectArray"),
485 ("[I", "JPrimitiveArray"),
486 ("Lcom/example/MyObject;", "JObject"),
487 ("Z", "jboolean"),
488 ("C", "jchar"),
489 ("B", "jbyte"),
490 ("S", "jshort"),
491 ("I", "jint"),
492 ("J", "jlong"),
493 ("F", "jfloat"),
494 ("D", "jdouble"),
495 ];
496
497 for (desc, jni_type) in types {
498 let jt = JavaType::try_from_str(desc).unwrap();
499 let expr = parse_quote![some_value];
500
501 let jvalue = generate_jvalue(jt, &expr);
502
503 assert!(
504 contains_ident(jvalue, jni_type),
505 "desc: {desc}, jni_type: {jni_type}"
506 );
507 }
508 }
509
510 #[test]
jni_types_are_used_for_stdlib_classes_output()511 fn jni_types_are_used_for_stdlib_classes_output() {
512 let types = [
513 ("Ljava/lang/String;", "JString"),
514 ("Ljava/util/Map;", "JMap"),
515 ("Ljava/util/List;", "JList"),
516 ("[Ljava/lang/String;", "JObjectArray"),
517 ("[[I", "JObjectArray"),
518 ("[I", "JPrimitiveArray"),
519 ];
520
521 for (desc, jni_type) in types {
522 let rt = ReturnType::try_from_str(desc).unwrap();
523
524 let conversion = return_value_conversion_from_sig(rt);
525
526 assert!(
527 contains_ident(conversion, jni_type),
528 "desc: {desc}, jni_type: {jni_type}"
529 );
530 }
531 }
532
533 #[test]
return_type_passed_to_jni_is_correct()534 fn return_type_passed_to_jni_is_correct() {
535 let types = [
536 ("Ljava/lang/String;", "Object"),
537 ("Ljava/util/Map;", "Object"),
538 ("Ljava/util/List;", "Object"),
539 ("[Ljava/lang/String;", "Array"),
540 ("[[I", "Array"),
541 ("[I", "Array"),
542 ("V", "Void"),
543 ("Z", "Boolean"),
544 ("C", "Char"),
545 ("B", "Byte"),
546 ("S", "Short"),
547 ("I", "Int"),
548 ("J", "Long"),
549 ("F", "Float"),
550 ("D", "Double"),
551 ];
552
553 for (desc, return_type) in types {
554 let rt = ReturnType::try_from_str(desc).unwrap();
555
556 let expr = return_type_from_sig(rt).into_token_stream();
557
558 assert!(
559 contains_ident(expr, return_type),
560 "desc: {desc}, return_type: {return_type}"
561 );
562 }
563 }
564
565 #[test]
method_desc_is_correct()566 fn method_desc_is_correct() {
567 let mut call = example_method_call();
568 call.method.sig = parse_quote!["(II)V"];
569 call.method.sig_str = call.method.sig.value();
570
571 let tests = [
572 (Receiver::Constructor, "constructor"),
573 (Receiver::Static, "static_method"),
574 (Receiver::Instance(parse_quote![this_value]), "method"),
575 ];
576
577 for (receiver, method_ident) in tests {
578 let desc = receiver.generate_method_desc(&call.method);
579 let rhs = desc.expr.into_token_stream();
580
581 assert!(contains_ident(rhs, method_ident), "method: {method_ident}");
582 }
583 }
584
585 #[test]
jni_call_is_correct()586 fn jni_call_is_correct() {
587 let mut call = example_method_call();
588 call.method.sig = parse_quote!["(II)V"];
589 call.method.sig_str = call.method.sig.value();
590 let sig = call.method.sig().unwrap();
591 let args = parse_quote![test_stub];
592
593 let tests = [
594 (Receiver::Constructor, "new_object_unchecked"),
595 (Receiver::Static, "call_static_method_unchecked"),
596 (
597 Receiver::Instance(parse_quote![this_value]),
598 "call_method_unchecked",
599 ),
600 ];
601
602 for (receiver, method_ident) in tests {
603 let call = receiver
604 .generate_call(&call.env, &call.method, &sig, &args)
605 .unwrap();
606
607 assert!(contains_ident(call, method_ident), "method: {method_ident}");
608 }
609 }
610 }
611