xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/xnnpack/OpContext.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef USE_XNNPACK
2 #include <ATen/native/xnnpack/Convolution.h>
3 #include <ATen/native/xnnpack/Linear.h>
4 #include <ATen/native/xnnpack/OpContext.h>
5 
6 #include <ATen/Context.h>
7 
8 namespace at::native::xnnpack {
9 
10 c10::intrusive_ptr<LinearOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)11 XNNPackLinearOpContext::create_context(
12     at::Tensor&& weight,
13     std::optional<at::Tensor>&& bias,
14     const std::optional<Scalar>& output_min,
15     const std::optional<Scalar>& output_max) {
16   auto linear_op_context =
17       c10::make_intrusive<XNNPackLinearOpContext>(
18           std::move(weight),
19           std::move(bias),
20           output_min,
21           output_max,
22           xnnpack::internal::linear::create(
23               weight,
24               bias,
25               output_min ? output_min->to<float>()
26                          : xnnpack::ContextLinear::kMin,
27               output_max ? output_max->to<float>()
28                          : xnnpack::ContextLinear::kMax)
29           );
30   if (at::globalContext().releaseWeightsWhenPrepacking()) {
31     linear_op_context->free_orig_weight_and_bias();
32   }
33 
34   return linear_op_context;
35 }
36 
free_orig_weight_and_bias()37 void XNNPackLinearOpContext::free_orig_weight_and_bias() {
38   orig_weight_and_bias_freed_ = true;
39   orig_weight_.reset();
40   orig_bias_.reset();
41 }
42 
run(const Tensor & input)43 Tensor XNNPackLinearOpContext::run(const Tensor& input) {
44   return xnnpack::internal::linear::run(op_context_, input);
45 }
46 
47 c10::intrusive_ptr<Conv2dOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)48 XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
49     std::optional<at::Tensor>&& bias,
50     std::vector<int64_t>&& padding,
51     std::vector<int64_t>&& stride,
52     std::vector<int64_t>&& dilation,
53     int64_t groups,
54     const std::optional<Scalar>& output_min,
55     const std::optional<Scalar>& output_max) {
56   auto op_context =
57       xnnpack::internal::convolution2d::create(
58           weight,
59           bias,
60           padding,
61           {0, 0}, // output_padding
62           stride,
63           dilation,
64           groups,
65           false,  // transposed
66           output_min ? output_min->to<float>()
67                      : xnnpack::ContextConv2D::kMin,
68           output_max ? output_max->to<float>()
69                      : xnnpack::ContextConv2D::kMax);
70 
71   auto conv2d_op_context =
72       c10::make_intrusive<XNNPackConv2dOpContext>(
73           std::move(weight),
74           std::move(bias),
75           std::move(padding),
76           std::move(stride),
77           std::move(dilation),
78           groups,
79           output_min,
80           output_max,
81           std::move(op_context));
82 
83   if (at::globalContext().releaseWeightsWhenPrepacking()) {
84     conv2d_op_context->free_orig_weight_and_bias();
85   }
86 
87   return conv2d_op_context;
88 }
89 
90 c10::intrusive_ptr<TransposeConv2dOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)91 XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight,
92     std::optional<at::Tensor>&& bias,
93     std::vector<int64_t>&& padding,
94     std::vector<int64_t>&& output_padding,
95     std::vector<int64_t>&& stride,
96     std::vector<int64_t>&& dilation,
97     int64_t groups,
98     const std::optional<Scalar>& output_min,
99     const std::optional<Scalar>& output_max) {
100   auto op_context =
101       xnnpack::internal::convolution2d::create(
102           weight,
103           bias,
104           padding,
105           output_padding,
106           stride,
107           dilation,
108           groups,
109           true, // transposed
110           output_min ? output_min->to<float>()
111                      : xnnpack::ContextConv2D::kMin,
112           output_max ? output_max->to<float>()
113                      : xnnpack::ContextConv2D::kMax);
114 
115   auto conv2d_op_context =
116       c10::make_intrusive<XNNPackTransposeConv2dOpContext>(
117           std::move(weight),
118           std::move(bias),
119           std::move(padding),
120           std::move(output_padding),
121           std::move(stride),
122           std::move(dilation),
123           groups,
124           output_min,
125           output_max,
126           std::move(op_context));
127 
128   if (at::globalContext().releaseWeightsWhenPrepacking()) {
129     conv2d_op_context->free_orig_weight_and_bias();
130   }
131 
132   return conv2d_op_context;
133 }
134 
run(const Tensor & input)135 Tensor XNNPackConv2dOpContext::run(const Tensor& input) {
136   std::lock_guard<std::mutex> lock(xnnp_mutex_);
137   return xnnpack::internal::convolution2d::run(op_context_, input);
138 }
139 
run(const Tensor & input)140 Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) {
141   std::lock_guard<std::mutex> lock(xnnp_mutex_);
142   return xnnpack::internal::convolution2d::run(op_context_, input);
143 }
144 
free_orig_weight_and_bias()145 void XNNPackConv2dOpContext::free_orig_weight_and_bias() {
146   orig_weight_and_bias_freed_ = true;
147   orig_weight_.reset();
148   orig_bias_.reset();
149 }
150 
free_orig_weight_and_bias()151 void XNNPackTransposeConv2dOpContext::free_orig_weight_and_bias() {
152   orig_weight_and_bias_freed_ = true;
153   orig_weight_.reset();
154   orig_bias_.reset();
155 }
156 
157 } // namespace at::native::xnnpack
158 
159 #endif /* USE_XNNPACK */
160