xref: /aosp_15_r20/external/pytorch/torch/_inductor/subgraph_lowering.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Utilities for lowering subgraphs used by higher order operators
2
3"""
4
5import functools
6import operator
7from dataclasses import dataclass
8from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
9from typing_extensions import ParamSpec
10
11import torch
12
13from . import ir
14from .exc import SubgraphLoweringException
15from .ops_handler import SimpleCSEHandler
16from .sizevars import SizeVarAllocator
17from .virtualized import ops, V, WrapperHandler
18
19
20T = TypeVar("T")
21_P = ParamSpec("_P")
22
23
24class PointwiseSubgraphLowering(torch.fx.Interpreter):
25    graph_outputs: Optional[List[ir.IRNode]]
26
27    def __init__(
28        self,
29        gm: torch.fx.GraphModule,
30        root_graph_lowering: "torch._inductor.graph.GraphLowering",
31    ) -> None:
32        super().__init__(gm)
33        self.graph_outputs = None
34        self.root_graph = root_graph_lowering
35
36    @property
37    def sizevars(self) -> SizeVarAllocator:
38        return self.root_graph.sizevars
39
40    def mark_buffer_mutated(self, name: str) -> None:
41        raise SubgraphLoweringException("Mutations are not supported in this context")
42
43    def register_buffer(self, buffer: ir.Buffer) -> str:
44        raise SubgraphLoweringException(
45            "Buffer creation is not supported in this context"
46        )
47
48    def call_function(
49        self,
50        target: Callable[[Any], Any],  # type: ignore[override]
51        args: Any,
52        kwargs: Dict[str, Any],
53    ) -> Any:
54        from .lowering import lowerings
55
56        if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
57            return super().call_function(target, args, kwargs)
58
59        assert isinstance(target, torch._ops.OpOverload)
60
61        if target not in lowerings:
62            raise SubgraphLoweringException(
63                f"{target} not supported in subgraph, (missing lowering)"
64            )
65
66        if torch.Tag.pointwise not in target.tags:
67            raise SubgraphLoweringException(
68                f"Only pointwise operators are supported in this context, but got {target}"
69            )
70
71        return lowerings[target](*args, **kwargs)
72
73    def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None:  # type: ignore[override]
74        assert len(args) == 1
75        self.graph_outputs = args[0]
76
77
78@dataclass
79class InputDescriptor:
80    dtype: torch.dtype
81    device: torch.device
82
83
84class TracingOpsHandler(WrapperHandler[T]):
85    def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
86        parent = tracer.create_proxy("placeholder", "ops", (), {})
87        super().__init__(parent)
88        self.tracer = tracer
89
90        self.placeholders = [
91            self.tracer.create_proxy("placeholder", f"input{i}", (), {})
92            for i in range(num_inputs)
93        ]
94
95    def placeholder(self, idx: int) -> torch.fx.Proxy:
96        return self.placeholders[idx]
97
98    def output(self, *args: Tuple[object]) -> torch.fx.Node:
99        return self.tracer.create_node(
100            "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
101        )
102
103
104def lower_pointwise_subgraph(
105    subgraph: ir.Subgraph, inputs: List[InputDescriptor]
106) -> Callable[_P, Any]:
107    # Lower subgraph to ir.Pointwise nodes
108    def fake_inner_fn(
109        loop_idx: int, input_idx: int
110    ) -> Union[ir.Expr, ir.TensorBox, None]:
111        return ops.placeholder(input_idx)
112
113    graph_inputs = [
114        ir.Pointwise.create(
115            device=desc.device,
116            dtype=desc.dtype,
117            inner_fn=functools.partial(fake_inner_fn, input_idx=i),
118            ranges=[],
119        )
120        for i, desc in enumerate(inputs)
121    ]
122    gm = subgraph.graph_module
123    pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
124    with V.set_graph_handler(pw_subgraph):  # type: ignore[arg-type]
125        pw_subgraph.run(*graph_inputs)
126
127    # Combine multiple pointwise computations into a single graph module
128    # Do this by tracing through each individually and doing CSE
129    tracer = torch.fx.Tracer()
130    tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
131    trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs)))
132    assert pw_subgraph.graph_outputs is not None
133
134    with V.set_ops_handler(trace_ops):
135        output_irs = []
136
137        for out_var in pw_subgraph.graph_outputs:
138            assert isinstance(out_var, ir.TensorBox), type(out_var)
139            assert out_var.get_size() == []
140            assert isinstance(out_var.data, ir.StorageBox)
141            assert isinstance(out_var.data.data, ir.Pointwise)
142
143            idx = ()
144            ir_out = out_var.data.data.inner_fn(idx)
145
146            output_irs.append(ir_out)
147
148        ops.output(*output_irs)
149
150    lowered_gm = torch.fx.GraphModule({}, tracer.graph)
151
152    def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any:
153        return lowered_gm(V.get_ops_handler(), *args, **kwargs)
154
155    return inner_fn
156