xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/operators/conv2d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/tensorexpr/loopnest.h>
4 #include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
5 #include <torch/csrc/jit/tensorexpr/operators/misc.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 namespace torch::jit::tensorexpr {
9 
10 namespace {
11 
assert_dims_constant(const BufHandle & buf)12 void assert_dims_constant(const BufHandle& buf) {
13   for (auto const& dim : buf.node()->dims()) {
14     TORCH_INTERNAL_ASSERT(dim->isConstant());
15   }
16 }
17 
18 using InitFunc = std::function<ExprHandle(const std::vector<VarHandle>&)>;
19 
conv2d_depthwise_static(const BufHandle & input,const BufHandle & weight,const InitFunc & init_func,int stride,int pad,int groups)20 Tensor conv2d_depthwise_static(
21     const BufHandle& input,
22     const BufHandle& weight,
23     const InitFunc& init_func,
24     int stride,
25     int pad,
26     int groups) {
27   TORCH_INTERNAL_ASSERT(input.ndim() == 4);
28   TORCH_INTERNAL_ASSERT(weight.ndim() == 4);
29 
30   assert_dims_constant(input);
31   assert_dims_constant(weight);
32 
33   auto const& N = immediateAs<int>(input.dim(0));
34   auto const& C = immediateAs<int>(input.dim(1));
35   auto const& H = immediateAs<int>(input.dim(2));
36   auto const& W = immediateAs<int>(input.dim(3));
37 
38   auto const& K = immediateAs<int>(weight.dim(0));
39   auto const& CperG = immediateAs<int>(weight.dim(1));
40   auto const& R = immediateAs<int>(weight.dim(2));
41   auto const& S = immediateAs<int>(weight.dim(3));
42 
43   TORCH_INTERNAL_ASSERT(C == K && K == groups && CperG == 1);
44   TORCH_INTERNAL_ASSERT(R == S);
45 
46   auto OH = (H - R + 2 * pad) / stride + 1;
47   auto OW = (W - S + 2 * pad) / stride + 1;
48 
49   Tensor conv = Reduce(
50       "conv2d_depthwise",
51       {N, K, OH, OW},
52       std::nullopt, // TODO
53       Sum(),
54       [&](const std::vector<VarHandle>& v) { return init_func(v); },
55       [&](const std::vector<VarHandle>& v) {
56         auto const& n = v[0];
57         auto const& k = v[1];
58         auto const& oh = v[2];
59         auto const& ow = v[3];
60         auto const& c = v[4];
61         auto const& r = v[5];
62         auto const& s = v[6];
63         auto cond = CompareSelect::make(oh * stride - pad + r, 0, 1, 0, kLT);
64         cond = CompareSelect::make(ow * stride - pad + s, 0, 1, cond, kLT);
65         cond = CompareSelect::make(oh * stride - pad + r, H, 1, cond, kGE);
66         cond = CompareSelect::make(ow * stride - pad + s, W, 1, cond, kGE);
67         auto in = ifThenElse(
68             cond,
69             0.f,
70             input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
71         return in * weight.load(k, c, r, s);
72       },
73       {C / groups, R, S});
74 
75   LoopNest nest({conv});
76 
77   constexpr int kLoopH = 2, kLoopW = 3;
78   if (R == 3 && stride == 2 && pad == 1) {
79     ForPtr head, tail;
80     auto loops = nest.getLoopStmtsFor(conv);
81     nest.sliceHead(loops[kLoopW], 2, &head, &tail);
82     loops = nest.getLoopStmtsFor(conv);
83     nest.sliceHead(loops[kLoopH], 2, &head, &tail);
84   } else if (R == 3 && stride == 1 && pad == 1) {
85     ForPtr main, peeled;
86     auto loops = nest.getAllLoopNestsWritingToBuf(conv.buf());
87     main = loops[1][kLoopW];
88     nest.sliceHead(main, 1, &peeled, &main);
89     nest.sliceTail(main, 1, &main, &peeled);
90     main = LoopNest::getParentLoop(main);
91     nest.sliceHead(main, 1, &peeled, &main);
92     nest.sliceTail(main, 1, &main, &peeled);
93   }
94 
95   return Tensor(conv.buf(), nest.root_stmt());
96 }
97 
conv2d_depthwise_dynamic(BufHandle input,BufHandle weight,const InitFunc & init_func,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)98 Tensor conv2d_depthwise_dynamic(
99     BufHandle input,
100     BufHandle weight,
101     const InitFunc& init_func,
102     ExprHandle N,
103     ExprHandle C,
104     ExprHandle H,
105     ExprHandle W,
106     ExprHandle K,
107     ExprHandle CperG,
108     ExprHandle R,
109     ExprHandle S,
110     ExprHandle stride,
111     ExprHandle pad,
112     ExprHandle groups) {
113   TORCH_INTERNAL_ASSERT(input.ndim() == 4);
114   TORCH_INTERNAL_ASSERT(weight.ndim() == 4);
115 
116   auto OH = (H - R + pad * 2) / stride + 1;
117   auto OW = (W - S + pad * 2) / stride + 1;
118 
119   return Reduce(
120       "conv2d_depthwise",
121       {N, K, OH, OW},
122       std::nullopt, // TODO
123       Sum(),
124       [&](const std::vector<VarHandle>& v) { return init_func(v); },
125       [&](const std::vector<VarHandle>& v) {
126         auto const& n = v[0];
127         auto const& k = v[1];
128         auto const& oh = v[2];
129         auto const& ow = v[3];
130         auto const& c = v[4];
131         auto const& r = v[5];
132         auto const& s = v[6];
133         auto cond = CompareSelect::make(oh * stride - pad + r, 0, 1, 0, kLT);
134         cond = CompareSelect::make(ow * stride - pad + s, 0, 1, cond, kLT);
135         cond = CompareSelect::make(oh * stride - pad + r, H, 1, cond, kGE);
136         cond = CompareSelect::make(ow * stride - pad + s, W, 1, cond, kGE);
137         auto in = ifThenElse(
138             cond,
139             0.f,
140             input.load(n, k, oh * stride - pad + r, ow * stride - pad + s));
141         return in * weight.load(k, c, r, s);
142       },
143       {C / groups, R, S});
144 }
145 
146 } // namespace
147 
conv2d_depthwise(BufHandle input,BufHandle weight,BufHandle bias,int stride,int pad,int groups)148 Tensor conv2d_depthwise(
149     BufHandle input,
150     BufHandle weight,
151     BufHandle bias,
152     int stride,
153     int pad,
154     int groups) {
155   assert_dims_constant(bias);
156   auto init_func = [&](const std::vector<VarHandle>& v) {
157     return bias.load(v[1]);
158   };
159   return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups);
160 }
161 
conv2d_depthwise(BufHandle input,BufHandle weight,int stride,int pad,int groups)162 Tensor conv2d_depthwise(
163     BufHandle input,
164     BufHandle weight,
165     int stride,
166     int pad,
167     int groups) {
168   auto init_func = [](const std::vector<VarHandle>& v) {
169     return ExprHandle(Sum().initializer());
170   };
171   return conv2d_depthwise_static(input, weight, init_func, stride, pad, groups);
172 }
173 
conv2d_depthwise(BufHandle input,BufHandle weight,BufHandle bias,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)174 Tensor conv2d_depthwise(
175     BufHandle input,
176     BufHandle weight,
177     BufHandle bias,
178     ExprHandle N,
179     ExprHandle C,
180     ExprHandle H,
181     ExprHandle W,
182     ExprHandle K,
183     ExprHandle CperG,
184     ExprHandle R,
185     ExprHandle S,
186     ExprHandle stride,
187     ExprHandle pad,
188     ExprHandle groups) {
189   assert_dims_constant(bias);
190   auto init_func = [&](const std::vector<VarHandle>& v) {
191     return bias.load(v[1]);
192   };
193   return conv2d_depthwise_dynamic(
194       input,
195       weight,
196       init_func,
197       N,
198       C,
199       H,
200       W,
201       K,
202       CperG,
203       R,
204       S,
205       stride,
206       pad,
207       groups);
208 }
209 
conv2d_depthwise(BufHandle input,BufHandle weight,ExprHandle N,ExprHandle C,ExprHandle H,ExprHandle W,ExprHandle K,ExprHandle CperG,ExprHandle R,ExprHandle S,ExprHandle stride,ExprHandle pad,ExprHandle groups)210 Tensor conv2d_depthwise(
211     BufHandle input,
212     BufHandle weight,
213     ExprHandle N,
214     ExprHandle C,
215     ExprHandle H,
216     ExprHandle W,
217     ExprHandle K,
218     ExprHandle CperG,
219     ExprHandle R,
220     ExprHandle S,
221     ExprHandle stride,
222     ExprHandle pad,
223     ExprHandle groups) {
224   auto init_func = [](const std::vector<VarHandle>& v) {
225     return ExprHandle(Sum().initializer());
226   };
227   return conv2d_depthwise_dynamic(
228       input,
229       weight,
230       init_func,
231       N,
232       C,
233       H,
234       W,
235       K,
236       CperG,
237       R,
238       S,
239       stride,
240       pad,
241       groups);
242 }
243 
_pair_int(ArgValue v)244 static std::vector<int64_t> _pair_int(ArgValue v) {
245   if (auto t = std::get_if<IntList>(&v)) {
246     return {(*t)[0], (*t)[1]};
247   }
248   auto i = std::get<int64_t>(v);
249   return {i, i};
250 }
251 
_single_int_list(ArgValue v)252 static std::vector<int64_t> _single_int_list(ArgValue v) {
253   if (auto t = std::get_if<IntList>(&v)) {
254     return {(*t)[0]};
255   }
256   auto i = std::get<int64_t>(v);
257   return {i};
258 }
259 
conv2dIsSupported(const TensorInfo & input,const TensorInfo & weight,const TensorInfo & bias,const std::vector<int64_t> & stride,const std::vector<int64_t> & pad,const std::vector<int64_t> & dilation,int64_t groups)260 bool conv2dIsSupported(
261     const TensorInfo& input,
262     const TensorInfo& weight,
263     const TensorInfo& bias,
264     const std::vector<int64_t>& stride,
265     const std::vector<int64_t>& pad,
266     const std::vector<int64_t>& dilation,
267     int64_t groups) {
268   if (input.dtype != c10::ScalarType::Float ||
269       weight.dtype != c10::ScalarType::Float ||
270       bias.dtype != c10::ScalarType::Float) {
271     GRAPH_DEBUG("conv2dIsSupported: only float32 allowed");
272     return false;
273   }
274   if (input.dims.size() != 4 || weight.dims.size() != 4 ||
275       bias.dims.size() != 1) {
276     GRAPH_DEBUG("conv2dIsSupported: inputs are the wrong size");
277     return false;
278   }
279   auto Cin = input.dims[1];
280   auto Cout = weight.dims[0];
281   auto CperG = weight.dims[1];
282   if (Cin != Cout || Cin != groups || CperG != 1) {
283     GRAPH_DEBUG("conv2dIsSupported: not depthwise");
284     return false;
285   }
286   auto KH = weight.dims[2];
287   auto KW = weight.dims[3];
288   if (KH != 3 || KW != 3) {
289     GRAPH_DEBUG("conv2dIsSupported: not 3x3");
290     return false;
291   }
292   if (stride.size() != 2 || stride[0] != stride[1]) {
293     GRAPH_DEBUG("conv2dIsSupported: unsupported stride");
294     return false;
295   }
296   if (pad.size() != 2 || pad[0] != pad[1]) {
297     GRAPH_DEBUG("conv2dIsSupported: unsupported pad");
298     return false;
299   }
300   if (dilation.size() != 2 || dilation[0] != 1 || dilation[1] != 1) {
301     GRAPH_DEBUG("conv2dIsSupported: unsupported dilation");
302     return false;
303   }
304   return true;
305 }
306 
mkldnnPrepackedConvIsSupported(const TensorInfo & input,const TensorInfo & weight,const std::vector<int64_t> & stride,const std::vector<int64_t> & pad,const std::vector<int64_t> & dilation,int64_t groups)307 bool mkldnnPrepackedConvIsSupported(
308     const TensorInfo& input,
309     const TensorInfo& weight,
310     const std::vector<int64_t>& stride,
311     const std::vector<int64_t>& pad,
312     const std::vector<int64_t>& dilation,
313     int64_t groups) {
314 #if AT_MKLDNN_ENABLED()
315   if (input.dtype != c10::ScalarType::Float ||
316       weight.dtype != c10::ScalarType::Float) {
317     GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: only float32 allowed");
318     return false;
319   }
320   if (input.dims.size() != 4 || weight.dims.size() != 4) {
321     GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: inputs are the wrong size");
322     return false;
323   }
324   if (stride.size() != 2) {
325     GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported stride");
326     return false;
327   }
328   if (pad.size() != 2) {
329     GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported pad");
330     return false;
331   }
332   if (dilation.size() != 2) {
333     GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: unsupported dilation");
334     return false;
335   }
336 
337   // Do not rewrite for cases where native is faster than mkldnn
338   // Conditions are from: aten/src/ATen/native/Convolution.cpp:use_mkldnn
339   bool use_mkldnn = groups > 1 || (weight.dims[2] > 3 && weight.dims[3] > 3) ||
340       input.dims[0] > 1 ||
341       input.dims[0] * input.dims[1] * input.dims[2] * input.dims[3] > 20480;
342   GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: ", use_mkldnn);
343   return use_mkldnn;
344 #endif
345   return false;
346 }
347 
computeConv2d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)348 Tensor computeConv2d(
349     const std::vector<ArgValue>& inputs,
350     const std::vector<ExprHandle>& outputShape,
351     const std::vector<ExprHandle>& outputStrides,
352     const std::optional<ScalarType>& outputType,
353     at::Device device) {
354   Dtype dtype = kFloat;
355   if (outputType) {
356     dtype = Dtype(*outputType);
357   }
358 
359   BufHandle ResultBuf("conv", outputShape, dtype);
360   const BufHandle& inp = std::get<BufHandle>(inputs[0]);
361   const BufHandle& w = std::get<BufHandle>(inputs[1]);
362   const BufHandle& b = std::get<BufHandle>(inputs[2]);
363 
364   auto strides = _pair_int(inputs[3]);
365   auto padding = _pair_int(inputs[4]);
366   auto dilation = _pair_int(inputs[5]);
367 
368   int groups = std::get<int64_t>(inputs[6]);
369 
370   auto inpInfo = getTensorInfo(inp);
371   auto wInfo = getTensorInfo(w);
372   auto bInfo = getTensorInfo(b);
373   // Generate TE for depthwise convolutions.
374   if (inpInfo && wInfo && bInfo &&
375       conv2dIsSupported(
376           *inpInfo, *wInfo, *bInfo, strides, padding, dilation, groups)) {
377     return conv2d_depthwise(inp, w, b, strides[0], padding[0], groups);
378   }
379 
380   // Once we have a performant TE representation for conv2d, we could use it
381   // here instead of the external call!
382   StmtPtr s = ExternalCall::make(
383       ResultBuf,
384       "nnc_aten_conv2d",
385       {inp, w, b},
386       {strides[0],
387        strides[1],
388        padding[0],
389        padding[1],
390        dilation[0],
391        dilation[1],
392        groups});
393   return Tensor(ResultBuf.node(), s);
394 }
395 
computeConv1d(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)396 Tensor computeConv1d(
397     const std::vector<ArgValue>& inputs,
398     const std::vector<ExprHandle>& outputShape,
399     const std::vector<ExprHandle>& outputStrides,
400     const std::optional<ScalarType>& outputType,
401     at::Device device) {
402   Dtype dtype = kFloat;
403   if (outputType) {
404     dtype = Dtype(*outputType);
405   }
406 
407   BufHandle ResultBuf("conv", outputShape, dtype);
408   const BufHandle& inp = std::get<BufHandle>(inputs[0]);
409   const BufHandle& w = std::get<BufHandle>(inputs[1]);
410   const BufHandle& b = std::get<BufHandle>(inputs[2]);
411 
412   auto strides = _single_int_list(inputs[3]);
413   auto padding = _single_int_list(inputs[4]);
414   auto dilation = _single_int_list(inputs[5]);
415 
416   int groups = std::get<int64_t>(inputs[6]);
417 
418   auto inpInfo = getTensorInfo(inp);
419   auto wInfo = getTensorInfo(w);
420   auto bInfo = getTensorInfo(b);
421 
422   StmtPtr s = ExternalCall::make(
423       ResultBuf,
424       "nnc_aten_conv1d",
425       {inp, w, b},
426       {strides[0], padding[0], dilation[0], groups});
427   return Tensor(ResultBuf.node(), s);
428 }
429 
computePrepackedConv2dClampRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)430 Tensor computePrepackedConv2dClampRun(
431     const std::vector<ArgValue>& inputs,
432     const std::vector<ExprHandle>& outputShape,
433     const std::vector<ExprHandle>& outputStrides,
434     const std::optional<ScalarType>& outputType,
435     at::Device device) {
436   Dtype dtype = kFloat;
437   if (outputType) {
438     dtype = Dtype(*outputType);
439   }
440 
441   BufHandle ResultBuf("prepacked_conv2d_clamp_run", outputShape, dtype);
442   const BufHandle& inp = std::get<BufHandle>(inputs[0]);
443   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
444   StmtPtr s = ExternalCall::make(
445       ResultBuf, "nnc_prepacked_conv2d_clamp_run", {inp, prepacked}, {});
446   return Tensor(ResultBuf.node(), s);
447 }
448 
computePrepackedLinearClampRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)449 Tensor computePrepackedLinearClampRun(
450     const std::vector<ArgValue>& inputs,
451     const std::vector<ExprHandle>& outputShape,
452     const std::vector<ExprHandle>& outputStrides,
453     const std::optional<ScalarType>& outputType,
454     at::Device device) {
455   Dtype dtype = kFloat;
456   if (outputType) {
457     dtype = Dtype(*outputType);
458   }
459 
460   BufHandle ResultBuf("prepacked_linear_clamp_run", outputShape, dtype);
461   const BufHandle& inp = std::get<BufHandle>(inputs[0]);
462   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
463   StmtPtr s = ExternalCall::make(
464       ResultBuf, "nnc_prepacked_linear_clamp_run", {inp, prepacked}, {});
465   return Tensor(ResultBuf.node(), s);
466 }
467 
computeMkldnnPrepackedConvRun(const std::vector<ArgValue> & inputs,const std::vector<ExprHandle> & outputShape,const std::vector<ExprHandle> & outputStrides,const std::optional<ScalarType> & outputType,at::Device device)468 Tensor computeMkldnnPrepackedConvRun(
469     const std::vector<ArgValue>& inputs,
470     const std::vector<ExprHandle>& outputShape,
471     const std::vector<ExprHandle>& outputStrides,
472     const std::optional<ScalarType>& outputType,
473     at::Device device) {
474   Dtype dtype = kFloat;
475   if (outputType) {
476     dtype = Dtype(*outputType);
477   }
478 
479   BufHandle ResultBuf(
480       "mkldnn_prepacked_conv_run", outputShape, outputStrides, dtype);
481   const BufHandle& inp = std::get<BufHandle>(inputs[0]);
482   const BufHandle& prepacked = std::get<BufHandle>(inputs[1]);
483   StmtPtr s = ExternalCall::make(
484       ResultBuf, "nnc_mkldnn_prepacked_conv_run", {inp, prepacked}, {});
485   return Tensor(ResultBuf.node(), s);
486 }
487 
488 } // namespace torch::jit::tensorexpr
489