xref: /aosp_15_r20/external/minijail/tools/compiler_unittest.py (revision 4b9c6d91573e8b3a96609339b46361b5476dd0f9)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17"""Unittests for the compiler module."""
18
19from __future__ import print_function
20
21import os
22import random
23import shutil
24import tempfile
25import unittest
26from importlib import resources
27
28import arch
29import bpf
30import compiler
31import parser  # pylint: disable=wrong-import-order
32
33ARCH_64 = arch.Arch.load_from_json_bytes(
34    resources.files("testdata").joinpath("arch_64.json").read_bytes()
35)
36
37
38class CompileFilterStatementTests(unittest.TestCase):
39    """Tests for PolicyCompiler.compile_filter_statement."""
40
41    def setUp(self):
42        self.arch = ARCH_64
43        self.compiler = compiler.PolicyCompiler(self.arch)
44
45    def _compile(self, line):
46        with tempfile.NamedTemporaryFile(mode='w') as policy_file:
47            policy_file.write(line)
48            policy_file.flush()
49            policy_parser = parser.PolicyParser(
50                self.arch, kill_action=bpf.KillProcess())
51            parsed_policy = policy_parser.parse_file(policy_file.name)
52            assert len(parsed_policy.filter_statements) == 1
53            return self.compiler.compile_filter_statement(
54                parsed_policy.filter_statements[0],
55                kill_action=bpf.KillProcess())
56
57    def test_allow(self):
58        """Accept lines where the syscall is accepted unconditionally."""
59        block = self._compile('read: allow')
60        self.assertEqual(block.filter, None)
61        self.assertEqual(
62            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
63                           0)[1], 'ALLOW')
64        self.assertEqual(
65            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
66                           1)[1], 'ALLOW')
67
68    def test_arg0_eq_generated_code(self):
69        """Accept lines with an argument filter with ==."""
70        block = self._compile('read: arg0 == 0x100')
71        # It might be a bit brittle to check the generated code in each test
72        # case instead of just the behavior, but there should be at least one
73        # test where this happens.
74        self.assertEqual(
75            block.filter.instructions,
76            [
77                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
78                               bpf.arg_offset(0, True)),
79                # Jump to KILL_PROCESS if the high word does not match.
80                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0),
81                bpf.SockFilter(bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS, 0, 0,
82                               bpf.arg_offset(0, False)),
83                # Jump to KILL_PROCESS if the low word does not match.
84                bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0,
85                               0x100),
86                bpf.SockFilter(bpf.BPF_RET, 0, 0,
87                               bpf.SECCOMP_RET_KILL_PROCESS),
88                bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW),
89            ])
90
91    def test_arg0_comparison_operators(self):
92        """Accept lines with an argument filter with comparison operators."""
93        biases = (-1, 0, 1)
94        # For each operator, store the expectations of simulating the program
95        # against the constant plus each entry from the |biases| array.
96        cases = (
97            ('==', ('KILL_PROCESS', 'ALLOW', 'KILL_PROCESS')),
98            ('!=', ('ALLOW', 'KILL_PROCESS', 'ALLOW')),
99            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
100            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
101            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
102            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
103        )
104        for operator, expectations in cases:
105            block = self._compile('read: arg0 %s 0x100' % operator)
106
107            # Check the filter's behavior.
108            for bias, expectation in zip(biases, expectations):
109                self.assertEqual(
110                    block.simulate(self.arch.arch_nr,
111                                   self.arch.syscalls['read'],
112                                   0x100 + bias)[1], expectation)
113
114    def test_arg0_mask_operator(self):
115        """Accept lines with an argument filter with &."""
116        block = self._compile('read: arg0 & 0x3')
117
118        self.assertEqual(
119            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
120                           0)[1], 'KILL_PROCESS')
121        self.assertEqual(
122            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
123                           1)[1], 'ALLOW')
124        self.assertEqual(
125            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
126                           2)[1], 'ALLOW')
127        self.assertEqual(
128            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
129                           3)[1], 'ALLOW')
130        self.assertEqual(
131            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
132                           4)[1], 'KILL_PROCESS')
133        self.assertEqual(
134            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
135                           5)[1], 'ALLOW')
136        self.assertEqual(
137            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
138                           6)[1], 'ALLOW')
139        self.assertEqual(
140            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
141                           7)[1], 'ALLOW')
142        self.assertEqual(
143            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
144                           8)[1], 'KILL_PROCESS')
145
146    def test_arg0_in_operator(self):
147        """Accept lines with an argument filter with in."""
148        block = self._compile('read: arg0 in 0x3')
149
150        # The 'in' operator only ensures that no bits outside the mask are set,
151        # which means that 0 is always allowed.
152        self.assertEqual(
153            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
154                           0)[1], 'ALLOW')
155        self.assertEqual(
156            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
157                           1)[1], 'ALLOW')
158        self.assertEqual(
159            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
160                           2)[1], 'ALLOW')
161        self.assertEqual(
162            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
163                           3)[1], 'ALLOW')
164        self.assertEqual(
165            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
166                           4)[1], 'KILL_PROCESS')
167        self.assertEqual(
168            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
169                           5)[1], 'KILL_PROCESS')
170        self.assertEqual(
171            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
172                           6)[1], 'KILL_PROCESS')
173        self.assertEqual(
174            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
175                           7)[1], 'KILL_PROCESS')
176        self.assertEqual(
177            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
178                           8)[1], 'KILL_PROCESS')
179
180    def test_arg0_short_gt_ge_comparisons(self):
181        """Ensure that the short comparison optimization kicks in."""
182        if self.arch.bits == 32:
183            return
184        short_constant_str = '0xdeadbeef'
185        short_constant = int(short_constant_str, base=0)
186        long_constant_str = '0xbadc0ffee0ddf00d'
187        long_constant = int(long_constant_str, base=0)
188        biases = (-1, 0, 1)
189        # For each operator, store the expectations of simulating the program
190        # against the constant plus each entry from the |biases| array.
191        cases = (
192            ('<', ('ALLOW', 'KILL_PROCESS', 'KILL_PROCESS')),
193            ('<=', ('ALLOW', 'ALLOW', 'KILL_PROCESS')),
194            ('>', ('KILL_PROCESS', 'KILL_PROCESS', 'ALLOW')),
195            ('>=', ('KILL_PROCESS', 'ALLOW', 'ALLOW')),
196        )
197        for operator, expectations in cases:
198            short_block = self._compile(
199                'read: arg0 %s %s' % (operator, short_constant_str))
200            long_block = self._compile(
201                'read: arg0 %s %s' % (operator, long_constant_str))
202
203            # Check that the emitted code is shorter when the high word of the
204            # constant is zero.
205            self.assertLess(
206                len(short_block.filter.instructions),
207                len(long_block.filter.instructions))
208
209            # Check the filter's behavior.
210            for bias, expectation in zip(biases, expectations):
211                self.assertEqual(
212                    long_block.simulate(self.arch.arch_nr,
213                                        self.arch.syscalls['read'],
214                                        long_constant + bias)[1], expectation)
215                self.assertEqual(
216                    short_block.simulate(
217                        self.arch.arch_nr, self.arch.syscalls['read'],
218                        short_constant + bias)[1], expectation)
219
220    def test_and_or(self):
221        """Accept lines with a complex expression in DNF."""
222        block = self._compile('read: arg0 == 0 && arg1 == 0 || arg0 == 1')
223
224        self.assertEqual(
225            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
226                           0)[1], 'ALLOW')
227        self.assertEqual(
228            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 0,
229                           1)[1], 'KILL_PROCESS')
230        self.assertEqual(
231            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
232                           0)[1], 'ALLOW')
233        self.assertEqual(
234            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'], 1,
235                           1)[1], 'ALLOW')
236
237    def test_trap(self):
238        """Accept lines that trap unconditionally."""
239        block = self._compile('read: trap')
240
241        self.assertEqual(
242            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
243                           0)[1], 'TRAP')
244
245    def test_ret_errno(self):
246        """Accept lines that return errno."""
247        block = self._compile('read : arg0 == 0 || arg0 == 1 ; return 1')
248
249        self.assertEqual(
250            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
251                           0)[1:], ('ERRNO', 1))
252        self.assertEqual(
253            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
254                           1)[1:], ('ERRNO', 1))
255        self.assertEqual(
256            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
257                           2)[1], 'KILL_PROCESS')
258
259    def test_ret_errno_unconditionally(self):
260        """Accept lines that return errno unconditionally."""
261        block = self._compile('read: return 1')
262
263        self.assertEqual(
264            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
265                           0)[1:], ('ERRNO', 1))
266
267    def test_trace(self):
268        """Accept lines that trace unconditionally."""
269        block = self._compile('read: trace')
270
271        self.assertEqual(
272            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
273                           0)[1], 'TRACE')
274
275    def test_user_notify(self):
276        """Accept lines that notify unconditionally."""
277        block = self._compile('read: user-notify')
278
279        self.assertEqual(
280            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
281                           0)[1], 'USER_NOTIF')
282
283    def test_log(self):
284        """Accept lines that log unconditionally."""
285        block = self._compile('read: log')
286
287        self.assertEqual(
288            block.simulate(self.arch.arch_nr, self.arch.syscalls['read'],
289                           0)[1], 'LOG')
290
291    def test_mmap_write_xor_exec(self):
292        """Accept the idiomatic filter for mmap."""
293        block = self._compile(
294            'read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC')
295
296        prot_exec_and_write = 6
297        for prot in range(0, 0xf):
298            if (prot & prot_exec_and_write) == prot_exec_and_write:
299                self.assertEqual(
300                    block.simulate(self.arch.arch_nr,
301                                   self.arch.syscalls['read'], prot)[1],
302                    'KILL_PROCESS')
303            else:
304                self.assertEqual(
305                    block.simulate(self.arch.arch_nr,
306                                   self.arch.syscalls['read'], prot)[1],
307                    'ALLOW')
308
309
310class CompileFileTests(unittest.TestCase):
311    """Tests for PolicyCompiler.compile_file."""
312
313    def setUp(self):
314        self.arch = ARCH_64
315        self.compiler = compiler.PolicyCompiler(self.arch)
316        self.tempdir = tempfile.mkdtemp()
317
318    def tearDown(self):
319        shutil.rmtree(self.tempdir)
320
321    def _write_file(self, filename, contents):
322        """Helper to write out a file for testing."""
323        path = os.path.join(self.tempdir, filename)
324        with open(path, 'w') as outf:
325            outf.write(contents)
326        return path
327
328    def test_compile(self):
329        """Ensure compilation works with all strategies."""
330        self._write_file(
331            'test.frequency', """
332            read: 1
333            close: 10
334        """)
335        path = self._write_file(
336            'test.policy', """
337            @frequency ./test.frequency
338            read: 1
339            close: 1
340        """)
341
342        program = self.compiler.compile_file(
343            path,
344            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
345            kill_action=bpf.KillProcess())
346        self.assertGreater(
347            bpf.simulate(program.instructions, self.arch.arch_nr,
348                         self.arch.syscalls['read'], 0)[0],
349            bpf.simulate(program.instructions, self.arch.arch_nr,
350                         self.arch.syscalls['close'], 0)[0],
351        )
352
353    def test_compile_bst(self):
354        """Ensure compilation with BST is cheaper than the linear model."""
355        self._write_file(
356            'test.frequency', """
357            read: 1
358            close: 10
359        """)
360        path = self._write_file(
361            'test.policy', """
362            @frequency ./test.frequency
363            read: 1
364            close: 1
365        """)
366
367        for strategy in list(compiler.OptimizationStrategy):
368            program = self.compiler.compile_file(
369                path,
370                optimization_strategy=strategy,
371                kill_action=bpf.KillProcess())
372            self.assertGreater(
373                bpf.simulate(program.instructions, self.arch.arch_nr,
374                             self.arch.syscalls['read'], 0)[0],
375                bpf.simulate(program.instructions, self.arch.arch_nr,
376                             self.arch.syscalls['close'], 0)[0],
377            )
378            self.assertEqual(
379                bpf.simulate(program.instructions, self.arch.arch_nr,
380                             self.arch.syscalls['read'], 0)[1], 'ALLOW')
381            self.assertEqual(
382                bpf.simulate(program.instructions, self.arch.arch_nr,
383                             self.arch.syscalls['close'], 0)[1], 'ALLOW')
384
385    def test_compile_empty_file(self):
386        """Accept empty files."""
387        path = self._write_file(
388            'test.policy', """
389            @default kill-thread
390        """)
391
392        for strategy in list(compiler.OptimizationStrategy):
393            program = self.compiler.compile_file(
394                path,
395                optimization_strategy=strategy,
396                kill_action=bpf.KillProcess())
397            self.assertEqual(
398                bpf.simulate(program.instructions, self.arch.arch_nr,
399                             self.arch.syscalls['read'], 0)[1], 'KILL_THREAD')
400
401    def test_compile_simulate(self):
402        """Ensure policy reflects script by testing some random scripts."""
403        iterations = 5
404        for i in range(iterations):
405            num_entries = 64 * (i + 1) // iterations
406            syscalls = dict(
407                zip(
408                    random.sample(
409                        list(self.arch.syscalls.keys()), num_entries),
410                    (random.randint(1, 1024) for _ in range(num_entries)),
411                ))
412
413            frequency_contents = '\n'.join(
414                '%s: %d' % s for s in syscalls.items())
415            policy_contents = '@frequency ./test.frequency\n' + '\n'.join(
416                '%s: 1' % s[0] for s in syscalls.items())
417
418            self._write_file('test.frequency', frequency_contents)
419            path = self._write_file('test.policy', policy_contents)
420
421            for strategy in list(compiler.OptimizationStrategy):
422                program = self.compiler.compile_file(
423                    path,
424                    optimization_strategy=strategy,
425                    kill_action=bpf.KillProcess())
426                for name, number in self.arch.syscalls.items():
427                    expected_result = ('ALLOW'
428                                       if name in syscalls else 'KILL_PROCESS')
429                    self.assertEqual(
430                        bpf.simulate(program.instructions, self.arch.arch_nr,
431                                     number, 0)[1], expected_result,
432                        ('syscall name: %s, syscall number: %d, '
433                         'strategy: %s, policy:\n%s') %
434                        (name, number, strategy, policy_contents))
435
436    @unittest.skipIf(not int(os.getenv('SLOW_TESTS', '0')), 'slow')
437    def test_compile_huge_policy(self):
438        """Ensure jumps while compiling a huge policy are still valid."""
439        # Given that the BST strategy is O(n^3), don't choose a crazy large
440        # value, but it still needs to be around 128 so that we exercise the
441        # codegen paths that depend on the length of the jump.
442        #
443        # Immediate jump offsets in BPF comparison instructions are limited to
444        # 256 instructions, so given that every syscall filter consists of a
445        # load and jump instructions, with 128 syscalls there will be at least
446        # one jump that's further than 256 instructions.
447        num_entries = 128
448        syscalls = dict(random.sample(self.arch.syscalls.items(), num_entries))
449        # Here we force every single filter to be distinct. Otherwise the
450        # codegen layer will coalesce filters that compile to the same
451        # instructions.
452        policy_contents = '\n'.join(
453            '%s: arg0 == %d' % s for s in syscalls.items())
454
455        path = self._write_file('test.policy', policy_contents)
456
457        program = self.compiler.compile_file(
458            path,
459            optimization_strategy=compiler.OptimizationStrategy.BST,
460            kill_action=bpf.KillProcess())
461        for name, number in self.arch.syscalls.items():
462            expected_result = ('ALLOW'
463                               if name in syscalls else 'KILL_PROCESS')
464            self.assertEqual(
465                bpf.simulate(program.instructions, self.arch.arch_nr,
466                             self.arch.syscalls[name], number)[1],
467                expected_result)
468            self.assertEqual(
469                bpf.simulate(program.instructions, self.arch.arch_nr,
470                             self.arch.syscalls[name], number + 1)[1],
471                'KILL_PROCESS')
472
473    def test_compile_huge_filter(self):
474        """Ensure jumps while compiling a huge policy are still valid."""
475        # This is intended to force cases where the AST visitation would result
476        # in a combinatorial explosion of calls to Block.accept(). An optimized
477        # implementation should be O(n).
478        num_entries = 128
479        syscalls = {}
480        # Here we force every single filter to be distinct. Otherwise the
481        # codegen layer will coalesce filters that compile to the same
482        # instructions.
483        policy_contents = []
484        for name in random.sample(
485            list(self.arch.syscalls.keys()), num_entries):
486            values = random.sample(range(1024), num_entries)
487            syscalls[name] = values
488            policy_contents.append(
489                '%s: %s' % (name, ' || '.join('arg0 == %d' % value
490                                              for value in values)))
491
492        path = self._write_file('test.policy', '\n'.join(policy_contents))
493
494        program = self.compiler.compile_file(
495            path,
496            optimization_strategy=compiler.OptimizationStrategy.LINEAR,
497            kill_action=bpf.KillProcess())
498        for name, values in syscalls.items():
499            self.assertEqual(
500                bpf.simulate(program.instructions,
501                             self.arch.arch_nr, self.arch.syscalls[name],
502                             random.choice(values))[1], 'ALLOW')
503            self.assertEqual(
504                bpf.simulate(program.instructions, self.arch.arch_nr,
505                             self.arch.syscalls[name], 1025)[1],
506                'KILL_PROCESS')
507
508
509if __name__ == '__main__':
510    unittest.main()
511