1 #include <ATen/native/mkldnn/OpContext.h>
2
3 #if AT_MKLDNN_ENABLED()
4 #include <ATen/native/mkldnn/ConvPrepack.h>
5
6 namespace at {
7 namespace native {
8 namespace mkldnn {
9
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,std::vector<int64_t> && input_size,const ideep::attr_t & attr)10 c10::intrusive_ptr<ConvOpContext> MkldnnConvOpContext::create_context(
11 at::Tensor&& weight,
12 std::optional<at::Tensor>&& bias,
13 std::vector<int64_t>&& padding,
14 std::vector<int64_t>&& stride,
15 std::vector<int64_t>&& dilation,
16 int64_t groups,
17 std::vector<int64_t>&& input_size,
18 const ideep::attr_t& attr) {
19 auto op_context = mkldnn::internal::convolution::create(
20 weight, bias, padding, stride, dilation, groups, input_size, attr);
21
22 auto conv_op_context = c10::make_intrusive<MkldnnConvOpContext>(
23 std::move(weight),
24 std::move(bias),
25 std::move(padding),
26 std::move(stride),
27 std::move(dilation),
28 groups,
29 std::move(input_size),
30 std::move(op_context));
31
32 return conv_op_context;
33 }
34
run(const Tensor & input)35 Tensor MkldnnConvOpContext::run(const Tensor& input) {
36 return mkldnn::internal::convolution::run(op_context_, input);
37 }
38
run(const Tensor & input,void * output)39 void MkldnnConvOpContext::run(const Tensor& input, void* output) {
40 mkldnn::internal::convolution::run(op_context_, input, output);
41 }
42
43 } // namespace mkldnn
44 } // namespace native
45 } // namespace at
46
47 #endif // AT_MKLDNN_ENABLED()
48