1# mypy: allow-untyped-defs 2import re 3from typing import Callable, List 4 5import torch 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12class _CodeParser: 13 def __init__(self, code_string: str): 14 optional_ws = r"\s*" 15 required_ws = r"\s+" 16 template_params = r"(?P<template_params>\<.+\>)" 17 return_type = r"(?P<return_type>\w+)" 18 function_name = r"(?P<function_name>\w+)" 19 function_params = r"(?P<function_params>\(.+\))" 20 function_body = r"(?P<function_body>\{.+\})" 21 22 pattern = ( 23 optional_ws 24 + "template" 25 + optional_ws 26 + template_params 27 + optional_ws 28 + return_type 29 + required_ws 30 + function_name 31 + optional_ws 32 + function_params 33 + optional_ws 34 + function_body 35 + optional_ws 36 ) 37 38 result = re.match( 39 pattern, code_string, re.DOTALL 40 ) # DOTALL for matching multiline 41 42 if result is None: 43 raise Exception( # noqa: TRY002 44 f"Couldn't parse code, please check correctness:\n {code_string}" 45 ) 46 47 self.template_params = result["template_params"] 48 self.return_type = result["return_type"] 49 self.function_name = result["function_name"] 50 self.function_params = result["function_params"] 51 self.function_body = result["function_body"] 52 53 54class _JittedFunction: 55 def __init__( 56 self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs 57 ): 58 self.code_string = code_string 59 60 assert ( 61 return_by_ref or num_outputs == 1 62 ), "Return by value only works for single output. " 63 self.return_by_ref = return_by_ref 64 self.num_outputs = num_outputs 65 66 parsed_code = _CodeParser(code_string) 67 self.kernel_name = parsed_code.function_name 68 69 self.kwargs_dict = kwargs 70 self.is_cuda_available = torch.cuda.is_available() 71 72 def __call__(self, *tensors: Tensor, **kwargs): 73 # Jiterator follow torch.cuda's lazy initialization behavior 74 # Defer checking cuda's availability at the function invocation time 75 assert ( 76 self.is_cuda_available 77 ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." 78 79 assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." 80 81 expanded_kwargs = self.kwargs_dict.copy() 82 for key, value in kwargs.items(): 83 if key in self.kwargs_dict: 84 expanded_kwargs[key] = value 85 else: 86 raise KeyError(f"{key} is not declared in function definition") 87 88 return torch._C._cuda_jiterator_compile_and_launch_kernel( 89 self.code_string, 90 self.kernel_name, 91 self.return_by_ref, 92 self.num_outputs, 93 tensors, 94 expanded_kwargs, 95 ) 96 97 98def _create_jit_fn(code_string: str, **kwargs) -> Callable: 99 """ 100 Create a jiterator-generated cuda kernel for an elementwise op. 101 102 The code string has to be a valid CUDA function that describes the computation for a single element. The code 103 string has to follow the c++ template pattern, as shown in the example below. This function will be inlined 104 into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as 105 local temp dir. 106 107 Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion. 108 109 Args: 110 code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value. 111 kwargs (Dict, optional): Keyword arguments for generated function 112 113 Example:: 114 115 code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" 116 jitted_fn = create_jit_fn(code_string, alpha=1.0) 117 a = torch.rand(3, device='cuda') 118 b = torch.rand(3, device='cuda') 119 # invoke jitted function like a regular python function 120 result = jitted_fn(a, b, alpha=3.14) 121 122 code_string also allows multiple function definitions, and the last function will be treated as the entry function. 123 124 Example:: 125 126 code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" 127 code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" 128 jitted_fn = create_jit_fn(code_string, val=0.0) 129 a = torch.rand(3, device='cuda') 130 b = torch.rand(3, device='cuda') 131 # invoke jitted function like a regular python function 132 result = jitted_fn(a, b) # using default val=0.0 133 134 Jiterator can be used together with python registration to override an operator's cuda kernel. 135 Following example is overriding gelu's cuda kernel with relu. 136 137 Example:: 138 139 code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }" 140 my_gelu = create_jit_fn(code_string) 141 my_lib = torch.library.Library("aten", "IMPL") 142 my_lib.impl('aten::gelu', my_gelu, "CUDA") 143 # torch.nn.GELU and torch.nn.function.gelu are now overridden 144 a = torch.rand(3, device='cuda') 145 torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) 146 147 .. warning:: 148 This API is in beta and may change in future releases. 149 150 .. warning:: 151 This API only supports up to 8 inputs and 1 output 152 153 .. warning:: 154 All input tensors must live in CUDA device 155 """ 156 return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) 157 158 159def _create_multi_output_jit_fn( 160 code_string: str, num_outputs: int, **kwargs 161) -> Callable: 162 """ 163 Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. 164 165 Args: 166 code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. 167 num_outputs(int): number of outputs return by the kernel 168 kwargs (Dict, optional): Keyword arguments for generated function 169 170 Example:: 171 172 code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" 173 jitted_fn = create_jit_fn(code_string, alpha=1.0) 174 a = torch.rand(3, device='cuda') 175 b = torch.rand(3, device='cuda') 176 # invoke jitted function like a regular python function 177 result = jitted_fn(a, b, alpha=3.14) 178 179 .. warning:: 180 This API is in beta and may change in future releases. 181 182 .. warning:: 183 This API only supports up to 8 inputs and 8 outputs 184 """ 185 return _JittedFunction( 186 code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs 187 ) 188