xref: /aosp_15_r20/external/sandboxed-api/sandboxed_api/tools/generator2/code_test.py (revision ec63e07ab9515d95e79c211197c445ef84cefa6a)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for code."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19from absl.testing import absltest
20from absl.testing import parameterized
21from clang import cindex
22from com_google_sandboxed_api.sandboxed_api.tools.generator2 import code
23from com_google_sandboxed_api.sandboxed_api.tools.generator2 import code_test_util
24
25CODE = """
26typedef int(fun*)(int,int);
27extern "C" int function_a(int x, int y) { return x + y; }
28extern "C" int function_b(int a, int b) { return a + b; }
29
30struct a {
31  void (*fun_ptr)(char, long);
32}
33"""
34
35
36def analyze_string(content, path='tmp.cc', limit_scan_depth=False):
37  """Returns Analysis object for in memory content."""
38  return analyze_strings(path, [(path, content)], limit_scan_depth)
39
40
41def analyze_strings(path, unsaved_files, limit_scan_depth=False):
42  """Returns Analysis object for in memory content."""
43  return code.Analyzer._analyze_file_for_tu(path, None, False, unsaved_files,
44                                            limit_scan_depth)
45
46
47class CodeAnalysisTest(parameterized.TestCase):
48
49  def testInMemoryFile(self):
50    translation_unit = analyze_string(CODE)
51    self.assertIsNotNone(translation_unit._tu.cursor)
52
53  def testSimpleASTTraversal(self):
54    translation_unit = analyze_string(CODE)
55
56    structs = 0
57    functions = 0
58    params = 0
59    typedefs = 0
60
61    for cursor in translation_unit._walk_preorder():
62      if cursor.kind == cindex.CursorKind.FUNCTION_DECL:
63        functions += 1
64      elif cursor.kind == cindex.CursorKind.STRUCT_DECL:
65        structs += 1
66      elif cursor.kind == cindex.CursorKind.PARM_DECL:
67        params += 1
68      elif cursor.kind == cindex.CursorKind.TYPEDEF_DECL:
69        typedefs += 1
70
71    self.assertEqual(functions, 2)
72    self.assertEqual(structs, 1)
73    self.assertEqual(params, 8)
74    self.assertEqual(typedefs, 1)
75
76  def testParseSkipFunctionBodies(self):
77    function_body = 'extern "C"  int function(bool a1) { return a1 ? 1 : 2; }'
78    translation_unit = analyze_string(function_body)
79    for cursor in translation_unit._walk_preorder():
80      if cursor.kind == cindex.CursorKind.FUNCTION_DECL:
81        # cursor.get_definition() is None when we skip parsing function bodies
82        self.assertIsNone(cursor.get_definition())
83
84  def testExternC(self):
85    translation_unit = analyze_string('extern "C" int function(char* a);')
86    cursor_kinds = [
87        x.kind
88        for x in translation_unit._walk_preorder()
89        if x.kind != cindex.CursorKind.MACRO_DEFINITION
90    ]
91    self.assertListEqual(cursor_kinds, [
92        cindex.CursorKind.TRANSLATION_UNIT, cindex.CursorKind.LINKAGE_SPEC,
93        cindex.CursorKind.FUNCTION_DECL, cindex.CursorKind.PARM_DECL
94    ])
95
96  @parameterized.named_parameters(
97      ('1:', '/tmp/test.h', 'tmp', 'tmp/test.h'),
98      ('2:', '/a/b/c/d/tmp/test.h', 'c/d', 'c/d/tmp/test.h'),
99      ('3:', '/tmp/test.h', None, '/tmp/test.h'),
100      ('4:', '/tmp/test.h', '', '/tmp/test.h'),
101      ('5:', '/tmp/test.h', 'xxx', 'xxx/test.h'),
102  )
103  def testGetIncludes(self, path, prefix, expected):
104    function_body = 'extern "C" int function(bool a1) { return a1 ? 1 : 2; }'
105    translation_unit = analyze_string(function_body)
106    for cursor in translation_unit._walk_preorder():
107      if cursor.kind == cindex.CursorKind.FUNCTION_DECL:
108        fn = code.Function(translation_unit, cursor)
109        fn.get_absolute_path = lambda: path
110        self.assertEqual(fn.get_include_path(prefix), expected)
111
112  def testCodeGeneratorOutput(self):
113    body = """
114      extern "C" {
115        int function_a(int x, int y) { return x + y; }
116
117        int types_1(bool a0, unsigned char a1, char a2, unsigned short a3, short a4);
118        int types_2(int a0, unsigned int a1, long a2, unsigned long a3);
119        int types_3(long long a0, unsigned long long a1, float a2, double a3);
120        int types_4(signed char a0, signed short a1, signed int a2, signed long a3);
121        int types_5(signed long long a0, long double a1);
122        void types_6(char* a0);
123      }
124    """
125    functions = [
126        'function_a', 'types_1', 'types_2', 'types_3', 'types_4', 'types_5',
127        'types_6'
128    ]
129    generator = code.Generator([analyze_string(body)])
130    result = generator.generate('Test', functions, 'sapi::Tests', None, None)
131    self.assertMultiLineEqual(code_test_util.CODE_GOLD, result)
132
133  def testElaboratedArgument(self):
134    body = """
135      struct x { int a; };
136      extern "C" int function(struct x a) { return a.a; }
137    """
138    generator = code.Generator([analyze_string(body)])
139    with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
140      generator.generate('Test', ['function'], 'sapi::Tests', None, None)
141
142  def testElaboratedArgument2(self):
143    body = """
144      typedef struct { int a; char b; } x;
145      extern "C" int function(x a) { return a.a; }
146    """
147    generator = code.Generator([analyze_string(body)])
148    with self.assertRaisesRegex(ValueError, r'Elaborate.*mapped.*'):
149      generator.generate('Test', ['function'], 'sapi::Tests', None, None)
150
151  def testGetMappedType(self):
152    body = """
153      typedef unsigned int uint;
154      typedef uint* uintp;
155      extern "C" uint function(uintp a) { return *a; }
156    """
157    generator = code.Generator([analyze_string(body)])
158    result = generator.generate('Test', [], 'sapi::Tests', None, None)
159    self.assertMultiLineEqual(code_test_util.CODE_GOLD_MAPPED, result)
160
161  @parameterized.named_parameters(
162      ('1:', '/tmp/test.h', '_TMP_TEST_H_'),
163      ('2:', 'tmp/te-st.h', 'TMP_TE_ST_H_'),
164      ('3:', 'tmp/te-st.h.gen', 'TMP_TE_ST_H_'),
165      ('4:', 'xx/genfiles/tmp/te-st.h', 'TMP_TE_ST_H_'),
166      ('5:', 'xx/genfiles/tmp/te-st.h.gen', 'TMP_TE_ST_H_'),
167      ('6:', 'xx/genfiles/.gen/tmp/te-st.h', '_GEN_TMP_TE_ST_H_'),
168  )
169  def testGetHeaderGuard(self, path, expected):
170    self.assertEqual(code.get_header_guard(path), expected)
171
172  @parameterized.named_parameters(
173      ('function with return value and arguments',
174       'extern "C" int function(bool arg_bool, char* arg_ptr);',
175       ['arg_bool', 'arg_ptr']),
176      ('function without return value and no arguments',
177       'extern "C" void function();', []),
178  )
179  def testArgumentNames(self, body, names):
180    generator = code.Generator([analyze_string(body)])
181    functions = generator._get_functions()
182    self.assertLen(functions, 1)
183    self.assertLen(functions[0].argument_types, len(names))
184    # Extra check for generation, in case rendering throws error for this test.
185    generator.generate('Test', [], 'sapi::Tests', None, None)
186    for t in functions[0].argument_types:
187      self.assertIn(t.name, names)
188
189  def testStaticFunctions(self):
190    body = 'static int function() { return 7; };'
191    generator = code.Generator([analyze_string(body)])
192    self.assertEmpty(generator._get_functions())
193
194  def testEnumGeneration(self):
195    body = """
196      enum ProcessStatus {
197        OK = 0,
198        ERROR = 1,
199      };
200
201      extern "C" ProcessStatus ProcessDatapoint(ProcessStatus status) {
202        return status;
203      }
204    """
205    generator = code.Generator([analyze_string(body)])
206    result = generator.generate('Test', [], 'sapi::Tests', None, None)
207    self.assertMultiLineEqual(code_test_util.CODE_ENUM_GOLD, result)
208
209  def testTypeEq(self):
210    body = """
211    typedef unsigned int uint;
212    extern "C" void function(uint a1, uint a2, char a3);
213    """
214    generator = code.Generator([analyze_string(body)])
215    functions = generator._get_functions()
216    self.assertLen(functions, 1)
217
218    args = functions[0].arguments()
219    self.assertLen(args, 3)
220    self.assertEqual(args[0], args[1])
221    self.assertNotEqual(args[0], args[2])
222    self.assertNotEqual(args[1], args[2])
223
224    self.assertLen(set(args), 2)
225    # Extra check for generation, in case rendering throws error for this test.
226    generator.generate('Test', [], 'sapi::Tests', None, None)
227
228  def testTypedefRelatedTypes(self):
229    body = """
230      typedef unsigned int uint;
231      typedef uint* uint_p;
232      typedef uint_p* uint_pp;
233
234      typedef struct data {
235        int a;
236        int b;
237      } data_s;
238      typedef data_s* data_p;
239
240      extern "C" uint function_using_typedefs(uint_p a1, uint_pp a2, data_p a3);
241    """
242    generator = code.Generator([analyze_string(body)])
243    functions = generator._get_functions()
244    self.assertLen(functions, 1)
245
246    args = functions[0].arguments()
247    self.assertLen(args, 3)
248
249    types = args[0].get_related_types()
250    names = [t._clang_type.spelling for t in types]
251    self.assertLen(types, 2)
252    self.assertSameElements(names, ['uint_p', 'uint'])
253
254    types = args[1].get_related_types()
255    names = [t._clang_type.spelling for t in types]
256    self.assertLen(types, 3)
257    self.assertSameElements(names, ['uint_pp', 'uint_p', 'uint'])
258
259    types = args[2].get_related_types()
260    names = [t._clang_type.spelling for t in types]
261    self.assertLen(types, 2)
262    self.assertSameElements(names, ['data_s', 'data_p'])
263
264    # Extra check for generation, in case rendering throws error for this test.
265    generator.generate('Test', [], 'sapi::Tests', None, None)
266
267  def testTypedefDuplicateType(self):
268    body = """
269      typedef struct data {
270        int a;
271        int b;
272      } data_s;
273
274      struct s {
275        struct data* f1;
276      };
277
278      extern "C" uint function_using_typedefs(struct s* a1, data_s* a2);
279    """
280    generator = code.Generator([analyze_string(body)])
281    functions = generator._get_functions()
282    self.assertLen(functions, 1)
283
284    args = functions[0].arguments()
285    self.assertLen(args, 2)
286
287    types = generator._get_related_types()
288    self.assertLen(generator.translation_units[0].types_to_skip, 1)
289
290    names = [t._clang_type.spelling for t in types]
291    self.assertSameElements(['data_s', 's'], names)
292
293    # Extra check for generation, in case rendering throws error for this test.
294    generator.generate('Test', [], 'sapi::Tests', None, None)
295
296  def testStructureRelatedTypes(self):
297    body = """
298      typedef unsigned int uint;
299
300      typedef struct {
301        uint a;
302        struct {
303          int a;
304          int b;
305        } b;
306      } struct_1;
307
308      struct struct_2 {
309        uint a;
310        char b;
311        struct_1* c;
312      };
313
314      typedef struct a {
315        int b;
316      } struct_a;
317
318      extern "C" int function_using_structures(struct struct_2* a1, struct_1* a2,
319      struct_a* a3);
320    """
321    generator = code.Generator([analyze_string(body)])
322    functions = generator._get_functions()
323    self.assertLen(functions, 1)
324
325    args = functions[0].arguments()
326    self.assertLen(args, 3)
327
328    types = args[0].get_related_types()
329    names = [t._clang_type.spelling for t in types]
330    self.assertLen(types, 3)
331    self.assertSameElements(names, ['struct_2', 'uint', 'struct_1'])
332
333    types = args[1].get_related_types()
334    names = [t._clang_type.spelling for t in types]
335    self.assertLen(types, 2)
336    self.assertSameElements(names, ['struct_1', 'uint'])
337
338    names = [t._clang_type.spelling for t in generator._get_related_types()]
339    self.assertEqual(names, ['uint', 'struct_1', 'struct_2', 'struct_a'])
340
341    types = args[2].get_related_types()
342    self.assertLen(types, 1)
343
344    # Extra check for generation, in case rendering throws error for this test.
345    generator.generate('Test', [], 'sapi::Tests', None, None)
346
347  def testUnionRelatedTypes(self):
348    body = """
349      typedef unsigned int uint;
350
351      typedef union {
352        uint a;
353        union {
354          int a;
355          int b;
356        } b;
357      } union_1;
358
359      union union_2 {
360        uint a;
361        char b;
362        union_1* c;
363      };
364
365      extern "C" int function_using_unions(union union_2* a1, union_1* a2);
366    """
367    generator = code.Generator([analyze_string(body)])
368    functions = generator._get_functions()
369    self.assertLen(functions, 1)
370
371    args = functions[0].arguments()
372    self.assertLen(args, 2)
373
374    types = args[0].get_related_types()
375    names = [t._clang_type.spelling for t in types]
376    self.assertLen(types, 3)
377    self.assertSameElements(names, ['union_2', 'uint', 'union_1'])
378
379    types = args[1].get_related_types()
380    names = [t._clang_type.spelling for t in types]
381    self.assertLen(types, 2)
382    self.assertSameElements(names, ['union_1', 'uint'])
383
384    # Extra check for generation, in case rendering throws error for this test.
385    generator.generate('Test', [], 'sapi::Tests', None, None)
386
387  def testFunctionPointerRelatedTypes(self):
388    body = """
389      typedef unsigned int uint;
390      typedef unsigned char uchar;
391      typedef uint (*funcp)(uchar, uchar);
392
393      struct struct_1 {
394        uint (*func)(uchar);
395        int a;
396      };
397
398      extern "C" void function(struct struct_1* a1, funcp a2);
399    """
400    generator = code.Generator([analyze_string(body)])
401    functions = generator._get_functions()
402    self.assertLen(functions, 1)
403
404    args = functions[0].arguments()
405    self.assertLen(args, 2)
406
407    types = args[0].get_related_types()
408    names = [t._clang_type.spelling for t in types]
409    self.assertLen(types, 3)
410    self.assertSameElements(names, ['struct_1', 'uint', 'uchar'])
411
412    types = args[1].get_related_types()
413    names = [t._clang_type.spelling for t in types]
414    self.assertLen(types, 3)
415    self.assertSameElements(names, ['funcp', 'uint', 'uchar'])
416
417    # Extra check for generation, in case rendering throws error for this test.
418    generator.generate('Test', [], 'sapi::Tests', None, None)
419
420  def testForwardDeclaration(self):
421    body = """
422      struct struct_6_def;
423      typedef struct struct_6_def struct_6;
424      typedef struct_6* struct_6p;
425      typedef void (*function_p3)(struct_6p);
426      struct struct_6_def {
427        function_p3 fn;
428      };
429
430      extern "C" void function_using_type_loop(struct_6p a1);
431    """
432    generator = code.Generator([analyze_string(body)])
433    functions = generator._get_functions()
434    self.assertLen(functions, 1)
435
436    args = functions[0].arguments()
437    self.assertLen(args, 1)
438
439    types = args[0].get_related_types()
440    names = [t._clang_type.spelling for t in types]
441    self.assertLen(types, 4)
442    self.assertSameElements(
443        names, ['struct_6p', 'struct_6', 'struct_6_def', 'function_p3'])
444
445    self.assertLen(generator.translation_units, 1)
446    self.assertLen(generator.translation_units[0].forward_decls, 1)
447
448    t = next(
449        x for x in types if x._clang_type.spelling == 'struct_6_def')
450    self.assertIn(t, generator.translation_units[0].forward_decls)
451
452    names = [t._clang_type.spelling for t in generator._get_related_types()]
453    self.assertEqual(
454        names, ['struct_6', 'struct_6p', 'function_p3', 'struct_6_def'])
455
456    # Extra check for generation, in case rendering throws error for this test.
457    forward_decls = generator._get_forward_decls(generator._get_related_types())
458    self.assertLen(forward_decls, 1)
459    self.assertEqual(forward_decls[0], 'struct struct_6_def;')
460    generator.generate('Test', [], 'sapi::Tests', None, None)
461
462  def testEnumRelatedTypes(self):
463    body = """
464      enum Enumeration { ONE, TWO, THREE };
465      typedef enum Numbers { UNKNOWN, FIVE = 5, SE7EN = 7 } Nums;
466      typedef enum { SIX = 6, TEN = 10 } SixOrTen;
467      enum class Color : long long { RED, GREEN = 20, BLUE };  // NOLINT
468      enum struct Direction { LEFT = 'l', RIGHT = 'r' };
469      enum __rlimit_resource {  RLIMIT_CPU = 0, RLIMIT_MEM = 1};
470
471      extern "C" int function_using_enums(Enumeration a1, SixOrTen a2, Color a3,
472                           Direction a4, Nums a5, enum __rlimit_resource a6);
473     """
474    generator = code.Generator([analyze_string(body)])
475    functions = generator._get_functions()
476    self.assertLen(functions, 1)
477
478    args = functions[0].arguments()
479    self.assertLen(args, 6)
480
481    self.assertLen(args[0].get_related_types(), 1)
482    self.assertLen(args[1].get_related_types(), 1)
483    self.assertLen(args[2].get_related_types(), 1)
484    self.assertLen(args[3].get_related_types(), 1)
485    self.assertLen(args[4].get_related_types(), 1)
486    self.assertLen(args[5].get_related_types(), 1)
487
488    # Extra check for generation, in case rendering throws error for this test.
489    generator.generate('Test', [], 'sapi::Tests', None, None)
490
491  def testArrayAsParam(self):
492    body = """
493      extern "C" int function_using_enums(char a[10], char *const __argv[]);
494     """
495    generator = code.Generator([analyze_string(body)])
496    functions = generator._get_functions()
497    self.assertLen(functions, 1)
498
499    args = functions[0].arguments()
500    self.assertLen(args, 2)
501
502  @parameterized.named_parameters(
503      ('uint < ushort  ', 'assertLess', 1, 2),
504      ('uint < chr     ', 'assertLess', 1, 3),
505      ('uint < uchar   ', 'assertLess', 1, 4),
506      ('uint < u32     ', 'assertLess', 1, 5),
507      ('uint < ulong   ', 'assertLess', 1, 6),
508      ('ushort < chr   ', 'assertLess', 2, 3),
509      ('ushort < uchar ', 'assertLess', 2, 4),
510      ('ushort < u32   ', 'assertLess', 2, 5),
511      ('ushort < ulong ', 'assertLess', 2, 6),
512      ('chr < uchar    ', 'assertLess', 3, 4),
513      ('chr < u32      ', 'assertLess', 3, 5),
514      ('chr < ulong    ', 'assertLess', 3, 6),
515      ('uchar < u32    ', 'assertLess', 4, 5),
516      ('uchar < ulong  ', 'assertLess', 4, 6),
517      ('u32 < ulong    ', 'assertLess', 5, 6),
518      ('ushort > uint  ', 'assertGreater', 2, 1),
519      ('chr > uint     ', 'assertGreater', 3, 1),
520      ('uchar > uint   ', 'assertGreater', 4, 1),
521      ('u32 > uint     ', 'assertGreater', 5, 1),
522      ('ulong > uint   ', 'assertGreater', 6, 1),
523      ('chr > ushort   ', 'assertGreater', 3, 2),
524      ('uchar > ushort ', 'assertGreater', 4, 2),
525      ('u32 > ushort   ', 'assertGreater', 5, 2),
526      ('ulong > ushort ', 'assertGreater', 6, 2),
527      ('uchar > chr    ', 'assertGreater', 4, 3),
528      ('u32 > chr      ', 'assertGreater', 5, 3),
529      ('ulong > chr    ', 'assertGreater', 6, 3),
530      ('u32 > uchar    ', 'assertGreater', 5, 4),
531      ('ulong > uchar  ', 'assertGreater', 6, 4),
532      ('ulong > u32    ', 'assertGreater', 6, 5),
533  )
534  def testTypeOrder(self, func, a1, a2):
535    """Checks if comparison functions of Type class work properly.
536
537    This is necessary for Generator._get_related_types to return types in
538    proper order, ready to be emitted in the generated file. To be more
539    specific: emitted types will be ordered in a way that would allow
540    compilation ie. if structure field type is a typedef, typedef definition
541    will end up before structure definition.
542
543    Args:
544      func: comparison assert to call
545      a1: function argument number to take the type to compare
546      a2: function argument number to take the type to compare
547    """
548
549    file1_code = """
550    typedef unsigned int uint;
551    #include "/f2.h"
552    typedef uint u32;
553    #include "/f3.h"
554
555    struct args {
556      u32 a;
557      uchar b;
558      ulong c;
559      ushort d;
560      chr e;
561    };
562    extern "C" int function(struct args* a0, uint a1, ushort a2, chr a3,
563                 uchar a4, u32 a5, ulong a6, struct args* a7);
564    """
565    file2_code = """
566    typedef unsigned short ushort;
567    #include "/f4.h"
568    typedef unsigned char uchar;"""
569    file3_code = 'typedef unsigned long ulong;'
570    file4_code = 'typedef char chr;'
571    files = [('f1.h', file1_code), ('/f2.h', file2_code), ('/f3.h', file3_code),
572             ('/f4.h', file4_code)]
573    generator = code.Generator([analyze_strings('f1.h', files)])
574    functions = generator._get_functions()
575    self.assertLen(functions, 1)
576
577    args = functions[0].arguments()
578    getattr(self, func)(args[a1], args[a2])
579    # Extra check for generation, in case rendering throws error for this test.
580    generator.generate('Test', [], 'sapi::Tests', None, None)
581
582  def testFilterFunctionsFromInputFilesOnly(self):
583    file1_code = """
584      #include "/f2.h"
585
586      extern "C" int function1();
587    """
588    file2_code = """
589      extern "C" int function2();
590    """
591
592    files = [('f1.h', file1_code), ('/f2.h', file2_code)]
593    generator = code.Generator([analyze_strings('f1.h', files)])
594    functions = generator._get_functions()
595    self.assertLen(functions, 2)
596
597    generator = code.Generator([analyze_strings('f1.h', files, True)])
598    functions = generator._get_functions()
599    self.assertLen(functions, 1)
600
601  def testTypeToString(self):
602    body = """
603      #define SIZE 1024
604      typedef unsigned int uint;
605
606      typedef struct {
607      #if SOME_DEFINE >= 12 \
608      && SOME_OTHER == 13
609        uint a;
610      #else
611        uint aa;
612      #endif
613        struct {
614          uint a;
615          int b;
616          char c[SIZE];
617        } b;
618      } struct_1;
619
620      extern "C" int function_using_structures(struct_1* a1);
621    """
622
623    # pylint: disable=trailing-whitespace
624    expected = """typedef struct {
625#if SOME_DEFINE >= 12 && SOME_OTHER == 13
626\tuint a ;
627#else
628\tuint aa ;
629#endif
630\tstruct {
631\t\tuint a ;
632\t\tint b ;
633\t\tchar c [ SIZE ] ;
634\t} b ;
635} struct_1"""
636    generator = code.Generator([analyze_string(body)])
637    functions = generator._get_functions()
638    self.assertLen(functions, 1)
639
640    types = generator._get_related_types()
641    self.assertLen(types, 2)
642    self.assertEqual('typedef unsigned int uint', types[0].stringify())
643    self.assertMultiLineEqual(expected, types[1].stringify())
644
645    # Extra check for generation, in case rendering throws error for this test.
646    generator.generate('Test', [], 'sapi::Tests', None, None)
647
648  def testCollectDefines(self):
649    body = """
650      #define SIZE 1024
651      #define NOT_USED 7
652      #define SIZE2 2*1024
653      #define SIZE3 1337
654      #define SIZE4 10
655      struct test {
656        int a[SIZE];
657        char b[SIZE2];
658        float c[777];
659        int (*d)[SIZE3*SIZE4];
660      };
661      extern "C" int function_1(struct test* a1);
662    """
663    generator = code.Generator([analyze_string(body)])
664    self.assertLen(generator.translation_units, 1)
665
666    generator._get_related_types()
667    tu = generator.translation_units[0]
668    tu._process()
669
670    self.assertLen(tu.required_defines, 4)
671    defines = generator._get_defines()
672    self.assertLen(defines, 4)
673    self.assertIn('#define SIZE 1024', defines)
674    self.assertIn('#define SIZE2 2 * 1024', defines)
675    self.assertIn('#define SIZE3 1337', defines)
676    self.assertIn('#define SIZE4 10', defines)
677
678    # Extra check for generation, in case rendering throws error for this test.
679    generator.generate('Test', [], 'sapi::Tests', None, None)
680
681  def testYaraCase(self):
682    body = """
683      #define YR_ALIGN(n) __attribute__((aligned(n)))
684      #define DECLARE_REFERENCE(type, name) union {    \
685        type name;            \
686        int64_t name##_;      \
687      } YR_ALIGN(8)
688      struct YR_NAMESPACE {
689        int32_t t_flags[1337];
690        DECLARE_REFERENCE(char*, name);
691      };
692
693      extern "C" int function_1(struct YR_NAMESPACE* a1);
694    """
695    generator = code.Generator([analyze_string(body)])
696    self.assertLen(generator.translation_units, 1)
697
698    generator._get_related_types()
699    tu = generator.translation_units[0]
700    tu._process()
701
702    self.assertLen(tu.required_defines, 2)
703    defines = generator._get_defines()
704    # _get_defines will add dependant defines to tu.required_defines
705    self.assertLen(defines, 2)
706    gold = '#define DECLARE_REFERENCE('
707    # DECLARE_REFERENCE must be second to pass this test
708    self.assertTrue(defines[1].startswith(gold))
709
710    # Extra check for generation, in case rendering throws error for this test.
711    generator.generate('Test', [], 'sapi::Tests', None, None)
712
713  def testDoubleFunction(self):
714    body = """
715      extern "C" int function_1(int a);
716      extern "C" int function_1(int a) {
717        return a + 1;
718      };
719    """
720    generator = code.Generator([analyze_string(body)])
721    self.assertLen(generator.translation_units, 1)
722
723    tu = generator.translation_units[0]
724    tu._process()
725
726    self.assertLen(tu.functions, 1)
727
728    # Extra check for generation, in case rendering throws error for this test.
729    generator.generate('Test', [], 'sapi::Tests', None, None)
730
731  def testDefineStructBody(self):
732    body = """
733      #define STRUCT_BODY \
734      int a;  \
735      char b; \
736      long c
737      struct test {
738        STRUCT_BODY;
739      };
740      extern "C" void function(struct test* a1);
741    """
742
743    generator = code.Generator([analyze_string(body)])
744    self.assertLen(generator.translation_units, 1)
745
746    # initialize all internal data
747    generator.generate('Test', [], 'sapi::Tests', None, None)
748    tu = generator.translation_units[0]
749
750    self.assertLen(tu.functions, 1)
751    self.assertLen(tu.required_defines, 1)
752
753  def testJpegTurboCase(self):
754    body = """
755      typedef short JCOEF;
756      #define DCTSIZE2 1024
757      typedef JCOEF JBLOCK[DCTSIZE2];
758
759      extern "C" void function(JBLOCK* a);
760    """
761    generator = code.Generator([analyze_string(body)])
762    self.assertLen(generator.translation_units, 1)
763
764    # initialize all internal data
765    generator.generate('Test', [], 'sapi::Tests', None, None)
766
767    tu = generator.translation_units[0]
768    self.assertLen(tu.functions, 1)
769    self.assertLen(generator._get_defines(), 1)
770    self.assertLen(generator._get_related_types(), 2)
771
772  def testMultipleTypesWhenConst(self):
773    body = """
774      struct Instance {
775        void* instance = nullptr;
776        void* state_memory = nullptr;
777        void* scratch_memory = nullptr;
778      };
779
780      extern "C" void function1(Instance* a);
781      extern "C" void function2(const Instance* a);
782    """
783    generator = code.Generator([analyze_string(body)])
784    self.assertLen(generator.translation_units, 1)
785
786    # Initialize all internal data
787    generator.generate('Test', [], 'sapi::Tests', None, None)
788
789    tu = generator.translation_units[0]
790    self.assertLen(tu.functions, 2)
791    self.assertLen(generator._get_related_types(), 1)
792
793  def testReference(self):
794    body = """
795      struct Instance {
796        int a;
797      };
798
799      void Function1(Instance& a, Instance&& a);
800    """
801    generator = code.Generator([analyze_string(body)])
802    self.assertLen(generator.translation_units, 1)
803
804    # Initialize all internal data
805    generator.generate('Test', [], 'sapi::Tests', None, None)
806
807    tu = generator.translation_units[0]
808    self.assertLen(tu.functions, 1)
809
810    # this will return 0 related types because function will be mangled and
811    # filtered out by generator
812    self.assertEmpty(generator._get_related_types())
813    self.assertLen(next(iter(tu.functions)).get_related_types(), 1)
814
815  def testCppHeader(self):
816    path = 'tmp.h'
817    content = """
818      int sum(int a, float b);
819
820      extern "C" int sum(int a, float b);
821    """
822    unsaved_files = [(path, content)]
823    generator = code.Generator([analyze_strings(path, unsaved_files)])
824    # Initialize all internal data
825    generator.generate('Test', [], 'sapi::Tests', None, None)
826
827    # generator should filter out mangled function
828    functions = generator._get_functions()
829    self.assertLen(functions, 1)
830
831    tu = generator.translation_units[0]
832    functions = tu.get_functions()
833    self.assertLen(functions, 2)
834
835    mangled_names = [f.cursor.mangled_name for f in functions]
836    self.assertSameElements(mangled_names, ['sum', '_Z3sumif'])
837
838
839if __name__ == '__main__':
840  absltest.main()
841