1#!/usr/bin/env python3
2#
3# Copyright 2024, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Repacks the ramdisk image to add kernel modules.
18
19Unpacks a ramdisk image, extracts and replaces kernel modules from another
20initramfs image, and repacks the ramdisk.
21"""
22
23import argparse
24import enum
25import logging
26import os
27import pathlib
28import shutil
29import subprocess
30import tempfile
31
32logger = logging.getLogger(__name__)
33
34_ANDROID_RAMDISK_DIR = "android_ramdisk"
35_KERNEL_RAMDISK_DIR = "kernel_ramdisk"
36
37def _parse_args():
38    """Parse command-line options."""
39    parser = argparse.ArgumentParser(
40        description='Repacks ramdisk image with modules from --kernel-ramdisk',
41    )
42
43    parser.add_argument(
44        '--android-ramdisk',
45        help='filename of input android ramdisk',
46        required=True)
47    parser.add_argument(
48        '--kernel-ramdisk',
49        help='filename of ramdisk to extract kernel modules from, '
50             'or the path of an existing directory containing the modules',
51        required=True)
52    parser.add_argument(
53        '--output-ramdisk',
54        help='filename of repacked ramdisk',
55        required=True)
56
57    return parser.parse_args()
58
59
60class RamdiskFormat(enum.Enum):
61    """Enum class for different ramdisk compression formats."""
62    LZ4 = 1
63    GZIP = 2
64
65
66# Based on system/tools/mkbootimg/repack_bootimg.py
67class RamdiskImage:
68    """A class that supports packing/unpacking a ramdisk."""
69    def __init__(self, ramdisk_img, directory, allow_dir):
70        # The caller gave us a directory instead of an image
71        # Assume it's already been extracted.
72        if os.path.isdir(ramdisk_img):
73            if not allow_dir:
74                raise RuntimeError(
75                    f"Directory not allowed for image {ramdisk_img}")
76
77            self._ramdisk_img = None
78            self._ramdisk_format = None
79            self._ramdisk_dir = ramdisk_img
80            return
81
82        self._ramdisk_img = ramdisk_img
83        self._ramdisk_format = None
84        self._ramdisk_dir = directory
85
86        self._unpack()
87
88    def _unpack(self):
89        """Unpacks the ramdisk."""
90        # The compression format might be in 'lz4' or 'gzip' format,
91        # trying lz4 first.
92        for compression_type, compression_util in [
93            (RamdiskFormat.LZ4, 'lz4'),
94            (RamdiskFormat.GZIP, 'gzip')]:
95
96            # Command arguments:
97            #   -d: decompression
98            #   -c: write to stdout
99            decompression_cmd = [
100                compression_util, '-d', '-c', self._ramdisk_img]
101
102            decompressed_result = subprocess.run(
103                decompression_cmd, check=False, capture_output=True)
104
105            if decompressed_result.returncode == 0:
106                self._ramdisk_format = compression_type
107                break
108
109        if self._ramdisk_format is not None:
110            # toybox cpio arguments:
111            #   -i: extract files from stdin
112            #   -d: create directories if needed
113            #   -u: override existing files
114            cpio_run = subprocess.run(
115                ['toybox', 'cpio', '-idu'], check=False,
116                input=decompressed_result.stdout, cwd=self._ramdisk_dir,
117                capture_output=True)
118            if (cpio_run.returncode != 0 and
119                b"Operation not permitted" not in cpio_run.stderr):
120                raise RuntimeError(f"cpio failed:\n{cpio_run.stderr}")
121
122            print(f"=== Unpacked ramdisk: '{self._ramdisk_img}' at "
123                  f"'{self._ramdisk_dir}' ===")
124        else:
125            raise RuntimeError('Failed to decompress ramdisk.')
126
127    def repack(self, out_ramdisk_file):
128        """Repacks a ramdisk from self._ramdisk_dir.
129
130        Args:
131            out_ramdisk_file: the output ramdisk file to save.
132        """
133        compression_cmd = ['lz4', '-l', '-12', '--favor-decSpeed']
134        if self._ramdisk_format == RamdiskFormat.GZIP:
135            compression_cmd = ['gzip']
136
137        print('Repacking ramdisk, which might take a few seconds ...')
138
139        mkbootfs_result = subprocess.run(
140            ['mkbootfs', self._ramdisk_dir], check=True, capture_output=True)
141
142        with open(out_ramdisk_file, 'wb') as output_fd:
143            subprocess.run(compression_cmd, check=True,
144                           input=mkbootfs_result.stdout, stdout=output_fd)
145
146        print(f"=== Repacked ramdisk: '{out_ramdisk_file}' ===")
147
148    @property
149    def ramdisk_dir(self):
150        """Returns the internal ramdisk dir."""
151        return self._ramdisk_dir
152
153    def get_modules(self):
154        """Returns the list of modules used in this ramdisk."""
155        modules_file_path = os.path.join(
156            self._ramdisk_dir, "lib/modules/modules.load")
157        with open(modules_file_path, "r", encoding="utf-8") as modules_file:
158            return [line.strip() for line in modules_file]
159
160    def write_modules(self, modules):
161        """Writes the list of modules used in this ramdisk."""
162        modules_file_path = os.path.join(
163            self._ramdisk_dir, "lib/modules/modules.load")
164        with open(modules_file_path, "w", encoding="utf-8") as modules_file:
165            for module in modules:
166                modules_file.write(f"{module}\n")
167
168
169def _replace_modules(dest_ramdisk, src_ramdisk):
170    """Replace any modules in dest_ramdisk with modules from src_ramdisk"""
171    src_dir = pathlib.Path(src_ramdisk.ramdisk_dir)
172    dest_dir = os.path.join(dest_ramdisk.ramdisk_dir, "lib/modules")
173    updated_modules = []
174    for module in dest_ramdisk.get_modules():
175        dest_module = os.path.join(dest_dir, module)
176        matches = list(src_dir.glob(f"**/{module}"))
177        if len(matches) > 1:
178            raise RuntimeError(
179                f"Found multiple candidates for module {module}")
180        if len(matches) == 0:
181            logger.warning(
182                "Could not find module %s, deleting this module.",
183                module)
184            os.remove(dest_module)
185            continue
186        shutil.copy(matches[0], dest_module)
187        updated_modules.append(module)
188
189    dest_ramdisk.write_modules(updated_modules)
190
191
192def main():
193    """Parse arguments and repack ramdisk image."""
194    args = _parse_args()
195    with tempfile.TemporaryDirectory() as tempdir:
196        android_ramdisk = os.path.join(tempdir, _ANDROID_RAMDISK_DIR)
197        os.mkdir(android_ramdisk)
198        kernel_ramdisk = os.path.join(tempdir, _KERNEL_RAMDISK_DIR)
199        os.mkdir(kernel_ramdisk)
200        android_ramdisk = RamdiskImage(
201            args.android_ramdisk, os.path.join(tempdir, _ANDROID_RAMDISK_DIR),
202            allow_dir=False)
203        kernel_ramdisk = RamdiskImage(
204            args.kernel_ramdisk, os.path.join(tempdir, _KERNEL_RAMDISK_DIR),
205            allow_dir=True)
206        _replace_modules(android_ramdisk, kernel_ramdisk)
207        android_ramdisk.repack(args.output_ramdisk)
208
209
210if __name__ == '__main__':
211    main()
212