xref: /aosp_15_r20/external/cronet/third_party/jni_zero/java_types.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Copyright 2023 The Chromium Authors
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import dataclasses
6from typing import Dict
7from typing import Optional
8from typing import Tuple
9
10import java_lang_classes
11
12_CPP_TYPE_BY_JAVA_TYPE = {
13    'boolean': 'jboolean',
14    'byte': 'jbyte',
15    'char': 'jchar',
16    'double': 'jdouble',
17    'float': 'jfloat',
18    'int': 'jint',
19    'long': 'jlong',
20    'short': 'jshort',
21    'void': 'void',
22    'java/lang/Class': 'jclass',
23    'java/lang/String': 'jstring',
24    'java/lang/Throwable': 'jthrowable',
25}
26
27_DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE = {
28    'boolean': 'Z',
29    'byte': 'B',
30    'char': 'C',
31    'double': 'D',
32    'float': 'F',
33    'int': 'I',
34    'long': 'J',
35    'short': 'S',
36    'void': 'V',
37}
38
39_PRIMITIVE_TYPE_BY_DESCRIPTOR_CHAR = {
40    v: k
41    for k, v in _DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE.items()
42}
43
44_DEFAULT_VALUE_BY_PRIMITIVE_TYPE = {
45    'boolean': 'false',
46    'byte': '0',
47    'char': '0',
48    'double': '0',
49    'float': '0',
50    'int': '0',
51    'long': '0',
52    'short': '0',
53    'void': '',
54}
55
56PRIMITIVES = frozenset(_DEFAULT_VALUE_BY_PRIMITIVE_TYPE)
57
58
59@dataclasses.dataclass(frozen=True, order=True)
60class JavaClass:
61  """Represents a reference type."""
62  _fqn: str
63  # This is only meaningful if make_prefix have been called on the original class.
64  _class_without_prefix: 'JavaClass' = None
65
66  def __post_init__(self):
67    assert '.' not in self._fqn, f'{self._fqn} should have / and $, but not .'
68
69  def __str__(self):
70    return self.full_name_with_slashes
71
72  @property
73  def name(self):
74    return self._fqn.rsplit('/', 1)[-1]
75
76  @property
77  def name_with_dots(self):
78    return self.name.replace('$', '.')
79
80  @property
81  def nested_name(self):
82    return self.name.rsplit('$', 1)[-1]
83
84  @property
85  def package_with_slashes(self):
86    return self._fqn.rsplit('/', 1)[0]
87
88  @property
89  def package_with_dots(self):
90    return self.package_with_slashes.replace('/', '.')
91
92  @property
93  def full_name_with_slashes(self):
94    return self._fqn
95
96  @property
97  def full_name_with_dots(self):
98    return self._fqn.replace('/', '.').replace('$', '.')
99
100  @property
101  def prefix(self):
102    if self._class_without_prefix == None:
103      return ""
104    full_name_with_dots_without_prefix = self._class_without_prefix.full_name_with_dots
105    index = self.full_name_with_dots.find(full_name_with_dots_without_prefix)
106    return self.full_name_with_dots[:index-1] if index > 0 else ""
107
108  @property
109  def class_without_prefix(self):
110    return self._class_without_prefix if self._class_without_prefix else self
111
112  @property
113  def outer_class_name(self):
114    return self.name.split('$', 1)[0]
115
116  def is_nested(self):
117    return '$' in self.name
118
119  def get_outer_class(self):
120    return JavaClass(f'{self.package_with_slashes}/{self.outer_class_name}')
121
122  def is_system_class(self):
123    return self._fqn.startswith(('android/', 'java/'))
124
125  def to_java(self, type_resolver=None):
126    # Empty resolver used to shorted java.lang classes.
127    type_resolver = type_resolver or _EMPTY_TYPE_RESOLVER
128    return type_resolver.contextualize(self)
129
130  def as_type(self):
131    return JavaType(java_class=self)
132
133  def make_prefixed(self, prefix=None):
134    if not prefix:
135      return self
136    prefix = prefix.replace('.', '/')
137    return JavaClass(f'{prefix}/{self._fqn}', self)
138
139  def make_nested(self, name):
140    return JavaClass(f'{self._fqn}${name}')
141
142
143@dataclasses.dataclass(frozen=True)
144class JavaType:
145  """Represents a parameter or return type."""
146  array_dimensions: int = 0
147  primitive_name: Optional[str] = None
148  java_class: Optional[JavaClass] = None
149  annotations: Dict[str, Optional[str]] = \
150      dataclasses.field(default_factory=dict, compare=False)
151
152  @staticmethod
153  def from_descriptor(descriptor):
154    # E.g.: [Ljava/lang/Class;
155    without_arrays = descriptor.lstrip('[')
156    array_dimensions = len(descriptor) - len(without_arrays)
157    descriptor = without_arrays
158
159    if descriptor[0] == 'L':
160      assert descriptor[-1] == ';', 'invalid descriptor: ' + descriptor
161      return JavaType(array_dimensions=array_dimensions,
162                      java_class=JavaClass(descriptor[1:-1]))
163    primitive_name = _PRIMITIVE_TYPE_BY_DESCRIPTOR_CHAR[descriptor[0]]
164    return JavaType(array_dimensions=array_dimensions,
165                    primitive_name=primitive_name)
166
167  @property
168  def non_array_full_name_with_slashes(self):
169    return self.primitive_name or self.java_class.full_name_with_slashes
170
171  # Cannot use dataclass(order=True) because some fields are None.
172  def __lt__(self, other):
173    if self.primitive_name and not other.primitive_name:
174      return -1
175    if other.primitive_name and not self.primitive_name:
176      return 1
177    lhs = (self.array_dimensions, self.primitive_name or self.java_class)
178    rhs = (other.array_dimensions, other.primitive_name or other.java_class)
179    return lhs < rhs
180
181  def is_primitive(self):
182    return self.primitive_name is not None and self.array_dimensions == 0
183
184  def is_array(self):
185    return self.array_dimensions > 0
186
187  def is_primitive_array(self):
188    return self.primitive_name is not None and self.array_dimensions > 0
189
190  def is_object_array(self):
191    return self.array_dimensions > 1 or (self.primitive_name is None
192                                         and self.array_dimensions > 0)
193
194  def is_void(self):
195    return self.primitive_name == 'void'
196
197  def to_array_element_type(self):
198    assert self.is_array()
199    return JavaType(array_dimensions=self.array_dimensions - 1,
200                    primitive_name=self.primitive_name,
201                    java_class=self.java_class)
202
203  def to_descriptor(self):
204    """Converts a Java type into a JNI signature type."""
205    if self.primitive_name:
206      name = _DESCRIPTOR_CHAR_BY_PRIMITIVE_TYPE[self.primitive_name]
207    else:
208      name = f'L{self.java_class.full_name_with_slashes};'
209    return ('[' * self.array_dimensions) + name
210
211  def to_java(self, type_resolver=None):
212    if self.primitive_name:
213      ret = self.primitive_name
214    else:
215      ret = self.java_class.to_java(type_resolver)
216    return ret + '[]' * self.array_dimensions
217
218  def to_cpp(self):
219    """Returns a C datatype for the given java type."""
220    if self.array_dimensions > 1:
221      return 'jobjectArray'
222    if self.array_dimensions > 0 and self.primitive_name is None:
223      # There is no jstringArray.
224      return 'jobjectArray'
225
226    cpp_type = _CPP_TYPE_BY_JAVA_TYPE.get(self.non_array_full_name_with_slashes,
227                                          'jobject')
228    if self.array_dimensions:
229      cpp_type = f'{cpp_type}Array'
230    return cpp_type
231
232  def to_cpp_default_value(self):
233    """Returns a valid C return value for the given java type."""
234    if self.is_primitive():
235      return _DEFAULT_VALUE_BY_PRIMITIVE_TYPE[self.primitive_name]
236    return 'nullptr'
237
238  def to_proxy(self):
239    """Converts to types used over JNI boundary."""
240    # All object array types of become jobjectArray in native, but need to be
241    # passed as the original type on the java side.
242    if self.non_array_full_name_with_slashes in _CPP_TYPE_BY_JAVA_TYPE:
243      return self
244
245    # All other types should just be passed as Objects or Object arrays.
246    return dataclasses.replace(self, java_class=OBJECT_CLASS)
247
248  def converted_type(self):
249    """Returns a C datatype listed in the JniType annotation for this type."""
250    ret = self.annotations.get('JniType', None)
251    # Allow "std::vector" as shorthand for:
252    #     std::vector<jni_zero::ScopedJavaLocalRef<jobject>>
253    if ret == 'std::vector':
254      if self.is_object_array():
255        ret += '<jni_zero::ScopedJavaLocalRef<jobject>>'
256      elif self.is_array():
257        cpp_type = _CPP_TYPE_BY_JAVA_TYPE[self.non_array_full_name_with_slashes]
258        ret += f'<{cpp_type}>'
259      else:
260        # TODO(agrieve): This should be checked at parse time.
261        raise Exception(
262            'Found non-templatized @JniType("std::vector") on non-array type')
263    return ret
264
265
266@dataclasses.dataclass(frozen=True)
267class JavaParam:
268  """Represents a parameter."""
269  java_type: JavaType
270  name: str
271
272  def to_proxy(self):
273    """Converts to types used over JNI boundary."""
274    return JavaParam(self.java_type.to_proxy(), self.name)
275
276
277class JavaParamList(tuple):
278  """Represents a parameter list."""
279  def to_proxy(self):
280    """Converts to types used over JNI boundary."""
281    return JavaParamList(p.to_proxy() for p in self)
282
283  def to_java_declaration(self, type_resolver=None):
284    return ', '.join('%s %s' % (p.java_type.to_java(type_resolver), p.name)
285                     for p in self)
286
287  def to_call_str(self):
288    return ', '.join(p.name for p in self)
289
290
291@dataclasses.dataclass(frozen=True, order=True)
292class JavaSignature:
293  """Represents a method signature (return type + parameter types)."""
294  return_type: JavaType
295  param_types: Tuple[JavaType]
296  # Signatures should be considered equal if parameter names differ, so exclude
297  # param_list from comparisons.
298  param_list: JavaParamList = dataclasses.field(compare=False)
299
300  @staticmethod
301  def from_params(return_type, param_list):
302    return JavaSignature(return_type=return_type,
303                         param_types=tuple(p.java_type for p in param_list),
304                         param_list=param_list)
305
306  @staticmethod
307  def from_descriptor(descriptor):
308    # E.g.: (Ljava/lang/Object;Ljava/lang/Runnable;)Ljava/lang/Class;
309    assert descriptor[0] == '('
310    i = 1
311    start_idx = i
312    params = []
313    while True:
314      char = descriptor[i]
315      if char == ')':
316        break
317      elif char == '[':
318        i += 1
319        continue
320      elif char == 'L':
321        end_idx = descriptor.index(';', i) + 1
322      else:
323        end_idx = i + 1
324      param_type = JavaType.from_descriptor(descriptor[start_idx:end_idx])
325      params.append(JavaParam(param_type, f'p{len(params)}'))
326      i = end_idx
327      start_idx = end_idx
328
329    return_type = JavaType.from_descriptor(descriptor[i + 1:])
330    return JavaSignature.from_params(return_type, JavaParamList(params))
331
332  def to_descriptor(self):
333    """Returns the JNI signature."""
334    sb = ['(']
335    sb += [t.to_descriptor() for t in self.param_types]
336    sb += [')']
337    sb += [self.return_type.to_descriptor()]
338    return ''.join(sb)
339
340  def to_proxy(self):
341    """Converts to types used over JNI boundary."""
342    return_type = self.return_type.to_proxy()
343    param_list = self.param_list.to_proxy()
344    return JavaSignature.from_params(return_type, param_list)
345
346
347class TypeResolver:
348  """Converts type names to fully qualified names."""
349  def __init__(self, java_class):
350    self.java_class = java_class
351    self.imports = []
352    self.nested_classes = []
353
354  def add_import(self, java_class):
355    self.imports.append(java_class)
356
357  def add_nested_class(self, java_class):
358    self.nested_classes.append(java_class)
359
360  def contextualize(self, java_class):
361    """Return the shortest string that resolves to the given class."""
362    type_package = java_class.package_with_slashes
363    if type_package in ('java/lang', self.java_class.package_with_slashes):
364      return java_class.name_with_dots
365    if java_class in self.imports:
366      return java_class.name_with_dots
367
368    return java_class.full_name_with_dots
369
370  def resolve(self, name):
371    """Return a JavaClass for the given type name."""
372    assert name not in PRIMITIVES
373    assert ' ' not in name
374
375    if '/' in name:
376      # Coming from javap, use the fully qualified name directly.
377      return JavaClass(name)
378
379    if self.java_class.name == name:
380      return self.java_class
381
382    for clazz in self.nested_classes:
383      if name in (clazz.name, clazz.nested_name):
384        return clazz
385
386    # Is it from an import? (e.g. referencing Class from import pkg.Class).
387    for clazz in self.imports:
388      if name in (clazz.name, clazz.nested_name):
389        return clazz
390
391    # Is it an inner class from an outer class import? (e.g. referencing
392    # Class.Inner from import pkg.Class).
393    if '.' in name:
394      # Assume lowercase means it's a fully qualifited name.
395      if name[0].islower():
396        return JavaClass(name.replace('.', '/'))
397      # Otherwise, try and find the outer class in imports.
398      components = name.split('.')
399      outer = '/'.join(components[:-1])
400      inner = components[-1]
401      for clazz in self.imports:
402        if clazz.name == outer:
403          return clazz.make_nested(inner)
404      name = name.replace('.', '$')
405
406    # java.lang classes always take priority over types from the same package.
407    # To use a type from the same package that has the same name as a java.lang
408    # type, it must be explicitly imported.
409    if java_lang_classes.contains(name):
410      return JavaClass(f'java/lang/{name}')
411
412    # Type not found, falling back to same package as this class.
413    ret = JavaClass(f'{self.java_class.class_without_prefix.package_with_slashes}/{name}')
414    return ret if self.java_class.prefix == "" else ret.make_prefixed(self.java_class.prefix)
415
416
417CLASS_CLASS = JavaClass('java/lang/Class')
418OBJECT_CLASS = JavaClass('java/lang/Object')
419STRING_CLASS = JavaClass('java/lang/String')
420_EMPTY_TYPE_RESOLVER = TypeResolver(OBJECT_CLASS)
421CLASS = JavaType(java_class=CLASS_CLASS)
422LONG = JavaType(primitive_name='long')
423VOID = JavaType(primitive_name='void')
424EMPTY_PARAM_LIST = JavaParamList()
425