1import json 2import logging 3import os 4from pathlib import Path 5from typing import Any, Callable, Dict, List, Optional, Tuple 6from unittest import mock 7 8import torch 9import torch._export 10from torch._inductor.utils import is_cpu_device 11 12from .runtime.runtime_utils import cache_dir 13 14 15log = logging.getLogger(__name__) 16 17 18def aoti_eager_cache_dir(namespace: str, device: str) -> Path: 19 return Path(cache_dir()) / "aoti_eager" / namespace / device 20 21 22def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: 23 from filelock import FileLock 24 25 # Avoid circular import 26 from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT 27 28 op_conf_lock_file = f"{op_func_name_with_overload}.lock" 29 lock_dir = get_lock_dir() 30 return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) 31 32 33def load_aoti_eager_cache( 34 ns: str, op_func_name_with_overload: str, device_type: str 35) -> List[Optional[Dict[str, Any]]]: 36 device_kernel_cache = aoti_eager_cache_dir(ns, device_type) 37 op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" 38 if not op_conf.exists(): 39 return [] 40 41 try: 42 with aoti_eager_op_conf_lock(op_func_name_with_overload): 43 with open(op_conf) as f: 44 json_data = json.load(f) 45 for item in json_data: 46 # Get absolution path for kernel library 47 kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] 48 item["kernel_path"] = kernel_lib_abs_path.as_posix() 49 50 # Check if the kernel library exists 51 if not kernel_lib_abs_path.exists(): 52 return [] 53 54 for metadata in item["meta_info"]: 55 if metadata.get("is_dynamic"): 56 raise NotImplementedError( 57 "Only support static shape for now" 58 ) 59 if ( 60 "device_type" in metadata 61 and metadata["device_type"] == "cpu" 62 ): 63 metadata["device_index"] = -1 64 for dtype_key in ["dtype", "dtype_value"]: 65 if dtype_key in metadata: 66 metadata[dtype_key] = getattr( 67 torch, metadata[dtype_key].split(".")[-1] 68 ) 69 if "layout_value" in metadata: 70 metadata["layout_value"] = getattr( 71 torch, metadata["layout_value"].split(".")[-1] 72 ) 73 if "memory_format_value" in metadata: 74 metadata["memory_format_value"] = getattr( 75 torch, metadata["memory_format_value"].split(".")[-1] 76 ) 77 78 return json_data 79 except Exception as e: 80 err_msg = f"Failed to load aoti eager cache: {e}" 81 log.exception(err_msg) 82 return [] 83 84 85def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: 86 return {int: torch.int32, float: torch.float, bool: torch.bool} 87 88 89def supported_scalar_types() -> Tuple[type, ...]: 90 type_to_torch_dtype = supported_builtin_dtype_torch_dtype() 91 return tuple(type_to_torch_dtype.keys()) 92 93 94def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: 95 metadata: Dict[str, Any] = {} 96 metadata["is_dynamic"] = dynamic 97 98 assert isinstance(input, torch.Tensor) 99 metadata["device_type"] = f"{input.device.type}" 100 if is_cpu_device([input]): 101 metadata["device_index"] = -1 102 else: 103 metadata["device_index"] = input.device.index 104 metadata["dtype"] = f"{input.dtype}" 105 metadata["sizes"] = list(input.size()) 106 metadata["strides"] = list(input.stride()) 107 metadata["requires_grad"] = input.requires_grad 108 metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() 109 return metadata 110 111 112def extract_tensor_list_metadata( 113 dynamic: bool, 114 input: List[torch.Tensor], 115) -> Dict[str, Any]: 116 metadata_list = [] 117 for item in input: 118 assert isinstance(item, torch.Tensor) 119 metadata_list.append(extract_tensor_metadata(dynamic, item)) 120 121 metadata: Dict[str, Any] = {} 122 metadata["tensor_list"] = metadata_list 123 return metadata 124 125 126def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]: 127 assert isinstance(input, supported_scalar_types()) 128 metadata: Dict[str, Any] = {} 129 metadata["is_dynamic"] = False 130 # Scalar tensor 131 metadata["device_type"] = device_type 132 metadata["device_index"] = -1 if device_type == "cpu" else 0 133 type_to_torch_dtype = supported_builtin_dtype_torch_dtype() 134 metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" 135 metadata["scalar_value"] = input 136 return metadata 137 138 139def extract_string_metadata(input: str) -> Dict[str, Any]: 140 assert isinstance(input, str) 141 metadata: Dict[str, Any] = {} 142 metadata["string_value"] = input 143 return metadata 144 145 146def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]: 147 assert isinstance(input, torch.dtype) 148 metadata: Dict[str, Any] = {} 149 metadata["dtype_value"] = f"{input}" 150 return metadata 151 152 153def extract_device_metadata(input: torch.device) -> Dict[str, Any]: 154 assert isinstance(input, torch.device) 155 metadata: Dict[str, Any] = {} 156 metadata["device_type_value"] = f"{input.type}" 157 metadata["device_index_value"] = input.index 158 return metadata 159 160 161def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]: 162 assert isinstance(input, torch.layout) 163 metadata: Dict[str, Any] = {} 164 metadata["layout_value"] = f"{input}" 165 return metadata 166 167 168def aoti_compile_with_persistent_cache( 169 ns: str, 170 op_func_name_with_overload: str, 171 device_type: str, 172 dynamic: bool, 173 f: Callable[..., Any], 174 args: Tuple[Any], 175 kwargs: Dict[str, Any], 176 *, 177 dynamic_shapes: Optional[Dict[str, Any]] = None, 178 options: Optional[Dict[str, Any]] = None, 179 remove_runtime_assertions: bool = False, 180 disable_constraint_solver: bool = False, 181) -> str: 182 """ 183 Compile the given function with persistent cache for AOTI eager mode. 184 """ 185 assert not dynamic, "Only support static shape for now" 186 flattened_inputs = list(args) + list(kwargs.values()) 187 if not all( 188 isinstance( 189 input, 190 ( 191 supported_scalar_types(), 192 torch.Tensor, 193 list, 194 str, 195 torch.dtype, 196 torch.device, 197 torch.layout, 198 ), 199 ) 200 for input in flattened_inputs 201 ): 202 err_msg = f"Unsupported input types: {flattened_inputs}" 203 log.exception(err_msg) 204 raise NotImplementedError(err_msg) 205 206 for input in flattened_inputs: 207 if isinstance(input, list) and not all( 208 isinstance(item, torch.Tensor) for item in input 209 ): 210 err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}" 211 log.exception(err_msg) 212 raise NotImplementedError(err_msg) 213 214 persistent_cache = aoti_eager_cache_dir(ns, device_type) 215 if not persistent_cache.exists(): 216 persistent_cache.mkdir(parents=True) 217 218 persistent_cache_lib = persistent_cache / "lib" 219 if not persistent_cache_lib.exists(): 220 persistent_cache_lib.mkdir() 221 222 with mock.patch.dict( 223 os.environ, 224 {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, 225 ): 226 try: 227 kernel_lib_path = torch._export.aot_compile( 228 f, 229 args, 230 kwargs, 231 dynamic_shapes=dynamic_shapes, 232 remove_runtime_assertions=remove_runtime_assertions, 233 disable_constraint_solver=disable_constraint_solver, 234 # Some operations may have non-Tensor parameters like int, float, bool. These 235 # non-Tensor parameters will not be the input of the graph. Therefore, we do 236 # need to keep the same signature. 237 same_signature=False, 238 ) 239 240 kernel_metadata_items = [] 241 242 for idx, input in enumerate(flattened_inputs): 243 if isinstance(input, torch.Tensor): 244 metadata = extract_tensor_metadata(dynamic, input) 245 elif isinstance(input, list): 246 assert all(isinstance(item, torch.Tensor) for item in input) 247 metadata = extract_tensor_list_metadata(dynamic, input) 248 elif isinstance(input, supported_scalar_types()): 249 metadata = extract_scalar_metadata(device_type, input) 250 elif isinstance(input, str): 251 metadata = extract_string_metadata(input) 252 elif isinstance(input, torch.dtype): 253 metadata = extract_dtype_metadata(input) 254 elif isinstance(input, torch.device): 255 metadata = extract_device_metadata(input) 256 elif isinstance(input, torch.layout): 257 metadata = extract_layout_metadata(input) 258 else: 259 raise NotImplementedError(f"Unsupported input type: {type(input)}") 260 261 metadata["arg_order"] = idx 262 kernel_metadata_items.append(metadata) 263 264 kernel_meta_info: Dict[str, Any] = {} 265 kernel_meta_info["meta_info"] = kernel_metadata_items 266 kernel_meta_info["kernel_path"] = ( 267 Path(kernel_lib_path).relative_to(persistent_cache).as_posix() 268 ) 269 270 json_data = [] 271 update_json = True 272 op_conf = persistent_cache / f"{op_func_name_with_overload}.json" 273 mode = "r" if op_conf.exists() else "w" 274 with aoti_eager_op_conf_lock(op_func_name_with_overload): 275 with open(op_conf, mode) as op_conf_file: 276 try: 277 json_data = json.load(op_conf_file) 278 except Exception as e: 279 json_data = [] 280 281 assert isinstance(json_data, list) 282 for item in json_data: 283 assert isinstance(item, dict) 284 # Same kernel meta info already exists in the json file 285 if item["meta_info"] == kernel_metadata_items: 286 update_json = False 287 break 288 289 if update_json: 290 json_data.append(kernel_meta_info) 291 with open(op_conf, "w") as op_conf_file: 292 json.dump(json_data, op_conf_file, indent=4) 293 294 return kernel_lib_path 295 except Exception as e: 296 err_msg = f"Failed to compile {op_func_name_with_overload}: {e}" 297 log.exception(err_msg) 298 return "" 299