1 #include <ATen/cudnn/Descriptors.h>
2
3 #include <ATen/ATen.h>
4 #include <c10/util/irange.h>
5
6 #include <iostream>
7 #include <sstream>
8
9 namespace at::native {
10
11 namespace {
12
getDataType(const at::Tensor & t)13 inline cudnnDataType_t getDataType(const at::Tensor& t) {
14 auto scalar_type = t.scalar_type();
15 if (scalar_type == at::kFloat) {
16 return CUDNN_DATA_FLOAT;
17 } else if (scalar_type == at::kHalf) {
18 return CUDNN_DATA_HALF;
19 } else if (scalar_type == at::kDouble) {
20 return CUDNN_DATA_DOUBLE;
21 }
22 else if (scalar_type == at::kBFloat16) {
23 return CUDNN_DATA_BFLOAT16;
24 } else if (scalar_type == at::kQInt8) {
25 return CUDNN_DATA_INT8;
26 }
27 TORCH_CHECK(false, "TensorDescriptor does not support ", scalar_type);
28 }
29
30 } // anonymous namespace
31
set(const at::Tensor & t,const cudnnRNNDataLayout_t layout,const int maxSeqLength,const int batchSize,const int vectorSize,const int * seqLengthArray)32 void RNNDataDescriptor::set(const at::Tensor &t, const cudnnRNNDataLayout_t layout, const int maxSeqLength, const int batchSize, const int vectorSize, const int* seqLengthArray) {
33 set(getDataType(t), layout, maxSeqLength, batchSize, vectorSize, seqLengthArray);
34 }
35
set(const at::Tensor & t,at::MemoryFormat memory_format,size_t pad)36 void TensorDescriptor::set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) {
37 set(getDataType(t), t.sizes(), t.strides(), pad,
38 memory_format == at::MemoryFormat::ChannelsLast ||
39 memory_format == at::MemoryFormat::ChannelsLast3d);
40 }
41
set(const at::Tensor & t,size_t pad)42 void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
43 auto memory_format = t.suggest_memory_format();
44 set(getDataType(t), t.sizes(), t.strides(), pad,
45 memory_format == at::MemoryFormat::ChannelsLast ||
46 memory_format == at::MemoryFormat::ChannelsLast3d);
47 }
48
set(cudnnDataType_t datatype,IntArrayRef t_sizes,IntArrayRef t_strides,size_t pad)49 void TensorDescriptor::set(cudnnDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
50 set(datatype, t_sizes, t_strides, pad,
51 is_channels_last_strides_2d(t_sizes, t_strides) ||
52 is_channels_last_strides_3d(t_sizes, t_strides));
53 }
54
set(cudnnDataType_t datatype,IntArrayRef t_sizes,IntArrayRef t_strides,size_t pad,bool nhwc)55 void TensorDescriptor::set(cudnnDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) {
56 size_t dim = t_sizes.size();
57 if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
58 TORCH_CHECK(false, "cuDNN supports only up to ", CUDNN_DIM_MAX, " dimensions");
59 int size[CUDNN_DIM_MAX];
60 int stride[CUDNN_DIM_MAX];
61 for (const auto i : c10::irange(dim)) {
62 size[i] = static_cast<int>(t_sizes[i]);
63 stride[i] = static_cast<int>(t_strides[i]);
64 }
65 for (const auto i : c10::irange(dim, pad)) {
66 size[i] = 1;
67 stride[i] = 1;
68 }
69 set(datatype, static_cast<int>(std::max(dim, pad)), size, stride, nhwc);
70 }
71
cudnnTypeToString(cudnnDataType_t dtype)72 std::string cudnnTypeToString(cudnnDataType_t dtype) {
73 switch (dtype) {
74 case CUDNN_DATA_FLOAT:
75 return "CUDNN_DATA_FLOAT";
76 case CUDNN_DATA_DOUBLE:
77 return "CUDNN_DATA_DOUBLE";
78 case CUDNN_DATA_HALF:
79 return "CUDNN_DATA_HALF";
80 case CUDNN_DATA_BFLOAT16:
81 return "CUDNN_DATA_BFLOAT16";
82 case CUDNN_DATA_INT8:
83 return "CUDNN_DATA_INT8";
84 case CUDNN_DATA_INT32:
85 return "CUDNN_DATA_INT32";
86 case CUDNN_DATA_INT8x4:
87 return "CUDNN_DATA_INT8x4";
88 case CUDNN_DATA_UINT8:
89 return "CUDNN_DATA_UINT8";
90 case CUDNN_DATA_UINT8x4:
91 return "CUDNN_DATA_UINT8x4";
92 default:
93 std::ostringstream oss;
94 oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
95 return oss.str();
96 }
97 }
98
operator <<(std::ostream & out,const TensorDescriptor & d)99 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
100 out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
101 int nbDims = 0;
102 int dimA[CUDNN_DIM_MAX];
103 int strideA[CUDNN_DIM_MAX];
104 cudnnDataType_t dtype;
105 cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
106 out << " type = " << cudnnTypeToString(dtype) << "\n";
107 out << " nbDims = " << nbDims << "\n";
108 // Read out only nbDims of the arrays!
109 out << " dimA = ";
110 for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
111 out << i << ", ";
112 }
113 out << "\n";
114 out << " strideA = ";
115 for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
116 out << i << ", ";
117 }
118 out << "\n";
119 return out;
120 }
121
print()122 void TensorDescriptor::print() { std::cout << *this; }
123
set(const at::Tensor & t,const at::MemoryFormat memory_format,int64_t pad)124 void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
125 auto dim = t.ndimension();
126 if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
127 TORCH_CHECK(false, "cuDNN supports only up to ", CUDNN_DIM_MAX, " dimensions");
128 // NB: It is possible for this test to be insufficient, because the
129 // Tensor passed in to set the filter descriptor may not be the actual
130 // Tensor whose data pointer is passed to cuDNN. Nevertheless,
131 // that is the common case, so we can catch most client errors with this test.
132 TORCH_CHECK(t.is_contiguous(memory_format),
133 "cuDNN filters (a.k.a. weights) must be contiguous in desired memory_format\n",
134 "Weight sizes: ", t.sizes(), "\n",
135 "Weight strides: ", t.strides(), "\n",
136 "cuDNN suggested memory_format: ", memory_format);
137
138 int size[CUDNN_DIM_MAX];
139 for (const auto i : c10::irange(dim)) {
140 size[i] = (int) t.size(i);
141 }
142 for (const auto i : c10::irange(dim, pad)) {
143 size[i] = (int) 1;
144 }
145 dim = std::max(dim, pad);
146 cudnnTensorFormat_t filter_format;
147 switch(memory_format) {
148 case at::MemoryFormat::Contiguous:
149 filter_format = CUDNN_TENSOR_NCHW;
150 break;
151 case at::MemoryFormat::ChannelsLast:
152 case at::MemoryFormat::ChannelsLast3d:
153 filter_format = CUDNN_TENSOR_NHWC;
154 break;
155 default:
156 TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
157 }
158 set(getDataType(t), (int) dim, size, filter_format);
159 }
160
cudnnMemoryFormatToString(cudnnTensorFormat_t tformat)161 std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
162 switch (tformat) {
163 case CUDNN_TENSOR_NCHW:
164 return "CUDNN_TENSOR_NCHW";
165 case CUDNN_TENSOR_NHWC:
166 return "CUDNN_TENSOR_NHWC";
167 default:
168 std::ostringstream oss;
169 oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
170 return oss.str();
171 }
172 }
173
operator <<(std::ostream & out,const FilterDescriptor & d)174 std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
175 out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
176 int nbDims = 0;
177 int dimA[CUDNN_DIM_MAX];
178 cudnnDataType_t dtype;
179 cudnnTensorFormat_t tformat;
180 cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA);
181 out << " type = " << cudnnTypeToString(dtype) << "\n";
182 out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n";
183 out << " nbDims = " << nbDims << "\n";
184 // Read out only nbDims of the arrays!
185 out << " dimA = ";
186 for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
187 out << i << ", ";
188 }
189 out << "\n";
190 return out;
191 }
192
print()193 void FilterDescriptor::print() { std::cout << *this; }
194
195 }
196