1# mypy: allow-untyped-defs 2import torch.fx as fx 3 4def set_trace(gm: fx.GraphModule) -> fx.GraphModule: 5 """ 6 Sets a breakpoint in `gm`'s generated python code. It drops into pdb when 7 `gm` gets run. 8 9 Args: 10 gm: graph module to insert breakpoint. It is then recompiled for it to 11 take effect. 12 13 Returns: 14 the `gm` with breakpoint inserted. 15 """ 16 def insert_pdb(body): 17 return ["import pdb; pdb.set_trace()\n", *body] 18 19 with gm.graph.on_generate_code( 20 make_transformer=lambda cur_transform: ( 21 # new code transformer to register 22 lambda body: ( 23 insert_pdb( 24 cur_transform(body) if cur_transform 25 else body 26 ) 27 ) 28 ) 29 ): 30 gm.recompile() 31 32 return gm 33