xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/ConvPrepack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/native/mkldnn/Common.h>
5 #include <ATen/native/mkldnn/OpContext.h>
6 
7 #if AT_MKLDNN_ENABLED()
8 
9 namespace at {
10 namespace native {
11 namespace mkldnn {
12 namespace internal {
13 namespace convolution {
14 
15 c10::intrusive_ptr<mkldnn::ConvOpContext> createConvPrePackOpContext(
16     Tensor weight,
17     std::optional<Tensor> bias,
18     std::vector<int64_t> stride,
19     std::vector<int64_t> padding,
20     std::vector<int64_t> dilation,
21     int64_t groups,
22     std::vector<int64_t> input_size,
23     std::string attr);
24 
25 Tensor conv_run(
26     const Tensor& input,
27     const c10::intrusive_ptr<mkldnn::ConvOpContext>& op_context);
28 
29 ContextConv create(
30     const Tensor& weight,
31     const std::optional<Tensor>& bias,
32     const IntArrayRef padding,
33     const IntArrayRef stride,
34     const IntArrayRef dilation,
35     const int64_t groups,
36     const IntArrayRef input_size,
37     const ideep::attr_t& attr);
38 
39 Tensor run(ContextConv& context, const Tensor& input);
40 
41 void run(ContextConv& context, const Tensor& input, void* output);
42 
43 } // namespace convolution
44 } // namespace internal
45 } // namespace mkldnn
46 } // namespace native
47 } // namespace at
48 
49 #endif // AT_MKLDNN_ENABLED()
50