1"""Utilities for lowering subgraphs used by higher order operators 2 3""" 4 5import functools 6import operator 7from dataclasses import dataclass 8from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union 9from typing_extensions import ParamSpec 10 11import torch 12 13from . import ir 14from .exc import SubgraphLoweringException 15from .ops_handler import SimpleCSEHandler 16from .sizevars import SizeVarAllocator 17from .virtualized import ops, V, WrapperHandler 18 19 20T = TypeVar("T") 21_P = ParamSpec("_P") 22 23 24class PointwiseSubgraphLowering(torch.fx.Interpreter): 25 graph_outputs: Optional[List[ir.IRNode]] 26 27 def __init__( 28 self, 29 gm: torch.fx.GraphModule, 30 root_graph_lowering: "torch._inductor.graph.GraphLowering", 31 ) -> None: 32 super().__init__(gm) 33 self.graph_outputs = None 34 self.root_graph = root_graph_lowering 35 36 @property 37 def sizevars(self) -> SizeVarAllocator: 38 return self.root_graph.sizevars 39 40 def mark_buffer_mutated(self, name: str) -> None: 41 raise SubgraphLoweringException("Mutations are not supported in this context") 42 43 def register_buffer(self, buffer: ir.Buffer) -> str: 44 raise SubgraphLoweringException( 45 "Buffer creation is not supported in this context" 46 ) 47 48 def call_function( 49 self, 50 target: Callable[[Any], Any], # type: ignore[override] 51 args: Any, 52 kwargs: Dict[str, Any], 53 ) -> Any: 54 from .lowering import lowerings 55 56 if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): 57 return super().call_function(target, args, kwargs) 58 59 assert isinstance(target, torch._ops.OpOverload) 60 61 if target not in lowerings: 62 raise SubgraphLoweringException( 63 f"{target} not supported in subgraph, (missing lowering)" 64 ) 65 66 if torch.Tag.pointwise not in target.tags: 67 raise SubgraphLoweringException( 68 f"Only pointwise operators are supported in this context, but got {target}" 69 ) 70 71 return lowerings[target](*args, **kwargs) 72 73 def output(self, target: str, args: Tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] 74 assert len(args) == 1 75 self.graph_outputs = args[0] 76 77 78@dataclass 79class InputDescriptor: 80 dtype: torch.dtype 81 device: torch.device 82 83 84class TracingOpsHandler(WrapperHandler[T]): 85 def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: 86 parent = tracer.create_proxy("placeholder", "ops", (), {}) 87 super().__init__(parent) 88 self.tracer = tracer 89 90 self.placeholders = [ 91 self.tracer.create_proxy("placeholder", f"input{i}", (), {}) 92 for i in range(num_inputs) 93 ] 94 95 def placeholder(self, idx: int) -> torch.fx.Proxy: 96 return self.placeholders[idx] 97 98 def output(self, *args: Tuple[object]) -> torch.fx.Node: 99 return self.tracer.create_node( 100 "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} 101 ) 102 103 104def lower_pointwise_subgraph( 105 subgraph: ir.Subgraph, inputs: List[InputDescriptor] 106) -> Callable[_P, Any]: 107 # Lower subgraph to ir.Pointwise nodes 108 def fake_inner_fn( 109 loop_idx: int, input_idx: int 110 ) -> Union[ir.Expr, ir.TensorBox, None]: 111 return ops.placeholder(input_idx) 112 113 graph_inputs = [ 114 ir.Pointwise.create( 115 device=desc.device, 116 dtype=desc.dtype, 117 inner_fn=functools.partial(fake_inner_fn, input_idx=i), 118 ranges=[], 119 ) 120 for i, desc in enumerate(inputs) 121 ] 122 gm = subgraph.graph_module 123 pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) 124 with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] 125 pw_subgraph.run(*graph_inputs) 126 127 # Combine multiple pointwise computations into a single graph module 128 # Do this by tracing through each individually and doing CSE 129 tracer = torch.fx.Tracer() 130 tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) 131 trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs))) 132 assert pw_subgraph.graph_outputs is not None 133 134 with V.set_ops_handler(trace_ops): 135 output_irs = [] 136 137 for out_var in pw_subgraph.graph_outputs: 138 assert isinstance(out_var, ir.TensorBox), type(out_var) 139 assert out_var.get_size() == [] 140 assert isinstance(out_var.data, ir.StorageBox) 141 assert isinstance(out_var.data.data, ir.Pointwise) 142 143 idx = () 144 ir_out = out_var.data.data.inner_fn(idx) 145 146 output_irs.append(ir_out) 147 148 ops.output(*output_irs) 149 150 lowered_gm = torch.fx.GraphModule({}, tracer.graph) 151 152 def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any: 153 return lowered_gm(V.get_ops_handler(), *args, **kwargs) 154 155 return inner_fn 156