xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/ConvPrepack.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <vector>
2 
3 #include <ATen/native/ConvUtils.h>
4 #include <ATen/native/mkldnn/Common.h>
5 #include <ATen/native/mkldnn/ConvPrepack.h>
6 #include <ATen/native/mkldnn/MKLDNNCommon.h>
7 #include <ATen/native/mkldnn/OpContext.h>
8 #include <ATen/native/utils/Factory.h>
9 #include <ATen/native/utils/ParamUtils.h>
10 #include <c10/util/irange.h>
11 
12 #if AT_MKLDNN_ENABLED()
13 
14 namespace at {
15 namespace native {
16 namespace mkldnn {
17 namespace internal {
18 namespace convolution {
19 
createConvPrePackOpContext(Tensor weight,std::optional<Tensor> bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,std::vector<int64_t> input_size,std::string attr)20 c10::intrusive_ptr<mkldnn::ConvOpContext> createConvPrePackOpContext(
21     Tensor weight,
22     std::optional<Tensor> bias,
23     std::vector<int64_t> stride,
24     std::vector<int64_t> padding,
25     std::vector<int64_t> dilation,
26     int64_t groups,
27     std::vector<int64_t> input_size,
28     std::string attr) {
29   auto it = fusion_attr_map.find(attr);
30   TORCH_CHECK(it != fusion_attr_map.end(), "Fusion behavior undefined.");
31   ideep::attr_t op_attr = it->second;
32 
33   return mkldnn::MkldnnConvOpContext::create_context(
34       std::move(weight),
35       std::move(bias),
36       std::move(padding),
37       std::move(stride),
38       std::move(dilation),
39       groups,
40       std::move(input_size),
41       op_attr);
42 }
43 
create(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const IntArrayRef input_size,const ideep::attr_t & attr)44 ContextConv create(
45     const Tensor& weight,
46     const std::optional<Tensor>& bias,
47     const IntArrayRef padding,
48     const IntArrayRef stride,
49     const IntArrayRef dilation,
50     const int64_t groups,
51     const IntArrayRef input_size,
52     const ideep::attr_t& attr) {
53   auto k = weight.ndimension();
54   int64_t dim = k - 2;
55   const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
56   const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
57   const auto dilation_expanded =
58       expand_param_if_needed(dilation, "dilation", dim);
59   const auto input_size_expanded =
60       expand_param_if_needed(input_size, "input_size", k);
61 
62   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
63   auto w = itensor_view_from_dense(weight);
64   // TODO: what if input is nhwc but w is nchw
65   bool is_channels_last =
66       weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
67   ideep::tensor::desc expected_weight_desc =
68       ideep::convolution_forward::expected_weights_desc(
69           w.get_dims(),
70           w.get_data_type(),
71           {stride_expanded.begin(), stride_expanded.end()},
72           {padding_expanded.begin(), padding_expanded.end()},
73           {padding_expanded.begin(), padding_expanded.end()},
74           {dilation_expanded.begin(), dilation_expanded.end()},
75           groups,
76           ideep::algorithm::convolution_direct,
77           ideep::prop_kind::forward,
78           /*x_dtype*/ w.get_data_type(),
79           {input_size_expanded.begin(), input_size_expanded.end()},
80           attr,
81           is_channels_last);
82 
83   ideep::tensor packed_weight;
84   packed_weight.init(expected_weight_desc);
85   packed_weight.feed_from(w);
86 
87   return ContextConv{
88       std::move(packed_weight),
89       bias.has_value() ? std::make_optional(*bias) : std::nullopt,
90       {padding_expanded.begin(), padding_expanded.end()},
91       {stride_expanded.begin(), stride_expanded.end()},
92       {dilation_expanded.begin(), dilation_expanded.end()},
93       groups,
94       attr};
95 }
96 
_mkldnn_convolution_out(const ideep::tensor & x,ideep::tensor & y,const ideep::tensor & w,const std::optional<ideep::tensor> & b,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef output_sizes,int64_t groups,const ideep::attr_t & attr=ideep::attr_t ())97 static void _mkldnn_convolution_out(
98     const ideep::tensor& x,
99     ideep::tensor& y,
100     const ideep::tensor& w,
101     const std::optional<ideep::tensor>& b,
102     IntArrayRef padding,
103     IntArrayRef stride,
104     IntArrayRef dilation,
105     IntArrayRef output_sizes,
106     int64_t groups,
107     const ideep::attr_t& attr = ideep::attr_t()) {
108   if (b.has_value()) {
109     ideep::convolution_forward::compute_v2(
110         x,
111         w,
112         b.value(),
113         {output_sizes.cbegin(), output_sizes.cend()},
114         y,
115         {stride.begin(), stride.end()},
116         {dilation.begin(), dilation.end()},
117         {padding.begin(), padding.end()},
118         {padding.begin(), padding.end()},
119         groups,
120         ideep::scale_t(),
121         ideep::scale_t(),
122         ideep::scale_t(),
123         ideep::zero_point_t(),
124         ideep::zero_point_t(),
125         attr);
126   } else {
127     ideep::convolution_forward::compute_v2(
128         x,
129         w,
130         {output_sizes.cbegin(), output_sizes.cend()},
131         y,
132         {stride.begin(), stride.end()},
133         {dilation.begin(), dilation.end()},
134         {padding.begin(), padding.end()},
135         {padding.begin(), padding.end()},
136         groups,
137         ideep::scale_t(),
138         ideep::scale_t(),
139         ideep::scale_t(),
140         ideep::zero_point_t(),
141         ideep::zero_point_t(),
142         attr);
143   }
144 }
145 
mkldnn_convolution_out(const Tensor & input,ideep::tensor & mkldnn_output,const ideep::tensor & mkldnn_weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef output_sizes,int64_t groups,const ideep::attr_t & attr=ideep::attr_t ())146 static void mkldnn_convolution_out(
147     const Tensor& input,
148     ideep::tensor& mkldnn_output,
149     const ideep::tensor& mkldnn_weight,
150     const std::optional<Tensor>& bias_opt,
151     IntArrayRef padding,
152     IntArrayRef stride,
153     IntArrayRef dilation,
154     IntArrayRef output_sizes,
155     int64_t groups,
156     const ideep::attr_t& attr = ideep::attr_t()) {
157   c10::MaybeOwned<Tensor> bias_maybe_owned =
158       at::borrow_from_optional_tensor(bias_opt);
159   const Tensor& bias = *bias_maybe_owned;
160 
161   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
162   const ideep::tensor mkldnn_input = itensor_from_tensor(input);
163   std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
164   if (bias.defined()) {
165     mkldnn_bias = itensor_from_tensor(bias);
166   }
167 
168   _mkldnn_convolution_out(
169       mkldnn_input,
170       mkldnn_output,
171       mkldnn_weight,
172       mkldnn_bias,
173       padding,
174       stride,
175       dilation,
176       output_sizes,
177       groups,
178       attr);
179 }
180 
get_output_sizes(ContextConv & context,const Tensor & input)181 static std::vector<int64_t> get_output_sizes(
182     ContextConv& context,
183     const Tensor& input) {
184   const ideep::tensor& mkldnn_weight = context.weight_packed_;
185   IntArrayRef padding = context.padding_;
186   IntArrayRef stride = context.stride_;
187   IntArrayRef dilation = context.dilation_;
188 
189   auto kernel_size = mkldnn_weight.get_dims();
190 
191   std::vector<int64_t> input_size = input.sizes().vec();
192   return conv_output_size(input_size, kernel_size, padding, stride, dilation);
193 }
194 
run(ContextConv & context,const Tensor & input)195 Tensor run(ContextConv& context, const Tensor& input) {
196   std::vector<int64_t> output_sizes = get_output_sizes(context, input);
197   auto output = at::empty(
198       output_sizes,
199       input.options().memory_format(input.suggest_memory_format()));
200 
201   bool is_channels_last =
202       input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
203   ideep::tensor y;
204 
205   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
206   ideep::tensor mkldnn_output = itensor_from_tensor(output);
207 
208   if (is_channels_last) {
209     mkldnn_convolution_out(
210         input,
211         mkldnn_output,
212         context.weight_packed_,
213         context.at_bias_,
214         context.padding_,
215         context.stride_,
216         context.dilation_,
217         output_sizes,
218         context.groups_,
219         context.attr_);
220   } else {
221     mkldnn_convolution_out(
222         input,
223         y,
224         context.weight_packed_,
225         context.at_bias_,
226         context.padding_,
227         context.stride_,
228         context.dilation_,
229         output_sizes,
230         context.groups_,
231         context.attr_);
232     mkldnn_output.feed_from(y);
233   }
234   return output;
235 }
236 
run(ContextConv & context,const Tensor & input,void * output)237 void run(ContextConv& context, const Tensor& input, void* output) {
238   std::vector<int64_t> output_sizes = get_output_sizes(context, input);
239 
240   bool is_channels_last =
241       input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
242   ideep::tensor y;
243 
244   ideep::tag o_tag = is_channels_last ? ideep::tag::nhwc : ideep::tag::nchw;
245   ideep::tensor::desc o_desc = {
246       output_sizes, get_mkldnn_dtype(input.scalar_type()), o_tag};
247   ideep::tensor mkldnn_output = {o_desc, output};
248 
249   if (is_channels_last) {
250     mkldnn_convolution_out(
251         input,
252         mkldnn_output,
253         context.weight_packed_,
254         context.at_bias_,
255         context.padding_,
256         context.stride_,
257         context.dilation_,
258         output_sizes,
259         context.groups_,
260         context.attr_);
261   } else {
262     mkldnn_convolution_out(
263         input,
264         y,
265         context.weight_packed_,
266         context.at_bias_,
267         context.padding_,
268         context.stride_,
269         context.dilation_,
270         output_sizes,
271         context.groups_,
272         context.attr_);
273     mkldnn_output.feed_from(y);
274   }
275 }
276 
conv_run(const Tensor & input,const c10::intrusive_ptr<mkldnn::ConvOpContext> & op_context)277 Tensor conv_run(
278     const Tensor& input,
279     const c10::intrusive_ptr<mkldnn::ConvOpContext>& op_context) {
280   return op_context->run(input);
281 }
282 
283 } // namespace convolution
284 } // namespace internal
285 } // namespace mkldnn
286 } // namespace native
287 } // namespace at
288 
289 #endif // AT_MKLDNN_ENABLED()
290