xref: /aosp_15_r20/external/pigweed/pw_module/py/pw_module/create.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Creates a new Pigweed module."""
15
16from __future__ import annotations
17
18import abc
19import argparse
20import dataclasses
21from dataclasses import dataclass
22import datetime
23import difflib
24from enum import Enum
25import functools
26import json
27import logging
28from pathlib import Path
29import re
30import sys
31from typing import Any, Collection, Iterable, Type
32
33from prompt_toolkit import prompt
34
35from pw_build import generate_modules_lists
36import pw_cli.color
37import pw_cli.env
38from pw_cli.diff import colorize_diff
39from pw_cli.status_reporter import StatusReporter
40
41from pw_module.templates import get_template
42
43_COLOR = pw_cli.color.colors()
44_LOG = logging.getLogger(__name__)
45_PW_ENV = pw_cli.env.pigweed_environment()
46_PW_ROOT = _PW_ENV.PW_ROOT
47
48_PIGWEED_LICENSE = f"""
49# Copyright {datetime.datetime.now().year} The Pigweed Authors
50#
51# Licensed under the Apache License, Version 2.0 (the "License"); you may not
52# use this file except in compliance with the License. You may obtain a copy of
53# the License at
54#
55#     https://www.apache.org/licenses/LICENSE-2.0
56#
57# Unless required by applicable law or agreed to in writing, software
58# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
59# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
60# License for the specific language governing permissions and limitations under
61# the License.""".lstrip()
62
63_PIGWEED_LICENSE_CC = _PIGWEED_LICENSE.replace('#', '//')
64
65_CREATE = _COLOR.green('create    ')
66_REPLACE = _COLOR.green('replace   ')
67_UPDATE = _COLOR.yellow('update    ')
68_UNCHANGED = _COLOR.blue('unchanged ')
69_IDENTICAL = _COLOR.blue('identical ')
70_REPORT = StatusReporter()
71
72
73def _report_write_file(file_path: Path) -> None:
74    """Print a notification when a file is newly created or replaced."""
75    relative_file_path = str(file_path.relative_to(_PW_ROOT))
76    if file_path.is_file():
77        _REPORT.info(_REPLACE + relative_file_path)
78        return
79    _REPORT.new(_CREATE + relative_file_path)
80
81
82def _report_unchanged_file(file_path: Path) -> None:
83    """Print a notification a file was not updated/changed."""
84    relative_file_path = str(file_path.relative_to(_PW_ROOT))
85    _REPORT.ok(_UNCHANGED + relative_file_path)
86
87
88def _report_identical_file(file_path: Path) -> None:
89    """Print a notification a file is identical."""
90    relative_file_path = str(file_path.relative_to(_PW_ROOT))
91    _REPORT.ok(_IDENTICAL + relative_file_path)
92
93
94def _report_edited_file(file_path: Path) -> None:
95    """Print a notification a file was modified/edited."""
96    relative_file_path = str(file_path.relative_to(_PW_ROOT))
97    _REPORT.new(_UPDATE + relative_file_path)
98
99
100class PromptChoice(Enum):
101    """Possible prompt responses."""
102
103    YES = 'yes'
104    NO = 'no'
105    DIFF = 'diff'
106
107
108def _prompt_user(message: str, allow_diff: bool = False) -> PromptChoice:
109    """Prompt the user for to choose between yes, no and optionally diff.
110
111    If the user presses enter with no text the response is assumed to be NO.
112    If the user presses ctrl-c call sys.exit(1).
113
114    Args:
115      message: The message to display at the start of the prompt.
116      allow_diff: If true add a 'd' to the help text in the prompt line.
117
118    Returns:
119      A PromptChoice enum value.
120    """
121    help_text = '[y/N]'
122    if allow_diff:
123        help_text = '[y/N/d]'
124
125    try:
126        decision = prompt(f'{message} {help_text} ')
127    except KeyboardInterrupt:
128        sys.exit(1)  # Ctrl-C pressed
129
130    if not decision or decision.lower().startswith('n'):
131        return PromptChoice.NO
132    if decision.lower().startswith('y'):
133        return PromptChoice.YES
134    if decision.lower().startswith('d'):
135        return PromptChoice.DIFF
136
137    return PromptChoice.NO
138
139
140def _print_diff(file_name: Path | str, in_text: str, out_text: str) -> None:
141    result_diff = list(
142        difflib.unified_diff(
143            in_text.splitlines(True),
144            out_text.splitlines(True),
145            f'{file_name}  (original)',
146            f'{file_name}  (updated)',
147        )
148    )
149    if not result_diff:
150        return
151    print()
152    print(''.join(colorize_diff(result_diff)))
153
154
155def _prompt_overwrite(file_path: Path, new_contents: str) -> bool:
156    """Returns true if a file should be written, prompts the user if needed."""
157    # File does not exist
158    if not file_path.is_file():
159        return True
160
161    # File exists but is identical.
162    old_contents = file_path.read_text(encoding='utf-8')
163    if new_contents and old_contents == new_contents:
164        _report_identical_file(file_path)
165        return False
166
167    file_name = file_path.relative_to(_PW_ROOT)
168    # File exists and is different.
169    _REPORT.wrn(f'{file_name} already exists.')
170
171    while True:
172        choice = _prompt_user('Overwrite?', allow_diff=True)
173        if choice == PromptChoice.DIFF:
174            _print_diff(file_name, old_contents, new_contents)
175        else:
176            if choice == PromptChoice.YES:
177                return True
178            break
179
180    # By default do not overwrite.
181    _report_unchanged_file(file_path)
182    return False
183
184
185# TODO(frolv): Adapted from pw_protobuf. Consolidate them.
186class _OutputFile:
187    DEFAULT_INDENT_WIDTH = 2
188
189    def __init__(self, file: Path, indent_width: int = DEFAULT_INDENT_WIDTH):
190        self._file = file
191        self._content: list[str] = []
192        self._indent_width: int = indent_width
193        self._indentation = 0
194
195    def line(self, line: str = '') -> None:
196        if line:
197            self._content.append(' ' * self._indentation)
198            self._content.append(line)
199        self._content.append('\n')
200
201    def indent(
202        self,
203        width: int | None = None,
204    ) -> _OutputFile._IndentationContext:
205        """Increases the indentation level of the output."""
206        return self._IndentationContext(
207            self, width if width is not None else self._indent_width
208        )
209
210    @property
211    def path(self) -> Path:
212        return self._file
213
214    @property
215    def content(self) -> str:
216        return ''.join(self._content)
217
218    def write(self, content: str | None = None) -> None:
219        """Write file contents. Prompts the user if necessary.
220
221        Args:
222          content: If provided will write this text to the file instead of
223              calling self.content.
224        """
225        output_text = self.content
226        if content:
227            output_text = content
228
229        if not output_text.endswith('\n'):
230            output_text += '\n'
231
232        if _prompt_overwrite(self._file, new_contents=output_text):
233            _report_write_file(self._file)
234            self._file.write_text(output_text)
235
236    def write_template(self, template_name: str, **template_args) -> None:
237        template = get_template(template_name)
238        rendered_template = template.render(**template_args)
239        self.write(content=rendered_template)
240
241    class _IndentationContext:
242        """Context that increases the output's indentation when it is active."""
243
244        def __init__(self, output: _OutputFile, width: int):
245            self._output = output
246            self._width: int = width
247
248        def __enter__(self):
249            self._output._indentation += self._width
250
251        def __exit__(self, typ, value, traceback):
252            self._output._indentation -= self._width
253
254
255class _ModuleName:
256    _MODULE_NAME_REGEX = re.compile(
257        # Match the two letter character module prefix e.g. 'pw':
258        r'^(?P<prefix>[a-zA-Z]{2,})'
259        # The rest of the module name consisting of multiple groups of a single
260        # underscore followed by alphanumeric characters. This prevents multiple
261        # underscores from appearing in a row and the name from ending in a an
262        # underscore.
263        r'(?P<main>'
264        r'(_[a-zA-Z0-9]+)+'
265        r')$'
266    )
267
268    def __init__(self, prefix: str, main: str, path: Path) -> None:
269        self._prefix = prefix
270        self._main = main.lstrip('_')  # Remove the leading underscore
271        self._path = path
272
273    @property
274    def path(self) -> str:
275        # Check if there are no parent directories for the full path.
276        # Note: This relies on Path('pw_module').parents returning Path('.') for
277        # paths that have no parent directories:
278        # https://docs.python.org/3/library/pathlib.html#pathlib.PurePath.parents
279        if self._path == Path('.'):
280            return self.full
281        return (self._path / self.full).as_posix()
282
283    @property
284    def full(self) -> str:
285        return f'{self._prefix}_{self._main}'
286
287    @property
288    def prefix(self) -> str:
289        return self._prefix
290
291    @property
292    def main(self) -> str:
293        return self._main
294
295    @property
296    def default_namespace(self) -> str:
297        return f'{self._prefix}::{self._main}'
298
299    def upper_camel_case(self) -> str:
300        return ''.join(s.capitalize() for s in self._main.split('_'))
301
302    @property
303    def header_line(self) -> str:
304        return '=' * len(self.full)
305
306    def __str__(self) -> str:
307        return self.full
308
309    def __repr__(self) -> str:
310        return self.full
311
312    @classmethod
313    def parse(cls, name: str) -> _ModuleName | None:
314        module_path = Path(name)
315        module_name = module_path.name
316        match = _ModuleName._MODULE_NAME_REGEX.fullmatch(module_name)
317        if not match:
318            return None
319
320        parts = match.groupdict()
321        return cls(parts['prefix'], parts['main'], module_path.parents[0])
322
323
324@dataclass
325class _ModuleContext:
326    name: _ModuleName
327    dir: Path
328    root_build_files: list[_BuildFile]
329    sub_build_files: list[_BuildFile]
330    build_systems: list[str]
331    is_upstream: bool
332
333    def build_files(self) -> Iterable[_BuildFile]:
334        yield from self.root_build_files
335        yield from self.sub_build_files
336
337    def add_docs_file(self, file: Path):
338        for build_file in self.root_build_files:
339            build_file.add_docs_source(str(file.relative_to(self.dir)))
340
341    def add_cc_target(self, target: _BuildFile.CcTarget) -> None:
342        for build_file in self.root_build_files:
343            build_file.add_cc_target(target)
344
345    def add_cc_test(self, target: _BuildFile.CcTarget) -> None:
346        for build_file in self.root_build_files:
347            build_file.add_cc_test(target)
348
349
350class _BuildFile:
351    """Abstract representation of a build file for a module."""
352
353    @dataclass
354    class Target:
355        name: str
356
357        # TODO(frolv): Shouldn't be a string list as that's build system
358        # specific. Figure out a way to resolve dependencies from targets.
359        deps: list[str] = dataclasses.field(default_factory=list)
360
361    @dataclass
362    class CcTarget(Target):
363        sources: list[Path] = dataclasses.field(default_factory=list)
364        headers: list[Path] = dataclasses.field(default_factory=list)
365
366        def rebased_sources(self, rebase_path: Path) -> Iterable[str]:
367            return (str(src.relative_to(rebase_path)) for src in self.sources)
368
369        def rebased_headers(self, rebase_path: Path) -> Iterable[str]:
370            return (str(hdr.relative_to(rebase_path)) for hdr in self.headers)
371
372    def __init__(self, path: Path, ctx: _ModuleContext):
373        self._path = path
374        self._ctx = ctx
375
376        self._docs_sources: list[str] = []
377        self._cc_targets: list[_BuildFile.CcTarget] = []
378        self._cc_tests: list[_BuildFile.CcTarget] = []
379
380    @property
381    def path(self) -> Path:
382        return self._path
383
384    @property
385    def dir(self) -> Path:
386        return self._path.parent
387
388    def add_docs_source(self, filename: str) -> None:
389        self._docs_sources.append(filename)
390
391    def add_cc_target(self, target: CcTarget) -> None:
392        self._cc_targets.append(target)
393
394    def add_cc_test(self, target: CcTarget) -> None:
395        self._cc_tests.append(target)
396
397    @property
398    def get_license(self) -> str:
399        if self._ctx.is_upstream:
400            return _PIGWEED_LICENSE
401        return ''
402
403    @property
404    def docs_sources(self) -> list[str]:
405        return self._docs_sources
406
407    @property
408    def cc_targets(self) -> list[_BuildFile.CcTarget]:
409        return self._cc_targets
410
411    @property
412    def cc_tests(self) -> list[_BuildFile.CcTarget]:
413        return self._cc_tests
414
415    def relative_file(self, file_path: Path | str) -> str:
416        if isinstance(file_path, str):
417            return file_path
418        return str(file_path.relative_to(self._path.parent))
419
420    def write(self) -> None:
421        """Writes the contents of the build file to disk."""
422        file = _OutputFile(self._path, self._indent_width())
423
424        if self._ctx.is_upstream:
425            file.line(_PIGWEED_LICENSE)
426            file.line()
427
428        self._write_preamble(file)
429
430        for target in self._cc_targets:
431            file.line()
432            self._write_cc_target(file, target)
433
434        for target in self._cc_tests:
435            file.line()
436            self._write_cc_test(file, target)
437
438        if self._docs_sources:
439            file.line()
440            self._write_docs_target(file, self._docs_sources)
441
442        file.write()
443
444    @abc.abstractmethod
445    def _indent_width(self) -> int:
446        """Returns the default indent width for the build file's code style."""
447
448    @abc.abstractmethod
449    def _write_preamble(self, file: _OutputFile) -> None:
450        """Formats"""
451
452    @abc.abstractmethod
453    def _write_cc_target(
454        self,
455        file: _OutputFile,
456        target: _BuildFile.CcTarget,
457    ) -> None:
458        """Defines a C++ library target within the build file."""
459
460    @abc.abstractmethod
461    def _write_cc_test(
462        self,
463        file: _OutputFile,
464        target: _BuildFile.CcTarget,
465    ) -> None:
466        """Defines a C++ unit test target within the build file."""
467
468    @abc.abstractmethod
469    def _write_docs_target(
470        self,
471        file: _OutputFile,
472        docs_sources: list[str],
473    ) -> None:
474        """Defines a documentation target within the build file."""
475
476
477# TODO(frolv): The dict here should be dict[str, '_GnVal'] (i.e. _GnScope),
478# but mypy does not yet support recursive types:
479# https://github.com/python/mypy/issues/731
480_GnVal = bool | int | str | list[str] | dict[str, Any]
481_GnScope = dict[str, _GnVal]
482
483
484class _GnBuildFile(_BuildFile):
485    _DEFAULT_FILENAME = 'BUILD.gn'
486    _INCLUDE_CONFIG_TARGET = 'public_include_path'
487
488    def __init__(
489        self,
490        directory: Path,
491        ctx: _ModuleContext,
492        filename: str = _DEFAULT_FILENAME,
493    ):
494        super().__init__(directory / filename, ctx)
495
496    def _indent_width(self) -> int:
497        return 2
498
499    def _write_preamble(self, file: _OutputFile) -> None:
500        # Upstream modules always require a tests target, even if it's empty.
501        has_tests = len(self._cc_tests) > 0 or self._ctx.is_upstream
502
503        imports = []
504
505        if self._cc_targets:
506            imports.append('$dir_pw_build/target_types.gni')
507
508        if has_tests:
509            imports.append('$dir_pw_unit_test/test.gni')
510
511        if self._docs_sources:
512            imports.append('$dir_pw_docgen/docs.gni')
513
514        file.line('import("//build_overrides/pigweed.gni")\n')
515        for imp in sorted(imports):
516            file.line(f'import("{imp}")')
517
518        if self._cc_targets:
519            file.line()
520            _GnBuildFile._target(
521                file,
522                'config',
523                _GnBuildFile._INCLUDE_CONFIG_TARGET,
524                {
525                    'include_dirs': ['public'],
526                    'visibility': [':*'],
527                },
528            )
529
530        if has_tests:
531            file.line()
532            _GnBuildFile._target(
533                file,
534                'pw_test_group',
535                'tests',
536                {
537                    'tests': list(f':{test.name}' for test in self._cc_tests),
538                },
539            )
540
541    def _write_cc_target(
542        self,
543        file: _OutputFile,
544        target: _BuildFile.CcTarget,
545    ) -> None:
546        """Defines a GN source_set for a C++ target."""
547
548        target_vars: _GnScope = {}
549
550        if target.headers:
551            target_vars['public_configs'] = [
552                f':{_GnBuildFile._INCLUDE_CONFIG_TARGET}'
553            ]
554            target_vars['public'] = list(target.rebased_headers(self.dir))
555
556        if target.sources:
557            target_vars['sources'] = list(target.rebased_sources(self.dir))
558
559        if target.deps:
560            target_vars['deps'] = target.deps
561
562        _GnBuildFile._target(file, 'pw_source_set', target.name, target_vars)
563
564    def _write_cc_test(
565        self,
566        file: _OutputFile,
567        target: _BuildFile.CcTarget,
568    ) -> None:
569        _GnBuildFile._target(
570            file,
571            'pw_test',
572            target.name,
573            {
574                'sources': list(target.rebased_sources(self.dir)),
575                'deps': target.deps,
576            },
577        )
578
579    def _write_docs_target(
580        self,
581        file: _OutputFile,
582        docs_sources: list[str],
583    ) -> None:
584        """Defines a pw_doc_group for module documentation."""
585        _GnBuildFile._target(
586            file,
587            'pw_doc_group',
588            'docs',
589            {
590                'sources': docs_sources,
591            },
592        )
593
594    @staticmethod
595    def _target(
596        file: _OutputFile,
597        target_type: str,
598        name: str,
599        args: _GnScope,
600    ) -> None:
601        """Formats a GN target."""
602
603        file.line(f'{target_type}("{name}") {{')
604
605        with file.indent():
606            _GnBuildFile._format_gn_scope(file, args)
607
608        file.line('}')
609
610    @staticmethod
611    def _format_gn_scope(file: _OutputFile, scope: _GnScope) -> None:
612        """Formats all of the variables within a GN scope to a file.
613
614        This function does not write the enclosing braces of the outer scope to
615        support use from multiple formatting contexts.
616        """
617        for key, val in scope.items():
618            if isinstance(val, int):
619                file.line(f'{key} = {val}')
620                continue
621
622            if isinstance(val, str):
623                file.line(f'{key} = {_GnBuildFile._gn_string(val)}')
624                continue
625
626            if isinstance(val, bool):
627                file.line(f'{key} = {str(val).lower()}')
628                continue
629
630            if isinstance(val, dict):
631                file.line(f'{key} = {{')
632                with file.indent():
633                    _GnBuildFile._format_gn_scope(file, val)
634                file.line('}')
635                continue
636
637            # Format a list of strings.
638            # TODO(frolv): Lists of other types?
639            assert isinstance(val, list)
640
641            if not val:
642                file.line(f'{key} = []')
643                continue
644
645            if len(val) == 1:
646                file.line(f'{key} = [ {_GnBuildFile._gn_string(val[0])} ]')
647                continue
648
649            file.line(f'{key} = [')
650            with file.indent():
651                for string in sorted(val):
652                    file.line(f'{_GnBuildFile._gn_string(string)},')
653            file.line(']')
654
655    @staticmethod
656    def _gn_string(string: str) -> str:
657        """Converts a Python string into a string literal within a GN file.
658
659        Accounts for the possibility of variable interpolation within GN,
660        removing quotes if unnecessary:
661
662            "string"           ->  "string"
663            "string"           ->  "string"
664            "$var"             ->  var
665            "$var2"            ->  var2
666            "$3var"            ->  "$3var"
667            "$dir_pw_foo"      ->  dir_pw_foo
668            "$dir_pw_foo:bar"  ->  "$dir_pw_foo:bar"
669            "$dir_pw_foo/baz"  ->  "$dir_pw_foo/baz"
670            "${dir_pw_foo}"    ->  dir_pw_foo
671
672        """
673
674        # Check if the entire string refers to a interpolated variable.
675        #
676        # Simple case: '$' followed a single word, e.g. "$my_variable".
677        # Note that identifiers can't start with a number.
678        if re.fullmatch(r'^\$[a-zA-Z_]\w*$', string):
679            return string[1:]
680
681        # GN permits wrapping an interpolated variable in braces.
682        # Check for strings of the format "${my_variable}".
683        if re.fullmatch(r'^\$\{[a-zA-Z_]\w*\}$', string):
684            return string[2:-1]
685
686        return f'"{string}"'
687
688
689class _BazelBuildFile(_BuildFile):
690    _DEFAULT_FILENAME = 'BUILD.bazel'
691
692    def __init__(
693        self,
694        directory: Path,
695        ctx: _ModuleContext,
696        filename: str = _DEFAULT_FILENAME,
697    ):
698        super().__init__(directory / filename, ctx)
699
700    def write(self) -> None:
701        """Writes the contents of the build file to disk."""
702        file = _OutputFile(self._path)
703        file.write_template('BUILD.bazel.jinja', build=self, module=self._ctx)
704
705    def _indent_width(self) -> int:
706        return 4
707
708    # TODO(tonymd): Remove these functions once all file types are created with
709    # templates.
710    def _write_preamble(self, file: _OutputFile) -> None:
711        pass
712
713    def _write_cc_target(
714        self,
715        file: _OutputFile,
716        target: _BuildFile.CcTarget,
717    ) -> None:
718        pass
719
720    def _write_cc_test(
721        self,
722        file: _OutputFile,
723        target: _BuildFile.CcTarget,
724    ) -> None:
725        pass
726
727    def _write_docs_target(
728        self,
729        file: _OutputFile,
730        docs_sources: list[str],
731    ) -> None:
732        pass
733
734
735class _CmakeBuildFile(_BuildFile):
736    _DEFAULT_FILENAME = 'CMakeLists.txt'
737
738    def __init__(
739        self,
740        directory: Path,
741        ctx: _ModuleContext,
742        filename: str = _DEFAULT_FILENAME,
743    ):
744        super().__init__(directory / filename, ctx)
745
746    def write(self) -> None:
747        """Writes the contents of the build file to disk."""
748        file = _OutputFile(self._path)
749        file.write_template(
750            'CMakeLists.txt.jinja', build=self, module=self._ctx
751        )
752
753    def _indent_width(self) -> int:
754        return 2
755
756    # TODO(tonymd): Remove these functions once all file types are created with
757    # templates.
758    def _write_preamble(self, file: _OutputFile) -> None:
759        pass
760
761    def _write_cc_target(
762        self,
763        file: _OutputFile,
764        target: _BuildFile.CcTarget,
765    ) -> None:
766        pass
767
768    def _write_cc_test(
769        self,
770        file: _OutputFile,
771        target: _BuildFile.CcTarget,
772    ) -> None:
773        pass
774
775    def _write_docs_target(
776        self,
777        file: _OutputFile,
778        docs_sources: list[str],
779    ) -> None:
780        pass
781
782
783class _LanguageGenerator:
784    """Generates files for a programming language in a new Pigweed module."""
785
786    def __init__(self, ctx: _ModuleContext) -> None:
787        self._ctx = ctx
788
789    @abc.abstractmethod
790    def create_source_files(self) -> None:
791        """Creates the boilerplate source files required by the language."""
792
793
794class _CcLanguageGenerator(_LanguageGenerator):
795    """Generates boilerplate source files for a C++ module."""
796
797    def __init__(self, ctx: _ModuleContext) -> None:
798        super().__init__(ctx)
799
800        self._public_dir = ctx.dir / 'public'
801        self._headers_dir = self._public_dir / ctx.name.full
802
803    def create_source_files(self) -> None:
804        self._headers_dir.mkdir(parents=True, exist_ok=True)
805
806        main_header = self._new_header(self._ctx.name.main)
807        main_source = self._new_source(self._ctx.name.main)
808        test_source = self._new_source(f'{self._ctx.name.main}_test')
809
810        # TODO(frolv): This could be configurable.
811        namespace = self._ctx.name.default_namespace
812
813        main_source.line(
814            f'#include "{main_header.path.relative_to(self._public_dir)}"\n'
815        )
816        main_source.line(f'namespace {namespace} {{\n')
817        main_source.line('int magic = 42;\n')
818        main_source.line(f'}}  // namespace {namespace}')
819
820        main_header.line(f'namespace {namespace} {{\n')
821        main_header.line('extern int magic;\n')
822        main_header.line(f'}}  // namespace {namespace}')
823
824        test_source.line(
825            f'#include "{main_header.path.relative_to(self._public_dir)}"\n'
826        )
827        test_source.line('#include "pw_unit_test/framework.h"\n')
828        test_source.line(f'namespace {namespace} {{')
829        test_source.line('namespace {\n')
830        test_source.line(
831            f'TEST({self._ctx.name.upper_camel_case()}, GeneratesCorrectly) {{'
832        )
833        with test_source.indent():
834            test_source.line('EXPECT_EQ(magic, 42);')
835        test_source.line('}\n')
836        test_source.line('}  // namespace')
837        test_source.line(f'}}  // namespace {namespace}')
838
839        self._ctx.add_cc_target(
840            _BuildFile.CcTarget(
841                name=self._ctx.name.full,
842                sources=[main_source.path],
843                headers=[main_header.path],
844            )
845        )
846
847        self._ctx.add_cc_test(
848            _BuildFile.CcTarget(
849                name=f'{self._ctx.name.main}_test',
850                deps=[f':{self._ctx.name.full}'],
851                sources=[test_source.path],
852            )
853        )
854
855        main_header.write()
856        main_source.write()
857        test_source.write()
858
859    def _new_source(self, name: str) -> _OutputFile:
860        file = _OutputFile(self._ctx.dir / f'{name}.cc')
861
862        if self._ctx.is_upstream:
863            file.line(_PIGWEED_LICENSE_CC)
864            file.line()
865
866        return file
867
868    def _new_header(self, name: str) -> _OutputFile:
869        file = _OutputFile(self._headers_dir / f'{name}.h')
870
871        if self._ctx.is_upstream:
872            file.line(_PIGWEED_LICENSE_CC)
873
874        file.line('#pragma once\n')
875        return file
876
877
878_BUILD_FILES: dict[str, Type[_BuildFile]] = {
879    'bazel': _BazelBuildFile,
880    'cmake': _CmakeBuildFile,
881    'gn': _GnBuildFile,
882}
883
884_LANGUAGE_GENERATORS: dict[str, Type[_LanguageGenerator]] = {
885    'cc': _CcLanguageGenerator,
886}
887
888
889def _check_module_name(
890    module: str,
891    is_upstream: bool,
892) -> _ModuleName | None:
893    """Checks whether a module name is valid."""
894
895    name = _ModuleName.parse(module)
896    if not name:
897        _LOG.error(
898            '"%s" does not conform to the Pigweed module name format', module
899        )
900        return None
901
902    if is_upstream and name.prefix != 'pw':
903        _LOG.error('Modules within Pigweed itself must start with "pw_"')
904        return None
905
906    return name
907
908
909def _create_main_docs_file(ctx: _ModuleContext) -> None:
910    """Populates the top-level docs.rst file within a new module."""
911
912    template = get_template('docs.rst.jinja')
913    rendered_template = template.render(module=ctx)
914
915    docs_file = _OutputFile(ctx.dir / 'docs.rst')
916    ctx.add_docs_file(docs_file.path)
917    docs_file.write(content=rendered_template)
918
919
920def _basic_module_setup(
921    module_name: _ModuleName,
922    module_dir: Path,
923    build_systems: Iterable[str],
924    is_upstream: bool,
925) -> _ModuleContext:
926    """Creates the basic layout of a Pigweed module."""
927    module_dir.mkdir(parents=True, exist_ok=True)
928    public_dir = module_dir / 'public' / module_name.full
929    public_dir.mkdir(parents=True, exist_ok=True)
930
931    ctx = _ModuleContext(
932        name=module_name,
933        dir=module_dir,
934        root_build_files=[],
935        sub_build_files=[],
936        build_systems=list(build_systems),
937        is_upstream=is_upstream,
938    )
939
940    ctx.root_build_files.extend(
941        _BUILD_FILES[build](module_dir, ctx) for build in ctx.build_systems
942    )
943
944    _create_main_docs_file(ctx)
945
946    return ctx
947
948
949def _add_to_module_metadata(
950    project_root: Path,
951    module_name: _ModuleName,
952    languages: Iterable[str] | None = None,
953) -> None:
954    """Update sphinx module metadata."""
955    module_metadata_file = project_root / 'docs/module_metadata.json'
956    metadata_dict = json.loads(module_metadata_file.read_text())
957
958    language_tags = []
959    if languages:
960        for lang in languages:
961            if lang == 'cc':
962                language_tags.append('C++')
963
964    # Add the new entry if it doesn't exist
965    if module_name.full not in metadata_dict:
966        metadata_dict[module_name.full] = dict(
967            status='experimental',
968            languages=language_tags,
969        )
970
971    # Sort by module name.
972    sorted_metadata = dict(
973        sorted(metadata_dict.items(), key=lambda item: item[0])
974    )
975    output_text = json.dumps(sorted_metadata, sort_keys=False, indent=2)
976    output_text += '\n'
977
978    # Write the file.
979    if _prompt_overwrite(module_metadata_file, new_contents=output_text):
980        _report_write_file(module_metadata_file)
981        module_metadata_file.write_text(output_text)
982
983
984def _add_to_pigweed_modules_file(
985    project_root: Path,
986    module_name: _ModuleName,
987) -> None:
988    modules_file = project_root / 'PIGWEED_MODULES'
989    if not modules_file.exists():
990        _LOG.error(
991            'Could not locate PIGWEED_MODULES file; '
992            'your repository may be in a bad state.'
993        )
994        return
995
996    modules_gni_file = (
997        project_root / 'pw_build' / 'generated_pigweed_modules_lists.gni'
998    )
999
1000    # Cut off the extra newline at the end of the file.
1001    modules_list = modules_file.read_text().splitlines()
1002    if module_name.path in modules_list:
1003        _report_unchanged_file(modules_file)
1004        return
1005    modules_list.append(module_name.path)
1006    modules_list.sort()
1007    modules_list.append('')
1008    modules_file.write_text('\n'.join(modules_list))
1009    _report_edited_file(modules_file)
1010
1011    generate_modules_lists.main(
1012        root=project_root,
1013        modules_list=modules_file,
1014        modules_gni_file=modules_gni_file,
1015        mode=generate_modules_lists.Mode.UPDATE,
1016    )
1017    _report_edited_file(modules_gni_file)
1018
1019
1020def _add_to_root_cmakelists(
1021    project_root: Path,
1022    module_name: _ModuleName,
1023) -> None:
1024    new_line = f'add_subdirectory({module_name.path} EXCLUDE_FROM_ALL)\n'
1025
1026    path = project_root / 'CMakeLists.txt'
1027    if not path.exists():
1028        _LOG.error('Could not locate root CMakeLists.txt file.')
1029        return
1030
1031    lines = path.read_text().splitlines(keepends=True)
1032    if new_line in lines:
1033        _report_unchanged_file(path)
1034        return
1035
1036    add_subdir_start = 0
1037    while add_subdir_start < len(lines):
1038        if lines[add_subdir_start].startswith('add_subdirectory'):
1039            break
1040        add_subdir_start += 1
1041
1042    insert_point = add_subdir_start
1043    while (
1044        lines[insert_point].startswith('add_subdirectory')
1045        and lines[insert_point] < new_line
1046    ):
1047        insert_point += 1
1048
1049    lines.insert(insert_point, new_line)
1050    path.write_text(''.join(lines))
1051    _report_edited_file(path)
1052
1053
1054def _project_root() -> Path:
1055    """Returns the path to the root directory of the current project."""
1056    project_root = _PW_ENV.PW_PROJECT_ROOT
1057    if not project_root.is_dir():
1058        _LOG.error(
1059            'Expected env var $PW_PROJECT_ROOT to point to a directory, but '
1060            'found `%s` which is not a directory.',
1061            project_root,
1062        )
1063        sys.exit(1)
1064    return project_root
1065
1066
1067def _is_upstream() -> bool:
1068    """Returns whether this command is being run within Pigweed itself."""
1069    return _PW_ROOT == _project_root()
1070
1071
1072_COMMENTS = re.compile(r'\w*#.*$')
1073
1074
1075def _read_root_owners(project_root: Path) -> Iterable[str]:
1076    for line in (project_root / 'OWNERS').read_text().splitlines():
1077        line = _COMMENTS.sub('', line).strip()
1078        if line:
1079            yield line
1080
1081
1082def _create_module(
1083    module: str,
1084    languages: Iterable[str],
1085    build_systems: Iterable[str],
1086    owners: Collection[str] | None = None,
1087) -> None:
1088    project_root = _project_root()
1089    is_upstream = _is_upstream()
1090
1091    module_name = _check_module_name(module, is_upstream)
1092    if not module_name:
1093        sys.exit(1)
1094
1095    if not is_upstream:
1096        _LOG.error(
1097            '`pw module create` is experimental and does '
1098            'not yet support downstream projects.'
1099        )
1100        sys.exit(1)
1101
1102    module_dir = project_root / module
1103
1104    if module_dir.is_dir():
1105        _REPORT.wrn(f'Directory {module} already exists.')
1106        if _prompt_user('Continue?') == PromptChoice.NO:
1107            sys.exit(1)
1108
1109    if module_dir.is_file():
1110        _LOG.error(
1111            'Cannot create module %s as a file of that name already exists',
1112            module,
1113        )
1114        sys.exit(1)
1115
1116    if owners is not None:
1117        if len(owners) < 2:
1118            _LOG.error(
1119                'New modules must have at least two owners, but only `%s` was '
1120                'provided.',
1121                owners,
1122            )
1123            sys.exit(1)
1124        for owner in owners:
1125            if '@' not in owner:
1126                _LOG.error(
1127                    'Owners should be email addresses, but found `%s`', owner
1128                )
1129                sys.exit(1)
1130        root_owners = list(_read_root_owners(project_root))
1131        if not any(owner in root_owners for owner in owners):
1132            root_owners_str = '\n'.join(root_owners)
1133            _LOG.error(
1134                'Module owners must include at least one root owner, but only '
1135                '`%s` was provided. Root owners include:\n%s',
1136                owners,
1137                root_owners_str,
1138            )
1139            sys.exit(1)
1140
1141    ctx = _basic_module_setup(
1142        module_name, module_dir, build_systems, is_upstream
1143    )
1144
1145    if owners is not None:
1146        owners_file = module_dir / 'OWNERS'
1147        owners_text = '\n'.join(sorted(owners))
1148        owners_text += '\n'
1149        if _prompt_overwrite(owners_file, new_contents=owners_text):
1150            _report_write_file(owners_file)
1151            owners_file.write_text(owners_text)
1152
1153    try:
1154        generators = list(_LANGUAGE_GENERATORS[lang](ctx) for lang in languages)
1155    except KeyError as key:
1156        _LOG.error('Unsupported language: %s', key)
1157        sys.exit(1)
1158
1159    for generator in generators:
1160        generator.create_source_files()
1161
1162    for build_file in ctx.build_files():
1163        build_file.write()
1164
1165    if is_upstream:
1166        _add_to_pigweed_modules_file(project_root, module_name)
1167        _add_to_module_metadata(project_root, module_name, languages)
1168        if 'cmake' in build_systems:
1169            _add_to_root_cmakelists(project_root, module_name)
1170
1171    print()
1172    _REPORT.new(f'{module_name} created at: {module_dir.relative_to(_PW_ROOT)}')
1173
1174
1175def register_subcommand(parser: argparse.ArgumentParser) -> None:
1176    """Registers the module `create` subcommand with `parser`."""
1177
1178    def csv(s):
1179        return s.split(",")
1180
1181    def csv_with_choices(choices: list[str], string) -> list[str]:
1182        chosen_items = list(string.split(','))
1183        invalid_items = set(chosen_items) - set(choices)
1184        if invalid_items:
1185            raise argparse.ArgumentTypeError(
1186                '\n'
1187                f'  invalid items: [ {", ".join(invalid_items)} ].\n'
1188                f'  choose from: [ {", ".join(choices)} ]'
1189            )
1190
1191        return chosen_items
1192
1193    parser.add_argument(
1194        '--build-systems',
1195        help=(
1196            'Comma-separated list of build systems the module supports. '
1197            f'Options: {", ".join(_BUILD_FILES.keys())}'
1198        ),
1199        default=_BUILD_FILES.keys(),
1200        type=functools.partial(csv_with_choices, _BUILD_FILES.keys()),
1201    )
1202    parser.add_argument(
1203        '--languages',
1204        help=(
1205            'Comma-separated list of languages the module will use. '
1206            f'Options: {", ".join(_LANGUAGE_GENERATORS.keys())}'
1207        ),
1208        default=[],
1209        type=functools.partial(csv_with_choices, _LANGUAGE_GENERATORS.keys()),
1210    )
1211    if _is_upstream():
1212        parser.add_argument(
1213            '--owners',
1214            help=(
1215                'Comma-separated list of emails of the people who will own and '
1216                'maintain the new module. This list must contain at least two '
1217                'entries, and at least one user must be a top-level OWNER '
1218                f'(listed in `{_project_root()}/OWNERS`).'
1219            ),
1220            required=True,
1221            metavar='[email protected],[email protected]',
1222            type=csv,
1223        )
1224    parser.add_argument(
1225        'module', help='Name of the module to create.', metavar='MODULE_NAME'
1226    )
1227    parser.set_defaults(func=_create_module)
1228