1 #pragma once 2 #include <ATen/core/Tensor.h> 3 4 #include <ATen/cudnn/Descriptors.h> 5 #include <ATen/cudnn/Types.h> 6 #include <ATen/cudnn/cudnn-wrapper.h> 7 #include <ATen/native/ConvUtils.h> 8 9 namespace at { 10 namespace native { 11 12 // --------------------------------------------------------------------- 13 // 14 // Helper classes 15 // 16 // --------------------------------------------------------------------- 17 18 // This POD struct is used to let us easily compute hashes of the 19 // parameters 20 struct ConvolutionParams { 21 c10::DeviceIndex device_id; 22 cudnnDataType_t dataType; 23 int input_size[2 + max_dim]; 24 uint8_t input_dim; 25 at::MemoryFormat memory_format; 26 int weight_size[2 + max_dim]; 27 int padding[max_dim]; 28 int stride[max_dim]; 29 int dilation[max_dim]; 30 int64_t groups; 31 bool deterministic; 32 bool allow_tf32; 33 // NB: transposed purposely omitted: transposed just swaps 34 // forward and backward, so you can reuse the benchmark entry, 35 }; 36 37 std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params); 38 39 // NB: This can't be a constructor, because then ConvolutionParams 40 // would not be a POD anymore. 41 // TODO: Use TensorGeometry here instead of the entire Tensor, which we 42 // don't actually need. (OTOH: We can always pass in 43 // grad_input/grad_output, so this is not very pressing) 44 void setConvolutionParams( 45 ConvolutionParams* params, 46 const at::Tensor& input, 47 const at::Tensor& weight, 48 IntArrayRef padding, 49 IntArrayRef stride, 50 IntArrayRef dilation, 51 int64_t groups, 52 bool deterministic, 53 bool allow_tf32, 54 at::MemoryFormat memory_format); 55 56 std::string repro_from_args(const ConvolutionParams& args); 57 58 // --------------------------------------------------------------------- 59 // 60 // Raw functions 61 // 62 // --------------------------------------------------------------------- 63 64 void raw_cudnn_convolution_forward_out( 65 const Tensor& output, 66 const Tensor& input, 67 const Tensor& weight, 68 IntArrayRef padding, 69 IntArrayRef stride, 70 IntArrayRef dilation, 71 int64_t groups, 72 bool benchmark, 73 bool deterministic, 74 bool allow_tf32); 75 76 void raw_cudnn_convolution_backward_input_out( 77 const at::Tensor& grad_input, 78 const at::Tensor& grad_output, 79 const at::Tensor& weight, 80 IntArrayRef padding, 81 IntArrayRef stride, 82 IntArrayRef dilation, 83 int64_t groups, 84 bool benchmark, 85 bool deterministic, 86 bool allow_tf32); 87 88 void raw_cudnn_convolution_backward_weight_out( 89 const Tensor& grad_weight, 90 const Tensor& grad_output, 91 const Tensor& input, 92 IntArrayRef padding, 93 IntArrayRef stride, 94 IntArrayRef dilation, 95 int64_t groups, 96 bool benchmark, 97 bool deterministic, 98 bool allow_tf32); 99 100 void raw_cudnn_convolution_add_relu_out( 101 const Tensor& output, 102 const Tensor& input, 103 const Tensor& weight, 104 const Tensor& z, 105 float alpha, 106 const Tensor& bias, 107 IntArrayRef stride, 108 IntArrayRef padding, 109 IntArrayRef dilation, 110 int64_t groups, 111 bool benchmark, 112 bool deterministic, 113 bool allow_tf32); 114 115 void raw_cudnn_convolution_add_relu_fallback_out( 116 const Tensor& output, 117 const Tensor& input, 118 const Tensor& weight, 119 const Tensor& z, 120 float alpha, 121 const Tensor& bias, 122 IntArrayRef stride, 123 IntArrayRef padding, 124 IntArrayRef dilation, 125 int64_t groups, 126 bool benchmark, 127 bool deterministic, 128 bool allow_tf32); 129 130 #if AT_CUDNN_ENABLED() 131 132 // v7 functions are preserved here to allow for runtime switching to v7 133 // (e.g., TORCH_CUDNN_V8_API_DISABLED=1). 134 // Note that v7 forward/backward out can have different behavior from the v8 135 // versions, as v7 explicitly splits large tensors as a 32-bit indexing 136 // workaround whereas v8 expects cuDNN to handle large tensors. 137 void raw_cudnn_convolution_forward_out_v7( 138 const Tensor& output, 139 const Tensor& input, 140 const Tensor& weight, 141 IntArrayRef padding, 142 IntArrayRef stride, 143 IntArrayRef dilation, 144 int64_t groups, 145 bool benchmark, 146 bool deterministic, 147 bool allow_tf32); 148 149 void raw_cudnn_convolution_backward_input_out_v7( 150 const at::Tensor& grad_input, 151 const at::Tensor& grad_output, 152 const at::Tensor& weight, 153 IntArrayRef padding, 154 IntArrayRef stride, 155 IntArrayRef dilation, 156 int64_t groups, 157 bool benchmark, 158 bool deterministic, 159 bool allow_tf32); 160 161 void raw_cudnn_convolution_backward_weight_out_v7( 162 const Tensor& grad_weight, 163 const Tensor& grad_output, 164 const Tensor& input, 165 IntArrayRef padding, 166 IntArrayRef stride, 167 IntArrayRef dilation, 168 int64_t groups, 169 bool benchmark, 170 bool deterministic, 171 bool allow_tf32); 172 173 void raw_cudnn_convolution_add_relu_out_v7( 174 const Tensor& output, 175 const Tensor& input, 176 const Tensor& weight, 177 const Tensor& z, 178 float alpha, 179 const Tensor& bias, 180 IntArrayRef stride, 181 IntArrayRef padding, 182 IntArrayRef dilation, 183 int64_t groups, 184 bool benchmark, 185 bool deterministic, 186 bool allow_tf32); 187 #endif 188 } // namespace native 189 } // namespace at 190