xref: /aosp_15_r20/external/pytorch/torch/cuda/jiterator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import re
3from typing import Callable, List
4
5import torch
6from torch import Tensor
7
8
9__all__: List[str] = []
10
11
12class _CodeParser:
13    def __init__(self, code_string: str):
14        optional_ws = r"\s*"
15        required_ws = r"\s+"
16        template_params = r"(?P<template_params>\<.+\>)"
17        return_type = r"(?P<return_type>\w+)"
18        function_name = r"(?P<function_name>\w+)"
19        function_params = r"(?P<function_params>\(.+\))"
20        function_body = r"(?P<function_body>\{.+\})"
21
22        pattern = (
23            optional_ws
24            + "template"
25            + optional_ws
26            + template_params
27            + optional_ws
28            + return_type
29            + required_ws
30            + function_name
31            + optional_ws
32            + function_params
33            + optional_ws
34            + function_body
35            + optional_ws
36        )
37
38        result = re.match(
39            pattern, code_string, re.DOTALL
40        )  # DOTALL for matching multiline
41
42        if result is None:
43            raise Exception(  # noqa: TRY002
44                f"Couldn't parse code, please check correctness:\n {code_string}"
45            )
46
47        self.template_params = result["template_params"]
48        self.return_type = result["return_type"]
49        self.function_name = result["function_name"]
50        self.function_params = result["function_params"]
51        self.function_body = result["function_body"]
52
53
54class _JittedFunction:
55    def __init__(
56        self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
57    ):
58        self.code_string = code_string
59
60        assert (
61            return_by_ref or num_outputs == 1
62        ), "Return by value only works for single output. "
63        self.return_by_ref = return_by_ref
64        self.num_outputs = num_outputs
65
66        parsed_code = _CodeParser(code_string)
67        self.kernel_name = parsed_code.function_name
68
69        self.kwargs_dict = kwargs
70        self.is_cuda_available = torch.cuda.is_available()
71
72    def __call__(self, *tensors: Tensor, **kwargs):
73        # Jiterator follow torch.cuda's lazy initialization behavior
74        # Defer checking cuda's availability at the function invocation time
75        assert (
76            self.is_cuda_available
77        ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
78
79        assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
80
81        expanded_kwargs = self.kwargs_dict.copy()
82        for key, value in kwargs.items():
83            if key in self.kwargs_dict:
84                expanded_kwargs[key] = value
85            else:
86                raise KeyError(f"{key} is not declared in function definition")
87
88        return torch._C._cuda_jiterator_compile_and_launch_kernel(
89            self.code_string,
90            self.kernel_name,
91            self.return_by_ref,
92            self.num_outputs,
93            tensors,
94            expanded_kwargs,
95        )
96
97
98def _create_jit_fn(code_string: str, **kwargs) -> Callable:
99    """
100    Create a jiterator-generated cuda kernel for an elementwise op.
101
102    The code string has to be a valid CUDA function that describes the computation for a single element. The code
103    string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
104    into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
105    local temp dir.
106
107    Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
108
109    Args:
110        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
111        kwargs (Dict, optional): Keyword arguments for generated function
112
113    Example::
114
115        code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
116        jitted_fn = create_jit_fn(code_string, alpha=1.0)
117        a = torch.rand(3, device='cuda')
118        b = torch.rand(3, device='cuda')
119        # invoke jitted function like a regular python function
120        result = jitted_fn(a, b, alpha=3.14)
121
122    code_string also allows multiple function definitions, and the last function will be treated as the entry function.
123
124    Example::
125
126        code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
127        code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
128        jitted_fn = create_jit_fn(code_string, val=0.0)
129        a = torch.rand(3, device='cuda')
130        b = torch.rand(3, device='cuda')
131        # invoke jitted function like a regular python function
132        result = jitted_fn(a, b)  # using default val=0.0
133
134    Jiterator can be used together with python registration to override an operator's cuda kernel.
135    Following example is overriding gelu's cuda kernel with relu.
136
137    Example::
138
139        code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
140        my_gelu = create_jit_fn(code_string)
141        my_lib = torch.library.Library("aten", "IMPL")
142        my_lib.impl('aten::gelu', my_gelu, "CUDA")
143        # torch.nn.GELU and torch.nn.function.gelu are now overridden
144        a = torch.rand(3, device='cuda')
145        torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
146
147    .. warning::
148        This API is in beta and may change in future releases.
149
150    .. warning::
151        This API only supports up to 8 inputs and 1 output
152
153    .. warning::
154        All input tensors must live in CUDA device
155    """
156    return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
157
158
159def _create_multi_output_jit_fn(
160    code_string: str, num_outputs: int, **kwargs
161) -> Callable:
162    """
163    Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
164
165    Args:
166        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
167        num_outputs(int): number of outputs return by the kernel
168        kwargs (Dict, optional): Keyword arguments for generated function
169
170    Example::
171
172        code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
173        jitted_fn = create_jit_fn(code_string, alpha=1.0)
174        a = torch.rand(3, device='cuda')
175        b = torch.rand(3, device='cuda')
176        # invoke jitted function like a regular python function
177        result = jitted_fn(a, b, alpha=3.14)
178
179    .. warning::
180        This API is in beta and may change in future releases.
181
182    .. warning::
183        This API only supports up to 8 inputs and 8 outputs
184    """
185    return _JittedFunction(
186        code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
187    )
188