xref: /aosp_15_r20/external/pigweed/pw_config_loader/py/pw_config_loader/yaml_config_loader_mixin.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2022 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Yaml config file loader mixin."""
15
16import enum
17import os
18import logging
19from pathlib import Path
20from typing import Any, Sequence
21
22import yaml
23
24_LOG = logging.getLogger(__package__)
25
26
27class MissingConfigTitle(Exception):
28    """Exception for when an existing YAML file is missing config_title."""
29
30
31class Stage(enum.Enum):
32    DEFAULT = 1
33    PROJECT_FILE = 1
34    USER_PROJECT_FILE = 2
35    USER_FILE = 3
36    ENVIRONMENT_VAR_FILE = 4
37    OUT_OF_BAND = 5
38
39
40class YamlConfigLoaderMixin:
41    """Yaml Config file loader mixin.
42
43    Use this mixin to load yaml file settings and save them into
44    ``self._config``. For example:
45
46    ::
47
48       class ConsolePrefs(YamlConfigLoaderMixin):
49           def __init__(self) -> None:
50               self.config_init(
51                   config_section_title='pw_console',
52                   project_file=Path('project_file.yaml'),
53                   project_user_file=Path('project_user_file.yaml'),
54                   user_file=Path('~/user_file.yaml'),
55                   default_config={},
56                   environment_var='PW_CONSOLE_CONFIG_FILE',
57               )
58
59    """
60
61    def config_init(
62        self,
63        config_section_title: str | Sequence[str],
64        project_file: Path | bool | None = None,
65        project_user_file: Path | bool | None = None,
66        user_file: Path | bool | None = None,
67        default_config: dict[Any, Any] | None = None,
68        environment_var: str | None = None,
69        skip_files_without_sections: bool = False,
70    ) -> None:
71        """Call this to load YAML config files in order of precedence.
72
73        The following files are loaded in this order:
74        1. project_file
75        2. project_user_file
76        3. user_file
77
78        Lastly, if a valid file path is specified at
79        ``os.environ[environment_var]`` then load that file overriding all
80        config options.
81
82        Args:
83            config_section_title: String name of this config section. For
84                example: ``pw_console`` or ``pw_watch``. In the YAML file this
85                is represented by a ``config_title`` key.
86
87                ::
88
89                   ---
90                   config_title: pw_console
91
92            project_file: Project level config file. This is intended to be a
93                file living somewhere under a project folder and is checked into
94                the repo. It serves as a base config all developers can inherit
95                from.
96            project_user_file: User's personal config file for a specific
97                project. This can be a file that lives in a project folder that
98                is git-ignored and not checked into the repo.
99            user_file: A global user based config file. This is typically a file
100                in the users home directory and settings here apply to all
101                projects.
102            default_config: A Python dict representing the base default
103                config. This dict will be applied as a starting point before
104                loading any yaml files.
105            environment_var: The name of an environment variable to check for a
106                config file. If a config file exists there it will be loaded on
107                top of the default_config ignoring project and user files.
108            skip_files_without_sections: Don't produce an exception if a
109                config file doesn't include the relevant section. Instead, just
110                move on to the next file.
111        """
112
113        self._config_section_title: tuple[str, ...]
114        if isinstance(config_section_title, (list, tuple)):
115            self._config_section_title = tuple(config_section_title)
116        elif isinstance(config_section_title, str):
117            self._config_section_title = (config_section_title,)
118        else:
119            raise TypeError(
120                f'unexpected config section title {config_section_title!r}'
121            )
122        self.default_config = default_config if default_config else {}
123        self.reset_config()
124
125        if project_file and isinstance(project_file, Path):
126            self.project_file = Path(
127                os.path.expandvars(str(project_file.expanduser()))
128            )
129            self.load_config_file(
130                self.project_file,
131                skip_files_without_sections=skip_files_without_sections,
132                stage=Stage.PROJECT_FILE,
133            )
134
135        if project_user_file and isinstance(project_user_file, Path):
136            self.project_user_file = Path(
137                os.path.expandvars(str(project_user_file.expanduser()))
138            )
139            self.load_config_file(
140                self.project_user_file,
141                skip_files_without_sections=skip_files_without_sections,
142                stage=Stage.USER_PROJECT_FILE,
143            )
144
145        if user_file and isinstance(user_file, Path):
146            self.user_file = Path(
147                os.path.expandvars(str(user_file.expanduser()))
148            )
149            self.load_config_file(
150                self.user_file,
151                skip_files_without_sections=skip_files_without_sections,
152                stage=Stage.USER_FILE,
153            )
154
155        # Check for a config file specified by an environment variable.
156        if environment_var is None:
157            return
158        environment_config = os.environ.get(environment_var, None)
159        if environment_config:
160            env_file_path = Path(environment_config)
161            if not env_file_path.is_file():
162                raise FileNotFoundError(
163                    f'Cannot load config file: {env_file_path}'
164                )
165            self.reset_config()
166            self.load_config_file(
167                env_file_path,
168                skip_files_without_sections=skip_files_without_sections,
169                stage=Stage.ENVIRONMENT_VAR_FILE,
170            )
171
172    def _update_config(self, cfg: dict[Any, Any], stage: Stage) -> None:
173        if cfg is None:
174            cfg = {}
175        for key, value in cfg.items():
176            if stage != Stage.DEFAULT:
177                self._config[key] = self.handle_overloaded_value(
178                    key=key,
179                    stage=stage,
180                    original_value=self._config.get(key),
181                    overriding_value=value,
182                )
183            else:
184                self._config[key] = value
185
186    def handle_overloaded_value(  # pylint: disable=no-self-use
187        self,
188        key: str,  # pylint: disable=unused-argument
189        stage: Stage,  # pylint: disable=unused-argument
190        original_value: Any,  # pylint: disable=unused-argument
191        overriding_value: Any,
192    ) -> Any:
193        """Overload this in subclasses to handle of overloaded values."""
194        return overriding_value
195
196    def reset_config(self) -> None:
197        self._config: dict[Any, Any] = {}
198        self._update_config(self.default_config, Stage.DEFAULT)
199
200    def _load_config_from_string(  # pylint: disable=no-self-use
201        self, file_contents: str
202    ) -> list[dict[Any, Any]]:
203        return list(yaml.safe_load_all(file_contents))
204
205    def load_config_file(
206        self,
207        file_path: Path,
208        skip_files_without_sections: bool = False,
209        stage: Stage = Stage.OUT_OF_BAND,
210    ) -> None:
211        """Load a config file and extract the appropriate section."""
212        if not file_path.is_file():
213            return
214
215        cfgs = self._load_config_from_string(file_path.read_text())
216
217        for cfg in cfgs:
218            cfg_copy = cfg
219            for config_section_title in self._config_section_title:
220                if config_section_title in cfg_copy:
221                    cfg_copy = cfg_copy[config_section_title]
222                else:
223                    break
224            else:
225                self._update_config(cfg_copy, stage)
226                continue
227
228            config_title_value = '.'.join(self._config_section_title)
229            if cfg.get('config_title', False) == config_title_value:
230                self._update_config(cfg, stage)
231                continue
232
233            if skip_files_without_sections:
234                pass
235            else:
236                raise MissingConfigTitle(
237                    f'\n\nThe config file "{file_path}" is missing the '
238                    f'expected "config_title: {config_title_value}" '
239                    'setting.'
240                )
241