xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Config.h>
5 
6 #if AT_MKLDNN_ENABLED()
7 
8 #include <ideep/tensor.hpp>
9 
10 namespace at {
11 namespace native {
12 namespace mkldnn {
13 
14 struct ContextConv final {
15   ideep::tensor weight_packed_;
16   std::optional<at::Tensor> at_bias_;
17   std::vector<int64_t> padding_;
18   std::vector<int64_t> stride_;
19   std::vector<int64_t> dilation_;
20   int64_t groups_;
21   ideep::attr_t attr_;
22 
23   ContextConv() = delete;
24 
ContextConvfinal25   ContextConv(
26       ideep::tensor&& weight_packed,
27       std::optional<at::Tensor> at_bias,
28       std::vector<int64_t> padding,
29       std::vector<int64_t> stride,
30       std::vector<int64_t> dilation,
31       int64_t groups,
32       ideep::attr_t attr)
33       : weight_packed_(std::move(weight_packed)),
34         at_bias_(std::move(at_bias)),
35         padding_(padding),
36         stride_(stride),
37         dilation_(dilation),
38         groups_(groups),
39         attr_(attr) {}
40 };
41 
42 } // namespace mkldnn
43 } // namespace native
44 } // namespace at
45 
46 #endif // AT_MKLDNN_ENABLED()
47