xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/debug.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch.fx as fx
3
4def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
5    """
6    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
7    `gm` gets run.
8
9    Args:
10        gm: graph module to insert breakpoint. It is then recompiled for it to
11            take effect.
12
13    Returns:
14        the `gm` with breakpoint inserted.
15    """
16    def insert_pdb(body):
17        return ["import pdb; pdb.set_trace()\n", *body]
18
19    with gm.graph.on_generate_code(
20        make_transformer=lambda cur_transform: (
21            # new code transformer to register
22            lambda body: (
23                insert_pdb(
24                    cur_transform(body) if cur_transform
25                    else body
26                )
27            )
28        )
29    ):
30        gm.recompile()
31
32    return gm
33