xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/Descriptors.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/miopen/Descriptors.h>
2 #include <ATen/ATen.h>
3 #include <c10/util/irange.h>
4 
5 #include <iostream>
6 
7 namespace at { namespace native {
8 
9 namespace {
10 
getDataType(const at::Tensor & t)11 inline miopenDataType_t getDataType(const at::Tensor& t) {
12   auto scalar_type = t.scalar_type();
13   if (scalar_type == at::kFloat) {
14     return miopenFloat;
15   } else if (scalar_type == at::kHalf) {
16     return miopenHalf;
17   } else if (scalar_type == at::kBFloat16) {
18     return miopenBFloat16;
19   } else {
20   throw std::runtime_error("TensorDescriptor only supports float, half and bfloat16 tensors");
21   }
22 }
23 
24 } // anonymous namespace
25 
26 
set(const at::Tensor & t,size_t pad)27 void TensorDescriptor::set(const at::Tensor &t, size_t pad) {
28   set(getDataType(t), t.sizes(), t.strides(), pad);
29 }
30 
31 constexpr size_t MIOPEN_DIM_MAX = 5;
32 
set(miopenDataType_t datatype,IntArrayRef t_sizes,IntArrayRef t_strides,size_t pad)33 void TensorDescriptor::set(miopenDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides, size_t pad) {
34   size_t dim = t_sizes.size();
35   if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
36 #define _STR(X) #X
37 #define STR(X) _STR(X)
38     throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions");
39 #undef _STR
40 #undef STR
41   int size[MIOPEN_DIM_MAX];
42   int stride[MIOPEN_DIM_MAX];
43   for (const auto i : c10::irange(dim)) {
44     size[i] = static_cast<int>(t_sizes[i]);
45     stride[i] = static_cast<int>(t_strides[i]);
46   }
47   for (const auto i : c10::irange(dim, pad)) {
48     size[i] = 1;
49     stride[i] = 1;
50   }
51   set(datatype, static_cast<int>(std::max(dim, pad)), size, stride);
52 }
53 
miopenTypeToString(miopenDataType_t dtype)54 std::string miopenTypeToString(miopenDataType_t dtype) {
55   switch (dtype) {
56     case miopenFloat:
57       return "miopenFloat";
58     case miopenHalf:
59       return "miopenHalf";
60     case miopenBFloat16:
61       return "miopenBFloat16";
62     default:
63       std::ostringstream oss;
64       oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
65       return oss.str();
66   }
67 }
68 
operator <<(std::ostream & out,const TensorDescriptor & d)69 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
70   out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
71   int nbDims = 4;
72   int dimA[MIOPEN_DIM_MAX];
73   int strideA[MIOPEN_DIM_MAX];
74   miopenDataType_t dtype;
75   miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
76   out << "    type = " << miopenTypeToString(dtype) << "\n";
77   out << "    nbDims = " << nbDims << "\n";
78   // Read out only nbDims of the arrays!
79   out << "    dimA = ";
80   for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
81     out << i << ", ";
82   }
83   out << "\n";
84   out << "    strideA = ";
85   for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
86     out << i << ", ";
87   }
88   out << "\n";
89   return out;
90 }
91 
print()92 void TensorDescriptor::print() { std::cout << *this; }
93 
set(const at::Tensor & t,const at::MemoryFormat memory_format,int64_t pad)94 void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
95   auto dim = t.ndimension();
96   if (dim > static_cast<int64_t>(MIOPEN_DIM_MAX) || pad > static_cast<int64_t>(MIOPEN_DIM_MAX)) {
97 #define _STR(X) #X
98 #define STR(X) _STR(X)
99     throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions");
100 #undef _STR
101 #undef STR
102   }
103   TORCH_CHECK(t.is_contiguous(memory_format),
104       "MIOpen filters (a.k.a. weights) must be contiguous");
105 
106   int size[MIOPEN_DIM_MAX];
107   int stride[MIOPEN_DIM_MAX];
108   for (const auto i : c10::irange(dim)) {
109     size[i] = (int) t.size(i);
110   }
111   for (const auto i : c10::irange(dim, pad)) {
112     size[i] = (int) 1;
113   }
114 
115   for (int i = pad; i >= dim; --i ) {
116       stride[i] = 1;
117   }
118   for (int i = dim-1 ; i >=0; --i ) {
119       // Pass-through
120       stride[i] = t.stride(i);
121   }
122 
123   dim = std::max<int64_t>(dim, pad);
124   set(getDataType(t), (int) dim, size, stride);
125 }
126 
127 }}
128