1# Copyright 2024 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Functons for checking Python call sites match expected callers.""" 15 16from dataclasses import dataclass 17from pathlib import Path 18import inspect 19from typing import Iterable 20 21 22@dataclass 23class AllowedCaller: 24 """Container class for storing Python call sites.""" 25 26 filename: str 27 function: str 28 name: str 29 self_class: str | None = None 30 31 @staticmethod 32 def from_frame_info(frame_info: inspect.FrameInfo) -> 'AllowedCaller': 33 """Returns an AllowedCaller based on an inspect.FrameInfo object.""" 34 self_obj = frame_info.frame.f_locals.get('self', None) 35 global_name_str = frame_info.frame.f_globals.get('__name__', None) 36 module_class = None 37 if self_obj: 38 module_class = self_obj.__class__.__name__ 39 return AllowedCaller( 40 filename=frame_info.filename, 41 function=frame_info.function, 42 name=global_name_str, 43 self_class=module_class, 44 ) 45 46 def matches(self, other: 'AllowedCaller') -> bool: 47 """Returns true if this AllowedCaller matches another one.""" 48 file_matches = Path(other.filename).match(f'**/{self.filename}') 49 name_matches = self.name == other.name 50 if self.name == '*': 51 name_matches = True 52 function_matches = self.function == other.function 53 final_match = file_matches and name_matches and function_matches 54 55 # If self_class is set, check those values too. 56 if self.self_class and other.self_class: 57 self_class_matches = self.self_class == other.self_class 58 final_match = final_match and self_class_matches 59 60 return final_match 61 62 def __repr__(self) -> str: 63 return f'''AllowedCaller( 64 filename='{self.filename}', 65 name='{self.name}', 66 function='{self.function}', 67 self_class='{self.self_class}', 68)''' 69 70 71def check_caller_in(allow_list: Iterable[AllowedCaller]) -> bool: 72 """Return true if the called function is in the allowed call list. 73 74 Raises a RuntimeError if the call location is not in the allow_list. 75 """ 76 # Get the current Python call stack. 77 call_stack = [ 78 AllowedCaller.from_frame_info(frame_info) 79 for frame_info in inspect.stack() 80 ] 81 82 called_function = call_stack[1] 83 call_location = call_stack[2] 84 85 # Check to see if the caller of this function is allowed. 86 caller_is_allowed = any( 87 allowed_call.matches(call_location) for allowed_call in allow_list 88 ) 89 90 if not caller_is_allowed: 91 raise RuntimeError( 92 '\n\nThis call location is not in the allow list for this ' 93 'called function.\n\n' 94 f'Called function:\n\n{called_function}\n\n' 95 f'Call location:\n\n{call_location}\n' 96 ) 97 98 return caller_is_allowed 99