xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mkl/Descriptors.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/mkl/Exceptions.h>
4 #include <mkl_dfti.h>
5 #include <ATen/Tensor.h>
6 
7 namespace at::native {
8 
9 struct DftiDescriptorDeleter {
operatorDftiDescriptorDeleter10   void operator()(DFTI_DESCRIPTOR* desc) {
11     if (desc != nullptr) {
12       MKL_DFTI_CHECK(DftiFreeDescriptor(&desc));
13     }
14   }
15 };
16 
17 class DftiDescriptor {
18 public:
init(DFTI_CONFIG_VALUE precision,DFTI_CONFIG_VALUE signal_type,MKL_LONG signal_ndim,MKL_LONG * sizes)19   void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type, MKL_LONG signal_ndim, MKL_LONG* sizes) {
20     if (desc_ != nullptr) {
21       throw std::runtime_error("DFTI DESCRIPTOR can only be initialized once");
22     }
23     DFTI_DESCRIPTOR *raw_desc;
24     if (signal_ndim == 1) {
25       MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
26     } else {
27       MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type, signal_ndim, sizes));
28     }
29     desc_.reset(raw_desc);
30   }
31 
get()32   DFTI_DESCRIPTOR *get() const {
33     if (desc_ == nullptr) {
34       throw std::runtime_error("DFTI DESCRIPTOR has not been initialized");
35     }
36     return desc_.get();
37   }
38 
39 private:
40   std::unique_ptr<DFTI_DESCRIPTOR, DftiDescriptorDeleter> desc_;
41 };
42 
43 
44 } // namespace at::native
45