1 #pragma once
2
3 #include <ATen/miopen/Exceptions.h>
4
5 #include <ATen/miopen/miopen-wrapper.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/TensorUtils.h>
8
9 namespace at { namespace native {
10
dataSize(miopenDataType_t dataType)11 inline int dataSize(miopenDataType_t dataType)
12 {
13 switch (dataType) {
14 case miopenHalf: return 2;
15 case miopenFloat: return 4;
16 case miopenBFloat16: return 2;
17 default: return 8;
18 }
19 }
20
21 template <typename T, miopenStatus_t (*dtor)(T*)>
22 struct DescriptorDeleter {
operatorDescriptorDeleter23 void operator()(T* x) {
24 if (x != nullptr) {
25 MIOPEN_CHECK(dtor(x));
26 }
27 }
28 };
29
30 // A generic class for wrapping MIOpen descriptor types. All you need
31 // is to give the underlying type the Descriptor_t points to (usually,
32 // if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
33 // the constructor and the destructor. Subclasses are responsible
34 // for defining a set() function to actually set the descriptor.
35 //
36 // Descriptors default construct to a nullptr, and have a descriptor
37 // initialized the first time you call set() or any other initializing
38 // function.
39 template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
40 class Descriptor
41 {
42 public:
43 // Use desc() to access the underlying descriptor pointer in
44 // a read-only fashion. Most client code should use this.
45 // If the descriptor was never initialized, this will return
46 // nullptr.
desc()47 T* desc() const { return desc_.get(); }
desc()48 T* desc() { return desc_.get(); }
49
50 // Use mut_desc() to access the underlying descriptor pointer
51 // if you intend to modify what it points to (e.g., using
52 // miopenSetFooDescriptor). This will ensure that the descriptor
53 // is initialized. Code in this file will use this function.
mut_desc()54 T* mut_desc() { init(); return desc_.get(); }
55 protected:
init()56 void init() {
57 if (desc_ == nullptr) {
58 T* raw_desc;
59 MIOPEN_CHECK(ctor(&raw_desc));
60 desc_.reset(raw_desc);
61 }
62 }
63 private:
64 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
65 };
66
67 class TensorDescriptor
68 : public Descriptor<miopenTensorDescriptor,
69 &miopenCreateTensorDescriptor,
70 &miopenDestroyTensorDescriptor>
71 {
72 public:
TensorDescriptor()73 TensorDescriptor() {}
74 explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
75 set(t, pad);
76 }
77
78 void set(const at::Tensor &t, size_t pad = 0);
79 void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
80
81 void print();
82
83 private:
set(miopenDataType_t dataType,int dim,int * size,int * stride)84 void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
85 MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
86 }
87 };
88
89 std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
90
91 class FilterDescriptor
92 : public Descriptor<miopenTensorDescriptor,
93 &miopenCreateTensorDescriptor,
94 &miopenDestroyTensorDescriptor>
95 {
96 public:
97 void set(const at::Tensor &t, int64_t pad = 0) {
98 set(t, at::MemoryFormat::Contiguous, pad);
99 }
100
101 void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
102
103 private:
set(miopenDataType_t dataType,int dim,int * size,int * stride)104 void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
105 MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
106 }
107 };
108
109 struct ConvolutionDescriptor
110 : public Descriptor<miopenConvolutionDescriptor,
111 &miopenCreateConvolutionDescriptor,
112 &miopenDestroyConvolutionDescriptor>
113 {
setConvolutionDescriptor114 void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool deterministic) {
115 MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
116 MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
117 MIOPEN_CHECK(miopenSetConvolutionAttribute(mut_desc(), MIOPEN_CONVOLUTION_ATTRIB_DETERMINISTIC, deterministic ? 1 : 0));
118 }
119 };
120
121
122 struct RNNDescriptor
123 : public Descriptor<miopenRNNDescriptor,
124 &miopenCreateRNNDescriptor,
125 &miopenDestroyRNNDescriptor>
126 {
setRNNDescriptor127 void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
128 miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
129 MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
130 }
131 };
132
133 union Constant
134 {
135 float f;
136 double d;
Constant(miopenDataType_t dataType,double value)137 Constant(miopenDataType_t dataType, double value) {
138 if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
139 f = static_cast<float>(value);
140 } else {
141 d = value;
142 }
143 }
144 };
145
146 }} // namespace
147