xref: /aosp_15_r20/external/pigweed/pw_cli/py/pw_cli/allowed_caller.py (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1# Copyright 2024 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Functons for checking Python call sites match expected callers."""
15
16from dataclasses import dataclass
17from pathlib import Path
18import inspect
19from typing import Iterable
20
21
22@dataclass
23class AllowedCaller:
24    """Container class for storing Python call sites."""
25
26    filename: str
27    function: str
28    name: str
29    self_class: str | None = None
30
31    @staticmethod
32    def from_frame_info(frame_info: inspect.FrameInfo) -> 'AllowedCaller':
33        """Returns an AllowedCaller based on an inspect.FrameInfo object."""
34        self_obj = frame_info.frame.f_locals.get('self', None)
35        global_name_str = frame_info.frame.f_globals.get('__name__', None)
36        module_class = None
37        if self_obj:
38            module_class = self_obj.__class__.__name__
39        return AllowedCaller(
40            filename=frame_info.filename,
41            function=frame_info.function,
42            name=global_name_str,
43            self_class=module_class,
44        )
45
46    def matches(self, other: 'AllowedCaller') -> bool:
47        """Returns true if this AllowedCaller matches another one."""
48        file_matches = Path(other.filename).match(f'**/{self.filename}')
49        name_matches = self.name == other.name
50        if self.name == '*':
51            name_matches = True
52        function_matches = self.function == other.function
53        final_match = file_matches and name_matches and function_matches
54
55        # If self_class is set, check those values too.
56        if self.self_class and other.self_class:
57            self_class_matches = self.self_class == other.self_class
58            final_match = final_match and self_class_matches
59
60        return final_match
61
62    def __repr__(self) -> str:
63        return f'''AllowedCaller(
64  filename='{self.filename}',
65  name='{self.name}',
66  function='{self.function}',
67  self_class='{self.self_class}',
68)'''
69
70
71def check_caller_in(allow_list: Iterable[AllowedCaller]) -> bool:
72    """Return true if the called function is in the allowed call list.
73
74    Raises a RuntimeError if the call location is not in the allow_list.
75    """
76    # Get the current Python call stack.
77    call_stack = [
78        AllowedCaller.from_frame_info(frame_info)
79        for frame_info in inspect.stack()
80    ]
81
82    called_function = call_stack[1]
83    call_location = call_stack[2]
84
85    # Check to see if the caller of this function is allowed.
86    caller_is_allowed = any(
87        allowed_call.matches(call_location) for allowed_call in allow_list
88    )
89
90    if not caller_is_allowed:
91        raise RuntimeError(
92            '\n\nThis call location is not in the allow list for this '
93            'called function.\n\n'
94            f'Called function:\n\n{called_function}\n\n'
95            f'Call location:\n\n{call_location}\n'
96        )
97
98    return caller_is_allowed
99