1 #pragma once 2 3 #include <ATen/mkl/Exceptions.h> 4 #include <mkl_dfti.h> 5 #include <ATen/Tensor.h> 6 7 namespace at::native { 8 9 struct DftiDescriptorDeleter { operatorDftiDescriptorDeleter10 void operator()(DFTI_DESCRIPTOR* desc) { 11 if (desc != nullptr) { 12 MKL_DFTI_CHECK(DftiFreeDescriptor(&desc)); 13 } 14 } 15 }; 16 17 class DftiDescriptor { 18 public: init(DFTI_CONFIG_VALUE precision,DFTI_CONFIG_VALUE signal_type,MKL_LONG signal_ndim,MKL_LONG * sizes)19 void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) { 20 if (desc_ != nullptr) { 21 throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once"); 22 } 23 DFTI_DESCRIPTOR *raw_desc; 24 if (signal_ndim == 1) { 25 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0])); 26 } else { 27 MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes)); 28 } 29 desc_.reset(raw_desc); 30 } 31 get()32 DFTI_DESCRIPTOR *get() const { 33 if (desc_ == nullptr) { 34 throw std::runtime_error("DFTI DESCRIPTOR has not been initialized"); 35 } 36 return desc_.get(); 37 } 38 39 private: 40 std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_; 41 }; 42 43 44 } // namespace at::native 45