xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/OpContext.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/mkldnn/OpContext.h>
2 
3 #if AT_MKLDNN_ENABLED()
4 #include <ATen/native/mkldnn/ConvPrepack.h>
5 
6 namespace at {
7 namespace native {
8 namespace mkldnn {
9 
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,std::vector<int64_t> && input_size,const ideep::attr_t & attr)10 c10::intrusive_ptr<ConvOpContext> MkldnnConvOpContext::create_context(
11     at::Tensor&& weight,
12     std::optional<at::Tensor>&& bias,
13     std::vector<int64_t>&& padding,
14     std::vector<int64_t>&& stride,
15     std::vector<int64_t>&& dilation,
16     int64_t groups,
17     std::vector<int64_t>&& input_size,
18     const ideep::attr_t& attr) {
19   auto op_context = mkldnn::internal::convolution::create(
20       weight, bias, padding, stride, dilation, groups, input_size, attr);
21 
22   auto conv_op_context = c10::make_intrusive<MkldnnConvOpContext>(
23       std::move(weight),
24       std::move(bias),
25       std::move(padding),
26       std::move(stride),
27       std::move(dilation),
28       groups,
29       std::move(input_size),
30       std::move(op_context));
31 
32   return conv_op_context;
33 }
34 
run(const Tensor & input)35 Tensor MkldnnConvOpContext::run(const Tensor& input) {
36   return mkldnn::internal::convolution::run(op_context_, input);
37 }
38 
run(const Tensor & input,void * output)39 void MkldnnConvOpContext::run(const Tensor& input, void* output) {
40   mkldnn::internal::convolution::run(op_context_, input, output);
41 }
42 
43 } // namespace mkldnn
44 } // namespace native
45 } // namespace at
46 
47 #endif // AT_MKLDNN_ENABLED()
48