xref: /aosp_15_r20/external/pigweed/pw_env_setup/py/pw_env_setup/environment.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Stores the environment changes necessary for Pigweed."""
15
16import contextlib
17import os
18import re
19
20# The order here is important. On Python 2 we want StringIO.StringIO and not
21# io.StringIO. On Python 3 there is no StringIO module so we want io.StringIO.
22# Not using six because six is not a standard package we can expect to have
23# installed in the system Python.
24try:
25    from StringIO import StringIO  # type: ignore
26except ImportError:
27    from io import StringIO
28
29from . import apply_visitor
30from . import batch_visitor
31from . import github_visitor
32from . import gni_visitor
33from . import json_visitor
34from . import shell_visitor
35
36
37class BadNameType(TypeError):
38    pass
39
40
41class BadValueType(TypeError):
42    pass
43
44
45class EmptyValue(ValueError):
46    pass
47
48
49class NewlineInValue(TypeError):
50    pass
51
52
53class BadVariableName(ValueError):
54    pass
55
56
57class UnexpectedAction(ValueError):
58    pass
59
60
61class AcceptNotOverridden(TypeError):
62    pass
63
64
65class _Action:
66    def unapply(self, env, orig_env):
67        pass
68
69    def accept(self, visitor):
70        del visitor
71        raise AcceptNotOverridden(
72            'accept() not overridden for {}'.format(self.__class__.__name__)
73        )
74
75    def write_deactivate(
76        self, outs, windows=(os.name == 'nt'), replacements=()
77    ):
78        pass
79
80
81class _VariableAction(_Action):
82    def __init__(self, name, value, *args, allow_empty_values=False, **kwargs):
83        super().__init__(*args, **kwargs)
84        self.name = name
85        self.value = value
86        self.allow_empty_values = allow_empty_values
87
88        self._check()
89
90    def _check(self):
91        try:
92            # In python2, unicode is a distinct type.
93            valid_types = (str, unicode)
94        except NameError:
95            valid_types = (str,)
96
97        if not isinstance(self.name, valid_types):
98            raise BadNameType(
99                'variable name {!r} not of type str'.format(self.name)
100            )
101        if not isinstance(self.value, valid_types):
102            raise BadValueType(
103                '{!r} value {!r} not of type str'.format(self.name, self.value)
104            )
105
106        # Empty strings as environment variable values have different behavior
107        # on different operating systems. Just don't allow them.
108        if not self.allow_empty_values and self.value == '':
109            raise EmptyValue(
110                '{!r} value {!r} is the empty string'.format(
111                    self.name, self.value
112                )
113            )
114
115        # Many tools have issues with newlines in environment variable values.
116        # Just don't allow them.
117        if '\n' in self.value:
118            raise NewlineInValue(
119                '{!r} value {!r} contains a newline'.format(
120                    self.name, self.value
121                )
122            )
123
124        if not re.match(r'^[A-Z_][A-Z0-9_]*$', self.name, re.IGNORECASE):
125            raise BadVariableName('bad variable name {!r}'.format(self.name))
126
127    def unapply(self, env, orig_env):
128        if self.name in orig_env:
129            env[self.name] = orig_env[self.name]
130        else:
131            env.pop(self.name, None)
132
133    def __repr__(self):
134        return '{}({}, {})'.format(
135            self.__class__.__name__, self.name, self.value
136        )
137
138
139class Set(_VariableAction):
140    """Set a variable."""
141
142    def __init__(self, *args, **kwargs):
143        deactivate = kwargs.pop('deactivate', True)
144        super().__init__(*args, **kwargs)
145        self.deactivate = deactivate
146
147    def accept(self, visitor):
148        visitor.visit_set(self)
149
150
151class Clear(_VariableAction):
152    """Remove a variable from the environment."""
153
154    def __init__(self, *args, **kwargs):
155        kwargs['value'] = ''
156        kwargs['allow_empty_values'] = True
157        super().__init__(*args, **kwargs)
158
159    def accept(self, visitor):
160        visitor.visit_clear(self)
161
162
163class Remove(_VariableAction):
164    """Remove a value from a PATH-like variable."""
165
166    def accept(self, visitor):
167        visitor.visit_remove(self)
168
169
170class BadVariableValue(ValueError):
171    pass
172
173
174def _append_prepend_check(action):
175    if '=' in action.value:
176        raise BadVariableValue('"{}" contains "="'.format(action.value))
177
178
179class Prepend(_VariableAction):
180    """Prepend a value to a PATH-like variable."""
181
182    def __init__(self, name, value, join, *args, **kwargs):
183        super().__init__(name, value, *args, **kwargs)
184        self._join = join
185
186    def _check(self):
187        super()._check()
188        _append_prepend_check(self)
189
190    def accept(self, visitor):
191        visitor.visit_prepend(self)
192
193
194class Append(_VariableAction):
195    """Append a value to a PATH-like variable. (Uncommon, see Prepend.)"""
196
197    def __init__(self, name, value, join, *args, **kwargs):
198        super().__init__(name, value, *args, **kwargs)
199        self._join = join
200
201    def _check(self):
202        super()._check()
203        _append_prepend_check(self)
204
205    def accept(self, visitor):
206        visitor.visit_append(self)
207
208
209class BadEchoValue(ValueError):
210    pass
211
212
213class Echo(_Action):
214    """Echo a value to the terminal."""
215
216    def __init__(self, value, newline, *args, **kwargs):
217        # These values act funny on Windows.
218        if value.lower() in ('off', 'on'):
219            raise BadEchoValue(value)
220        super().__init__(*args, **kwargs)
221        self.value = value
222        self.newline = newline
223
224    def accept(self, visitor):
225        visitor.visit_echo(self)
226
227    def __repr__(self):
228        return 'Echo({}, newline={})'.format(self.value, self.newline)
229
230
231class Comment(_Action):
232    """Add a comment to the init script."""
233
234    def __init__(self, value, *args, **kwargs):
235        super().__init__(*args, **kwargs)
236        self.value = value
237
238    def accept(self, visitor):
239        visitor.visit_comment(self)
240
241    def __repr__(self):
242        return 'Comment({})'.format(self.value)
243
244
245class Command(_Action):
246    """Run a command."""
247
248    def __init__(self, command, *args, **kwargs):
249        exit_on_error = kwargs.pop('exit_on_error', True)
250        super().__init__(*args, **kwargs)
251        assert isinstance(command, (list, tuple))
252        self.command = command
253        self.exit_on_error = exit_on_error
254
255    def accept(self, visitor):
256        visitor.visit_command(self)
257
258    def __repr__(self):
259        return 'Command({})'.format(self.command)
260
261
262class Doctor(Command):
263    def __init__(self, *args, **kwargs):
264        log_level = 'warn' if 'PW_ENVSETUP_QUIET' in os.environ else 'info'
265        cmd = [
266            'pw',
267            '--no-banner',
268            '--loglevel',
269            log_level,
270            'doctor',
271        ]
272        super().__init__(command=cmd, *args, **kwargs)
273
274    def accept(self, visitor):
275        visitor.visit_doctor(self)
276
277    def __repr__(self):
278        return 'Doctor()'
279
280
281class BlankLine(_Action):
282    """Write a blank line to the init script."""
283
284    def accept(self, visitor):
285        visitor.visit_blank_line(self)
286
287    def __repr__(self):
288        return 'BlankLine()'
289
290
291class Function(_Action):
292    def __init__(self, name, body, *args, **kwargs):
293        super().__init__(*args, **kwargs)
294        self.name = name
295        self.body = body
296
297    def accept(self, visitor):
298        visitor.visit_function(self)
299
300    def __repr__(self):
301        return 'Function({}, {})'.format(self.name, self.body)
302
303
304class Hash(_Action):
305    def accept(self, visitor):
306        visitor.visit_hash(self)
307
308    def __repr__(self):
309        return 'Hash()'
310
311
312class Join:
313    def __init__(self, pathsep=os.pathsep):
314        self.pathsep = pathsep
315
316
317class Environment:
318    """Stores the environment changes necessary for Pigweed.
319
320    These changes can be accessed by writing them to a file for bash-like
321    shells to source or by using this as a context manager.
322    """
323
324    def __init__(self, *args, **kwargs):
325        pathsep = kwargs.pop('pathsep', os.pathsep)
326        windows = kwargs.pop('windows', os.name == 'nt')
327        allcaps = kwargs.pop('allcaps', windows)
328        super().__init__(*args, **kwargs)
329        self._actions = []
330        self._pathsep = pathsep
331        self._windows = windows
332        self._allcaps = allcaps
333        self.replacements = []
334        self._join = Join(pathsep)
335        self._finalized = False
336        self._shell_file = ''
337
338    def add_replacement(self, variable, value=None):
339        self.replacements.append((variable, value))
340
341    def normalize_key(self, name):
342        if self._allcaps:
343            try:
344                return name.upper()
345            except AttributeError:
346                # The _Action class has code to handle incorrect types, so
347                # we just ignore this error here.
348                pass
349        return name
350
351    # A newline is printed after each high-level operation. Top-level
352    # operations should not invoke each other (this is why _remove() exists).
353
354    def set(self, name, value, deactivate=True):
355        """Set a variable."""
356        assert not self._finalized
357        name = self.normalize_key(name)
358        self._actions.append(Set(name, value, deactivate=deactivate))
359        self._blankline()
360
361    def clear(self, name):
362        """Remove a variable."""
363        assert not self._finalized
364        name = self.normalize_key(name)
365        self._actions.append(Clear(name))
366        self._blankline()
367
368    def _remove(self, name, value):
369        """Remove a value from a variable."""
370        assert not self._finalized
371        name = self.normalize_key(name)
372        if self.get(name, None):
373            self._actions.append(Remove(name, value))
374
375    def remove(self, name, value):
376        """Remove a value from a PATH-like variable."""
377        assert not self._finalized
378        self._remove(name, value)
379        self._blankline()
380
381    def append(self, name, value):
382        """Add a value to a PATH-like variable. Rarely used, see prepend()."""
383        assert not self._finalized
384        name = self.normalize_key(name)
385        if self.get(name, None):
386            self._remove(name, value)
387            self._actions.append(Append(name, value, self._join))
388        else:
389            self._actions.append(Set(name, value))
390        self._blankline()
391
392    def prepend(self, name, value):
393        """Add a value to the beginning of a PATH-like variable."""
394        assert not self._finalized
395        name = self.normalize_key(name)
396        if self.get(name, None):
397            self._remove(name, value)
398            self._actions.append(Prepend(name, value, self._join))
399        else:
400            self._actions.append(Set(name, value))
401        self._blankline()
402
403    def echo(self, value='', newline=True):
404        """Echo a value to the terminal."""
405        # echo() deliberately ignores self._finalized.
406        self._actions.append(Echo(value, newline))
407        if value:
408            self._blankline()
409
410    def comment(self, comment):
411        """Add a comment to the init script."""
412        # comment() deliberately ignores self._finalized.
413        self._actions.append(Comment(comment))
414        self._blankline()
415
416    def command(self, command, exit_on_error=True):
417        """Run a command."""
418        # command() deliberately ignores self._finalized.
419        self._actions.append(Command(command, exit_on_error=exit_on_error))
420        self._blankline()
421
422    def doctor(self):
423        """Run 'pw doctor'."""
424        self._actions.append(Doctor())
425
426    def function(self, name, body):
427        """Define a function."""
428        assert not self._finalized
429        self._actions.append(Command(name, body))
430        self._blankline()
431
432    def _blankline(self):
433        self._actions.append(BlankLine())
434
435    def finalize(self):
436        """Run cleanup at the end of environment setup."""
437        assert not self._finalized
438        self._finalized = True
439        self._actions.append(Hash())
440        self._blankline()
441
442        if not self._windows:
443            buf = StringIO()
444            self.write_deactivate(buf, shell_file=self._shell_file)
445            self._actions.append(Function('_pw_deactivate', buf.getvalue()))
446            self._blankline()
447
448    def accept(self, visitor):
449        for action in self._actions:
450            action.accept(visitor)
451
452    def github(self, root):
453        github_visitor.GitHubVisitor().serialize(self, root)
454
455    def gni(self, outs, project_root, gni_file):
456        gni_visitor.GNIVisitor(project_root, gni_file).serialize(self, outs)
457
458    def json(self, outs):
459        json_visitor.JSONVisitor().serialize(self, outs)
460
461    def write(self, outs, shell_file):
462        if self._windows:
463            visitor = batch_visitor.BatchVisitor(pathsep=self._pathsep)
464        else:
465            if shell_file.endswith('.fish'):
466                visitor = shell_visitor.FishShellVisitor()
467            else:
468                visitor = shell_visitor.ShellVisitor(pathsep=self._pathsep)
469        visitor.serialize(self, outs)
470
471    def write_deactivate(self, outs, shell_file):
472        if self._windows:
473            return
474        if shell_file.endswith('.fish'):
475            visitor = shell_visitor.DeactivateFishShellVisitor(
476                pathsep=self._pathsep
477            )
478        else:
479            visitor = shell_visitor.DeactivateShellVisitor(
480                pathsep=self._pathsep
481            )
482        visitor.serialize(self, outs)
483
484    @contextlib.contextmanager
485    def __call__(self, export=True):
486        """Set environment as if this was written to a file and sourced.
487
488        Within this context os.environ is updated with the environment
489        defined by this object. If export is False, os.environ is not updated,
490        but in both cases the updated environment is yielded.
491
492        On exit, previous environment is restored. See contextlib documentation
493        for details on how this function is structured.
494
495        Args:
496          export(bool): modify the environment of the running process (and
497            thus, its subprocesses)
498
499        Yields the new environment object.
500        """
501        orig_env = {}
502        try:
503            if export:
504                orig_env = os.environ.copy()
505                env = os.environ
506            else:
507                env = os.environ.copy()
508
509            apply = apply_visitor.ApplyVisitor(pathsep=self._pathsep)
510            apply.apply(self, env)
511
512            yield env
513
514        finally:
515            if export:
516                for key in set(os.environ):
517                    try:
518                        os.environ[key] = orig_env[key]
519                    except KeyError:
520                        del os.environ[key]
521                for key in set(orig_env) - set(os.environ):
522                    os.environ[key] = orig_env[key]
523
524    def get(self, key, default=None):
525        """Get the value of a variable within context of this object."""
526        key = self.normalize_key(key)
527        with self(export=False) as env:
528            return env.get(key, default)
529
530    def __getitem__(self, key):
531        """Get the value of a variable within context of this object."""
532        key = self.normalize_key(key)
533        with self(export=False) as env:
534            return env[key]
535