xref: /aosp_15_r20/external/iptables/xlate-test.py (revision a71a954618bbadd4a345637e5edcf36eec826889)
1#!/usr/bin/env python3
2# encoding: utf-8
3
4import os
5import sys
6import shlex
7import argparse
8from subprocess import Popen, PIPE
9
10def run_proc(args, shell = False, input = None):
11    """A simple wrapper around Popen, returning (rc, stdout, stderr)"""
12    process = Popen(args, text = True, shell = shell,
13                    stdin = PIPE, stdout = PIPE, stderr = PIPE)
14    output, error = process.communicate(input)
15    return (process.returncode, output, error)
16
17keywords = ("iptables-translate", "ip6tables-translate", "ebtables-translate")
18xtables_nft_multi = 'xtables-nft-multi'
19
20if sys.stdout.isatty():
21    colors = {"magenta": "\033[95m", "green": "\033[92m", "yellow": "\033[93m",
22              "red": "\033[91m", "end": "\033[0m"}
23else:
24    colors = {"magenta": "", "green": "", "yellow": "", "red": "", "end": ""}
25
26
27def magenta(string):
28    return colors["magenta"] + string + colors["end"]
29
30
31def red(string):
32    return colors["red"] + string + colors["end"]
33
34
35def yellow(string):
36    return colors["yellow"] + string + colors["end"]
37
38
39def green(string):
40    return colors["green"] + string + colors["end"]
41
42
43def test_one_xlate(name, sourceline, expected, result):
44    rc, output, error = run_proc([xtables_nft_multi] + shlex.split(sourceline))
45    if rc != 0:
46        result.append(name + ": " + red("Error: ") + "iptables-translate failure")
47        result.append(error)
48        return False
49
50    translation = output.rstrip(" \n")
51    if translation != expected:
52        result.append(name + ": " + red("Fail"))
53        result.append(magenta("src: ") + sourceline.rstrip(" \n"))
54        result.append(magenta("exp: ") + expected)
55        result.append(magenta("res: ") + translation + "\n")
56        return False
57
58    return True
59
60def test_one_replay(name, sourceline, expected, result):
61    global args
62
63    searchline = None
64    if sourceline.find(';') >= 0:
65        sourceline, searchline = sourceline.split(';')
66
67    srcwords = shlex.split(sourceline)
68
69    srccmd = srcwords[0]
70    ipt = srccmd.split('-')[0]
71    table_idx = -1
72    chain_idx = -1
73    table_name = "filter"
74    chain_name = None
75    for idx in range(1, len(srcwords)):
76        if srcwords[idx] in ["-A", "-I", "--append", "--insert"]:
77            chain_idx = idx
78            chain_name = srcwords[idx + 1]
79        elif srcwords[idx] in ["-t", "--table"]:
80            table_idx = idx
81            table_name = srcwords[idx + 1]
82
83    if not chain_name:
84        return True     # nothing to do?
85
86    if searchline is None:
87        # adjust sourceline as required
88        checkcmd = srcwords[:]
89        checkcmd[0] = ipt
90        checkcmd[chain_idx] = "--check"
91    else:
92        checkcmd = [ipt, "-t", table_name]
93        checkcmd += ["--check", chain_name, searchline]
94
95    fam = ""
96    if srccmd.startswith("ip6"):
97        fam = "ip6 "
98    elif srccmd.startswith("ebt"):
99        fam = "bridge "
100
101    expected = [ l.removeprefix("nft ").strip(" '") for l in expected.split("\n") ]
102    nft_input = [
103            "flush ruleset",
104            "add table " + fam + table_name,
105            "add chain " + fam + table_name + " " + chain_name,
106    ] + expected
107
108    rc, output, error = run_proc([args.nft, "-f", "-"], shell = False, input = "\n".join(nft_input))
109    if rc != 0:
110        result.append(name + ": " + red("Replay Fail"))
111        result.append(args.nft + " call failed: " + error.rstrip('\n'))
112        for line in nft_input:
113            result.append(magenta("input: ") + line)
114        return False
115
116    rc, output, error = run_proc([xtables_nft_multi] + checkcmd)
117    if rc != 0:
118        result.append(name + ": " + red("Check Fail"))
119        result.append(magenta("check: ") + " ".join(checkcmd))
120        result.append(magenta("error: ") + error)
121        rc, output, error = run_proc([xtables_nft_multi, ipt + "-save"])
122        for l in output.split("\n"):
123            result.append(magenta("ipt: ") + l)
124        rc, output, error = run_proc([args.nft, "list", "ruleset"])
125        for l in output.split("\n"):
126            result.append(magenta("nft: ") + l)
127        return False
128
129    return True
130
131
132def run_test(name, payload):
133    global xtables_nft_multi
134    global args
135
136    test_passed = True
137    tests = passed = failed = errors = 0
138    result = []
139
140    line = payload.readline()
141    while line:
142        if not line.startswith(keywords):
143            line = payload.readline()
144            continue
145
146        sourceline = replayline = line.rstrip("\n")
147        if line.find(';') >= 0:
148            sourceline = line.split(';')[0]
149
150        expected = payload.readline().rstrip(" \n")
151        next_expected = payload.readline()
152        if next_expected.startswith("nft"):
153            expected += "\n" + next_expected.rstrip(" \n")
154            line = payload.readline()
155        else:
156            line = next_expected
157
158        tests += 1
159        if test_one_xlate(name, sourceline, expected, result):
160            passed += 1
161        else:
162            errors += 1
163            test_passed = False
164            continue
165
166        if args.replay:
167            tests += 1
168            if test_one_replay(name, replayline, expected, result):
169                passed += 1
170            else:
171                errors += 1
172                test_passed = False
173
174            rc, output, error = run_proc([args.nft, "flush", "ruleset"])
175            if rc != 0:
176                result.append(name + ": " + red("Fail"))
177                result.append("nft flush ruleset call failed: " + error)
178
179    if (passed == tests):
180        print(name + ": " + green("OK"))
181    if not test_passed:
182        print("\n".join(result), file=sys.stderr)
183    return tests, passed, failed, errors
184
185
186def load_test_files():
187    test_files = total_tests = total_passed = total_error = total_failed = 0
188    tests = sorted(os.listdir("extensions"))
189    for test in ['extensions/' + f for f in tests if f.endswith(".txlate")]:
190        with open(test, "r") as payload:
191            tests, passed, failed, errors = run_test(test, payload)
192            test_files += 1
193            total_tests += tests
194            total_passed += passed
195            total_failed += failed
196            total_error += errors
197    return (test_files, total_tests, total_passed, total_failed, total_error)
198
199
200def spawn_netns():
201    # prefer unshare module
202    try:
203        import unshare
204        unshare.unshare(unshare.CLONE_NEWNET)
205        return True
206    except:
207        pass
208
209    # sledgehammer style:
210    # - call ourselves prefixed by 'unshare -n' if found
211    # - pass extra --no-netns parameter to avoid another recursion
212    try:
213        import shutil
214
215        unshare = shutil.which("unshare")
216        if unshare is None:
217            return False
218
219        sys.argv.append("--no-netns")
220        os.execv(unshare, [unshare, "-n", sys.executable] + sys.argv)
221    except:
222        pass
223
224    return False
225
226
227def main():
228    global xtables_nft_multi
229
230    if args.replay:
231        if os.getuid() != 0:
232            print("Replay test requires root, sorry", file=sys.stderr)
233            return
234        if not args.no_netns and not spawn_netns():
235            print("Cannot run in own namespace, connectivity might break",
236                  file=sys.stderr)
237
238    if not args.host:
239        os.putenv("XTABLES_LIBDIR", os.path.abspath("extensions"))
240        xtables_nft_multi = os.path.abspath(os.path.curdir) \
241                            + '/iptables/' + xtables_nft_multi
242
243    files = tests = passed = failed = errors = 0
244    for test in args.test:
245        if not test.endswith(".txlate"):
246            test += ".txlate"
247        try:
248            with open(test, "r") as payload:
249                t, p, f, e = run_test(test, payload)
250                files += 1
251                tests += t
252                passed += p
253                failed += f
254                errors += e
255        except IOError:
256            print(red("Error: ") + "test file does not exist", file=sys.stderr)
257            return 99
258
259    if files == 0:
260        files, tests, passed, failed, errors = load_test_files()
261
262    if files > 1:
263        file_word = "files"
264    else:
265        file_word = "file"
266    print("%d test %s, %d tests, %d tests passed, %d tests failed, %d errors"
267            % (files, file_word, tests, passed, failed, errors))
268    return passed - tests
269
270
271parser = argparse.ArgumentParser()
272parser.add_argument('-H', '--host', action='store_true',
273                    help='Run tests against installed binaries')
274parser.add_argument('-R', '--replay', action='store_true',
275                    help='Replay tests to check iptables-nft parser')
276parser.add_argument('-n', '--nft', type=str, default='nft',
277                    help='Replay using given nft binary (default: \'%(default)s\')')
278parser.add_argument('--no-netns', action='store_true',
279                    help='Do not run testsuite in own network namespace')
280parser.add_argument("test", nargs="*", help="run only the specified test file(s)")
281args = parser.parse_args()
282sys.exit(main())
283