1# mypy: allow-untyped-defs 2import dataclasses 3from dataclasses import field 4from types import CodeType, ModuleType 5from typing import Any, Dict 6 7from torch.utils._import_utils import import_dill 8 9 10dill = import_dill() 11 12 13@dataclasses.dataclass 14class ModuleRecord: 15 module: ModuleType 16 accessed_attrs: Dict[str, Any] = field(default_factory=dict) 17 18 19@dataclasses.dataclass 20class DummyModule: 21 name: str 22 is_torch: bool = False 23 24 @property 25 def __name__(self): 26 return self.name 27 28 29@dataclasses.dataclass 30class ExecutionRecord: 31 code: CodeType 32 globals: Dict[str, Any] = field(default_factory=dict) 33 locals: Dict[str, Any] = field(default_factory=dict) 34 builtins: Dict[str, Any] = field(default_factory=dict) 35 code_options: Dict[str, Any] = field(default_factory=dict) 36 37 def dump(self, f): 38 assert dill is not None, "replay_record requires `pip install dill`" 39 dill.dump(self, f) 40 41 @classmethod 42 def load(cls, f): 43 assert dill is not None, "replay_record requires `pip install dill`" 44 return dill.load(f) 45 46 47@dataclasses.dataclass 48class ExecutionRecorder: 49 LOCAL_MOD_PREFIX = "___local_mod_" 50 51 code: CodeType 52 globals: Dict[str, Any] = field(default_factory=dict) 53 locals: Dict[str, Any] = field(default_factory=dict) 54 builtins: Dict[str, Any] = field(default_factory=dict) 55 code_options: Dict[str, Any] = field(default_factory=dict) 56 name_to_modrec: Dict[str, Any] = field(default_factory=dict) 57 58 def add_local_var(self, name, var): 59 if isinstance(var, ModuleType): 60 self.locals[name] = self._add_mod(var) 61 else: 62 self.locals[name] = var 63 64 def add_global_var(self, name, var): 65 if isinstance(var, ModuleType): 66 self.globals[name] = self._add_mod(var) 67 else: 68 self.globals[name] = var 69 70 def add_local_mod(self, name, mod): 71 assert isinstance(mod, ModuleType) 72 73 self.add_global_var(name, mod) 74 75 def record_module_access(self, mod, name, val): 76 if isinstance(val, ModuleType): 77 self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val) 78 return 79 80 if mod.__name__ in self.name_to_modrec: 81 self.name_to_modrec[mod.__name__].accessed_attrs[name] = val 82 83 def get_record(self): 84 return ExecutionRecord( 85 self.code, 86 ExecutionRecorder._resolve_modules(self.globals), 87 ExecutionRecorder._resolve_modules(self.locals), 88 self.builtins.copy(), 89 self.code_options.copy(), 90 ) 91 92 def _add_mod(self, mod): 93 if mod.__name__ not in self.name_to_modrec: 94 self.name_to_modrec[mod.__name__] = ModuleRecord(mod) 95 96 return self.name_to_modrec[mod.__name__] 97 98 # Convert ModuleRecords -> DummyModule tree 99 @classmethod 100 def _resolve_modules(cls, vars): 101 def resolve_module(var): 102 if not isinstance(var, ModuleRecord): 103 return var 104 105 dummy_mod = DummyModule(var.module.__name__) 106 for attr_name, attr_value in var.accessed_attrs.items(): 107 attr_value = resolve_module(attr_value) 108 dummy_mod.__setattr__(attr_name, attr_value) 109 110 return dummy_mod 111 112 return {k: resolve_module(v) for k, v in vars.items()} 113