xref: /aosp_15_r20/external/pigweed/pw_presubmit/py/pw_presubmit/rst_format.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1#!/usr/bin/env python3
2
3# Copyright 2023 The Pigweed Authors
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may not
6# use this file except in compliance with the License. You may obtain a copy of
7# the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations under
15# the License.
16"""Restructured Text Formatting."""
17
18import argparse
19from dataclasses import dataclass, field
20import difflib
21from functools import cached_property
22from pathlib import Path
23import textwrap
24from typing import Iterable
25
26from pw_cli.diff import colorize_diff
27
28TAB_WIDTH = 8  # Number of spaces to use for \t replacement
29CODE_BLOCK_INDENTATION = 3
30
31
32def _parse_args() -> argparse.Namespace:
33    parser = argparse.ArgumentParser(description=__doc__)
34
35    parser.add_argument(
36        '--diff',
37        action='store_true',
38        help='Print a diff of formatting changes.',
39    )
40    parser.add_argument(
41        '-i',
42        '--in-place',
43        action='store_true',
44        help='Replace existing file with the reformatted copy.',
45    )
46    parser.add_argument(
47        'rst_files',
48        nargs='+',
49        default=[],
50        type=Path,
51        help='Paths to rst files.',
52    )
53
54    return parser.parse_args()
55
56
57def _indent_amount(line: str) -> int:
58    return len(line) - len(line.lstrip())
59
60
61def _reindent(input_text: str, amount: int) -> Iterable[str]:
62    for line in textwrap.dedent(input_text).splitlines():
63        if len(line.strip()) == 0:
64            yield '\n'
65            continue
66        yield f'{" " * amount}{line}\n'
67
68
69def _fix_whitespace(line: str) -> str:
70    return line.rstrip().replace('\t', ' ' * TAB_WIDTH) + '\n'
71
72
73@dataclass
74class CodeBlock:
75    """Store a single code block."""
76
77    directive_lineno: int
78    directive_line: str
79    first_line_indent: int | None = None
80    end_lineno: int | None = None
81    option_lines: list[str] = field(default_factory=list)
82    code_lines: list[str] = field(default_factory=list)
83
84    def __post_init__(self) -> None:
85        self._blank_line_after_options_found = False
86
87    def finished(self) -> bool:
88        return self.end_lineno is not None
89
90    def append_line(self, index: int, line: str) -> None:
91        """Process a line for this code block."""
92        # Check if outside the code block (indentation is less):
93        if (
94            self.first_line_indent is not None
95            and len(line.strip()) > 0
96            and _indent_amount(line) < self.first_line_indent
97        ):
98            # Code block ended
99            self.end_lineno = index
100            return
101
102        # If first line indent hasn't been found
103        if self.first_line_indent is None:
104            # Check if the first word is a directive option.
105            # E.g. :caption:
106            line_words = line.split()
107            if (
108                line_words
109                and line_words[0].startswith(':')
110                and line_words[0].endswith(':')
111                # In case the first word starts with two colons '::'
112                and len(line_words[0]) > 2
113            ):
114                self.option_lines.append(line.rstrip())
115                return
116
117            # Step 1: Check for a blank line
118            if len(line.strip()) == 0:
119                if (
120                    self.option_lines
121                    and not self._blank_line_after_options_found
122                ):
123                    self._blank_line_after_options_found = True
124                return
125
126            # Step 2: Check for a line that is a continuation of a previous
127            # option.
128            if self.option_lines and not self._blank_line_after_options_found:
129                self.option_lines.append(line.rstrip())
130                return
131
132            # Step 3: Check a line with content.
133            if len(line.strip()) > 0:
134                # Line is not a directive and not blank: it is content.
135                # Flag the end of the options
136                self._blank_line_after_options_found = True
137
138            # Set the content indentation amount.
139            self.first_line_indent = _indent_amount(line)
140
141        # Save this line as code.
142        self.code_lines.append(self._clean_up_line(line))
143
144    def _clean_up_line(self, line: str) -> str:
145        line = line.rstrip()
146        if not self._keep_codeblock_tabs:
147            line = line.replace('\t', ' ' * TAB_WIDTH)
148        return line
149
150    @cached_property
151    def _keep_codeblock_tabs(self) -> bool:
152        """True if tabs should NOT be replaced; keep for 'none' or 'go'."""
153        return 'none' in self.directive_line or 'go' in self.directive_line
154
155    @cached_property
156    def directive_indent_amount(self) -> int:
157        return _indent_amount(self.directive_line)
158
159    def options_block_lines(self) -> Iterable[str]:
160        yield from _reindent(
161            input_text='\n'.join(self.option_lines),
162            amount=self.directive_indent_amount + CODE_BLOCK_INDENTATION,
163        )
164
165    def code_block_lines(self) -> Iterable[str]:
166        yield from _reindent(
167            input_text='\n'.join(self.code_lines),
168            amount=self.directive_indent_amount + CODE_BLOCK_INDENTATION,
169        )
170
171    def lines(self) -> Iterable[str]:
172        """Yields the code block directives's lines."""
173        yield self.directive_line
174        if self.option_lines:
175            yield from self.options_block_lines()
176        yield '\n'
177        yield from self.code_block_lines()
178        yield '\n'
179
180    def __repr__(self) -> str:
181        return ''.join(self.lines())
182
183
184def _parse_and_format_rst(in_text: str) -> Iterable[str]:
185    """Reindents code blocks to 3 spaces and fixes whitespace."""
186    current_block: CodeBlock | None = None
187    for index, line in enumerate(in_text.splitlines(keepends=True)):
188        # If a code block is active, process this line.
189        if current_block:
190            current_block.append_line(index, line)
191            if current_block.finished():
192                yield from current_block.lines()
193                # This line wasn't part of the code block, process as normal.
194                yield _fix_whitespace(line)
195                # Erase this code_block variable
196                current_block = None
197        # Check for new code block start
198        elif line.lstrip().startswith(('.. code-block::', '.. code::')):
199            current_block = CodeBlock(
200                directive_lineno=index,
201                # Change `.. code::` to Sphinx's `.. code-block::`.
202                directive_line=line.replace('code::', 'code-block::'),
203            )
204            continue
205        else:
206            yield _fix_whitespace(line)
207    # If the document ends with a code block it may still need to be written.
208    if current_block is not None:
209        yield from current_block.lines()
210
211
212def reformat_rst(
213    file_name: Path,
214    diff: bool = False,
215    in_place: bool = False,
216    suppress_stdout: bool = False,
217) -> list[str]:
218    """Reformats an rst file and returns a list of diff lines."""
219    in_text = file_name.read_text()
220    out_lines = list(_parse_and_format_rst(in_text))
221
222    # Remove blank lines from the end of the output, if any.
223    while out_lines and not out_lines[-1].strip():
224        out_lines.pop()
225
226    # Add a trailing \n if needed.
227    if out_lines and not out_lines[-1].endswith('\n'):
228        out_lines[-1] += '\n'
229
230    result_diff = list(
231        difflib.unified_diff(
232            in_text.splitlines(True),
233            out_lines,
234            f'{file_name}  (original)',
235            f'{file_name}  (reformatted)',
236        )
237    )
238    if diff and result_diff:
239        if not suppress_stdout:
240            print(''.join(colorize_diff(result_diff)))
241
242    if in_place:
243        file_name.write_text(''.join(out_lines))
244
245    return result_diff
246
247
248def rst_format_main(
249    rst_files: list[Path],
250    diff: bool = False,
251    in_place: bool = False,
252) -> None:
253    for rst_file in rst_files:
254        reformat_rst(rst_file, diff, in_place)
255
256
257if __name__ == '__main__':
258    rst_format_main(**vars(_parse_args()))
259