Lines Matching refs:GraphModule
52 PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
288 def __init__(self, callback: "_ExportPassBase", gm: fx.GraphModule) -> None: argument
383 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
417 def inputs(self, graph_module: torch.fx.GraphModule) -> List[Argument]: argument
483 true_fn: torch.fx.GraphModule, argument
484 false_fn: torch.fx.GraphModule, argument
502 f: torch.fx.GraphModule, argument
527 self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] argument
535 torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
541 new_graph_module = torch.fx.GraphModule(self.tracer.root, self.tracer.graph)
550 def call(self, graph_module: fx.GraphModule) -> PassResult: argument
597 def __init__(self, callback: "ExportPass", gm: fx.GraphModule) -> None: argument
660 self, graph_module: fx.GraphModule, inputs: Tuple[Argument, ...] argument
665 gm: torch.fx.GraphModule, new_gm: torch.fx.GraphModule argument
667 def get_phs(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: argument