1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Tests for pw_config_loader.""" 15 16from pathlib import Path 17import tempfile 18from typing import Any 19import unittest 20 21from pw_config_loader import yaml_config_loader_mixin 22import yaml 23 24# pylint: disable=no-member,no-self-use 25 26 27class YamlConfigLoader(yaml_config_loader_mixin.YamlConfigLoaderMixin): 28 @property 29 def config(self) -> dict[str, Any]: 30 return self._config 31 32 33class TestOneFile(unittest.TestCase): 34 """Tests for loading a config section from one file.""" 35 36 def setUp(self): 37 self._title = 'title' 38 39 def init(self, config: dict[str, Any]) -> dict[str, Any]: 40 loader = YamlConfigLoader() 41 with tempfile.TemporaryDirectory() as folder: 42 path = Path(folder, 'foo.yaml') 43 path.write_bytes(yaml.safe_dump(config).encode()) 44 loader.config_init( 45 user_file=path, 46 config_section_title=self._title, 47 ) 48 return loader.config 49 50 def test_normal(self): 51 content = {'a': 1, 'b': 2} 52 config = self.init({self._title: content}) 53 self.assertEqual(content['a'], config['a']) 54 self.assertEqual(content['b'], config['b']) 55 56 def test_config_title(self): 57 content = {'a': 1, 'b': 2, 'config_title': self._title} 58 config = self.init(content) 59 self.assertEqual(content['a'], config['a']) 60 self.assertEqual(content['b'], config['b']) 61 62 63class TestMultipleFiles(unittest.TestCase): 64 """Tests for loading config sections from multiple files.""" 65 66 def init( 67 self, 68 project_config: dict[str, Any], 69 project_user_config: dict[str, Any], 70 user_config: dict[str, Any], 71 ) -> dict[str, Any]: 72 """Write config files then read and parse them.""" 73 74 loader = YamlConfigLoader() 75 title = 'title' 76 77 with tempfile.TemporaryDirectory() as folder: 78 path = Path(folder) 79 80 user_path = path / 'user.yaml' 81 user_path.write_text(yaml.safe_dump({title: user_config})) 82 83 project_user_path = path / 'project_user.yaml' 84 project_user_path.write_text( 85 yaml.safe_dump({title: project_user_config}) 86 ) 87 88 project_path = path / 'project.yaml' 89 project_path.write_text(yaml.safe_dump({title: project_config})) 90 91 loader.config_init( 92 user_file=user_path, 93 project_user_file=project_user_path, 94 project_file=project_path, 95 config_section_title=title, 96 ) 97 98 return loader.config 99 100 def test_user_override(self): 101 config = self.init( 102 user_config={'a': 1}, 103 project_user_config={'a': 2}, 104 project_config={'a': 3}, 105 ) 106 self.assertEqual(config['a'], 1) 107 108 def test_project_user_override(self): 109 config = self.init( 110 user_config={}, 111 project_user_config={'a': 2}, 112 project_config={'a': 3}, 113 ) 114 self.assertEqual(config['a'], 2) 115 116 def test_not_overridden(self): 117 config = self.init( 118 user_config={}, 119 project_user_config={}, 120 project_config={'a': 3}, 121 ) 122 self.assertEqual(config['a'], 3) 123 124 def test_different_keys(self): 125 config = self.init( 126 user_config={'a': 1}, 127 project_user_config={'b': 2}, 128 project_config={'c': 3}, 129 ) 130 self.assertEqual(config['a'], 1) 131 self.assertEqual(config['b'], 2) 132 self.assertEqual(config['c'], 3) 133 134 135class TestNestedTitle(unittest.TestCase): 136 """Tests for nested config section loading.""" 137 138 def setUp(self): 139 self._title = ('title', 'subtitle', 'subsubtitle', 'subsubsubtitle') 140 141 def init(self, config: dict[str, Any]) -> dict[str, Any]: 142 loader = YamlConfigLoader() 143 with tempfile.TemporaryDirectory() as folder: 144 path = Path(folder, 'foo.yaml') 145 path.write_bytes(yaml.safe_dump(config).encode()) 146 loader.config_init( 147 user_file=path, 148 config_section_title=self._title, 149 ) 150 return loader.config 151 152 def test_normal(self): 153 content = {'a': 1, 'b': 2} 154 for part in reversed(self._title): 155 content = {part: content} 156 config = self.init(content) 157 self.assertEqual(config['a'], 1) 158 self.assertEqual(config['b'], 2) 159 160 def test_config_title(self): 161 content = {'a': 1, 'b': 2, 'config_title': '.'.join(self._title)} 162 config = self.init(content) 163 self.assertEqual(config['a'], 1) 164 self.assertEqual(config['b'], 2) 165 166 167class CustomOverloadYamlConfigLoader( 168 yaml_config_loader_mixin.YamlConfigLoaderMixin 169): 170 """Custom config loader that implements handle_overloaded_value().""" 171 172 @property 173 def config(self) -> dict[str, Any]: 174 return self._config 175 176 def handle_overloaded_value( # pylint: disable=no-self-use 177 self, 178 key: str, 179 stage: yaml_config_loader_mixin.Stage, 180 original_value: Any, 181 overriding_value: Any, 182 ): 183 if key == 'extend': 184 if original_value: 185 return original_value + overriding_value 186 return overriding_value 187 188 if key == 'extend_sort': 189 if original_value: 190 result = original_value + overriding_value 191 else: 192 result = overriding_value 193 return sorted(result) 194 195 if key == 'do_not_override': 196 if original_value: 197 return original_value 198 199 if key == 'max': 200 return max(original_value, overriding_value) 201 202 return overriding_value 203 204 205class TestOverloading(unittest.TestCase): 206 """Tests for envparse.EnvironmentParser.""" 207 208 def init( 209 self, 210 project_config: dict[str, Any], 211 project_user_config: dict[str, Any], 212 user_config: dict[str, Any], 213 ) -> dict[str, Any]: 214 """Write config files then read and parse them.""" 215 216 loader = CustomOverloadYamlConfigLoader() 217 title = 'title' 218 219 with tempfile.TemporaryDirectory() as folder: 220 path = Path(folder) 221 222 user_path = path / 'user.yaml' 223 user_path.write_text(yaml.safe_dump({title: user_config})) 224 225 project_user_path = path / 'project_user.yaml' 226 project_user_path.write_text( 227 yaml.safe_dump({title: project_user_config}) 228 ) 229 230 project_path = path / 'project.yaml' 231 project_path.write_text(yaml.safe_dump({title: project_config})) 232 233 loader.config_init( 234 user_file=user_path, 235 project_user_file=project_user_path, 236 project_file=project_path, 237 config_section_title=title, 238 ) 239 240 return loader.config 241 242 def test_lists(self): 243 config = self.init( 244 project_config={ 245 'extend': list('abc'), 246 'extend_sort': list('az'), 247 'do_not_override': ['persists'], 248 'override': ['hidden'], 249 }, 250 project_user_config={ 251 'extend': list('def'), 252 'extend_sort': list('by'), 253 'do_not_override': ['ignored'], 254 'override': ['ignored'], 255 }, 256 user_config={ 257 'extend': list('ghi'), 258 'extend_sort': list('cx'), 259 'do_not_override': ['ignored_2'], 260 'override': ['overrides'], 261 }, 262 ) 263 self.assertEqual(config['extend'], list('abcdefghi')) 264 self.assertEqual(config['extend_sort'], list('abcxyz')) 265 self.assertEqual(config['do_not_override'], ['persists']) 266 self.assertEqual(config['override'], ['overrides']) 267 268 def test_scalars(self): 269 config = self.init( 270 project_config={'extend': 'abc', 'max': 1}, 271 project_user_config={'extend': 'def', 'max': 3}, 272 user_config={'extend': 'ghi', 'max': 2}, 273 ) 274 self.assertEqual(config['extend'], 'abcdefghi') 275 self.assertEqual(config['max'], 3) 276 277 278if __name__ == '__main__': 279 unittest.main() 280