# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. # pyre-strict import logging from typing import Optional, Sequence, Union import torch from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue from torch._subclasses import FakeTensor, FakeTensorMode from torch.fx.node import Argument, Target from torch.utils import _pytree as pytree class GraphBuilder(ExportPass): """Utility class for creating a graph module with user-specified ops. This class allows us to create test graph modules with any ops we want directly, rather than relying on decomposition or passes. Usage: builder = GraphBuilder() # To insert placeholders, use builder.placeholder. x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) # To insert an op, use builder.call_operator. op = builder.call_operator( some_op (x, other_args, ...), ) # Insert outputs as a list of ProxyValues using builder.output. builder.output([op]) # Get GraphModule from builder. gm = builder.get_graph_module() """ def __init__(self) -> None: self.exporter = ExportPass() self.tracer: ExportPass.ExportTracer = self.ExportTracer( self, torch.fx.graph.CodeGen() ) self.fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) self.tracer.fake_tensor_mode = self.fake_tensor_mode # This will be called to create nodes in tracer. self.interpreter = torch.fx.Interpreter( torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) ) # pyre-ignore[14]: Inconsistent override. def placeholder( self, target: str, fake_tensor: Union[FakeTensor, torch.Tensor] ) -> ProxyValue: if not isinstance(fake_tensor, FakeTensor): fake_tensor = self.fake_tensor_mode.from_tensor(fake_tensor) logging.info(f"Creating placeholder {target} => {fake_tensor.shape}") placeholder = super().placeholder(target, fake_tensor, NodeMetadata({})) return placeholder # pyre-ignore[14]: Inconsistent override. def output(self, results: list[ProxyValue]) -> ProxyValue: logging.info(f"Creating outputs {results}") return super().output(results, NodeMetadata({})) def get_graph_module(self) -> torch.fx.GraphModule: return torch.fx.GraphModule(self.tracer.root, self.tracer.graph) def call_operator( self, op, # pyre-ignore args: tuple[Argument, ...], kwargs: Optional[dict[str, Argument]] = None, meta: Optional[NodeMetadata] = None, ) -> ProxyValue: if meta is None: meta = NodeMetadata({}) if kwargs is None: kwargs = {} return super().call_operator(op, args, kwargs, meta) def single_op_builder( placeholders: Sequence[Union[torch.Tensor, FakeTensor]], op: Target, args: Sequence[Argument], kwargs: Optional[dict[str, Argument]] = None, ) -> torch.fx.GraphModule: """Create a graph module with a single op. Args: placeholders: Placeholders to be used as inputs to the GraphModule. op: The op to be inserted. args: The args to be passed to the op. kwargs: The kwargs to be passed to the op. Returns: A graph module with a single op """ builder = GraphBuilder() op_to_placeholder_dict = { p: builder.placeholder(f"p_{i}", p) for i, p in enumerate(placeholders) } proxy_args, proxy_kwargs = pytree.tree_map_only( (torch.Tensor, FakeTensor), lambda x: op_to_placeholder_dict[x], (args, kwargs) ) node = builder.call_operator(op, proxy_args, proxy_kwargs) builder.output([node]) return builder.get_graph_module()