import importlib from typing import List, Optional from torch.testing._internal.common_utils import TestCase class AOMigrationTestCase(TestCase): def _test_function_import( self, package_name: str, function_list: List[str], base: Optional[str] = None, new_package_name: Optional[str] = None, ): r"""Tests individual function list import by comparing the functions and their hashes.""" if base is None: base = "quantization" old_base = "torch." + base new_base = "torch.ao." + base if new_package_name is None: new_package_name = package_name old_location = importlib.import_module(f"{old_base}.{package_name}") new_location = importlib.import_module(f"{new_base}.{new_package_name}") for fn_name in function_list: old_function = getattr(old_location, fn_name) new_function = getattr(new_location, fn_name) assert old_function == new_function, f"Functions don't match: {fn_name}" assert hash(old_function) == hash(new_function), ( f"Hashes don't match: {old_function}({hash(old_function)}) vs. " f"{new_function}({hash(new_function)})" ) def _test_dict_import( self, package_name: str, dict_list: List[str], base: Optional[str] = None ): r"""Tests individual function list import by comparing the functions and their hashes.""" if base is None: base = "quantization" old_base = "torch." + base new_base = "torch.ao." + base old_location = importlib.import_module(f"{old_base}.{package_name}") new_location = importlib.import_module(f"{new_base}.{package_name}") for dict_name in dict_list: old_dict = getattr(old_location, dict_name) new_dict = getattr(new_location, dict_name) assert old_dict == new_dict, f"Dicts don't match: {dict_name}" for key in new_dict.keys(): assert ( old_dict[key] == new_dict[key] ), f"Dicts don't match: {dict_name} for key {key}"