1# mypy: allow-untyped-defs 2from typing import Dict, List 3from unittest.mock import patch 4 5import sympy 6 7import torch._inductor.virtualized as virtualized 8from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise 9from torch._inductor.utils import IndentedBuffer, sympy_str 10 11 12# Used as a magic string to indicate an unsupported sympy expression 13# became part of generated C++ code. 14_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]" 15 16 17def _arg_str(a): 18 if isinstance(a, sympy.Expr): 19 # If this return value containing the _MAGIC_SYMPY_ERROR_STRING 20 # is used as part of the final generated C++ code, 21 # a CUTLASSEVTOpNotImplementedError is raised to indicate that 22 # the op could not be converted to a valid EVT expression. 23 return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')" 24 return str(a) 25 26 27class CUTLASSEVTOpNotImplementedError(NotImplementedError): 28 pass 29 30 31class CutlassEVTEpilogueTypeFormatter: 32 """ 33 Codegen class, which provides an entry point to generate 34 Cutlass "Epilogue Visitor Tree" (EVT) functor declarations. 35 36 See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder 37 for more about EVTs and how they are declared and used to generate. 38 39 Notes: 40 * Used by CUTLASSGemmTemplate. 41 * This class should not be instantiated by users, it is intended to be used 42 by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...) 43 which instantiates this class as an ops handler for virtualized.V.ops.[op-name] 44 * Extend this with more _op_<whatever> nodes to add support for new pointwise operations. 45 46 47 """ 48 49 def __init__(self, accumulator_node_name, evt_type_name): 50 """ 51 52 Initialize an instance of CutlassEVTEpilogueTypeFormatter. 53 54 Parameters: 55 - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused) 56 IR graph. 57 - evt_type_name (str): The output name of the EVT type we are generating. 58 59 """ 60 self.accumulator_node_name = accumulator_node_name 61 self.output = IndentedBuffer(0) 62 self.var_counter = 0 63 self.evt_type_name = evt_type_name 64 self.aliases = {} 65 66 @staticmethod 67 def ir_to_evt_string( 68 template_output_node_name: str, 69 evt_type_name: str, 70 epilogue_nodes: List[IRNode], 71 ): 72 """ 73 Formats IR nodes into a string representation compatible with Cutlass EVT format. 74 75 Args: 76 template_output_node_name (str): The name of the template output node. 77 evt_type_name (str): The name of the EVT type. 78 epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be 79 ComputedBuffer nodes wrapping Pointwise nodes. 80 81 Returns: 82 A string representation of the IR nodes formatted according to the Cutlass EVT format. 83 """ 84 formatter = CutlassEVTEpilogueTypeFormatter( 85 template_output_node_name, evt_type_name 86 ) 87 88 with virtualized.V.set_ops_handler(formatter), patch.object( 89 FlexibleLayout, "allow_indexing", True 90 ): 91 for node in epilogue_nodes: 92 if isinstance(node, ComputedBuffer): 93 pnode = node.data 94 else: 95 raise RuntimeError( 96 "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer" 97 ) 98 assert isinstance(pnode, Pointwise) 99 index = pnode._index(pnode.ranges) 100 result = pnode.inner_fn(index) 101 # each epilogue node results in a single "using" statement and may refer to the previous steps by name 102 formatter.aliases[node.name] = result 103 res = formatter.getvalue(result) # type: ignore[possibly-undefined] 104 if _MAGIC_SYMPY_ERROR_STRING in res: 105 raise CUTLASSEVTOpNotImplementedError( 106 "sympy / indexing expressions not yet supported in EVT fusion" 107 ) 108 else: 109 return res 110 111 def __getattr__(self, name): 112 """ 113 Resolve V.ops.<whatever> calls, after this instance has been installed as V.ops handler. 114 """ 115 116 def inner(*args, **kwargs): 117 fargs = [_arg_str(a) for a in args] 118 fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} 119 fn = getattr(self, f"_op_{name}") 120 line = fn(*fargs, **fkwargs) 121 self.var_counter += 1 122 varname = f"EVT_expr_{self.var_counter}" 123 # replace line with a new variable name 124 self.output.writeline(f"using {varname} = {line};") 125 return varname 126 127 if name.startswith("_"): 128 raise CUTLASSEVTOpNotImplementedError(name) 129 if hasattr(self, f"_op_{name}"): 130 return inner 131 else: 132 raise CUTLASSEVTOpNotImplementedError(name) 133 134 def _op_load(self, name, index_expr): 135 # Load an input to an operation. Might be the output of the matmul, the result 136 # of a previous epilogue node, a constant or (TODO) an auxiliary input. 137 if name == self.accumulator_node_name: 138 return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */" 139 elif name in self.aliases: 140 return self.aliases[name] 141 else: 142 # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */" 143 raise CUTLASSEVTOpNotImplementedError( 144 f"Operand {name} not found. Auxiliary inputs not supported yet." 145 ) 146 147 def _op_constant(self, value, dtype): 148 # Load a constant 149 if str(dtype) in ("torch.float16", "torch.float32"): 150 return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc> /* value={value}, dtype={dtype} */" 151 else: 152 raise CUTLASSEVTOpNotImplementedError( 153 f"Unsupported dtype for constant: {dtype}" 154 ) 155 156 def _cutlass_binary_functional_op(self, op, a, b): 157 # Perform a named operation on two inputs 158 # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops 159 return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::{op}, ElementAcc, ElementAcc, RoundStyle>,{a},{b}>" # noqa: B950 160 161 def _convert_to_output_dtype(self, a): 162 # Convert the final output to the dtype of the output buffer 163 return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<identity_op, ElementD, ElementAcc, RoundStyle>,{a}>" # noqa: B950 164 165 def _op_to_dtype(self, a, *args, **kwargs): 166 # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator 167 # dtype. 168 # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible 169 # throughout the fusion chain. 170 return a # noqa: B950 171 172 def _op_mul(self, a, b): 173 return self._cutlass_binary_functional_op("multiplies", a, b) 174 175 def _op_div(self, a, b): 176 return self._cutlass_binary_functional_op("divides", a, b) 177 178 def _op_truediv(self, a, b): 179 return self._cutlass_binary_functional_op("divides", a, b) 180 181 def _op_ge(self, a, b): 182 return self._cutlass_binary_functional_op("greater_equal", a, b) 183 184 def _op_add(self, a, b): 185 return self._cutlass_binary_functional_op("plus", a, b) 186 187 def _op_sub(self, a, b): 188 return self._cutlass_binary_functional_op("minus", a, b) 189 190 def _op_minimum(self, a, b): 191 return self._cutlass_binary_functional_op("minimum", a, b) 192 193 def _op_maximum(self, a, b): 194 return self._cutlass_binary_functional_op("maximum", a, b) 195 196 def _op_relu(self, a): 197 const_zero = self._op_constant(0.0, "torch.float32") 198 return f"cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::maximum, ElementAcc, ElementAcc, RoundStyle>,{a}, {const_zero}>" # noqa: B950 199 200 def reduction(self, dtype, src_dtype, reduction_type, value): 201 raise CUTLASSEVTOpNotImplementedError 202 203 # Add more ops here... 204 def getvalue(self, result) -> str: 205 # Return final result 206 dtype_converted_expr = self._convert_to_output_dtype( 207 f"EVT_expr_{self.var_counter}" 208 ) 209 self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};") 210 return self.output.getvalue() 211 212 213class CutlassEVTEpilogueArgumentFormatter: 214 """ 215 Codegen class, which provides an entry point to generate 216 Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers 217 218 See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder 219 for more about EVTs and how they are declared and used to generate. 220 221 Notes: 222 * Used by CUTLASSGemmTemplate. 223 * This class should not be instantiated by users, it is intended to be used 224 by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...) 225 which instantiates this class as an ops handler for virtualized.V.ops.[op-name] 226 * Extend this with more _op_<whatever> nodes to add support for new pointwise operations. 227 228 229 """ 230 231 def __init__(self, accumulator_node_name: str): 232 """ 233 234 Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. 235 Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method. 236 237 Args: 238 accumulator_node_name (str): The name of the accumulator node which should contain 239 the Matmul result before fusion according to the IR graph. 240 """ 241 self.accumulator_node_name: str = accumulator_node_name # 242 self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen 243 self.var_counter: int = ( 244 0 # used to generate variable names, incremented for each new variable 245 ) 246 self.aliases: Dict[str, str] = {} # Aliases for subexpression functors 247 248 @staticmethod 249 def ir_to_evt_argument_string( 250 template_output_node_name: str, 251 epilogue_nodes: List[IRNode], 252 ) -> str: 253 formatter = CutlassEVTEpilogueArgumentFormatter( 254 template_output_node_name, 255 ) 256 257 with virtualized.V.set_ops_handler(formatter), patch.object( 258 FlexibleLayout, "allow_indexing", True 259 ): 260 for node in epilogue_nodes: 261 assert isinstance(node, ComputedBuffer) 262 pnode = node.data 263 assert isinstance(pnode, Pointwise) 264 index = pnode._index(pnode.ranges) 265 result = pnode.inner_fn(index) 266 # each epilogue node results in a single "using" statement and may refer to the previous steps by name 267 if node.name is not None: 268 formatter.aliases[node.name] = result 269 270 res: str = formatter.getvalue(result) # type: ignore[possibly-undefined] 271 if _MAGIC_SYMPY_ERROR_STRING in res: 272 raise CUTLASSEVTOpNotImplementedError( 273 "sympy / indexing expressions not yet supported in EVT fusion" 274 ) 275 else: 276 return res 277 278 def __getattr__(self, name): 279 def inner(*args, **kwargs): 280 fargs = [_arg_str(a) for a in args] 281 fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} 282 fn = getattr(self, f"_op_{name}") 283 line = fn(*fargs, **fkwargs) 284 return line 285 286 if name.startswith("_"): 287 raise CUTLASSEVTOpNotImplementedError(name) 288 289 if hasattr(self, f"_op_{name}"): 290 return inner 291 else: 292 raise CUTLASSEVTOpNotImplementedError(name) 293 294 def _op_load(self, name, index_expr): 295 if name == self.accumulator_node_name: 296 return "{}" 297 elif name in self.aliases: 298 return self.aliases[name] 299 else: 300 raise CUTLASSEVTOpNotImplementedError( 301 f"Operand {name} not found. Auxiliary inputs not supported yet." 302 ) 303 304 def _op_constant(self, value, dtype): 305 if str(dtype) in ("torch.float16", "torch.float32"): 306 return "{ static_cast<ElementAcc>(" + str(value) + ") }" 307 else: 308 raise CUTLASSEVTOpNotImplementedError( 309 f"Unsupported dtype for constant: {dtype}" 310 ) 311 312 def _cutlass_binary_functional_op(self, op, a, b): 313 return f"{{ /*{op}: */ {a}, {b} }}" 314 315 def _op_mul(self, a, b): 316 return self._cutlass_binary_functional_op("multiplies", a, b) 317 318 def _op_div(self, a, b): 319 return self._cutlass_binary_functional_op("divides", a, b) 320 321 def _op_truediv(self, a, b): 322 return self._cutlass_binary_functional_op("divides", a, b) 323 324 def _op_ge(self, a, b): 325 return self._cutlass_binary_functional_op("greater_equal", a, b) 326 327 def _op_add(self, a, b): 328 return self._cutlass_binary_functional_op("plus", a, b) 329 330 def _op_sub(self, a, b): 331 return self._cutlass_binary_functional_op("minus", a, b) 332 333 def _op_minimum(self, a, b): 334 return self._cutlass_binary_functional_op("minimum", a, b) 335 336 def _op_maximum(self, a, b): 337 return self._cutlass_binary_functional_op("maximum", a, b) 338 339 def _op_relu(self, a): 340 const_zero = self._op_constant(0.0, "torch.float32") 341 return "{" + str(a) + ", " + const_zero + "}" 342 343 def _op_to_dtype(self, a, dtype, src_dtype=None): 344 # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible 345 # throughout the fusion chain. 346 assert dtype in ( 347 "torch.float32", 348 "torch.float16", 349 ), f"Unsupported dtype: {dtype}" 350 assert src_dtype in ( 351 None, 352 "torch.float32", 353 "torch.float16", 354 ), f"Unsupported source dtype: {src_dtype}" 355 return a 356 357 def reduction(self, dtype, src_dtype, reduction_type, value): 358 raise CUTLASSEVTOpNotImplementedError 359 360 def getvalue(self, result) -> str: 361 return "{" + str(result) + "}" 362