1# mypy: ignore-errors 2 3import copy 4import math 5import os 6import sys 7from dataclasses import dataclass 8from functools import partial, wraps 9from typing import Callable, List 10 11import torch 12import torch.fx as fx 13from torch.hub import tqdm 14from torch.multiprocessing.reductions import StorageWeakRef 15from torch.utils._content_store import ContentStoreWriter 16 17from .compile_utils import get_outputs, get_placeholders 18 19 20is_tuple = object() 21 22 23@dataclass 24class LoadTensorMeta: 25 size: List[int] 26 stride: List[int] 27 dtype: torch.dtype 28 device: torch.device 29 30 31class ConcreteProp(torch.fx.Interpreter): 32 def __init__(self, mod, *, writer=None, skip_offload=False): 33 super().__init__(mod) 34 self.writer = writer 35 self.skip_offload = skip_offload 36 self.seen_storages = set() 37 38 def run_node(self, n): 39 self.pbar.update(1) 40 r = super().run_node(n) 41 name = n.name 42 43 if isinstance(r, torch.Tensor): 44 if self.writer is None: 45 n.meta["concrete_value"] = r 46 else: 47 if StorageWeakRef(r.untyped_storage()) in self.seen_storages: 48 # Refuse to offload tensors which alias other live 49 # tensors, because this will violate operator contracts 50 n.meta["concrete_value"] = None 51 else: 52 if not self.skip_offload: 53 self.writer.write_tensor(os.path.join("eager", name), r) 54 n.meta["concrete_value"] = LoadTensorMeta( 55 r.size(), r.stride(), r.dtype, r.device 56 ) 57 self.seen_storages.add(StorageWeakRef(r.untyped_storage())) 58 else: 59 n.meta["concrete_value"] = is_tuple 60 61 return r 62 63 def propagate(self, *args): 64 with tqdm( 65 desc="Saving intermediates for delta debugging", 66 total=len(self.module.graph.nodes), 67 disable=self.writer is None, 68 ) as pbar: 69 self.pbar = pbar 70 r = super().run(*args) 71 if not self.skip_offload: 72 pbar.set_description( 73 "Saved! To skip next time, run with --skip-saving-eager-intermediates" 74 ) 75 return r 76 77 78def is_load_tensor_node(node): 79 return ( 80 node.op == "call_function" 81 and node.target is torch.ops.debugprims.load_tensor.default 82 ) 83 84 85# inplace modifies node/inps 86def _convert_node_to_placeholder(graph, node, inps): 87 if node.op == "output" or node.op == "placeholder": 88 return False 89 90 if is_load_tensor_node(node): 91 return False 92 93 concrete_val = node.meta.get("concrete_value", None) 94 95 if isinstance(concrete_val, torch.Tensor): 96 node.op = "placeholder" 97 node.target = node.name 98 node.args = () 99 node.kwargs = {} 100 101 inps.append(concrete_val) 102 return True 103 104 elif concrete_val is None: 105 return False 106 107 elif concrete_val is is_tuple: 108 r = False 109 for tuple_user in list(node.users): 110 r = _convert_node_to_placeholder(graph, tuple_user, inps) or r 111 # NB: We must not erase the node at this point, because 112 # we are iterating over the nodes and this would change 113 # the iteration order 114 # graph.erase_node(node) 115 return r 116 117 elif isinstance(concrete_val, LoadTensorMeta): 118 node.op = "call_function" 119 node.target = torch.ops.debugprims.load_tensor.default 120 node.args = ( 121 os.path.join("eager", node.name), 122 concrete_val.size, 123 concrete_val.stride, 124 ) 125 node.kwargs = { 126 "device": concrete_val.device, 127 "dtype": concrete_val.dtype, 128 } 129 return True 130 131 return False 132 133 134def create_minified_hlo_graph(minified_fx_graph, inputs): 135 """ 136 Takes minified FX graph as primary input, and ports it to HLO via StableHLO 137 Provides minified HLO graph as output, and archive them to local directory 138 """ 139 hlo_dir = f"{os.getcwd()}/hlo_files" 140 os.makedirs(hlo_dir, exists_ok=True) 141 142 from torch_xla.stablehlo import save_torch_model_as_stablehlo 143 144 save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir) 145 146 147def dump_state(fx_g, inps): 148 print( 149 f""" 150# Working Repro with {len(fx_g.graph.nodes)} nodes 151inps = {[(i.shape, i.dtype, i.device.type) for i in inps]} 152inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps] 153{fx_g.code} 154""" 155 ) 156 157 158def is_power_of_two(n): 159 if n == 0: 160 return False 161 return (n & (n - 1)) == 0 162 163 164@dataclass 165class ReproState: 166 graph: fx.Graph 167 inps: List[torch.Tensor] 168 169 def __post_init__(self): 170 ph_nodes = get_placeholders(self.graph) 171 assert len(ph_nodes) == len(self.inps) 172 173 174def minifier( 175 fail_f: fx.GraphModule, 176 inps, 177 module_fails, 178 dump_state: Callable = dump_state, 179 *, 180 save_dir=None, 181 offload_to_disk=False, 182 skip_offload=False, 183 skip_sanity=False, 184 max_granularity=None, 185): 186 """ 187 Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails. 188 189 Does 2 main strategies: 190 1. Truncates suffix: Removes some suffix from the graph and sets a new output. 191 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, 192 tries replacing quarter of the graph, etc. 193 194 >>> # xdoctest: +SKIP(failing) 195 >>> failing_function = fx.symbolic_trace(f) 196 >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps)) 197 198 note: module_fails returns True if it fails. 199 """ 200 assert isinstance(inps, (tuple, list)) 201 202 failing_graph = fail_f.graph 203 cur_size = len(failing_graph.nodes) 204 205 if max_granularity is not None and not is_power_of_two(max_granularity): 206 raise RuntimeError(f"max_granularity {max_granularity} not power of two") 207 208 num_queries = 0 209 210 def deepcopy_fx_graph(fx_graph): 211 return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph 212 213 def graph_fails(graph, inps): 214 nonlocal num_queries 215 graph = copy.deepcopy(graph) 216 num_queries += 1 217 mod = fx.GraphModule(fail_f, graph) 218 mod.graph.lint() 219 return module_fails(mod, inps) 220 221 writer = None 222 if offload_to_disk: 223 writer = ContentStoreWriter(save_dir) 224 225 ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps) 226 if not skip_sanity and not graph_fails(failing_graph, inps): 227 raise RuntimeError("Input graph did not fail the tester") 228 print(f"Started off with {cur_size} nodes", file=sys.stderr) 229 230 def _register_strategy(strategy: Callable, name: str): 231 @wraps(strategy) 232 def new_func(old_state: ReproState, granularity=1): 233 print(file=sys.stderr) 234 print( 235 f"Strategy: {name} (G: {granularity}) " 236 f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)", 237 file=sys.stderr, 238 ) 239 new_state = strategy( 240 deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity 241 ) 242 if new_state is not None: 243 new_nodes = len(new_state.graph.nodes) 244 old_nodes = len(old_state.graph.nodes) 245 new_inps = len(new_state.inps) 246 old_inps = len(old_state.inps) 247 new_outs = len(get_outputs(new_state.graph)) 248 old_outs = len(get_outputs(old_state.graph)) 249 progress_made = False 250 if new_nodes < old_nodes: 251 progress_made = True 252 print( 253 f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes", 254 file=sys.stderr, 255 ) 256 if new_inps > old_inps: 257 progress_made = True 258 print( 259 f"SUCCESS: Went from {old_inps} to {new_inps} inputs", 260 file=sys.stderr, 261 ) 262 if new_outs < old_outs: 263 progress_made = True 264 print( 265 f"SUCCESS: Went from {old_outs} to {new_outs} outputs", 266 file=sys.stderr, 267 ) 268 269 if not progress_made: 270 raise RuntimeError("Success raised but no progress made?") 271 272 if not graph_fails(new_state.graph, new_state.inps): 273 print( 274 "WARNING: Something went wrong, not applying this minification", 275 file=sys.stderr, 276 ) 277 return None 278 return new_state 279 else: 280 print(f"FAIL: {name}", file=sys.stderr) 281 return None 282 283 return new_func 284 285 def register_strategy(name: str): 286 return partial(_register_strategy, name=name) 287 288 @register_strategy("Truncate suffix") 289 def remove_suffix(cur_graph, cur_inps, granularity): 290 tested = set() 291 new_graph = fx.Graph() 292 env = {} 293 for idx, node in enumerate(cur_graph.nodes): 294 new_node = new_graph.node_copy(node, lambda x: env[x]) 295 if node.op not in ["placeholder", "output"]: 296 # If idx is divisible by (granularity * 2), it would have been checked already. 297 if ( 298 idx % granularity == 0 299 and (idx % (granularity * 2) != 0) 300 and idx not in tested 301 ): 302 output_node = new_graph.output((new_node,)) 303 if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails( 304 new_graph, cur_inps 305 ): 306 return ReproState(new_graph, cur_inps) 307 else: 308 tested.add(idx) 309 new_graph.erase_node(output_node) 310 env[node] = new_node 311 return None 312 313 @register_strategy("Remove outputs") 314 def remove_outputs(cur_graph, cur_inps, granularity): 315 granularity = max(1, granularity // 2) 316 for idx, node in enumerate(cur_graph.nodes): 317 node.idx = idx 318 if node.op == "output": 319 output = node 320 break 321 322 if isinstance(output.args[0], fx.Node): 323 return None 324 325 output_args = sorted( 326 output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9) 327 ) 328 if len(output_args) == 1: 329 return None 330 331 for idx in range(0, len(output_args), granularity): 332 output.args = (output_args[:idx] + output_args[idx + granularity :],) 333 if graph_fails(cur_graph, cur_inps): 334 return ReproState(cur_graph, cur_inps) 335 return None 336 337 def remove_unused_inputs_unchecked(cur_state: ReproState): 338 cur_graph = cur_state.graph 339 cur_inps = cur_state.inps 340 ph_nodes = get_placeholders(cur_graph) 341 assert len(ph_nodes) == len(cur_inps) 342 343 new_inps = [] 344 for idx in range(len(ph_nodes)): 345 if len(ph_nodes[idx].users) == 0: 346 cur_graph.erase_node(ph_nodes[idx]) 347 else: 348 new_inps.append(cur_inps[idx]) 349 if len(new_inps) < len(cur_inps): 350 return ReproState(cur_graph, new_inps) 351 return None 352 353 def remove_unused_inputs_checked(cur_state: ReproState): 354 new_state = remove_unused_inputs_unchecked(cur_state) 355 if new_state is not None and graph_fails(new_state.graph, new_state.inps): 356 return new_state 357 return None 358 359 def _remove_unused_wrapper(cur_graph, cur_inps, granularity): 360 return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps)) 361 362 remove_unused_inputs = register_strategy("Remove unused inputs")( 363 _remove_unused_wrapper 364 ) 365 366 @register_strategy("Eliminate dead code") 367 def eliminate_dead_code(cur_graph, cur_inps, granularity): 368 if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps): 369 return ReproState(cur_graph, cur_inps) 370 return None 371 372 def _consolidate_placeholders(cur_graph, inps): 373 new_graph = fx.Graph() 374 env = {} 375 seen_non_placeholder = False 376 377 # Move all placeholders to the front; also, if any load_tensor 378 # is at the front, convert it into an input (because it can be live 379 # all the time) 380 for node in cur_graph.nodes: 381 if node.op == "placeholder": 382 new_node = new_graph.node_copy(node, lambda x: env[x]) 383 env[node] = new_node 384 elif not seen_non_placeholder and is_load_tensor_node(node): 385 new_node = new_graph.placeholder(node.name) 386 env[node] = new_node 387 inps.append( 388 torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs) 389 ) 390 else: 391 seen_non_placeholder = True 392 393 # Move everyone else 394 for node in cur_graph.nodes: 395 if node not in env: 396 new_node = new_graph.node_copy(node, lambda x: env[x]) 397 env[node] = new_node 398 return new_graph 399 400 @register_strategy("Delta Debugging") 401 def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity): 402 num_nodes = len(cur_graph.nodes) 403 for start_range in range(0, num_nodes, granularity): 404 is_removing = False 405 new_graph = deepcopy_fx_graph(cur_graph) 406 new_inps = cur_inps[:] 407 end_range = min(num_nodes, start_range + granularity) 408 for idx in range(start_range, end_range): 409 new_node = list(new_graph.nodes)[idx] 410 if _convert_node_to_placeholder(new_graph, new_node, new_inps): 411 is_removing = True 412 if not is_removing: 413 continue 414 new_graph.eliminate_dead_code() 415 new_graph = _consolidate_placeholders(new_graph, new_inps) 416 new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps)) 417 if new_state is None: 418 new_state = ReproState(new_graph, new_inps) 419 if graph_fails(new_state.graph, new_state.inps): 420 return ReproState(new_state.graph, new_state.inps) 421 422 return None 423 424 @register_strategy("Consolidate Inputs") 425 def consolidate_inputs(cur_graph, cur_inps, granularity): 426 old_len = len(cur_inps) 427 cur_graph = _consolidate_placeholders(cur_graph, cur_inps) 428 if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps): 429 return ReproState(cur_graph, cur_inps) 430 return None 431 432 failing_state = ReproState(failing_graph, inps) 433 434 def try_granularity(failing_state, granularity, use_non_granular): 435 print(f"Trying granularity {granularity}", file=sys.stderr) 436 437 strategies = [] 438 num_nodes = len(failing_state.graph.nodes) 439 num_outputs = len(get_outputs(failing_state.graph)) 440 if num_outputs > num_nodes // 2: 441 strategies += [remove_outputs] 442 443 if use_non_granular: 444 strategies += [ 445 eliminate_dead_code, 446 remove_unused_inputs, 447 consolidate_inputs, 448 ] 449 450 strategies += [remove_suffix, delta_debugging] 451 452 for strategy in strategies: 453 new_state = strategy(failing_state, granularity) 454 if new_state is not None: 455 return new_state 456 return None 457 458 while True: 459 dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps) 460 granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes))))) 461 if max_granularity is not None: 462 granularity = min(max_granularity, granularity) 463 new_state = try_granularity(failing_state, granularity, use_non_granular=True) 464 if new_state is not None: 465 failing_state = new_state 466 continue 467 468 granularity //= 2 469 has_progress = False 470 while granularity >= 1: 471 new_state = try_granularity( 472 failing_state, granularity, use_non_granular=False 473 ) 474 if new_state is not None: 475 failing_state = new_state 476 has_progress = True 477 break 478 granularity //= 2 479 if has_progress: 480 continue 481 482 new_state = remove_outputs(failing_state, 1) 483 if new_state is not None: 484 failing_state = new_state 485 continue 486 487 break 488 489 if not graph_fails(failing_state.graph, failing_state.inps): 490 raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing") 491 492 print(f"Made {num_queries} queries", file=sys.stderr) 493 failing_fx = fx.GraphModule(fail_f, failing_state.graph) 494 495 # If XLA debugging environment is enabled, create minified HLO graph as well 496 if "XLA_HLO_DEBUG" in os.environ: 497 create_minified_hlo_graph(failing_fx, failing_state.inps) 498 499 dump_state(failing_fx, failing_state.inps) 500 print("Wrote minimal repro out to repro.py", file=sys.stderr) 501 return failing_fx, failing_state.inps 502