1 #include <vector>
2
3 #include <ATen/native/ConvUtils.h>
4 #include <ATen/native/mkldnn/Common.h>
5 #include <ATen/native/mkldnn/ConvPrepack.h>
6 #include <ATen/native/mkldnn/MKLDNNCommon.h>
7 #include <ATen/native/mkldnn/OpContext.h>
8 #include <ATen/native/utils/Factory.h>
9 #include <ATen/native/utils/ParamUtils.h>
10 #include <c10/util/irange.h>
11
12 #if AT_MKLDNN_ENABLED()
13
14 namespace at {
15 namespace native {
16 namespace mkldnn {
17 namespace internal {
18 namespace convolution {
19
createConvPrePackOpContext(Tensor weight,std::optional<Tensor> bias,std::vector<int64_t> stride,std::vector<int64_t> padding,std::vector<int64_t> dilation,int64_t groups,std::vector<int64_t> input_size,std::string attr)20 c10::intrusive_ptr<mkldnn::ConvOpContext> createConvPrePackOpContext(
21 Tensor weight,
22 std::optional<Tensor> bias,
23 std::vector<int64_t> stride,
24 std::vector<int64_t> padding,
25 std::vector<int64_t> dilation,
26 int64_t groups,
27 std::vector<int64_t> input_size,
28 std::string attr) {
29 auto it = fusion_attr_map.find(attr);
30 TORCH_CHECK(it != fusion_attr_map.end(), "Fusion behavior undefined.");
31 ideep::attr_t op_attr = it->second;
32
33 return mkldnn::MkldnnConvOpContext::create_context(
34 std::move(weight),
35 std::move(bias),
36 std::move(padding),
37 std::move(stride),
38 std::move(dilation),
39 groups,
40 std::move(input_size),
41 op_attr);
42 }
43
create(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef padding,const IntArrayRef stride,const IntArrayRef dilation,const int64_t groups,const IntArrayRef input_size,const ideep::attr_t & attr)44 ContextConv create(
45 const Tensor& weight,
46 const std::optional<Tensor>& bias,
47 const IntArrayRef padding,
48 const IntArrayRef stride,
49 const IntArrayRef dilation,
50 const int64_t groups,
51 const IntArrayRef input_size,
52 const ideep::attr_t& attr) {
53 auto k = weight.ndimension();
54 int64_t dim = k - 2;
55 const auto padding_expanded = expand_param_if_needed(padding, "padding", dim);
56 const auto stride_expanded = expand_param_if_needed(stride, "stride", dim);
57 const auto dilation_expanded =
58 expand_param_if_needed(dilation, "dilation", dim);
59 const auto input_size_expanded =
60 expand_param_if_needed(input_size, "input_size", k);
61
62 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
63 auto w = itensor_view_from_dense(weight);
64 // TODO: what if input is nhwc but w is nchw
65 bool is_channels_last =
66 weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
67 ideep::tensor::desc expected_weight_desc =
68 ideep::convolution_forward::expected_weights_desc(
69 w.get_dims(),
70 w.get_data_type(),
71 {stride_expanded.begin(), stride_expanded.end()},
72 {padding_expanded.begin(), padding_expanded.end()},
73 {padding_expanded.begin(), padding_expanded.end()},
74 {dilation_expanded.begin(), dilation_expanded.end()},
75 groups,
76 ideep::algorithm::convolution_direct,
77 ideep::prop_kind::forward,
78 /*x_dtype*/ w.get_data_type(),
79 {input_size_expanded.begin(), input_size_expanded.end()},
80 attr,
81 is_channels_last);
82
83 ideep::tensor packed_weight;
84 packed_weight.init(expected_weight_desc);
85 packed_weight.feed_from(w);
86
87 return ContextConv{
88 std::move(packed_weight),
89 bias.has_value() ? std::make_optional(*bias) : std::nullopt,
90 {padding_expanded.begin(), padding_expanded.end()},
91 {stride_expanded.begin(), stride_expanded.end()},
92 {dilation_expanded.begin(), dilation_expanded.end()},
93 groups,
94 attr};
95 }
96
_mkldnn_convolution_out(const ideep::tensor & x,ideep::tensor & y,const ideep::tensor & w,const std::optional<ideep::tensor> & b,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef output_sizes,int64_t groups,const ideep::attr_t & attr=ideep::attr_t ())97 static void _mkldnn_convolution_out(
98 const ideep::tensor& x,
99 ideep::tensor& y,
100 const ideep::tensor& w,
101 const std::optional<ideep::tensor>& b,
102 IntArrayRef padding,
103 IntArrayRef stride,
104 IntArrayRef dilation,
105 IntArrayRef output_sizes,
106 int64_t groups,
107 const ideep::attr_t& attr = ideep::attr_t()) {
108 if (b.has_value()) {
109 ideep::convolution_forward::compute_v2(
110 x,
111 w,
112 b.value(),
113 {output_sizes.cbegin(), output_sizes.cend()},
114 y,
115 {stride.begin(), stride.end()},
116 {dilation.begin(), dilation.end()},
117 {padding.begin(), padding.end()},
118 {padding.begin(), padding.end()},
119 groups,
120 ideep::scale_t(),
121 ideep::scale_t(),
122 ideep::scale_t(),
123 ideep::zero_point_t(),
124 ideep::zero_point_t(),
125 attr);
126 } else {
127 ideep::convolution_forward::compute_v2(
128 x,
129 w,
130 {output_sizes.cbegin(), output_sizes.cend()},
131 y,
132 {stride.begin(), stride.end()},
133 {dilation.begin(), dilation.end()},
134 {padding.begin(), padding.end()},
135 {padding.begin(), padding.end()},
136 groups,
137 ideep::scale_t(),
138 ideep::scale_t(),
139 ideep::scale_t(),
140 ideep::zero_point_t(),
141 ideep::zero_point_t(),
142 attr);
143 }
144 }
145
mkldnn_convolution_out(const Tensor & input,ideep::tensor & mkldnn_output,const ideep::tensor & mkldnn_weight,const std::optional<Tensor> & bias_opt,IntArrayRef padding,IntArrayRef stride,IntArrayRef dilation,IntArrayRef output_sizes,int64_t groups,const ideep::attr_t & attr=ideep::attr_t ())146 static void mkldnn_convolution_out(
147 const Tensor& input,
148 ideep::tensor& mkldnn_output,
149 const ideep::tensor& mkldnn_weight,
150 const std::optional<Tensor>& bias_opt,
151 IntArrayRef padding,
152 IntArrayRef stride,
153 IntArrayRef dilation,
154 IntArrayRef output_sizes,
155 int64_t groups,
156 const ideep::attr_t& attr = ideep::attr_t()) {
157 c10::MaybeOwned<Tensor> bias_maybe_owned =
158 at::borrow_from_optional_tensor(bias_opt);
159 const Tensor& bias = *bias_maybe_owned;
160
161 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
162 const ideep::tensor mkldnn_input = itensor_from_tensor(input);
163 std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
164 if (bias.defined()) {
165 mkldnn_bias = itensor_from_tensor(bias);
166 }
167
168 _mkldnn_convolution_out(
169 mkldnn_input,
170 mkldnn_output,
171 mkldnn_weight,
172 mkldnn_bias,
173 padding,
174 stride,
175 dilation,
176 output_sizes,
177 groups,
178 attr);
179 }
180
get_output_sizes(ContextConv & context,const Tensor & input)181 static std::vector<int64_t> get_output_sizes(
182 ContextConv& context,
183 const Tensor& input) {
184 const ideep::tensor& mkldnn_weight = context.weight_packed_;
185 IntArrayRef padding = context.padding_;
186 IntArrayRef stride = context.stride_;
187 IntArrayRef dilation = context.dilation_;
188
189 auto kernel_size = mkldnn_weight.get_dims();
190
191 std::vector<int64_t> input_size = input.sizes().vec();
192 return conv_output_size(input_size, kernel_size, padding, stride, dilation);
193 }
194
run(ContextConv & context,const Tensor & input)195 Tensor run(ContextConv& context, const Tensor& input) {
196 std::vector<int64_t> output_sizes = get_output_sizes(context, input);
197 auto output = at::empty(
198 output_sizes,
199 input.options().memory_format(input.suggest_memory_format()));
200
201 bool is_channels_last =
202 input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
203 ideep::tensor y;
204
205 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
206 ideep::tensor mkldnn_output = itensor_from_tensor(output);
207
208 if (is_channels_last) {
209 mkldnn_convolution_out(
210 input,
211 mkldnn_output,
212 context.weight_packed_,
213 context.at_bias_,
214 context.padding_,
215 context.stride_,
216 context.dilation_,
217 output_sizes,
218 context.groups_,
219 context.attr_);
220 } else {
221 mkldnn_convolution_out(
222 input,
223 y,
224 context.weight_packed_,
225 context.at_bias_,
226 context.padding_,
227 context.stride_,
228 context.dilation_,
229 output_sizes,
230 context.groups_,
231 context.attr_);
232 mkldnn_output.feed_from(y);
233 }
234 return output;
235 }
236
run(ContextConv & context,const Tensor & input,void * output)237 void run(ContextConv& context, const Tensor& input, void* output) {
238 std::vector<int64_t> output_sizes = get_output_sizes(context, input);
239
240 bool is_channels_last =
241 input.suggest_memory_format() == at::MemoryFormat::ChannelsLast;
242 ideep::tensor y;
243
244 ideep::tag o_tag = is_channels_last ? ideep::tag::nhwc : ideep::tag::nchw;
245 ideep::tensor::desc o_desc = {
246 output_sizes, get_mkldnn_dtype(input.scalar_type()), o_tag};
247 ideep::tensor mkldnn_output = {o_desc, output};
248
249 if (is_channels_last) {
250 mkldnn_convolution_out(
251 input,
252 mkldnn_output,
253 context.weight_packed_,
254 context.at_bias_,
255 context.padding_,
256 context.stride_,
257 context.dilation_,
258 output_sizes,
259 context.groups_,
260 context.attr_);
261 } else {
262 mkldnn_convolution_out(
263 input,
264 y,
265 context.weight_packed_,
266 context.at_bias_,
267 context.padding_,
268 context.stride_,
269 context.dilation_,
270 output_sizes,
271 context.groups_,
272 context.attr_);
273 mkldnn_output.feed_from(y);
274 }
275 }
276
conv_run(const Tensor & input,const c10::intrusive_ptr<mkldnn::ConvOpContext> & op_context)277 Tensor conv_run(
278 const Tensor& input,
279 const c10::intrusive_ptr<mkldnn::ConvOpContext>& op_context) {
280 return op_context->run(input);
281 }
282
283 } // namespace convolution
284 } // namespace internal
285 } // namespace mkldnn
286 } // namespace native
287 } // namespace at
288
289 #endif // AT_MKLDNN_ENABLED()
290