1import functools 2import logging 3import math 4import os 5from collections import Counter, defaultdict 6from functools import partial 7from typing import Any, Dict, Generator, Iterable, Tuple 8 9import torch 10from torch.testing import make_tensor 11from torch.utils import _pytree as pytree 12from torch.utils._python_dispatch import TorchDispatchMode 13from torch.utils._pytree import tree_map 14 15 16log = logging.getLogger(__name__) 17 18OP_INP_DIRECTORY = os.path.join(os.path.dirname(__file__), "operator_inp_logs") 19 20TIMM_DIR = os.path.join(OP_INP_DIRECTORY, "timm_train") 21HF_DIR = os.path.join(OP_INP_DIRECTORY, "hf_train") 22TORCHBENCH_DIR = os.path.join(OP_INP_DIRECTORY, "torchbench_train") 23 24aten = torch.ops.aten 25tensor_type = torch._C.TensorType.get() 26 27dtype_abbrs = { 28 torch.bfloat16: "bf16", 29 torch.float64: "f64", 30 torch.float32: "f32", 31 torch.float16: "f16", 32 torch.complex32: "c32", 33 torch.complex64: "c64", 34 torch.complex128: "c128", 35 torch.int8: "i8", 36 torch.int16: "i16", 37 torch.int32: "i32", 38 torch.int64: "i64", 39 torch.bool: "b8", 40 torch.uint8: "u8", 41} 42 43dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} 44 45 46def truncate_inp(arg): 47 if arg in dtype_abbrs: 48 return dtype_abbrs[arg] 49 elif isinstance(arg, torch.device): 50 return arg.type 51 else: 52 return arg 53 54 55# Serialize Function Call 56class FuncCallWrapper: 57 def __init__(self, call, *args, **kwargs): 58 self.call = call 59 self.args = tree_map(truncate_inp, args) 60 self.kwargs = tree_map(truncate_inp, kwargs) if kwargs is not None else {} 61 62 def __repr__(self): 63 args = ", ".join([repr(arg) for arg in self.args]) 64 kwargs = "".join( 65 [f", {str(key)}={value}" for key, value in self.kwargs.items()] 66 ) 67 out = f"{self.call}({args}{kwargs})".strip('"') 68 # f strings introduce quotations we dont want 69 for key in dtype_abbrs_parsing: 70 out = out.replace(f"'{key}'", key) 71 return out 72 73 74def serialize_sparse_tensor(e): 75 if isinstance(e, torch._subclasses.FakeTensor): 76 return FuncCallWrapper("ST", list(e.shape), e.dtype, e.layout, e.is_coalesced()) 77 else: 78 return FuncCallWrapper( 79 "ST", list(e.shape), e.dtype, e.layout, e.is_coalesced(), e._nnz() 80 ) 81 82 83def deserialize_sparse_tensor(size, dtype, layout, is_coalesced, nnz=None): 84 raise NotImplementedError 85 86 87def deserialize_tensor(size, dtype, stride=None): 88 if stride is not None: 89 out = torch.empty_strided(size, stride, dtype=dtype) 90 else: 91 out = torch.empty(size, dtype=dtype) 92 try: 93 out.copy_(make_tensor(size, dtype=dtype, device="cpu")) 94 except Exception as e: 95 print(e) 96 return out 97 return out 98 99 100def serialize_tensor(e): 101 if not e.is_contiguous(): 102 return FuncCallWrapper("T", list(e.shape), e.dtype, stride=e.stride()) 103 else: 104 return FuncCallWrapper("T", list(e.shape), e.dtype) 105 106 107def serialize_torch_args(e): 108 if isinstance(e, torch.Tensor): 109 if e.is_sparse: 110 return serialize_sparse_tensor(e) 111 return serialize_tensor(e) 112 else: 113 return truncate_inp(e) 114 115 116def contains_tensor(elems): 117 for elem in pytree.tree_leaves(elems): 118 if isinstance(elem, torch.Tensor): 119 return True 120 return False 121 122 123def skip_args(elems): 124 for i in pytree.tree_leaves(elems): 125 # only shows up in constructors and ops like that 126 if isinstance(i, (torch.memory_format, torch.storage.UntypedStorage)): 127 return True 128 return False 129 130 131def contains_tensor_types(type): 132 return type.isSubtypeOf(tensor_type) or any( 133 contains_tensor_types(e) for e in type.containedTypes() 134 ) 135 136 137@functools.lru_cache(None) 138def non_compute_operator(op): 139 schema = op._schema 140 141 # skip constructors 142 if not any(contains_tensor_types(arg.type) for arg in schema.arguments): 143 return True 144 if "_like" in op.name(): 145 return True 146 147 # allow in place writes 148 if schema.is_mutable: 149 return False 150 151 tensor_inps = [arg for arg in schema.arguments if arg.type is tensor_type] 152 tensor_outputs = [ret for ret in schema.returns if ret.type is tensor_type] 153 154 # skip aliasing unless there are multiple outputs 155 if len(tensor_outputs) != 1: 156 return False 157 158 for inp in tensor_inps: 159 if inp.alias_info and tensor_outputs[0].alias_info: 160 if inp.alias_info.before_set.intersection( 161 tensor_outputs[0].alias_info.after_set 162 ): 163 return True 164 165 return False 166 167 168class OperatorInputsMode(TorchDispatchMode): 169 def __init__(self, func_db=None): 170 self.func_db = defaultdict(Counter) if func_db is None else func_db 171 172 def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None): 173 kwargs = kwargs if kwargs else {} 174 arg_meta, kwarg_meta = tree_map(serialize_torch_args, (args, kwargs)) 175 176 out = func_overload(*args, **kwargs) 177 178 inps = (args, kwargs) 179 if contains_tensor(inps) and not skip_args(inps) and contains_tensor(out): 180 serialized_str = repr((arg_meta, kwarg_meta)) 181 self.func_db[str(func_overload)][serialized_str] += 1 182 183 return out 184 185 def log_to_file(self, output_filename, *, skip_non_compute_operators=True): 186 sorted_operators = sorted(self.func_db.keys()) 187 with open(output_filename, "w") as f: 188 for operator in sorted_operators: 189 if skip_non_compute_operators and non_compute_operator(eval(operator)): 190 continue 191 f.write(f"Operator: {operator}\n") 192 operator_inputs = self.func_db[operator] 193 for inps, count in operator_inputs.items(): 194 f.write(f"cnt: {count}, ") 195 # repr will add quotation marks around the dtype strings 196 for dtype_abbr in dtype_abbrs.values(): 197 inps = inps.replace("'" + dtype_abbr + "'", dtype_abbr) 198 f.write(inps) 199 f.write("\n") 200 201 202def map_to_device(e, device): 203 if isinstance(e, torch.Tensor): 204 return e.to(device) 205 elif isinstance(e, torch.device): 206 return device 207 elif isinstance(e, str): 208 if e == "cuda" or e == "cpu": 209 return device.type 210 else: 211 return e 212 213 214def map_to_dtype(e, dtype): 215 if isinstance(e, torch.Tensor) and e.is_floating_point(): 216 return e.to(dtype) 217 elif isinstance(e, torch.dtype): 218 return dtype 219 else: 220 return e 221 222 223def deserialize_args(inps): 224 inps = inps.strip().strip("'") 225 global_vals = { 226 "T": deserialize_tensor, 227 "ST": deserialize_sparse_tensor, 228 "th": torch, 229 "inf": math.inf, 230 "torch": torch, 231 **dtype_abbrs_parsing, 232 } 233 # f strings introduce quotations we dont want 234 for key in dtype_abbrs_parsing: 235 inps = inps.replace(f"'{key}'", key) 236 return eval(inps.strip().strip("'").strip('"'), global_vals) 237 238 239class OperatorInputsLoader: 240 def __init__(self, json_file_path): 241 self.operator_db = defaultdict(Counter) 242 243 with open(json_file_path) as f: 244 lines = f.readlines() 245 246 i = 0 247 while i < len(lines): 248 op_line = lines[i].strip("\n") 249 assert "Operator: " in op_line, op_line 250 operator = op_line[len("Operator: ") :] 251 operator = ( 252 operator if operator != "aten.sum.SymInt" else "aten.sum.dim_IntList" 253 ) 254 op_inps = Counter() 255 i += 1 256 while i < len(lines) and "Operator: " not in lines[i]: 257 line = lines[i] 258 cnt = eval(line[len("cnt: ") : line.find(",")]) 259 inps = line[line.find(",") + 2 :].strip("'") 260 op_inps[inps] += cnt 261 i += 1 262 self.operator_db[operator] = op_inps 263 264 def get_inputs_for_operator( 265 self, operator, dtype=None, device="cuda" 266 ) -> Generator[Tuple[Iterable[Any], Dict[str, Any]], None, None]: 267 assert ( 268 str(operator) in self.operator_db 269 ), f"Could not find {operator}, must provide overload" 270 271 if "embedding" in str(operator): 272 log.warning("Embedding inputs NYI, input data cannot be randomized") 273 yield 274 return 275 276 # line[1] represents number of times these inputs occured, ignored for now 277 for line in self.operator_db[str(operator)].items(): 278 inps = line[0] 279 280 args, kwargs = deserialize_args(inps) 281 282 # Backwards require some inputs to be float16 and some to be float32 283 # So we record on half and upcast to float when specified 284 if dtype and dtype != torch.float16: 285 to_dtype = partial(map_to_dtype, dtype=dtype) 286 args, kwargs = tree_map(to_dtype, (args, kwargs)) 287 288 if device: 289 to_device = partial(map_to_device, device=torch.device(device)) 290 args, kwargs = tree_map(to_device, (args, kwargs)) 291 292 yield args, kwargs 293 294 def get_all_ops(self): 295 for key in self.operator_db.keys(): 296 try: 297 op = eval(key) 298 except AttributeError as ae: 299 log.warning("Evaluating an op name into an OpOverload: %s", ae) 300 continue 301 yield op 302 303 def get_call_frequency(self, op): 304 assert ( 305 str(op) in self.operator_db 306 ), f"Could not find {op}, must provide overload" 307 308 count = 0 309 for counter in self.operator_db[str(op)].values(): 310 count += counter 311 return count 312 313 def merge(self, other): 314 for operator, counter_dict in other.operator_db.items(): 315 for inps, cnt in counter_dict.items(): 316 self.operator_db[operator][inps] += cnt 317 318 @staticmethod 319 def get_timm_loader(): 320 return OperatorInputsLoader._load_directory(TIMM_DIR) 321 322 @staticmethod 323 def get_huggingface_loader(): 324 return OperatorInputsLoader._load_directory(HF_DIR) 325 326 @staticmethod 327 def get_torchbench_loader(): 328 return OperatorInputsLoader._load_directory(TORCHBENCH_DIR) 329 330 @staticmethod 331 def _load_directory(inp_dir): 332 assert os.path.isdir(inp_dir), inp_dir 333 union = None 334 for inp in os.listdir(inp_dir): 335 if inp[-4:] != ".txt": 336 continue 337 path = os.path.join(inp_dir, inp) 338 if union is None: 339 union = OperatorInputsLoader(path) 340 else: 341 union.merge(OperatorInputsLoader(path)) 342 return union 343