1# Copyright 2022 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Yaml config file loader mixin.""" 15 16import enum 17import os 18import logging 19from pathlib import Path 20from typing import Any, Sequence 21 22import yaml 23 24_LOG = logging.getLogger(__package__) 25 26 27class MissingConfigTitle(Exception): 28 """Exception for when an existing YAML file is missing config_title.""" 29 30 31class Stage(enum.Enum): 32 DEFAULT = 1 33 PROJECT_FILE = 1 34 USER_PROJECT_FILE = 2 35 USER_FILE = 3 36 ENVIRONMENT_VAR_FILE = 4 37 OUT_OF_BAND = 5 38 39 40class YamlConfigLoaderMixin: 41 """Yaml Config file loader mixin. 42 43 Use this mixin to load yaml file settings and save them into 44 ``self._config``. For example: 45 46 :: 47 48 class ConsolePrefs(YamlConfigLoaderMixin): 49 def __init__(self) -> None: 50 self.config_init( 51 config_section_title='pw_console', 52 project_file=Path('project_file.yaml'), 53 project_user_file=Path('project_user_file.yaml'), 54 user_file=Path('~/user_file.yaml'), 55 default_config={}, 56 environment_var='PW_CONSOLE_CONFIG_FILE', 57 ) 58 59 """ 60 61 def config_init( 62 self, 63 config_section_title: str | Sequence[str], 64 project_file: Path | bool | None = None, 65 project_user_file: Path | bool | None = None, 66 user_file: Path | bool | None = None, 67 default_config: dict[Any, Any] | None = None, 68 environment_var: str | None = None, 69 skip_files_without_sections: bool = False, 70 ) -> None: 71 """Call this to load YAML config files in order of precedence. 72 73 The following files are loaded in this order: 74 1. project_file 75 2. project_user_file 76 3. user_file 77 78 Lastly, if a valid file path is specified at 79 ``os.environ[environment_var]`` then load that file overriding all 80 config options. 81 82 Args: 83 config_section_title: String name of this config section. For 84 example: ``pw_console`` or ``pw_watch``. In the YAML file this 85 is represented by a ``config_title`` key. 86 87 :: 88 89 --- 90 config_title: pw_console 91 92 project_file: Project level config file. This is intended to be a 93 file living somewhere under a project folder and is checked into 94 the repo. It serves as a base config all developers can inherit 95 from. 96 project_user_file: User's personal config file for a specific 97 project. This can be a file that lives in a project folder that 98 is git-ignored and not checked into the repo. 99 user_file: A global user based config file. This is typically a file 100 in the users home directory and settings here apply to all 101 projects. 102 default_config: A Python dict representing the base default 103 config. This dict will be applied as a starting point before 104 loading any yaml files. 105 environment_var: The name of an environment variable to check for a 106 config file. If a config file exists there it will be loaded on 107 top of the default_config ignoring project and user files. 108 skip_files_without_sections: Don't produce an exception if a 109 config file doesn't include the relevant section. Instead, just 110 move on to the next file. 111 """ 112 113 self._config_section_title: tuple[str, ...] 114 if isinstance(config_section_title, (list, tuple)): 115 self._config_section_title = tuple(config_section_title) 116 elif isinstance(config_section_title, str): 117 self._config_section_title = (config_section_title,) 118 else: 119 raise TypeError( 120 f'unexpected config section title {config_section_title!r}' 121 ) 122 self.default_config = default_config if default_config else {} 123 self.reset_config() 124 125 if project_file and isinstance(project_file, Path): 126 self.project_file = Path( 127 os.path.expandvars(str(project_file.expanduser())) 128 ) 129 self.load_config_file( 130 self.project_file, 131 skip_files_without_sections=skip_files_without_sections, 132 stage=Stage.PROJECT_FILE, 133 ) 134 135 if project_user_file and isinstance(project_user_file, Path): 136 self.project_user_file = Path( 137 os.path.expandvars(str(project_user_file.expanduser())) 138 ) 139 self.load_config_file( 140 self.project_user_file, 141 skip_files_without_sections=skip_files_without_sections, 142 stage=Stage.USER_PROJECT_FILE, 143 ) 144 145 if user_file and isinstance(user_file, Path): 146 self.user_file = Path( 147 os.path.expandvars(str(user_file.expanduser())) 148 ) 149 self.load_config_file( 150 self.user_file, 151 skip_files_without_sections=skip_files_without_sections, 152 stage=Stage.USER_FILE, 153 ) 154 155 # Check for a config file specified by an environment variable. 156 if environment_var is None: 157 return 158 environment_config = os.environ.get(environment_var, None) 159 if environment_config: 160 env_file_path = Path(environment_config) 161 if not env_file_path.is_file(): 162 raise FileNotFoundError( 163 f'Cannot load config file: {env_file_path}' 164 ) 165 self.reset_config() 166 self.load_config_file( 167 env_file_path, 168 skip_files_without_sections=skip_files_without_sections, 169 stage=Stage.ENVIRONMENT_VAR_FILE, 170 ) 171 172 def _update_config(self, cfg: dict[Any, Any], stage: Stage) -> None: 173 if cfg is None: 174 cfg = {} 175 for key, value in cfg.items(): 176 if stage != Stage.DEFAULT: 177 self._config[key] = self.handle_overloaded_value( 178 key=key, 179 stage=stage, 180 original_value=self._config.get(key), 181 overriding_value=value, 182 ) 183 else: 184 self._config[key] = value 185 186 def handle_overloaded_value( # pylint: disable=no-self-use 187 self, 188 key: str, # pylint: disable=unused-argument 189 stage: Stage, # pylint: disable=unused-argument 190 original_value: Any, # pylint: disable=unused-argument 191 overriding_value: Any, 192 ) -> Any: 193 """Overload this in subclasses to handle of overloaded values.""" 194 return overriding_value 195 196 def reset_config(self) -> None: 197 self._config: dict[Any, Any] = {} 198 self._update_config(self.default_config, Stage.DEFAULT) 199 200 def _load_config_from_string( # pylint: disable=no-self-use 201 self, file_contents: str 202 ) -> list[dict[Any, Any]]: 203 return list(yaml.safe_load_all(file_contents)) 204 205 def load_config_file( 206 self, 207 file_path: Path, 208 skip_files_without_sections: bool = False, 209 stage: Stage = Stage.OUT_OF_BAND, 210 ) -> None: 211 """Load a config file and extract the appropriate section.""" 212 if not file_path.is_file(): 213 return 214 215 cfgs = self._load_config_from_string(file_path.read_text()) 216 217 for cfg in cfgs: 218 cfg_copy = cfg 219 for config_section_title in self._config_section_title: 220 if config_section_title in cfg_copy: 221 cfg_copy = cfg_copy[config_section_title] 222 else: 223 break 224 else: 225 self._update_config(cfg_copy, stage) 226 continue 227 228 config_title_value = '.'.join(self._config_section_title) 229 if cfg.get('config_title', False) == config_title_value: 230 self._update_config(cfg, stage) 231 continue 232 233 if skip_files_without_sections: 234 pass 235 else: 236 raise MissingConfigTitle( 237 f'\n\nThe config file "{file_path}" is missing the ' 238 f'expected "config_title: {config_title_value}" ' 239 'setting.' 240 ) 241