xref: /aosp_15_r20/external/pytorch/aten/src/ATen/miopen/Descriptors.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/miopen/Exceptions.h>
4 
5 #include <ATen/miopen/miopen-wrapper.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/TensorUtils.h>
8 
9 namespace at { namespace native {
10 
dataSize(miopenDataType_t dataType)11 inline int dataSize(miopenDataType_t dataType)
12 {
13   switch (dataType) {
14     case miopenHalf: return 2;
15     case miopenFloat: return 4;
16     case miopenBFloat16: return 2;
17     default: return 8;
18   }
19 }
20 
21 template <typename T, miopenStatus_t (*dtor)(T*)>
22 struct DescriptorDeleter {
operatorDescriptorDeleter23   void operator()(T* x) {
24     if (x != nullptr) {
25       MIOPEN_CHECK(dtor(x));
26     }
27   }
28 };
29 
30 // A generic class for wrapping MIOpen descriptor types.  All you need
31 // is to give the underlying type the Descriptor_t points to (usually,
32 // if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
33 // the constructor and the destructor.  Subclasses are responsible
34 // for defining a set() function to actually set the descriptor.
35 //
36 // Descriptors default construct to a nullptr, and have a descriptor
37 // initialized the first time you call set() or any other initializing
38 // function.
39 template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
40 class Descriptor
41 {
42 public:
43   // Use desc() to access the underlying descriptor pointer in
44   // a read-only fashion.  Most client code should use this.
45   // If the descriptor was never initialized, this will return
46   // nullptr.
desc()47   T* desc() const { return desc_.get(); }
desc()48   T* desc() { return desc_.get(); }
49 
50   // Use mut_desc() to access the underlying descriptor pointer
51   // if you intend to modify what it points to (e.g., using
52   // miopenSetFooDescriptor).  This will ensure that the descriptor
53   // is initialized.  Code in this file will use this function.
mut_desc()54   T* mut_desc() { init(); return desc_.get(); }
55 protected:
init()56   void init() {
57     if (desc_ == nullptr) {
58       T* raw_desc;
59       MIOPEN_CHECK(ctor(&raw_desc));
60       desc_.reset(raw_desc);
61     }
62   }
63 private:
64   std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
65 };
66 
67 class TensorDescriptor
68   : public Descriptor<miopenTensorDescriptor,
69                       &miopenCreateTensorDescriptor,
70                       &miopenDestroyTensorDescriptor>
71 {
72 public:
TensorDescriptor()73   TensorDescriptor() {}
74   explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
75     set(t, pad);
76   }
77 
78   void set(const at::Tensor &t, size_t pad = 0);
79   void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
80 
81   void print();
82 
83 private:
set(miopenDataType_t dataType,int dim,int * size,int * stride)84   void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
85     MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
86   }
87 };
88 
89 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
90 
91 class FilterDescriptor
92   : public Descriptor<miopenTensorDescriptor,
93                       &miopenCreateTensorDescriptor,
94                       &miopenDestroyTensorDescriptor>
95 {
96  public:
97   void set(const at::Tensor &t, int64_t pad = 0) {
98     set(t, at::MemoryFormat::Contiguous, pad);
99   }
100 
101   void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
102 
103 private:
set(miopenDataType_t dataType,int dim,int * size,int * stride)104   void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
105     MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
106   }
107 };
108 
109 struct ConvolutionDescriptor
110   : public Descriptor<miopenConvolutionDescriptor,
111                       &miopenCreateConvolutionDescriptor,
112                       &miopenDestroyConvolutionDescriptor>
113 {
setConvolutionDescriptor114   void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode,  int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool deterministic) {
115     MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
116     MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
117     MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0));
118   }
119 };
120 
121 
122 struct RNNDescriptor
123   : public Descriptor<miopenRNNDescriptor,
124                       &miopenCreateRNNDescriptor,
125                       &miopenDestroyRNNDescriptor>
126 {
setRNNDescriptor127     void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
128               miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
129       MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
130     }
131 };
132 
133 union Constant
134 {
135   float f;
136   double d;
Constant(miopenDataType_t dataType,double value)137   Constant(miopenDataType_t dataType, double value) {
138     if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
139       f = static_cast<float>(value);
140     } else {
141       d = value;
142     }
143   }
144 };
145 
146 }}  // namespace
147