1 #include <ATen/Config.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/tensorexpr/loopnest.h>
4 #include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
5 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7
8 namespace torch::jit::tensorexpr {
9
10 namespace {
11
assert_dims_constant(const BufHandle & buf)12 void assert_dims_constant(const BufHandle& buf) {
13 for (auto const& dim : buf.node()->dims()) {
14 TORCH_INTERNAL_ASSERT(dim->isConstant());
15 }
16 }
17
18 using InitFunc = std::function<ExprHandle(const std::vector<VarHandle>&)>;
19
conv2d_depthwise_static(const BufHandle & input,const BufHandle & weight,const InitFunc & init_func,int stride,int pad,int groups)20 Tensor conv2d_depthwise_static(
21 const BufHandle& input,
22 const BufHandle& weight,
23 const InitFunc& init_func,
24 int stride,
25 int pad,
26 int groups) {
27 TORCH_INTERNAL_ASSERT(input.ndim() == 4);
28 TORCH_INTERNAL_ASSERT(weight.ndim() == 4);
29
30 assert_dims_constant(input);
31 assert_dims_constant(weight);
32
33 auto const& N = immediateAs<int>(input.dim(0));
34 auto const& C = immediateAs<int>(input.dim(1));
35 auto const& H = immediateAs<int>(input.dim(2));
36 auto const& W = immediateAs<int>(input.dim(3));
37
38 auto const& K = immediateAs<int>(weight.dim(0));
39 auto const& CperG = immediateAs<int>(weight.dim(1));
40 auto const& R = immediateAs<int>(weight.dim(2));
41 auto const& S = immediateAs<int>(weight.dim(3));
42
43 TORCH_INTERNAL_ASSERT(C == K && K == groups && CperG == 1);
44 TORCH_INTERNAL_ASSERT(R == S);
45
46 auto OH = (H - R + 2 * pad) / stride + 1;
47 auto OW = (W - S + 2 * pad) / stride + 1;
48
49 Tensor conv = Reduce(
50 "conv2d_depthwise",
51 {N, K, OH, OW},
52 std::nullopt, // TODO
53 Sum(),
54 [&](const std::vector<VarHandle>& v) { return init_func(v); },
55 [&](const std::vector<VarHandle>& v) {
56 auto const& n = v[0];
57 auto const& k = v[1];
58 auto const& oh = v[2];
59 auto const& ow = v[3];
60 auto const& c = v[4];
61 auto const& r = v[5];
62 auto const& s = v[6];
63 auto cond = CompareSelect::make(oh * stride - pad + r, 0, 1, 0, kLT);
64 cond = CompareSelect::make(ow * stride - pad + s, 0, 1, cond, kLT);
65 cond = CompareSelect::make(oh * stride - pad + r, H, 1, cond, kGE);
66 cond = CompareSelect::make(ow * stride - pad + s, W, 1, cond, kGE);
67 auto in = ifThenElse(
68 cond,
69 0.f,
70 input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
71 return in * weight.load(k, c, r, s);
72 },
73 {C / groups, R, S});
74
75 LoopNest nest({conv});
76
77 constexpr int kLoopH = 2, kLoopW = 3;
78 if (R == 3 && stride == 2 && pad == 1) {
79 ForPtr head, tail;
80 auto loops = nest.getLoopStmtsFor(conv);
81 nest.sliceHead(loops[kLoopW], 2, &head, &tail);
82 loops = nest.getLoopStmtsFor(conv);
83 nest.sliceHead(loops[kLoopH], 2, &head, &tail);
84 } else if (R == 3 && stride == 1 && pad == 1) {
85 ForPtr main, peeled;
86 auto loops = nest.getAllLoopNestsWritingToBuf(conv.buf());
87 main = loops[1][kLoopW];
88 nest.sliceHead(main, 1, &peeled, &main);
89 nest.sliceTail(main, 1, &main, &peeled);
90 main = LoopNest::getParentLoop(main);
91 nest.sliceHead(main, 1, &peeled, &main);
92 nest.sliceTail(main, 1, &main, &peeled);
93 }
94
95 return Tensor(conv.buf(), nest.root_stmt());
96 }
97
conv2d_depthwise_dynamic(BufHandle input,BufHandle weight,const InitFunc & init_func,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)98 Tensor conv2d_depthwise_dynamic(
99 BufHandle input,
100 BufHandle weight,
101 const InitFunc& init_func,
102 ExprHandle N,
103 ExprHandle C,
104 ExprHandle H,
105 ExprHandle W,
106 ExprHandle K,
107 ExprHandle CperG,
108 ExprHandle R,
109 ExprHandle S,
110 ExprHandle stride,
111 ExprHandle pad,
112 ExprHandle groups) {
113 TORCH_INTERNAL_ASSERT(input.ndim() == 4);
114 TORCH_INTERNAL_ASSERT(weight.ndim() == 4);
115
116 auto OH = (H - R + pad * 2) / stride + 1;
117 auto OW = (W - S + pad * 2) / stride + 1;
118
119 return Reduce(
120 "conv2d_depthwise",
121 {N, K, OH, OW},
122 std::nullopt, // TODO
123 Sum(),
124 [&](const std::vector<VarHandle>& v) { return init_func(v); },
125 [&](const std::vector<VarHandle>& v) {
126 auto const& n = v[0];
127 auto const& k = v[1];
128 auto const& oh = v[2];
129 auto const& ow = v[3];
130 auto const& c = v[4];
131 auto const& r = v[5];
132 auto const& s = v[6];
133 auto cond = CompareSelect::make(oh * stride - pad + r, 0, 1, 0, kLT);
134 cond = CompareSelect::make(ow * stride - pad + s, 0, 1, cond, kLT);
135 cond = CompareSelect::make(oh * stride - pad + r, H, 1, cond, kGE);
136 cond = CompareSelect::make(ow * stride - pad + s, W, 1, cond, kGE);
137 auto in = ifThenElse(
138 cond,
139 0.f,
140 input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
141 return in * weight.load(k, c, r, s);
142 },
143 {C / groups, R, S});
144 }
145
146 } // namespace
147
conv2d_depthwise(BufHandle input,BufHandle weight,BufHandle bias,int stride,int pad,int groups)148 Tensor conv2d_depthwise(
149 BufHandle input,
150 BufHandle weight,
151 BufHandle bias,
152 int stride,
153 int pad,
154 int groups) {
155 assert_dims_constant(bias);
156 auto init_func = [&](const std::vector<VarHandle>& v) {
157 return bias.load(v[1]);
158 };
159 return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups);
160 }
161
conv2d_depthwise(BufHandle input,BufHandle weight,int stride,int pad,int groups)162 Tensor conv2d_depthwise(
163 BufHandle input,
164 BufHandle weight,
165 int stride,
166 int pad,
167 int groups) {
168 auto init_func = [](const std::vector<VarHandle>& v) {
169 return ExprHandle(Sum().initializer());
170 };
171 return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups);
172 }
173
conv2d_depthwise(BufHandle input,BufHandle weight,BufHandle bias,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)174 Tensor conv2d_depthwise(
175 BufHandle input,
176 BufHandle weight,
177 BufHandle bias,
178 ExprHandle N,
179 ExprHandle C,
180 ExprHandle H,
181 ExprHandle W,
182 ExprHandle K,
183 ExprHandle CperG,
184 ExprHandle R,
185 ExprHandle S,
186 ExprHandle stride,
187 ExprHandle pad,
188 ExprHandle groups) {
189 assert_dims_constant(bias);
190 auto init_func = [&](const std::vector<VarHandle>& v) {
191 return bias.load(v[1]);
192 };
193 return conv2d_depthwise_dynamic(
194 input,
195 weight,
196 init_func,
197 N,
198 C,
199 H,
200 W,
201 K,
202 CperG,
203 R,
204 S,
205 stride,
206 pad,
207 groups);
208 }
209
conv2d_depthwise(BufHandle input,BufHandle weight,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)210 Tensor conv2d_depthwise(
211 BufHandle input,
212 BufHandle weight,
213 ExprHandle N,
214 ExprHandle C,
215 ExprHandle H,
216 ExprHandle W,
217 ExprHandle K,
218 ExprHandle CperG,
219 ExprHandle R,
220 ExprHandle S,
221 ExprHandle stride,
222 ExprHandle pad,
223 ExprHandle groups) {
224 auto init_func = [](const std::vector<VarHandle>& v) {
225 return ExprHandle(Sum().initializer());
226 };
227 return conv2d_depthwise_dynamic(
228 input,
229 weight,
230 init_func,
231 N,
232 C,
233 H,
234 W,
235 K,
236 CperG,
237 R,
238 S,
239 stride,
240 pad,
241 groups);
242 }
243
_pair_int(ArgValue v)244 static std::vector<int64_t> _pair_int(ArgValue v) {
245 if (auto t = std::get_if<IntList>(&v)) {
246 return {(*t)[0], (*t)[1]};
247 }
248 auto i = std::get<int64_t>(v);
249 return {i, i};
250 }
251
_single_int_list(ArgValue v)252 static std::vector<int64_t> _single_int_list(ArgValue v) {
253 if (auto t = std::get_if<IntList>(&v)) {
254 return {(*t)[0]};
255 }
256 auto i = std::get<int64_t>(v);
257 return {i};
258 }
259
conv2dIsSupported(const TensorInfo & input,const TensorInfo & weight,const TensorInfo & bias,const std::vector<int64_t> & stride,const std::vector<int64_t> & pad,const std::vector<int64_t> & dilation,int64_t groups)260 bool conv2dIsSupported(
261 const TensorInfo& input,
262 const TensorInfo& weight,
263 const TensorInfo& bias,
264 const std::vector<int64_t>& stride,
265 const std::vector<int64_t>& pad,
266 const std::vector<int64_t>& dilation,
267 int64_t groups) {
268 if (input.dtype != c10::ScalarType::Float ||
269 weight.dtype != c10::ScalarType::Float ||
270 bias.dtype != c10::ScalarType::Float) {
271 GRAPH_DEBUG("conv2dIsSupported: only float32 allowed");
272 return false;
273 }
274 if (input.dims.size() != 4 || weight.dims.size() != 4 ||
275 bias.dims.size() != 1) {
276 GRAPH_DEBUG("conv2dIsSupported: inputs are the wrong size");
277 return false;
278 }
279 auto Cin = input.dims[1];
280 auto Cout = weight.dims[0];
281 auto CperG = weight.dims[1];
282 if (Cin != Cout || Cin != groups || CperG != 1) {
283 GRAPH_DEBUG("conv2dIsSupported: not depthwise");
284 return false;
285 }
286 auto KH = weight.dims[2];
287 auto KW = weight.dims[3];
288 if (KH != 3 || KW != 3) {
289 GRAPH_DEBUG("conv2dIsSupported: not 3x3");
290 return false;
291 }
292 if (stride.size() != 2 || stride[0] != stride[1]) {
293 GRAPH_DEBUG("conv2dIsSupported: unsupported stride");
294 return false;
295 }
296 if (pad.size() != 2 || pad[0] != pad[1]) {
297 GRAPH_DEBUG("conv2dIsSupported: unsupported pad");
298 return false;
299 }
300 if (dilation.size() != 2 || dilation[0] != 1 || dilation[1] != 1) {
301 GRAPH_DEBUG("conv2dIsSupported: unsupported dilation");
302 return false;
303 }
304 return true;
305 }
306
mkldnnPrepackedConvIsSupported(const TensorInfo & input,const TensorInfo & weight,const std::vector<int64_t> & stride,const std::vector<int64_t> & pad,const std::vector<int64_t> & dilation,int64_t groups)307 bool mkldnnPrepackedConvIsSupported(
308 const TensorInfo& input,
309 const TensorInfo& weight,
310 const std::vector<int64_t>& stride,
311 const std::vector<int64_t>& pad,
312 const std::vector<int64_t>& dilation,
313 int64_t groups) {
314 #if AT_MKLDNN_ENABLED()
315 if (input.dtype != c10::ScalarType::Float ||
316 weight.dtype != c10::ScalarType::Float) {
317 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: only float32 allowed");
318 return false;
319 }
320 if (input.dims.size() != 4 || weight.dims.size() != 4) {
321 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: inputs are the wrong size");
322 return false;
323 }
324 if (stride.size() != 2) {
325 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported stride");
326 return false;
327 }
328 if (pad.size() != 2) {
329 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported pad");
330 return false;
331 }
332 if (dilation.size() != 2) {
333 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported dilation");
334 return false;
335 }
336
337 // Do not rewrite for cases where native is faster than mkldnn
338 // Conditions are from: aten/src/ATen/native/Convolution.cpp:use_mkldnn
339 bool use_mkldnn = groups > 1 || (weight.dims[2] > 3 && weight.dims[3] > 3) ||
340 input.dims[0] > 1 ||
341 input.dims[0] * input.dims[1] * input.dims[2] * input.dims[3] > 20480;
342 GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: ", use_mkldnn);
343 return use_mkldnn;
344 #endif
345 return false;
346 }
347
computeConv2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)348 Tensor computeConv2d(
349 const std::vector<ArgValue>& inputs,
350 const std::vector<ExprHandle>& outputShape,
351 const std::vector<ExprHandle>& outputStrides,
352 const std::optional<ScalarType>& outputType,
353 at::Device device) {
354 Dtype dtype = kFloat;
355 if (outputType) {
356 dtype = Dtype(*outputType);
357 }
358
359 BufHandle ResultBuf("conv", outputShape, dtype);
360 const BufHandle& inp = std::get<BufHandle>(inputs[0]);
361 const BufHandle& w = std::get<BufHandle>(inputs[1]);
362 const BufHandle& b = std::get<BufHandle>(inputs[2]);
363
364 auto strides = _pair_int(inputs[3]);
365 auto padding = _pair_int(inputs[4]);
366 auto dilation = _pair_int(inputs[5]);
367
368 int groups = std::get<int64_t>(inputs[6]);
369
370 auto inpInfo = getTensorInfo(inp);
371 auto wInfo = getTensorInfo(w);
372 auto bInfo = getTensorInfo(b);
373 // Generate TE for depthwise convolutions.
374 if (inpInfo && wInfo && bInfo &&
375 conv2dIsSupported(
376 *inpInfo, *wInfo, *bInfo, strides, padding, dilation, groups)) {
377 return conv2d_depthwise(inp, w, b, strides[0], padding[0], groups);
378 }
379
380 // Once we have a performant TE representation for conv2d, we could use it
381 // here instead of the external call!
382 StmtPtr s = ExternalCall::make(
383 ResultBuf,
384 "nnc_aten_conv2d",
385 {inp, w, b},
386 {strides[0],
387 strides[1],
388 padding[0],
389 padding[1],
390 dilation[0],
391 dilation[1],
392 groups});
393 return Tensor(ResultBuf.node(), s);
394 }
395
computeConv1d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)396 Tensor computeConv1d(
397 const std::vector<ArgValue>& inputs,
398 const std::vector<ExprHandle>& outputShape,
399 const std::vector<ExprHandle>& outputStrides,
400 const std::optional<ScalarType>& outputType,
401 at::Device device) {
402 Dtype dtype = kFloat;
403 if (outputType) {
404 dtype = Dtype(*outputType);
405 }
406
407 BufHandle ResultBuf("conv", outputShape, dtype);
408 const BufHandle& inp = std::get<BufHandle>(inputs[0]);
409 const BufHandle& w = std::get<BufHandle>(inputs[1]);
410 const BufHandle& b = std::get<BufHandle>(inputs[2]);
411
412 auto strides = _single_int_list(inputs[3]);
413 auto padding = _single_int_list(inputs[4]);
414 auto dilation = _single_int_list(inputs[5]);
415
416 int groups = std::get<int64_t>(inputs[6]);
417
418 auto inpInfo = getTensorInfo(inp);
419 auto wInfo = getTensorInfo(w);
420 auto bInfo = getTensorInfo(b);
421
422 StmtPtr s = ExternalCall::make(
423 ResultBuf,
424 "nnc_aten_conv1d",
425 {inp, w, b},
426 {strides[0], padding[0], dilation[0], groups});
427 return Tensor(ResultBuf.node(), s);
428 }
429
computePrepackedConv2dClampRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)430 Tensor computePrepackedConv2dClampRun(
431 const std::vector<ArgValue>& inputs,
432 const std::vector<ExprHandle>& outputShape,
433 const std::vector<ExprHandle>& outputStrides,
434 const std::optional<ScalarType>& outputType,
435 at::Device device) {
436 Dtype dtype = kFloat;
437 if (outputType) {
438 dtype = Dtype(*outputType);
439 }
440
441 BufHandle ResultBuf("prepacked_conv2d_clamp_run", outputShape, dtype);
442 const BufHandle& inp = std::get<BufHandle>(inputs[0]);
443 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
444 StmtPtr s = ExternalCall::make(
445 ResultBuf, "nnc_prepacked_conv2d_clamp_run", {inp, prepacked}, {});
446 return Tensor(ResultBuf.node(), s);
447 }
448
computePrepackedLinearClampRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)449 Tensor computePrepackedLinearClampRun(
450 const std::vector<ArgValue>& inputs,
451 const std::vector<ExprHandle>& outputShape,
452 const std::vector<ExprHandle>& outputStrides,
453 const std::optional<ScalarType>& outputType,
454 at::Device device) {
455 Dtype dtype = kFloat;
456 if (outputType) {
457 dtype = Dtype(*outputType);
458 }
459
460 BufHandle ResultBuf("prepacked_linear_clamp_run", outputShape, dtype);
461 const BufHandle& inp = std::get<BufHandle>(inputs[0]);
462 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
463 StmtPtr s = ExternalCall::make(
464 ResultBuf, "nnc_prepacked_linear_clamp_run", {inp, prepacked}, {});
465 return Tensor(ResultBuf.node(), s);
466 }
467
computeMkldnnPrepackedConvRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)468 Tensor computeMkldnnPrepackedConvRun(
469 const std::vector<ArgValue>& inputs,
470 const std::vector<ExprHandle>& outputShape,
471 const std::vector<ExprHandle>& outputStrides,
472 const std::optional<ScalarType>& outputType,
473 at::Device device) {
474 Dtype dtype = kFloat;
475 if (outputType) {
476 dtype = Dtype(*outputType);
477 }
478
479 BufHandle ResultBuf(
480 "mkldnn_prepacked_conv_run", outputShape, outputStrides, dtype);
481 const BufHandle& inp = std::get<BufHandle>(inputs[0]);
482 const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
483 StmtPtr s = ExternalCall::make(
484 ResultBuf, "nnc_mkldnn_prepacked_conv_run", {inp, prepacked}, {});
485 return Tensor(ResultBuf.node(), s);
486 }
487
488 } // namespace torch::jit::tensorexpr
489