1 #ifdef USE_XNNPACK
2 #include <ATen/native/xnnpack/Convolution.h>
3 #include <ATen/native/xnnpack/Linear.h>
4 #include <ATen/native/xnnpack/OpContext.h>
5
6 #include <ATen/Context.h>
7
8 namespace at::native::xnnpack {
9
10 c10::intrusive_ptr<LinearOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)11 XNNPackLinearOpContext::create_context(
12 at::Tensor&& weight,
13 std::optional<at::Tensor>&& bias,
14 const std::optional<Scalar>& output_min,
15 const std::optional<Scalar>& output_max) {
16 auto linear_op_context =
17 c10::make_intrusive<XNNPackLinearOpContext>(
18 std::move(weight),
19 std::move(bias),
20 output_min,
21 output_max,
22 xnnpack::internal::linear::create(
23 weight,
24 bias,
25 output_min ? output_min->to<float>()
26 : xnnpack::ContextLinear::kMin,
27 output_max ? output_max->to<float>()
28 : xnnpack::ContextLinear::kMax)
29 );
30 if (at::globalContext().releaseWeightsWhenPrepacking()) {
31 linear_op_context->free_orig_weight_and_bias();
32 }
33
34 return linear_op_context;
35 }
36
free_orig_weight_and_bias()37 void XNNPackLinearOpContext::free_orig_weight_and_bias() {
38 orig_weight_and_bias_freed_ = true;
39 orig_weight_.reset();
40 orig_bias_.reset();
41 }
42
run(const Tensor & input)43 Tensor XNNPackLinearOpContext::run(const Tensor& input) {
44 return xnnpack::internal::linear::run(op_context_, input);
45 }
46
47 c10::intrusive_ptr<Conv2dOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)48 XNNPackConv2dOpContext::create_context(at::Tensor&& weight,
49 std::optional<at::Tensor>&& bias,
50 std::vector<int64_t>&& padding,
51 std::vector<int64_t>&& stride,
52 std::vector<int64_t>&& dilation,
53 int64_t groups,
54 const std::optional<Scalar>& output_min,
55 const std::optional<Scalar>& output_max) {
56 auto op_context =
57 xnnpack::internal::convolution2d::create(
58 weight,
59 bias,
60 padding,
61 {0, 0}, // output_padding
62 stride,
63 dilation,
64 groups,
65 false, // transposed
66 output_min ? output_min->to<float>()
67 : xnnpack::ContextConv2D::kMin,
68 output_max ? output_max->to<float>()
69 : xnnpack::ContextConv2D::kMax);
70
71 auto conv2d_op_context =
72 c10::make_intrusive<XNNPackConv2dOpContext>(
73 std::move(weight),
74 std::move(bias),
75 std::move(padding),
76 std::move(stride),
77 std::move(dilation),
78 groups,
79 output_min,
80 output_max,
81 std::move(op_context));
82
83 if (at::globalContext().releaseWeightsWhenPrepacking()) {
84 conv2d_op_context->free_orig_weight_and_bias();
85 }
86
87 return conv2d_op_context;
88 }
89
90 c10::intrusive_ptr<TransposeConv2dOpContext>
create_context(at::Tensor && weight,std::optional<at::Tensor> && bias,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && stride,std::vector<int64_t> && dilation,int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)91 XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight,
92 std::optional<at::Tensor>&& bias,
93 std::vector<int64_t>&& padding,
94 std::vector<int64_t>&& output_padding,
95 std::vector<int64_t>&& stride,
96 std::vector<int64_t>&& dilation,
97 int64_t groups,
98 const std::optional<Scalar>& output_min,
99 const std::optional<Scalar>& output_max) {
100 auto op_context =
101 xnnpack::internal::convolution2d::create(
102 weight,
103 bias,
104 padding,
105 output_padding,
106 stride,
107 dilation,
108 groups,
109 true, // transposed
110 output_min ? output_min->to<float>()
111 : xnnpack::ContextConv2D::kMin,
112 output_max ? output_max->to<float>()
113 : xnnpack::ContextConv2D::kMax);
114
115 auto conv2d_op_context =
116 c10::make_intrusive<XNNPackTransposeConv2dOpContext>(
117 std::move(weight),
118 std::move(bias),
119 std::move(padding),
120 std::move(output_padding),
121 std::move(stride),
122 std::move(dilation),
123 groups,
124 output_min,
125 output_max,
126 std::move(op_context));
127
128 if (at::globalContext().releaseWeightsWhenPrepacking()) {
129 conv2d_op_context->free_orig_weight_and_bias();
130 }
131
132 return conv2d_op_context;
133 }
134
run(const Tensor & input)135 Tensor XNNPackConv2dOpContext::run(const Tensor& input) {
136 std::lock_guard<std::mutex> lock(xnnp_mutex_);
137 return xnnpack::internal::convolution2d::run(op_context_, input);
138 }
139
run(const Tensor & input)140 Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) {
141 std::lock_guard<std::mutex> lock(xnnp_mutex_);
142 return xnnpack::internal::convolution2d::run(op_context_, input);
143 }
144
free_orig_weight_and_bias()145 void XNNPackConv2dOpContext::free_orig_weight_and_bias() {
146 orig_weight_and_bias_freed_ = true;
147 orig_weight_.reset();
148 orig_bias_.reset();
149 }
150
free_orig_weight_and_bias()151 void XNNPackTransposeConv2dOpContext::free_orig_weight_and_bias() {
152 orig_weight_and_bias_freed_ = true;
153 orig_weight_.reset();
154 orig_bias_.reset();
155 }
156
157 } // namespace at::native::xnnpack
158
159 #endif /* USE_XNNPACK */
160