xref: /aosp_15_r20/external/pytorch/torch/_dynamo/replay_record.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3from dataclasses import field
4from types import CodeType, ModuleType
5from typing import Any, Dict
6
7from torch.utils._import_utils import import_dill
8
9
10dill = import_dill()
11
12
13@dataclasses.dataclass
14class ModuleRecord:
15    module: ModuleType
16    accessed_attrs: Dict[str, Any] = field(default_factory=dict)
17
18
19@dataclasses.dataclass
20class DummyModule:
21    name: str
22    is_torch: bool = False
23
24    @property
25    def __name__(self):
26        return self.name
27
28
29@dataclasses.dataclass
30class ExecutionRecord:
31    code: CodeType
32    globals: Dict[str, Any] = field(default_factory=dict)
33    locals: Dict[str, Any] = field(default_factory=dict)
34    builtins: Dict[str, Any] = field(default_factory=dict)
35    code_options: Dict[str, Any] = field(default_factory=dict)
36
37    def dump(self, f):
38        assert dill is not None, "replay_record requires `pip install dill`"
39        dill.dump(self, f)
40
41    @classmethod
42    def load(cls, f):
43        assert dill is not None, "replay_record requires `pip install dill`"
44        return dill.load(f)
45
46
47@dataclasses.dataclass
48class ExecutionRecorder:
49    LOCAL_MOD_PREFIX = "___local_mod_"
50
51    code: CodeType
52    globals: Dict[str, Any] = field(default_factory=dict)
53    locals: Dict[str, Any] = field(default_factory=dict)
54    builtins: Dict[str, Any] = field(default_factory=dict)
55    code_options: Dict[str, Any] = field(default_factory=dict)
56    name_to_modrec: Dict[str, Any] = field(default_factory=dict)
57
58    def add_local_var(self, name, var):
59        if isinstance(var, ModuleType):
60            self.locals[name] = self._add_mod(var)
61        else:
62            self.locals[name] = var
63
64    def add_global_var(self, name, var):
65        if isinstance(var, ModuleType):
66            self.globals[name] = self._add_mod(var)
67        else:
68            self.globals[name] = var
69
70    def add_local_mod(self, name, mod):
71        assert isinstance(mod, ModuleType)
72
73        self.add_global_var(name, mod)
74
75    def record_module_access(self, mod, name, val):
76        if isinstance(val, ModuleType):
77            self.name_to_modrec[mod.__name__].accessed_attrs[name] = self._add_mod(val)
78            return
79
80        if mod.__name__ in self.name_to_modrec:
81            self.name_to_modrec[mod.__name__].accessed_attrs[name] = val
82
83    def get_record(self):
84        return ExecutionRecord(
85            self.code,
86            ExecutionRecorder._resolve_modules(self.globals),
87            ExecutionRecorder._resolve_modules(self.locals),
88            self.builtins.copy(),
89            self.code_options.copy(),
90        )
91
92    def _add_mod(self, mod):
93        if mod.__name__ not in self.name_to_modrec:
94            self.name_to_modrec[mod.__name__] = ModuleRecord(mod)
95
96        return self.name_to_modrec[mod.__name__]
97
98    # Convert ModuleRecords -> DummyModule tree
99    @classmethod
100    def _resolve_modules(cls, vars):
101        def resolve_module(var):
102            if not isinstance(var, ModuleRecord):
103                return var
104
105            dummy_mod = DummyModule(var.module.__name__)
106            for attr_name, attr_value in var.accessed_attrs.items():
107                attr_value = resolve_module(attr_value)
108                dummy_mod.__setattr__(attr_name, attr_value)
109
110            return dummy_mod
111
112        return {k: resolve_module(v) for k, v in vars.items()}
113