xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List
3from unittest.mock import patch
4
5import sympy
6
7import torch._inductor.virtualized as virtualized
8from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise
9from torch._inductor.utils import IndentedBuffer, sympy_str
10
11
12# Used as a magic string to indicate an unsupported sympy expression
13# became part of generated C++ code.
14_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"
15
16
17def _arg_str(a):
18    if isinstance(a, sympy.Expr):
19        # If this return value containing the _MAGIC_SYMPY_ERROR_STRING
20        # is used as part of the final generated C++ code,
21        # a CUTLASSEVTOpNotImplementedError is raised to indicate that
22        # the op could not be converted to a valid EVT expression.
23        return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')"
24    return str(a)
25
26
27class CUTLASSEVTOpNotImplementedError(NotImplementedError):
28    pass
29
30
31class CutlassEVTEpilogueTypeFormatter:
32    """
33    Codegen class, which provides an entry point to generate
34    Cutlass "Epilogue Visitor Tree" (EVT) functor declarations.
35
36    See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
37    for more about EVTs and how they are declared and used to generate.
38
39    Notes:
40        * Used by CUTLASSGemmTemplate.
41        * This class should not be instantiated by users, it is intended to be used
42            by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...)
43            which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
44        * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
45
46
47    """
48
49    def __init__(self, accumulator_node_name, evt_type_name):
50        """
51
52        Initialize an instance of CutlassEVTEpilogueTypeFormatter.
53
54        Parameters:
55        - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
56                                       IR graph.
57        - evt_type_name (str):      The output name of the EVT type we are generating.
58
59        """
60        self.accumulator_node_name = accumulator_node_name
61        self.output = IndentedBuffer(0)
62        self.var_counter = 0
63        self.evt_type_name = evt_type_name
64        self.aliases = {}
65
66    @staticmethod
67    def ir_to_evt_string(
68        template_output_node_name: str,
69        evt_type_name: str,
70        epilogue_nodes: List[IRNode],
71    ):
72        """
73        Formats IR nodes into a string representation compatible with Cutlass EVT format.
74
75        Args:
76            template_output_node_name (str): The name of the template output node.
77            evt_type_name (str): The name of the EVT type.
78            epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
79                ComputedBuffer nodes wrapping Pointwise nodes.
80
81        Returns:
82            A string representation of the IR nodes formatted according to the Cutlass EVT format.
83        """
84        formatter = CutlassEVTEpilogueTypeFormatter(
85            template_output_node_name, evt_type_name
86        )
87
88        with virtualized.V.set_ops_handler(formatter), patch.object(
89            FlexibleLayout, "allow_indexing", True
90        ):
91            for node in epilogue_nodes:
92                if isinstance(node, ComputedBuffer):
93                    pnode = node.data
94                else:
95                    raise RuntimeError(
96                        "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer"
97                    )
98                assert isinstance(pnode, Pointwise)
99                index = pnode._index(pnode.ranges)
100                result = pnode.inner_fn(index)
101                # each epilogue node results in a single "using" statement and may refer to the previous steps by name
102                formatter.aliases[node.name] = result
103            res = formatter.getvalue(result)  # type: ignore[possibly-undefined]
104            if _MAGIC_SYMPY_ERROR_STRING in res:
105                raise CUTLASSEVTOpNotImplementedError(
106                    "sympy / indexing expressions not yet supported in EVT fusion"
107                )
108            else:
109                return res
110
111    def __getattr__(self, name):
112        """
113        Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler.
114        """
115
116        def inner(*args, **kwargs):
117            fargs = [_arg_str(a) for a in args]
118            fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
119            fn = getattr(self, f"_op_{name}")
120            line = fn(*fargs, **fkwargs)
121            self.var_counter += 1
122            varname = f"EVT_expr_{self.var_counter}"
123            # replace line with a new variable name
124            self.output.writeline(f"using {varname} = {line};")
125            return varname
126
127        if name.startswith("_"):
128            raise CUTLASSEVTOpNotImplementedError(name)
129        if hasattr(self, f"_op_{name}"):
130            return inner
131        else:
132            raise CUTLASSEVTOpNotImplementedError(name)
133
134    def _op_load(self, name, index_expr):
135        # Load an input to an operation. Might be the output of the matmul, the result
136        # of a previous epilogue node, a constant or (TODO) an auxiliary input.
137        if name == self.accumulator_node_name:
138            return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */"
139        elif name in self.aliases:
140            return self.aliases[name]
141        else:
142            # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */"
143            raise CUTLASSEVTOpNotImplementedError(
144                f"Operand {name} not found. Auxiliary inputs not supported yet."
145            )
146
147    def _op_constant(self, value, dtype):
148        # Load a constant
149        if str(dtype) in ("torch.float16", "torch.float32"):
150            return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */"
151        else:
152            raise CUTLASSEVTOpNotImplementedError(
153                f"Unsupported dtype for constant: {dtype}"
154            )
155
156    def _cutlass_binary_functional_op(self, op, a, b):
157        # Perform a named operation on two inputs
158        # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops
159        return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>"  # noqa: B950
160
161    def _convert_to_output_dtype(self, a):
162        # Convert the final output to the dtype of the output buffer
163        return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>"  # noqa: B950
164
165    def _op_to_dtype(self, a, *args, **kwargs):
166        # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator
167        # dtype.
168        # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
169        # throughout the fusion chain.
170        return a  # noqa: B950
171
172    def _op_mul(self, a, b):
173        return self._cutlass_binary_functional_op("multiplies", a, b)
174
175    def _op_div(self, a, b):
176        return self._cutlass_binary_functional_op("divides", a, b)
177
178    def _op_truediv(self, a, b):
179        return self._cutlass_binary_functional_op("divides", a, b)
180
181    def _op_ge(self, a, b):
182        return self._cutlass_binary_functional_op("greater_equal", a, b)
183
184    def _op_add(self, a, b):
185        return self._cutlass_binary_functional_op("plus", a, b)
186
187    def _op_sub(self, a, b):
188        return self._cutlass_binary_functional_op("minus", a, b)
189
190    def _op_minimum(self, a, b):
191        return self._cutlass_binary_functional_op("minimum", a, b)
192
193    def _op_maximum(self, a, b):
194        return self._cutlass_binary_functional_op("maximum", a, b)
195
196    def _op_relu(self, a):
197        const_zero = self._op_constant(0.0, "torch.float32")
198        return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>"  # noqa: B950
199
200    def reduction(self, dtype, src_dtype, reduction_type, value):
201        raise CUTLASSEVTOpNotImplementedError
202
203    # Add more ops here...
204    def getvalue(self, result) -> str:
205        # Return final result
206        dtype_converted_expr = self._convert_to_output_dtype(
207            f"EVT_expr_{self.var_counter}"
208        )
209        self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};")
210        return self.output.getvalue()
211
212
213class CutlassEVTEpilogueArgumentFormatter:
214    """
215    Codegen class, which provides an entry point to generate
216    Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers
217
218    See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder
219    for more about EVTs and how they are declared and used to generate.
220
221    Notes:
222        * Used by CUTLASSGemmTemplate.
223        * This class should not be instantiated by users, it is intended to be used
224            by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...)
225            which instantiates this class as an ops handler for virtualized.V.ops.[op-name]
226        * Extend this with more _op_<whatever> nodes to add support for new pointwise operations.
227
228
229    """
230
231    def __init__(self, accumulator_node_name: str):
232        """
233
234        Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly.
235        Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method.
236
237        Args:
238            accumulator_node_name (str): The name of the accumulator node which should contain
239                                          the Matmul result before fusion according to the IR graph.
240        """
241        self.accumulator_node_name: str = accumulator_node_name  #
242        self.output: IndentedBuffer = IndentedBuffer(0)  # The output buffer for codegen
243        self.var_counter: int = (
244            0  # used to generate variable names, incremented for each new variable
245        )
246        self.aliases: Dict[str, str] = {}  # Aliases for subexpression functors
247
248    @staticmethod
249    def ir_to_evt_argument_string(
250        template_output_node_name: str,
251        epilogue_nodes: List[IRNode],
252    ) -> str:
253        formatter = CutlassEVTEpilogueArgumentFormatter(
254            template_output_node_name,
255        )
256
257        with virtualized.V.set_ops_handler(formatter), patch.object(
258            FlexibleLayout, "allow_indexing", True
259        ):
260            for node in epilogue_nodes:
261                assert isinstance(node, ComputedBuffer)
262                pnode = node.data
263                assert isinstance(pnode, Pointwise)
264                index = pnode._index(pnode.ranges)
265                result = pnode.inner_fn(index)
266                # each epilogue node results in a single "using" statement and may refer to the previous steps by name
267                if node.name is not None:
268                    formatter.aliases[node.name] = result
269
270            res: str = formatter.getvalue(result)  # type: ignore[possibly-undefined]
271            if _MAGIC_SYMPY_ERROR_STRING in res:
272                raise CUTLASSEVTOpNotImplementedError(
273                    "sympy / indexing expressions not yet supported in EVT fusion"
274                )
275            else:
276                return res
277
278    def __getattr__(self, name):
279        def inner(*args, **kwargs):
280            fargs = [_arg_str(a) for a in args]
281            fkwargs = {key: _arg_str(a) for key, a in kwargs.items()}
282            fn = getattr(self, f"_op_{name}")
283            line = fn(*fargs, **fkwargs)
284            return line
285
286        if name.startswith("_"):
287            raise CUTLASSEVTOpNotImplementedError(name)
288
289        if hasattr(self, f"_op_{name}"):
290            return inner
291        else:
292            raise CUTLASSEVTOpNotImplementedError(name)
293
294    def _op_load(self, name, index_expr):
295        if name == self.accumulator_node_name:
296            return "{}"
297        elif name in self.aliases:
298            return self.aliases[name]
299        else:
300            raise CUTLASSEVTOpNotImplementedError(
301                f"Operand {name} not found. Auxiliary inputs not supported yet."
302            )
303
304    def _op_constant(self, value, dtype):
305        if str(dtype) in ("torch.float16", "torch.float32"):
306            return "{ static_cast<ElementAcc>(" + str(value) + ") }"
307        else:
308            raise CUTLASSEVTOpNotImplementedError(
309                f"Unsupported dtype for constant: {dtype}"
310            )
311
312    def _cutlass_binary_functional_op(self, op, a, b):
313        return f"{{ /*{op}: */ {a}, {b} }}"
314
315    def _op_mul(self, a, b):
316        return self._cutlass_binary_functional_op("multiplies", a, b)
317
318    def _op_div(self, a, b):
319        return self._cutlass_binary_functional_op("divides", a, b)
320
321    def _op_truediv(self, a, b):
322        return self._cutlass_binary_functional_op("divides", a, b)
323
324    def _op_ge(self, a, b):
325        return self._cutlass_binary_functional_op("greater_equal", a, b)
326
327    def _op_add(self, a, b):
328        return self._cutlass_binary_functional_op("plus", a, b)
329
330    def _op_sub(self, a, b):
331        return self._cutlass_binary_functional_op("minus", a, b)
332
333    def _op_minimum(self, a, b):
334        return self._cutlass_binary_functional_op("minimum", a, b)
335
336    def _op_maximum(self, a, b):
337        return self._cutlass_binary_functional_op("maximum", a, b)
338
339    def _op_relu(self, a):
340        const_zero = self._op_constant(0.0, "torch.float32")
341        return "{" + str(a) + ", " + const_zero + "}"
342
343    def _op_to_dtype(self, a, dtype, src_dtype=None):
344        # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible
345        # throughout the fusion chain.
346        assert dtype in (
347            "torch.float32",
348            "torch.float16",
349        ), f"Unsupported dtype: {dtype}"
350        assert src_dtype in (
351            None,
352            "torch.float32",
353            "torch.float16",
354        ), f"Unsupported source dtype: {src_dtype}"
355        return a
356
357    def reduction(self, dtype, src_dtype, reduction_type, value):
358        raise CUTLASSEVTOpNotImplementedError
359
360    def getvalue(self, result) -> str:
361        return "{" + str(result) + "}"
362