xref: /aosp_15_r20/external/pytorch/torch/_inductor/aoti_eager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import json
2import logging
3import os
4from pathlib import Path
5from typing import Any, Callable, Dict, List, Optional, Tuple
6from unittest import mock
7
8import torch
9import torch._export
10from torch._inductor.utils import is_cpu_device
11
12from .runtime.runtime_utils import cache_dir
13
14
15log = logging.getLogger(__name__)
16
17
18def aoti_eager_cache_dir(namespace: str, device: str) -> Path:
19    return Path(cache_dir()) / "aoti_eager" / namespace / device
20
21
22def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
23    from filelock import FileLock
24
25    # Avoid circular import
26    from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT
27
28    op_conf_lock_file = f"{op_func_name_with_overload}.lock"
29    lock_dir = get_lock_dir()
30    return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT)
31
32
33def load_aoti_eager_cache(
34    ns: str, op_func_name_with_overload: str, device_type: str
35) -> List[Optional[Dict[str, Any]]]:
36    device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
37    op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
38    if not op_conf.exists():
39        return []
40
41    try:
42        with aoti_eager_op_conf_lock(op_func_name_with_overload):
43            with open(op_conf) as f:
44                json_data = json.load(f)
45                for item in json_data:
46                    # Get absolution path for kernel library
47                    kernel_lib_abs_path = device_kernel_cache / item["kernel_path"]
48                    item["kernel_path"] = kernel_lib_abs_path.as_posix()
49
50                    # Check if the kernel library exists
51                    if not kernel_lib_abs_path.exists():
52                        return []
53
54                    for metadata in item["meta_info"]:
55                        if metadata.get("is_dynamic"):
56                            raise NotImplementedError(
57                                "Only support static shape for now"
58                            )
59                        if (
60                            "device_type" in metadata
61                            and metadata["device_type"] == "cpu"
62                        ):
63                            metadata["device_index"] = -1
64                        for dtype_key in ["dtype", "dtype_value"]:
65                            if dtype_key in metadata:
66                                metadata[dtype_key] = getattr(
67                                    torch, metadata[dtype_key].split(".")[-1]
68                                )
69                        if "layout_value" in metadata:
70                            metadata["layout_value"] = getattr(
71                                torch, metadata["layout_value"].split(".")[-1]
72                            )
73                        if "memory_format_value" in metadata:
74                            metadata["memory_format_value"] = getattr(
75                                torch, metadata["memory_format_value"].split(".")[-1]
76                            )
77
78                return json_data
79    except Exception as e:
80        err_msg = f"Failed to load aoti eager cache: {e}"
81        log.exception(err_msg)
82        return []
83
84
85def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
86    return {int: torch.int32, float: torch.float, bool: torch.bool}
87
88
89def supported_scalar_types() -> Tuple[type, ...]:
90    type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
91    return tuple(type_to_torch_dtype.keys())
92
93
94def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
95    metadata: Dict[str, Any] = {}
96    metadata["is_dynamic"] = dynamic
97
98    assert isinstance(input, torch.Tensor)
99    metadata["device_type"] = f"{input.device.type}"
100    if is_cpu_device([input]):
101        metadata["device_index"] = -1
102    else:
103        metadata["device_index"] = input.device.index
104    metadata["dtype"] = f"{input.dtype}"
105    metadata["sizes"] = list(input.size())
106    metadata["strides"] = list(input.stride())
107    metadata["requires_grad"] = input.requires_grad
108    metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr()
109    return metadata
110
111
112def extract_tensor_list_metadata(
113    dynamic: bool,
114    input: List[torch.Tensor],
115) -> Dict[str, Any]:
116    metadata_list = []
117    for item in input:
118        assert isinstance(item, torch.Tensor)
119        metadata_list.append(extract_tensor_metadata(dynamic, item))
120
121    metadata: Dict[str, Any] = {}
122    metadata["tensor_list"] = metadata_list
123    return metadata
124
125
126def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
127    assert isinstance(input, supported_scalar_types())
128    metadata: Dict[str, Any] = {}
129    metadata["is_dynamic"] = False
130    # Scalar tensor
131    metadata["device_type"] = device_type
132    metadata["device_index"] = -1 if device_type == "cpu" else 0
133    type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
134    metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}"
135    metadata["scalar_value"] = input
136    return metadata
137
138
139def extract_string_metadata(input: str) -> Dict[str, Any]:
140    assert isinstance(input, str)
141    metadata: Dict[str, Any] = {}
142    metadata["string_value"] = input
143    return metadata
144
145
146def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
147    assert isinstance(input, torch.dtype)
148    metadata: Dict[str, Any] = {}
149    metadata["dtype_value"] = f"{input}"
150    return metadata
151
152
153def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
154    assert isinstance(input, torch.device)
155    metadata: Dict[str, Any] = {}
156    metadata["device_type_value"] = f"{input.type}"
157    metadata["device_index_value"] = input.index
158    return metadata
159
160
161def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
162    assert isinstance(input, torch.layout)
163    metadata: Dict[str, Any] = {}
164    metadata["layout_value"] = f"{input}"
165    return metadata
166
167
168def aoti_compile_with_persistent_cache(
169    ns: str,
170    op_func_name_with_overload: str,
171    device_type: str,
172    dynamic: bool,
173    f: Callable[..., Any],
174    args: Tuple[Any],
175    kwargs: Dict[str, Any],
176    *,
177    dynamic_shapes: Optional[Dict[str, Any]] = None,
178    options: Optional[Dict[str, Any]] = None,
179    remove_runtime_assertions: bool = False,
180    disable_constraint_solver: bool = False,
181) -> str:
182    """
183    Compile the given function with persistent cache for AOTI eager mode.
184    """
185    assert not dynamic, "Only support static shape for now"
186    flattened_inputs = list(args) + list(kwargs.values())
187    if not all(
188        isinstance(
189            input,
190            (
191                supported_scalar_types(),
192                torch.Tensor,
193                list,
194                str,
195                torch.dtype,
196                torch.device,
197                torch.layout,
198            ),
199        )
200        for input in flattened_inputs
201    ):
202        err_msg = f"Unsupported input types: {flattened_inputs}"
203        log.exception(err_msg)
204        raise NotImplementedError(err_msg)
205
206    for input in flattened_inputs:
207        if isinstance(input, list) and not all(
208            isinstance(item, torch.Tensor) for item in input
209        ):
210            err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}"
211            log.exception(err_msg)
212            raise NotImplementedError(err_msg)
213
214    persistent_cache = aoti_eager_cache_dir(ns, device_type)
215    if not persistent_cache.exists():
216        persistent_cache.mkdir(parents=True)
217
218    persistent_cache_lib = persistent_cache / "lib"
219    if not persistent_cache_lib.exists():
220        persistent_cache_lib.mkdir()
221
222    with mock.patch.dict(
223        os.environ,
224        {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()},
225    ):
226        try:
227            kernel_lib_path = torch._export.aot_compile(
228                f,
229                args,
230                kwargs,
231                dynamic_shapes=dynamic_shapes,
232                remove_runtime_assertions=remove_runtime_assertions,
233                disable_constraint_solver=disable_constraint_solver,
234                # Some operations may have non-Tensor parameters like int, float, bool. These
235                # non-Tensor parameters will not be the input of the graph. Therefore, we do
236                # need to keep the same signature.
237                same_signature=False,
238            )
239
240            kernel_metadata_items = []
241
242            for idx, input in enumerate(flattened_inputs):
243                if isinstance(input, torch.Tensor):
244                    metadata = extract_tensor_metadata(dynamic, input)
245                elif isinstance(input, list):
246                    assert all(isinstance(item, torch.Tensor) for item in input)
247                    metadata = extract_tensor_list_metadata(dynamic, input)
248                elif isinstance(input, supported_scalar_types()):
249                    metadata = extract_scalar_metadata(device_type, input)
250                elif isinstance(input, str):
251                    metadata = extract_string_metadata(input)
252                elif isinstance(input, torch.dtype):
253                    metadata = extract_dtype_metadata(input)
254                elif isinstance(input, torch.device):
255                    metadata = extract_device_metadata(input)
256                elif isinstance(input, torch.layout):
257                    metadata = extract_layout_metadata(input)
258                else:
259                    raise NotImplementedError(f"Unsupported input type: {type(input)}")
260
261                metadata["arg_order"] = idx
262                kernel_metadata_items.append(metadata)
263
264            kernel_meta_info: Dict[str, Any] = {}
265            kernel_meta_info["meta_info"] = kernel_metadata_items
266            kernel_meta_info["kernel_path"] = (
267                Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
268            )
269
270            json_data = []
271            update_json = True
272            op_conf = persistent_cache / f"{op_func_name_with_overload}.json"
273            mode = "r" if op_conf.exists() else "w"
274            with aoti_eager_op_conf_lock(op_func_name_with_overload):
275                with open(op_conf, mode) as op_conf_file:
276                    try:
277                        json_data = json.load(op_conf_file)
278                    except Exception as e:
279                        json_data = []
280
281                    assert isinstance(json_data, list)
282                    for item in json_data:
283                        assert isinstance(item, dict)
284                        # Same kernel meta info already exists in the json file
285                        if item["meta_info"] == kernel_metadata_items:
286                            update_json = False
287                            break
288
289                if update_json:
290                    json_data.append(kernel_meta_info)
291                    with open(op_conf, "w") as op_conf_file:
292                        json.dump(json_data, op_conf_file, indent=4)
293
294            return kernel_lib_path
295        except Exception as e:
296            err_msg = f"Failed to compile {op_func_name_with_overload}: {e}"
297            log.exception(err_msg)
298            return ""
299