xref: /aosp_15_r20/external/pigweed/pw_config_loader/py/yaml_config_loader_mixin_test.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Tests for pw_config_loader."""
15
16from pathlib import Path
17import tempfile
18from typing import Any
19import unittest
20
21from pw_config_loader import yaml_config_loader_mixin
22import yaml
23
24# pylint: disable=no-member,no-self-use
25
26
27class YamlConfigLoader(yaml_config_loader_mixin.YamlConfigLoaderMixin):
28    @property
29    def config(self) -> dict[str, Any]:
30        return self._config
31
32
33class TestOneFile(unittest.TestCase):
34    """Tests for loading a config section from one file."""
35
36    def setUp(self):
37        self._title = 'title'
38
39    def init(self, config: dict[str, Any]) -> dict[str, Any]:
40        loader = YamlConfigLoader()
41        with tempfile.TemporaryDirectory() as folder:
42            path = Path(folder, 'foo.yaml')
43            path.write_bytes(yaml.safe_dump(config).encode())
44            loader.config_init(
45                user_file=path,
46                config_section_title=self._title,
47            )
48            return loader.config
49
50    def test_normal(self):
51        content = {'a': 1, 'b': 2}
52        config = self.init({self._title: content})
53        self.assertEqual(content['a'], config['a'])
54        self.assertEqual(content['b'], config['b'])
55
56    def test_config_title(self):
57        content = {'a': 1, 'b': 2, 'config_title': self._title}
58        config = self.init(content)
59        self.assertEqual(content['a'], config['a'])
60        self.assertEqual(content['b'], config['b'])
61
62
63class TestMultipleFiles(unittest.TestCase):
64    """Tests for loading config sections from multiple files."""
65
66    def init(
67        self,
68        project_config: dict[str, Any],
69        project_user_config: dict[str, Any],
70        user_config: dict[str, Any],
71    ) -> dict[str, Any]:
72        """Write config files then read and parse them."""
73
74        loader = YamlConfigLoader()
75        title = 'title'
76
77        with tempfile.TemporaryDirectory() as folder:
78            path = Path(folder)
79
80            user_path = path / 'user.yaml'
81            user_path.write_text(yaml.safe_dump({title: user_config}))
82
83            project_user_path = path / 'project_user.yaml'
84            project_user_path.write_text(
85                yaml.safe_dump({title: project_user_config})
86            )
87
88            project_path = path / 'project.yaml'
89            project_path.write_text(yaml.safe_dump({title: project_config}))
90
91            loader.config_init(
92                user_file=user_path,
93                project_user_file=project_user_path,
94                project_file=project_path,
95                config_section_title=title,
96            )
97
98        return loader.config
99
100    def test_user_override(self):
101        config = self.init(
102            user_config={'a': 1},
103            project_user_config={'a': 2},
104            project_config={'a': 3},
105        )
106        self.assertEqual(config['a'], 1)
107
108    def test_project_user_override(self):
109        config = self.init(
110            user_config={},
111            project_user_config={'a': 2},
112            project_config={'a': 3},
113        )
114        self.assertEqual(config['a'], 2)
115
116    def test_not_overridden(self):
117        config = self.init(
118            user_config={},
119            project_user_config={},
120            project_config={'a': 3},
121        )
122        self.assertEqual(config['a'], 3)
123
124    def test_different_keys(self):
125        config = self.init(
126            user_config={'a': 1},
127            project_user_config={'b': 2},
128            project_config={'c': 3},
129        )
130        self.assertEqual(config['a'], 1)
131        self.assertEqual(config['b'], 2)
132        self.assertEqual(config['c'], 3)
133
134
135class TestNestedTitle(unittest.TestCase):
136    """Tests for nested config section loading."""
137
138    def setUp(self):
139        self._title = ('title', 'subtitle', 'subsubtitle', 'subsubsubtitle')
140
141    def init(self, config: dict[str, Any]) -> dict[str, Any]:
142        loader = YamlConfigLoader()
143        with tempfile.TemporaryDirectory() as folder:
144            path = Path(folder, 'foo.yaml')
145            path.write_bytes(yaml.safe_dump(config).encode())
146            loader.config_init(
147                user_file=path,
148                config_section_title=self._title,
149            )
150            return loader.config
151
152    def test_normal(self):
153        content = {'a': 1, 'b': 2}
154        for part in reversed(self._title):
155            content = {part: content}
156        config = self.init(content)
157        self.assertEqual(config['a'], 1)
158        self.assertEqual(config['b'], 2)
159
160    def test_config_title(self):
161        content = {'a': 1, 'b': 2, 'config_title': '.'.join(self._title)}
162        config = self.init(content)
163        self.assertEqual(config['a'], 1)
164        self.assertEqual(config['b'], 2)
165
166
167class CustomOverloadYamlConfigLoader(
168    yaml_config_loader_mixin.YamlConfigLoaderMixin
169):
170    """Custom config loader that implements handle_overloaded_value()."""
171
172    @property
173    def config(self) -> dict[str, Any]:
174        return self._config
175
176    def handle_overloaded_value(  # pylint: disable=no-self-use
177        self,
178        key: str,
179        stage: yaml_config_loader_mixin.Stage,
180        original_value: Any,
181        overriding_value: Any,
182    ):
183        if key == 'extend':
184            if original_value:
185                return original_value + overriding_value
186            return overriding_value
187
188        if key == 'extend_sort':
189            if original_value:
190                result = original_value + overriding_value
191            else:
192                result = overriding_value
193            return sorted(result)
194
195        if key == 'do_not_override':
196            if original_value:
197                return original_value
198
199        if key == 'max':
200            return max(original_value, overriding_value)
201
202        return overriding_value
203
204
205class TestOverloading(unittest.TestCase):
206    """Tests for envparse.EnvironmentParser."""
207
208    def init(
209        self,
210        project_config: dict[str, Any],
211        project_user_config: dict[str, Any],
212        user_config: dict[str, Any],
213    ) -> dict[str, Any]:
214        """Write config files then read and parse them."""
215
216        loader = CustomOverloadYamlConfigLoader()
217        title = 'title'
218
219        with tempfile.TemporaryDirectory() as folder:
220            path = Path(folder)
221
222            user_path = path / 'user.yaml'
223            user_path.write_text(yaml.safe_dump({title: user_config}))
224
225            project_user_path = path / 'project_user.yaml'
226            project_user_path.write_text(
227                yaml.safe_dump({title: project_user_config})
228            )
229
230            project_path = path / 'project.yaml'
231            project_path.write_text(yaml.safe_dump({title: project_config}))
232
233            loader.config_init(
234                user_file=user_path,
235                project_user_file=project_user_path,
236                project_file=project_path,
237                config_section_title=title,
238            )
239
240        return loader.config
241
242    def test_lists(self):
243        config = self.init(
244            project_config={
245                'extend': list('abc'),
246                'extend_sort': list('az'),
247                'do_not_override': ['persists'],
248                'override': ['hidden'],
249            },
250            project_user_config={
251                'extend': list('def'),
252                'extend_sort': list('by'),
253                'do_not_override': ['ignored'],
254                'override': ['ignored'],
255            },
256            user_config={
257                'extend': list('ghi'),
258                'extend_sort': list('cx'),
259                'do_not_override': ['ignored_2'],
260                'override': ['overrides'],
261            },
262        )
263        self.assertEqual(config['extend'], list('abcdefghi'))
264        self.assertEqual(config['extend_sort'], list('abcxyz'))
265        self.assertEqual(config['do_not_override'], ['persists'])
266        self.assertEqual(config['override'], ['overrides'])
267
268    def test_scalars(self):
269        config = self.init(
270            project_config={'extend': 'abc', 'max': 1},
271            project_user_config={'extend': 'def', 'max': 3},
272            user_config={'extend': 'ghi', 'max': 2},
273        )
274        self.assertEqual(config['extend'], 'abcdefghi')
275        self.assertEqual(config['max'], 3)
276
277
278if __name__ == '__main__':
279    unittest.main()
280