1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Cacheable descriptors. These `jni` crate compatible descriptors will cache their value on first
16 //! lookup. They are meant to be combined with static variables in order to globally cache their
17 //! associated id values.
18 //!
19 //! ### Example
20 //!
21 //! ```java
22 //! package com.example;
23 //!
24 //! public class MyClass {
25 //!     public int foo;
26 //!     public int getBar() { /* ... */ }
27 //! }
28 //! ```
29 //!
30 //! ```rust
31 //! use pourover::desc::*;
32 //!
33 //! static MY_CLASS_DESC: ClassDesc = ClassDesc::new("com/example/MyClass");
34 //! static MY_CLASS_FOO_FIELD: FieldDesc = MY_CLASS_DESC.field("foo", "I");
35 //! static MY_CLASS_GET_BAR_METHOD: MethodDesc = MY_CLASS_DESC.method("getBar", "I()");
36 //! ```
37 
38 #![allow(unsafe_code)]
39 
40 use jni::{
41     descriptors::Desc,
42     objects::{GlobalRef, JClass, JFieldID, JMethodID, JObject, JStaticFieldID, JStaticMethodID},
43     JNIEnv,
44 };
45 use std::sync::{LockResult, RwLock, RwLockReadGuard};
46 
47 /// JNI descriptor that caches a Java class.
48 pub struct ClassDesc {
49     /// The JNI descriptor for this class.
50     descriptor: &'static str,
51     /// The cached class
52     ///
53     /// The implementation assumes that `None` is only written to the lock when `&mut self` is
54     /// held. Only `Some` is valid to write to this lock while `&self` is held.
55     cls: RwLock<Option<GlobalRef>>,
56 }
57 
58 impl ClassDesc {
59     /// Create a new descriptor with the given JNI descriptor string.
new(descriptor: &'static str) -> Self60     pub const fn new(descriptor: &'static str) -> Self {
61         Self {
62             descriptor,
63             cls: RwLock::new(None),
64         }
65     }
66 
67     /// Create a new descriptor for a field member of this class.
field<'cls>(&'cls self, name: &'static str, sig: &'static str) -> FieldDesc<'cls>68     pub const fn field<'cls>(&'cls self, name: &'static str, sig: &'static str) -> FieldDesc<'cls> {
69         FieldDesc::new(self, name, sig)
70     }
71 
72     /// Create a new descriptor for a static field member of this class.
static_field<'cls>( &'cls self, name: &'static str, sig: &'static str, ) -> StaticFieldDesc<'cls>73     pub const fn static_field<'cls>(
74         &'cls self,
75         name: &'static str,
76         sig: &'static str,
77     ) -> StaticFieldDesc<'cls> {
78         StaticFieldDesc::new(self, name, sig)
79     }
80 
81     /// Create a new descriptor for a method member of this class.
method<'cls>( &'cls self, name: &'static str, sig: &'static str, ) -> MethodDesc<'cls>82     pub const fn method<'cls>(
83         &'cls self,
84         name: &'static str,
85         sig: &'static str,
86     ) -> MethodDesc<'cls> {
87         MethodDesc::new(self, name, sig)
88     }
89 
90     /// Create a new descriptor for a constructor for this class.
constructor<'cls>(&'cls self, sig: &'static str) -> MethodDesc<'cls>91     pub const fn constructor<'cls>(&'cls self, sig: &'static str) -> MethodDesc<'cls> {
92         MethodDesc::new(self, "<init>", sig)
93     }
94 
95     /// Create a new descriptor for a static method member of this class.
static_method<'cls>( &'cls self, name: &'static str, sig: &'static str, ) -> StaticMethodDesc<'cls>96     pub const fn static_method<'cls>(
97         &'cls self,
98         name: &'static str,
99         sig: &'static str,
100     ) -> StaticMethodDesc<'cls> {
101         StaticMethodDesc::new(self, name, sig)
102     }
103 
104     /// Free the cached GlobalRef to the class object. This will happen automatically on drop, but
105     /// this method is provided to allow the value to be dropped early. This can be used to perform
106     /// cleanup on a thread that is already attached to the JVM.
free(&mut self)107     pub fn free(&mut self) {
108         // Get a mutable reference ignoring poison state
109         let mut guard = self.cls.write().ignore_poison();
110         let global = guard.take();
111         // Drop the guard before global in case it panics. We don't want to poison the lock.
112         core::mem::drop(guard);
113         // Drop the GlobalRef value to cleanup
114         core::mem::drop(global);
115     }
116 
get_cached(&self) -> Option<CachedClass<'_>>117     fn get_cached(&self) -> Option<CachedClass<'_>> {
118         CachedClass::from_lock(&self.cls)
119     }
120 }
121 
122 /// Wrapper to allow RwLock references to be returned from Desc. Use the `AsRef` impl to get the
123 /// associated `JClass` reference. The inner `Option` must always be `Some`. This is enfocred by
124 /// the `from_lock` constructor.
125 pub struct CachedClass<'lock>(RwLockReadGuard<'lock, Option<GlobalRef>>);
126 
127 impl<'lock> CachedClass<'lock> {
128     /// Read from the given lock and create a `CachedClass` instance if the lock contains a cached
129     /// class value. The given lock must have valid data even if it is poisoned.
from_lock(lock: &'lock RwLock<Option<GlobalRef>>) -> Option<CachedClass<'lock>>130     fn from_lock(lock: &'lock RwLock<Option<GlobalRef>>) -> Option<CachedClass<'lock>> {
131         let guard = lock.read().ignore_poison();
132 
133         // Validate that there is a GlobalRef value already, otherwise avoid constructing `Self`.
134         if guard.is_some() {
135             Some(Self(guard))
136         } else {
137             None
138         }
139     }
140 }
141 
142 // Implement AsRef so that we can use this type as `Desc::Output` in [`ClassDesc`].
143 impl<'lock> AsRef<JClass<'static>> for CachedClass<'lock> {
as_ref(&self) -> &JClass<'static>144     fn as_ref(&self) -> &JClass<'static> {
145         // `unwrap` is valid since we checked for `Some` in the constructor.
146         #[allow(clippy::expect_used)]
147         let global = self
148             .0
149             .as_ref()
150             .expect("Created CachedClass in an invalid state");
151         // No direct conversion to JClass, so let's go through JObject first.
152         let obj: &JObject<'static> = global.as_ref();
153         // This assumes our object is a class object.
154         let cls: &JClass<'static> = From::from(obj);
155         cls
156     }
157 }
158 
159 /// # Safety
160 ///
161 /// This returns the correct class instance via `JNIEnv::find_class`. The cached class is held in a
162 /// [`GlobalRef`] so that it cannot be unloaded.
163 unsafe impl<'a, 'local> Desc<'local, JClass<'static>> for &'a ClassDesc {
164     type Output = CachedClass<'a>;
165 
lookup(self, env: &mut JNIEnv<'local>) -> jni::errors::Result<Self::Output>166     fn lookup(self, env: &mut JNIEnv<'local>) -> jni::errors::Result<Self::Output> {
167         // Check the cache
168         if let Some(cls) = self.get_cached() {
169             return Ok(cls);
170         }
171 
172         {
173             // Ignoring poison is fine because we only write fully-constructed values.
174             let mut guard = self.cls.write().ignore_poison();
175 
176             // Multiple threads could have hit this block at the same time. Only allocate the
177             // `GlobalRef` if it was not already allocated.
178             if guard.is_none() {
179                 let cls = env.find_class(self.descriptor)?;
180                 let global = env.new_global_ref(cls)?;
181 
182                 // Only directly assign valid values. That way poison state can't have broken
183                 // invariants. If the above panicked then it will poison without changing the
184                 // lock's data, and everything will still be fine albeit with a sprinkle of
185                 // possibly-leaked memory.
186                 *guard = Some(global);
187             }
188         }
189 
190         // Safe to unwrap since we just set `self.cls` to `Some`. `ClassDesc::free` can't be called
191         // before this point because it takes a mutable reference to `*self`.
192         #[allow(clippy::unwrap_used)]
193         Ok(self.get_cached().unwrap())
194     }
195 }
196 
197 /// A descriptor for a class member. `Id` is expected to implement the [`MemberId`] trait.
198 ///
199 /// See [`FieldDesc`], [`StaticFieldDesc`], [`MethodDesc`], and [`StaticMethodDesc`] aliases.
200 pub struct MemberDesc<'cls, Id> {
201     cls: &'cls ClassDesc,
202     name: &'static str,
203     sig: &'static str,
204     id: RwLock<Option<Id>>,
205 }
206 
207 /// Descriptor for a field.
208 pub type FieldDesc<'cls> = MemberDesc<'cls, JFieldID>;
209 /// Descriptor for a static field.
210 pub type StaticFieldDesc<'cls> = MemberDesc<'cls, JStaticFieldID>;
211 /// Descriptor for a method.
212 pub type MethodDesc<'cls> = MemberDesc<'cls, JMethodID>;
213 /// Descriptor for a static method.
214 pub type StaticMethodDesc<'cls> = MemberDesc<'cls, JStaticMethodID>;
215 
216 impl<'cls, Id: MemberId> MemberDesc<'cls, Id> {
217     /// Create a new descriptor for a member of the given class with the given name and signature.
218     ///
219     /// Please use the helpers on [`ClassDesc`] instead of directly calling this method.
new(cls: &'cls ClassDesc, name: &'static str, sig: &'static str) -> Self220     pub const fn new(cls: &'cls ClassDesc, name: &'static str, sig: &'static str) -> Self {
221         Self {
222             cls,
223             name,
224             sig,
225             id: RwLock::new(None),
226         }
227     }
228 
229     /// Get the class descriptor that this member is associated to.
cls(&self) -> &'cls ClassDesc230     pub const fn cls(&self) -> &'cls ClassDesc {
231         self.cls
232     }
233 }
234 
235 /// # Safety
236 ///
237 /// This returns the correct id. It is the same id obtained from the JNI. This id can be a pointer
238 /// in some JVM implementations. See trait [`MemberId`].
239 unsafe impl<'cls, 'local, Id: MemberId> Desc<'local, Id> for &MemberDesc<'cls, Id> {
240     type Output = Id;
241 
lookup(self, env: &mut JNIEnv<'local>) -> jni::errors::Result<Self::Output>242     fn lookup(self, env: &mut JNIEnv<'local>) -> jni::errors::Result<Self::Output> {
243         // Check the cache.
244         if let Some(id) = *self.id.read().ignore_poison() {
245             return Ok(id);
246         }
247 
248         {
249             // Ignoring poison is fine because we only write valid values.
250             let mut guard = self.id.write().ignore_poison();
251 
252             // Multiple threads could have hit this block at the same time. Only lookup the id if
253             // the lookup was not already performed.
254             if guard.is_none() {
255                 let id = Id::lookup(env, self)?;
256 
257                 // Only directly assign valid values. That way poison state can't have broken
258                 // invariants. If the above panicked then it will poison without changing the
259                 // lock's data and everything will still be fine.
260                 *guard = Some(id);
261 
262                 Ok(id)
263             } else {
264                 // Can unwrap since we just checked for `None`.
265                 #[allow(clippy::unwrap_used)]
266                 Ok(*guard.as_ref().unwrap())
267             }
268         }
269     }
270 }
271 
272 /// Helper trait that calls into `jni` to lookup the id values. This is specialized on the id's
273 /// type to call the correct lookup function.
274 ///
275 /// # Safety
276 ///
277 /// Implementers must be an ID returned from the JNI. `lookup` must be implemented such that the
278 /// returned ID matches the JNI descriptor given. All values must be sourced from the JNI APIs.
279 /// See [`::jni::descriptors::Desc`].
280 pub unsafe trait MemberId: Sized + Copy + AsRef<Self> {
281     /// Lookup the id of the given descriptor via the given environment.
lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>282     fn lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>;
283 }
284 
285 /// # Safety
286 ///
287 /// This fetches the matching ID from the JNI APIs.
288 unsafe impl MemberId for JFieldID {
lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>289     fn lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self> {
290         env.get_field_id(desc.cls(), desc.name, desc.sig)
291     }
292 }
293 
294 /// # Safety
295 ///
296 /// This fetches the matching ID from the JNI APIs.
297 unsafe impl MemberId for JStaticFieldID {
lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>298     fn lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self> {
299         env.get_static_field_id(desc.cls(), desc.name, desc.sig)
300     }
301 }
302 
303 /// # Safety
304 ///
305 /// This fetches the matching ID from the JNI APIs.
306 unsafe impl MemberId for JMethodID {
lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>307     fn lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self> {
308         env.get_method_id(desc.cls(), desc.name, desc.sig)
309     }
310 }
311 
312 /// # Safety
313 ///
314 /// This fetches the matching ID from the JNI APIs.
315 unsafe impl MemberId for JStaticMethodID {
lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self>316     fn lookup(env: &mut JNIEnv, desc: &MemberDesc<Self>) -> jni::errors::Result<Self> {
317         env.get_static_method_id(desc.cls(), desc.name, desc.sig)
318     }
319 }
320 
321 /// Internal helper to ignore the poison state of `LockResult`.
322 ///
323 /// The poison state occurs when a panic occurs during the lock's critical section. This means that
324 /// the invariants of the locked data that were protected by the lock may be broken. When this
325 /// trait is used below, it is used in scenarios where the locked data does not have invariants
326 /// that can be broken in this way. In these cases, the possibly-poisoned lock is being used purely
327 /// for synchronization, so the poison state may be ignored.
328 trait IgnoreLockPoison {
329     type Guard;
330 
331     /// Extract the inner `Guard` of this `LockResult`. This ignores whether the lock state is
332     /// poisoned or not.
ignore_poison(self) -> Self::Guard333     fn ignore_poison(self) -> Self::Guard;
334 }
335 
336 impl<G> IgnoreLockPoison for LockResult<G> {
337     type Guard = G;
ignore_poison(self) -> Self::Guard338     fn ignore_poison(self) -> Self::Guard {
339         self.unwrap_or_else(|poison| poison.into_inner())
340     }
341 }
342 
343 #[cfg(test)]
344 mod test {
345     use super::*;
346 
347     const DESC: &str = "com/example/Foo";
348     static FOO: ClassDesc = ClassDesc::new(DESC);
349 
350     static FIELD: FieldDesc = FOO.field("foo", "I");
351     static STATIC_FIELD: StaticFieldDesc = FOO.static_field("sfoo", "J");
352     static CONSTRUCTOR: MethodDesc = FOO.constructor("(I)V");
353     static METHOD: MethodDesc = FOO.method("mfoo", "()Z");
354     static STATIC_METHOD: StaticMethodDesc = FOO.static_method("smfoo", "()I");
355 
356     #[test]
class_desc_created_properly()357     fn class_desc_created_properly() {
358         assert_eq!(DESC, FOO.descriptor);
359         assert!(FOO.cls.read().ignore_poison().is_none());
360     }
361 
362     #[test]
field_desc_created_properly()363     fn field_desc_created_properly() {
364         assert!(std::ptr::eq(&FOO, FIELD.cls()));
365         assert_eq!("foo", FIELD.name);
366         assert_eq!("I", FIELD.sig);
367     }
368 
369     #[test]
static_field_desc_created_properly()370     fn static_field_desc_created_properly() {
371         assert!(std::ptr::eq(&FOO, STATIC_FIELD.cls()));
372         assert_eq!("sfoo", STATIC_FIELD.name);
373         assert_eq!("J", STATIC_FIELD.sig);
374     }
375 
376     #[test]
constructor_desc_created_properly()377     fn constructor_desc_created_properly() {
378         assert!(std::ptr::eq(&FOO, CONSTRUCTOR.cls()));
379         assert_eq!("<init>", CONSTRUCTOR.name);
380         assert_eq!("(I)V", CONSTRUCTOR.sig);
381     }
382 
383     #[test]
method_desc_created_properly()384     fn method_desc_created_properly() {
385         assert!(std::ptr::eq(&FOO, METHOD.cls()));
386         assert_eq!("mfoo", METHOD.name);
387         assert_eq!("()Z", METHOD.sig);
388     }
389 
390     #[test]
static_method_desc_created_properly()391     fn static_method_desc_created_properly() {
392         assert!(std::ptr::eq(&FOO, STATIC_METHOD.cls()));
393         assert_eq!("smfoo", STATIC_METHOD.name);
394         assert_eq!("()I", STATIC_METHOD.sig);
395     }
396 }
397