xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/Descriptors.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cudnn/Descriptors.h>
2 
3 #include <ATen/ATen.h>
4 #include <c10/util/irange.h>
5 
6 #include <iostream>
7 #include <sstream>
8 
9 namespace at::native {
10 
11 namespace {
12 
getDataType(const at::Tensor & t)13 inline cudnnDataType_t getDataType(const at::Tensor& t) {
14   auto scalar_type = t.scalar_type();
15   if (scalar_type == at::kFloat) {
16     return CUDNN_DATA_FLOAT;
17   } else if (scalar_type == at::kHalf) {
18     return CUDNN_DATA_HALF;
19   } else if (scalar_type == at::kDouble) {
20     return CUDNN_DATA_DOUBLE;
21   }
22     else if (scalar_type == at::kBFloat16) {
23     return CUDNN_DATA_BFLOAT16;
24   } else if (scalar_type == at::kQInt8) {
25     return CUDNN_DATA_INT8;
26   }
27   TORCH_CHECK(false, "TensorDescriptor does not support ", scalar_type);
28 }
29 
30 } // anonymous namespace
31 
set(const at::Tensor & t,const cudnnRNNDataLayout_t layout,const int maxSeqLength,const int batchSize,const int vectorSize,const int * seqLengthArray)32 void RNNDataDescriptor::set(const at::Tensor &t, const cudnnRNNDataLayout_t layout, const int maxSeqLength, const int batchSize, const int vectorSize, const int* seqLengthArray) {
33   set(getDataType(t), layout, maxSeqLength, batchSize, vectorSize, seqLengthArray);
34 }
35 
set(const at::Tensor & t,at::MemoryFormat memory_format,size_t pad)36 void TensorDescriptor::set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad) {
37   set(getDataType(t), t.sizes(), t.strides(), pad,
38     memory_format == at::MemoryFormat::ChannelsLast ||
39     memory_format == at::MemoryFormat::ChannelsLast3d);
40 }
41 
set(const at::Tensor & t,size_t pad)42 void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
43   auto memory_format = t.suggest_memory_format();
44   set(getDataType(t), t.sizes(), t.strides(), pad,
45     memory_format == at::MemoryFormat::ChannelsLast ||
46     memory_format == at::MemoryFormat::ChannelsLast3d);
47 }
48 
set(cudnnDataType_t datatype,IntArrayRef t_sizes,IntArrayRef t_strides,size_t pad)49 void TensorDescriptor::set(cudnnDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
50   set(datatype, t_sizes, t_strides, pad,
51     is_channels_last_strides_2d(t_sizes, t_strides) ||
52     is_channels_last_strides_3d(t_sizes, t_strides));
53 }
54 
set(cudnnDataType_t datatype,IntArrayRef t_sizes,IntArrayRef t_strides,size_t pad,bool nhwc)55 void TensorDescriptor::set(cudnnDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad, bool nhwc) {
56   size_t dim = t_sizes.size();
57   if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
58     TORCH_CHECK(false, "cuDNN supports only up to ", CUDNN_DIM_MAX, " dimensions");
59   int size[CUDNN_DIM_MAX];
60   int stride[CUDNN_DIM_MAX];
61   for (const auto i : c10::irange(dim)) {
62     size[i] = static_cast<int>(t_sizes[i]);
63     stride[i] = static_cast<int>(t_strides[i]);
64   }
65   for (const auto i : c10::irange(dim, pad)) {
66     size[i] = 1;
67     stride[i] = 1;
68   }
69   set(datatype, static_cast<int>(std::max(dim, pad)), size, stride, nhwc);
70 }
71 
cudnnTypeToString(cudnnDataType_t dtype)72 std::string cudnnTypeToString(cudnnDataType_t dtype) {
73   switch (dtype) {
74     case CUDNN_DATA_FLOAT:
75       return "CUDNN_DATA_FLOAT";
76     case CUDNN_DATA_DOUBLE:
77       return "CUDNN_DATA_DOUBLE";
78     case CUDNN_DATA_HALF:
79       return "CUDNN_DATA_HALF";
80     case CUDNN_DATA_BFLOAT16:
81       return "CUDNN_DATA_BFLOAT16";
82     case CUDNN_DATA_INT8:
83       return "CUDNN_DATA_INT8";
84     case CUDNN_DATA_INT32:
85       return "CUDNN_DATA_INT32";
86     case CUDNN_DATA_INT8x4:
87       return "CUDNN_DATA_INT8x4";
88     case CUDNN_DATA_UINT8:
89       return "CUDNN_DATA_UINT8";
90     case CUDNN_DATA_UINT8x4:
91       return "CUDNN_DATA_UINT8x4";
92     default:
93       std::ostringstream oss;
94       oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
95       return oss.str();
96   }
97 }
98 
operator <<(std::ostream & out,const TensorDescriptor & d)99 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
100   out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
101   int nbDims = 0;
102   int dimA[CUDNN_DIM_MAX];
103   int strideA[CUDNN_DIM_MAX];
104   cudnnDataType_t dtype;
105   cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
106   out << "    type = " << cudnnTypeToString(dtype) << "\n";
107   out << "    nbDims = " << nbDims << "\n";
108   // Read out only nbDims of the arrays!
109   out << "    dimA = ";
110   for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
111     out << i << ", ";
112   }
113   out << "\n";
114   out << "    strideA = ";
115   for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
116     out << i << ", ";
117   }
118   out << "\n";
119   return out;
120 }
121 
print()122 void TensorDescriptor::print() { std::cout << *this; }
123 
set(const at::Tensor & t,const at::MemoryFormat memory_format,int64_t pad)124 void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
125   auto dim = t.ndimension();
126   if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
127   TORCH_CHECK(false, "cuDNN supports only up to ", CUDNN_DIM_MAX, " dimensions");
128   // NB: It is possible for this test to be insufficient, because the
129   // Tensor passed in to set the filter descriptor may not be the actual
130   // Tensor whose data pointer is passed to cuDNN.  Nevertheless,
131   // that is the common case, so we can catch most client errors with this test.
132   TORCH_CHECK(t.is_contiguous(memory_format),
133     "cuDNN filters (a.k.a. weights) must be contiguous in desired memory_format\n",
134     "Weight sizes: ", t.sizes(), "\n",
135     "Weight strides: ", t.strides(), "\n",
136     "cuDNN suggested memory_format: ", memory_format);
137 
138   int size[CUDNN_DIM_MAX];
139   for (const auto i : c10::irange(dim)) {
140     size[i] = (int) t.size(i);
141   }
142   for (const auto i : c10::irange(dim, pad)) {
143     size[i] = (int) 1;
144   }
145   dim = std::max(dim, pad);
146   cudnnTensorFormat_t filter_format;
147   switch(memory_format) {
148     case at::MemoryFormat::Contiguous:
149       filter_format = CUDNN_TENSOR_NCHW;
150       break;
151     case at::MemoryFormat::ChannelsLast:
152     case at::MemoryFormat::ChannelsLast3d:
153       filter_format = CUDNN_TENSOR_NHWC;
154       break;
155     default:
156       TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
157   }
158   set(getDataType(t), (int) dim, size, filter_format);
159 }
160 
cudnnMemoryFormatToString(cudnnTensorFormat_t tformat)161 std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
162   switch (tformat) {
163     case CUDNN_TENSOR_NCHW:
164       return "CUDNN_TENSOR_NCHW";
165     case CUDNN_TENSOR_NHWC:
166       return "CUDNN_TENSOR_NHWC";
167     default:
168       std::ostringstream oss;
169       oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
170       return oss.str();
171   }
172 }
173 
operator <<(std::ostream & out,const FilterDescriptor & d)174 std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
175   out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
176   int nbDims = 0;
177   int dimA[CUDNN_DIM_MAX];
178   cudnnDataType_t dtype;
179   cudnnTensorFormat_t tformat;
180   cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA);
181   out << "    type = " << cudnnTypeToString(dtype) << "\n";
182   out << "    tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n";
183   out << "    nbDims = " << nbDims << "\n";
184   // Read out only nbDims of the arrays!
185   out << "    dimA = ";
186   for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
187     out << i << ", ";
188   }
189   out << "\n";
190   return out;
191 }
192 
print()193 void FilterDescriptor::print() { std::cout << *this; }
194 
195 }
196