xref: /aosp_15_r20/external/pytorch/torch/_functorch/fx_minifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import copy
4import math
5import os
6import sys
7from dataclasses import dataclass
8from functools import partial, wraps
9from typing import Callable, List
10
11import torch
12import torch.fx as fx
13from torch.hub import tqdm
14from torch.multiprocessing.reductions import StorageWeakRef
15from torch.utils._content_store import ContentStoreWriter
16
17from .compile_utils import get_outputs, get_placeholders
18
19
20is_tuple = object()
21
22
23@dataclass
24class LoadTensorMeta:
25    size: List[int]
26    stride: List[int]
27    dtype: torch.dtype
28    device: torch.device
29
30
31class ConcreteProp(torch.fx.Interpreter):
32    def __init__(self, mod, *, writer=None, skip_offload=False):
33        super().__init__(mod)
34        self.writer = writer
35        self.skip_offload = skip_offload
36        self.seen_storages = set()
37
38    def run_node(self, n):
39        self.pbar.update(1)
40        r = super().run_node(n)
41        name = n.name
42
43        if isinstance(r, torch.Tensor):
44            if self.writer is None:
45                n.meta["concrete_value"] = r
46            else:
47                if StorageWeakRef(r.untyped_storage()) in self.seen_storages:
48                    # Refuse to offload tensors which alias other live
49                    # tensors, because this will violate operator contracts
50                    n.meta["concrete_value"] = None
51                else:
52                    if not self.skip_offload:
53                        self.writer.write_tensor(os.path.join("eager", name), r)
54                    n.meta["concrete_value"] = LoadTensorMeta(
55                        r.size(), r.stride(), r.dtype, r.device
56                    )
57                    self.seen_storages.add(StorageWeakRef(r.untyped_storage()))
58        else:
59            n.meta["concrete_value"] = is_tuple
60
61        return r
62
63    def propagate(self, *args):
64        with tqdm(
65            desc="Saving intermediates for delta debugging",
66            total=len(self.module.graph.nodes),
67            disable=self.writer is None,
68        ) as pbar:
69            self.pbar = pbar
70            r = super().run(*args)
71            if not self.skip_offload:
72                pbar.set_description(
73                    "Saved!  To skip next time, run with --skip-saving-eager-intermediates"
74                )
75            return r
76
77
78def is_load_tensor_node(node):
79    return (
80        node.op == "call_function"
81        and node.target is torch.ops.debugprims.load_tensor.default
82    )
83
84
85# inplace modifies node/inps
86def _convert_node_to_placeholder(graph, node, inps):
87    if node.op == "output" or node.op == "placeholder":
88        return False
89
90    if is_load_tensor_node(node):
91        return False
92
93    concrete_val = node.meta.get("concrete_value", None)
94
95    if isinstance(concrete_val, torch.Tensor):
96        node.op = "placeholder"
97        node.target = node.name
98        node.args = ()
99        node.kwargs = {}
100
101        inps.append(concrete_val)
102        return True
103
104    elif concrete_val is None:
105        return False
106
107    elif concrete_val is is_tuple:
108        r = False
109        for tuple_user in list(node.users):
110            r = _convert_node_to_placeholder(graph, tuple_user, inps) or r
111        # NB: We must not erase the node at this point, because
112        # we are iterating over the nodes and this would change
113        # the iteration order
114        # graph.erase_node(node)
115        return r
116
117    elif isinstance(concrete_val, LoadTensorMeta):
118        node.op = "call_function"
119        node.target = torch.ops.debugprims.load_tensor.default
120        node.args = (
121            os.path.join("eager", node.name),
122            concrete_val.size,
123            concrete_val.stride,
124        )
125        node.kwargs = {
126            "device": concrete_val.device,
127            "dtype": concrete_val.dtype,
128        }
129        return True
130
131    return False
132
133
134def create_minified_hlo_graph(minified_fx_graph, inputs):
135    """
136    Takes minified FX graph as primary input, and ports it to HLO via StableHLO
137    Provides minified HLO graph as output, and archive them to local directory
138    """
139    hlo_dir = f"{os.getcwd()}/hlo_files"
140    os.makedirs(hlo_dir, exists_ok=True)
141
142    from torch_xla.stablehlo import save_torch_model_as_stablehlo
143
144    save_torch_model_as_stablehlo(minified_fx_graph, inputs, hlo_dir)
145
146
147def dump_state(fx_g, inps):
148    print(
149        f"""
150# Working Repro with {len(fx_g.graph.nodes)} nodes
151inps = {[(i.shape, i.dtype, i.device.type) for i in inps]}
152inps = [torch.zeros(())] + [torch.ones(shape, dtype=dtype, device=device) for (shape, dtype, device) in inps]
153{fx_g.code}
154"""
155    )
156
157
158def is_power_of_two(n):
159    if n == 0:
160        return False
161    return (n & (n - 1)) == 0
162
163
164@dataclass
165class ReproState:
166    graph: fx.Graph
167    inps: List[torch.Tensor]
168
169    def __post_init__(self):
170        ph_nodes = get_placeholders(self.graph)
171        assert len(ph_nodes) == len(self.inps)
172
173
174def minifier(
175    fail_f: fx.GraphModule,
176    inps,
177    module_fails,
178    dump_state: Callable = dump_state,
179    *,
180    save_dir=None,
181    offload_to_disk=False,
182    skip_offload=False,
183    skip_sanity=False,
184    max_granularity=None,
185):
186    """
187    Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
188
189    Does 2 main strategies:
190    1. Truncates suffix: Removes some suffix from the graph and sets a new output.
191    2. Delta Debugging: Tries replacing half of the graph with inputs. If fails,
192        tries replacing quarter of the graph, etc.
193
194    >>> # xdoctest: +SKIP(failing)
195    >>> failing_function = fx.symbolic_trace(f)
196    >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
197
198    note: module_fails returns True if it fails.
199    """
200    assert isinstance(inps, (tuple, list))
201
202    failing_graph = fail_f.graph
203    cur_size = len(failing_graph.nodes)
204
205    if max_granularity is not None and not is_power_of_two(max_granularity):
206        raise RuntimeError(f"max_granularity {max_granularity} not power of two")
207
208    num_queries = 0
209
210    def deepcopy_fx_graph(fx_graph):
211        return fx.GraphModule(fail_f, copy.deepcopy(fx_graph)).graph
212
213    def graph_fails(graph, inps):
214        nonlocal num_queries
215        graph = copy.deepcopy(graph)
216        num_queries += 1
217        mod = fx.GraphModule(fail_f, graph)
218        mod.graph.lint()
219        return module_fails(mod, inps)
220
221    writer = None
222    if offload_to_disk:
223        writer = ContentStoreWriter(save_dir)
224
225    ConcreteProp(fail_f, writer=writer, skip_offload=skip_offload).propagate(*inps)
226    if not skip_sanity and not graph_fails(failing_graph, inps):
227        raise RuntimeError("Input graph did not fail the tester")
228    print(f"Started off with {cur_size} nodes", file=sys.stderr)
229
230    def _register_strategy(strategy: Callable, name: str):
231        @wraps(strategy)
232        def new_func(old_state: ReproState, granularity=1):
233            print(file=sys.stderr)
234            print(
235                f"Strategy: {name} (G: {granularity}) "
236                f"({len(old_state.graph.nodes)} nodes, {len(old_state.inps)} inputs)",
237                file=sys.stderr,
238            )
239            new_state = strategy(
240                deepcopy_fx_graph(old_state.graph), list(old_state.inps), granularity
241            )
242            if new_state is not None:
243                new_nodes = len(new_state.graph.nodes)
244                old_nodes = len(old_state.graph.nodes)
245                new_inps = len(new_state.inps)
246                old_inps = len(old_state.inps)
247                new_outs = len(get_outputs(new_state.graph))
248                old_outs = len(get_outputs(old_state.graph))
249                progress_made = False
250                if new_nodes < old_nodes:
251                    progress_made = True
252                    print(
253                        f"SUCCESS: Went from {old_nodes} to {new_nodes} nodes",
254                        file=sys.stderr,
255                    )
256                if new_inps > old_inps:
257                    progress_made = True
258                    print(
259                        f"SUCCESS: Went from {old_inps} to {new_inps} inputs",
260                        file=sys.stderr,
261                    )
262                if new_outs < old_outs:
263                    progress_made = True
264                    print(
265                        f"SUCCESS: Went from {old_outs} to {new_outs} outputs",
266                        file=sys.stderr,
267                    )
268
269                if not progress_made:
270                    raise RuntimeError("Success raised but no progress made?")
271
272                if not graph_fails(new_state.graph, new_state.inps):
273                    print(
274                        "WARNING: Something went wrong, not applying this minification",
275                        file=sys.stderr,
276                    )
277                    return None
278                return new_state
279            else:
280                print(f"FAIL: {name}", file=sys.stderr)
281            return None
282
283        return new_func
284
285    def register_strategy(name: str):
286        return partial(_register_strategy, name=name)
287
288    @register_strategy("Truncate suffix")
289    def remove_suffix(cur_graph, cur_inps, granularity):
290        tested = set()
291        new_graph = fx.Graph()
292        env = {}
293        for idx, node in enumerate(cur_graph.nodes):
294            new_node = new_graph.node_copy(node, lambda x: env[x])
295            if node.op not in ["placeholder", "output"]:
296                # If idx is divisible by (granularity * 2), it would have been checked already.
297                if (
298                    idx % granularity == 0
299                    and (idx % (granularity * 2) != 0)
300                    and idx not in tested
301                ):
302                    output_node = new_graph.output((new_node,))
303                    if len(new_graph.nodes) < len(cur_graph.nodes) and graph_fails(
304                        new_graph, cur_inps
305                    ):
306                        return ReproState(new_graph, cur_inps)
307                    else:
308                        tested.add(idx)
309                        new_graph.erase_node(output_node)
310            env[node] = new_node
311        return None
312
313    @register_strategy("Remove outputs")
314    def remove_outputs(cur_graph, cur_inps, granularity):
315        granularity = max(1, granularity // 2)
316        for idx, node in enumerate(cur_graph.nodes):
317            node.idx = idx
318            if node.op == "output":
319                output = node
320                break
321
322        if isinstance(output.args[0], fx.Node):
323            return None
324
325        output_args = sorted(
326            output.args[0], key=lambda x: x.idx if isinstance(x, fx.Node) else int(1e9)
327        )
328        if len(output_args) == 1:
329            return None
330
331        for idx in range(0, len(output_args), granularity):
332            output.args = (output_args[:idx] + output_args[idx + granularity :],)
333            if graph_fails(cur_graph, cur_inps):
334                return ReproState(cur_graph, cur_inps)
335        return None
336
337    def remove_unused_inputs_unchecked(cur_state: ReproState):
338        cur_graph = cur_state.graph
339        cur_inps = cur_state.inps
340        ph_nodes = get_placeholders(cur_graph)
341        assert len(ph_nodes) == len(cur_inps)
342
343        new_inps = []
344        for idx in range(len(ph_nodes)):
345            if len(ph_nodes[idx].users) == 0:
346                cur_graph.erase_node(ph_nodes[idx])
347            else:
348                new_inps.append(cur_inps[idx])
349        if len(new_inps) < len(cur_inps):
350            return ReproState(cur_graph, new_inps)
351        return None
352
353    def remove_unused_inputs_checked(cur_state: ReproState):
354        new_state = remove_unused_inputs_unchecked(cur_state)
355        if new_state is not None and graph_fails(new_state.graph, new_state.inps):
356            return new_state
357        return None
358
359    def _remove_unused_wrapper(cur_graph, cur_inps, granularity):
360        return remove_unused_inputs_checked(ReproState(cur_graph, cur_inps))
361
362    remove_unused_inputs = register_strategy("Remove unused inputs")(
363        _remove_unused_wrapper
364    )
365
366    @register_strategy("Eliminate dead code")
367    def eliminate_dead_code(cur_graph, cur_inps, granularity):
368        if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
369            return ReproState(cur_graph, cur_inps)
370        return None
371
372    def _consolidate_placeholders(cur_graph, inps):
373        new_graph = fx.Graph()
374        env = {}
375        seen_non_placeholder = False
376
377        # Move all placeholders to the front; also, if any load_tensor
378        # is at the front, convert it into an input (because it can be live
379        # all the time)
380        for node in cur_graph.nodes:
381            if node.op == "placeholder":
382                new_node = new_graph.node_copy(node, lambda x: env[x])
383                env[node] = new_node
384            elif not seen_non_placeholder and is_load_tensor_node(node):
385                new_node = new_graph.placeholder(node.name)
386                env[node] = new_node
387                inps.append(
388                    torch.ops.debugprims.load_tensor.default(*node.args, **node.kwargs)
389                )
390            else:
391                seen_non_placeholder = True
392
393        # Move everyone else
394        for node in cur_graph.nodes:
395            if node not in env:
396                new_node = new_graph.node_copy(node, lambda x: env[x])
397                env[node] = new_node
398        return new_graph
399
400    @register_strategy("Delta Debugging")
401    def delta_debugging(cur_graph: fx.Graph, cur_inps, granularity):
402        num_nodes = len(cur_graph.nodes)
403        for start_range in range(0, num_nodes, granularity):
404            is_removing = False
405            new_graph = deepcopy_fx_graph(cur_graph)
406            new_inps = cur_inps[:]
407            end_range = min(num_nodes, start_range + granularity)
408            for idx in range(start_range, end_range):
409                new_node = list(new_graph.nodes)[idx]
410                if _convert_node_to_placeholder(new_graph, new_node, new_inps):
411                    is_removing = True
412            if not is_removing:
413                continue
414            new_graph.eliminate_dead_code()
415            new_graph = _consolidate_placeholders(new_graph, new_inps)
416            new_state = remove_unused_inputs_unchecked(ReproState(new_graph, new_inps))
417            if new_state is None:
418                new_state = ReproState(new_graph, new_inps)
419            if graph_fails(new_state.graph, new_state.inps):
420                return ReproState(new_state.graph, new_state.inps)
421
422        return None
423
424    @register_strategy("Consolidate Inputs")
425    def consolidate_inputs(cur_graph, cur_inps, granularity):
426        old_len = len(cur_inps)
427        cur_graph = _consolidate_placeholders(cur_graph, cur_inps)
428        if len(cur_inps) > old_len and graph_fails(cur_graph, cur_inps):
429            return ReproState(cur_graph, cur_inps)
430        return None
431
432    failing_state = ReproState(failing_graph, inps)
433
434    def try_granularity(failing_state, granularity, use_non_granular):
435        print(f"Trying granularity {granularity}", file=sys.stderr)
436
437        strategies = []
438        num_nodes = len(failing_state.graph.nodes)
439        num_outputs = len(get_outputs(failing_state.graph))
440        if num_outputs > num_nodes // 2:
441            strategies += [remove_outputs]
442
443        if use_non_granular:
444            strategies += [
445                eliminate_dead_code,
446                remove_unused_inputs,
447                consolidate_inputs,
448            ]
449
450        strategies += [remove_suffix, delta_debugging]
451
452        for strategy in strategies:
453            new_state = strategy(failing_state, granularity)
454            if new_state is not None:
455                return new_state
456        return None
457
458    while True:
459        dump_state(fx.GraphModule(fail_f, failing_state.graph), failing_state.inps)
460        granularity = int(2 ** (math.floor(math.log2(len(failing_state.graph.nodes)))))
461        if max_granularity is not None:
462            granularity = min(max_granularity, granularity)
463        new_state = try_granularity(failing_state, granularity, use_non_granular=True)
464        if new_state is not None:
465            failing_state = new_state
466            continue
467
468        granularity //= 2
469        has_progress = False
470        while granularity >= 1:
471            new_state = try_granularity(
472                failing_state, granularity, use_non_granular=False
473            )
474            if new_state is not None:
475                failing_state = new_state
476                has_progress = True
477                break
478            granularity //= 2
479        if has_progress:
480            continue
481
482        new_state = remove_outputs(failing_state, 1)
483        if new_state is not None:
484            failing_state = new_state
485            continue
486
487        break
488
489    if not graph_fails(failing_state.graph, failing_state.inps):
490        raise RuntimeError("Uh oh, something went wrong :( Final graph is not failing")
491
492    print(f"Made {num_queries} queries", file=sys.stderr)
493    failing_fx = fx.GraphModule(fail_f, failing_state.graph)
494
495    # If XLA debugging environment is enabled, create minified HLO graph as well
496    if "XLA_HLO_DEBUG" in os.environ:
497        create_minified_hlo_graph(failing_fx, failing_state.inps)
498
499    dump_state(failing_fx, failing_state.inps)
500    print("Wrote minimal repro out to repro.py", file=sys.stderr)
501    return failing_fx, failing_state.inps
502