xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/cudnn_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * CuDNN ReLU extension. Simple function but contains the general structure of
3  * most CuDNN extensions:
4  * 1) Check arguments. torch::check* functions provide a standard way to
5  * validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
6  * functions require creating and setting a variety of descriptors. 3) Apply the
7  * CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
8  */
9 
10 #include <torch/extension.h>
11 
12 #include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
13 #include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
14 #include <ATen/cudnn/Handle.h> // for getCudnnHandle
15 
16 // Name of function in python module and name used for error messages by
17 // torch::check* functions.
18 const char* cudnn_relu_name = "cudnn_relu";
19 
20 // Check arguments to cudnn_relu
cudnn_relu_check(const torch::Tensor & inputs,const torch::Tensor & outputs)21 void cudnn_relu_check(
22     const torch::Tensor& inputs,
23     const torch::Tensor& outputs) {
24   // Create TensorArgs. These record the names and positions of each tensor as a
25   // parameter.
26   torch::TensorArg arg_inputs(inputs, "inputs", 0);
27   torch::TensorArg arg_outputs(outputs, "outputs", 1);
28   // Check arguments. No need to return anything. These functions with throw an
29   // error if they fail. Messages are populated using information from
30   // TensorArgs.
31   torch::checkContiguous(cudnn_relu_name, arg_inputs);
32   torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
33   torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
34   torch::checkContiguous(cudnn_relu_name, arg_outputs);
35   torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
36   torch::checkBackend(
37       cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
38   torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
39 }
40 
cudnn_relu(const torch::Tensor & inputs,const torch::Tensor & outputs)41 void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
42   // Most CuDNN extensions will follow a similar pattern.
43   // Step 1: Check inputs. This will throw an error if inputs are invalid, so no
44   // need to check return codes here.
45   cudnn_relu_check(inputs, outputs);
46   // Step 2: Create descriptors
47   cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
48   // Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
49   // size and type and contiguous, so one descriptor is sufficient.
50   torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
51   cudnnActivationDescriptor_t activationDesc;
52   // Note: Always check return value of cudnn functions using CUDNN_CHECK
53   AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
54   AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
55       activationDesc,
56       /*mode=*/CUDNN_ACTIVATION_RELU,
57       /*reluNanOpt=*/CUDNN_PROPAGATE_NAN,
58       /*coef=*/1.));
59   // Step 3: Apply CuDNN function
60   float alpha = 1.;
61   float beta = 0.;
62   AT_CUDNN_CHECK(cudnnActivationForward(
63       cuDnn,
64       activationDesc,
65       &alpha,
66       input_tensor_desc.desc(),
67       inputs.data_ptr(),
68       &beta,
69       input_tensor_desc.desc(), // output descriptor same as input
70       outputs.data_ptr()));
71   // Step 4: Destroy descriptors
72   AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc));
73   // Step 5: Return something (optional)
74 }
75 
76 // Create the pybind11 module
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)77 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
78   // Use the same name as the check functions so error messages make sense
79   m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU");
80 }
81