xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ConvolutionTBC.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4 #include <tuple>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/conv_tbc_backward_native.h>
11 #include <ATen/ops/conv_tbc_native.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/zeros_like.h>
14 #endif
15 
16 namespace at::native {
17 
conv_tbc(const Tensor & self,const Tensor & weight,const Tensor & bias,int64_t pad)18 Tensor conv_tbc(const Tensor& self, const Tensor& weight, const Tensor& bias, int64_t pad) {
19   TORCH_CHECK(self.dim() == 3, "Input must have 3 dims: time, batch, "
20       "in_channel");
21   TORCH_CHECK(weight.dim() == 3, "Weight tensor must have 3 dims: kernel_width,"
22       " in_channels, out_channels.");
23   TORCH_CHECK(bias.dim() == 1, "Bias must be 1-D");
24 
25   auto input_size = self.sizes();
26   auto weight_size = weight.sizes();
27 
28   auto ilen = input_size[0];
29   auto batchSize = input_size[1];
30   auto inputPlanes = input_size[2];
31   auto outputPlanes = weight_size[2];
32   auto kw = weight_size[0];
33   auto olen = input_size[0] - kw + 1 + pad * 2;
34   auto real_pad = (olen - ilen + kw - 1) / 2;
35 
36   // Make sure shapes are correct.
37   // Input = (time, batch, in_channels)
38   // Weight = (kernel_width, in_channels, out_channels)
39   // Bias = (out_channels)
40   TORCH_CHECK(inputPlanes == weight_size[1], "Input dim 2 (input channels) "
41       "is not == dim 1 in the weight tensor");
42   TORCH_CHECK(weight_size[2] == bias.sizes()[0], "Bias size must equal dim 2 in "
43       "the weight tensor (output channels).");
44 
45   // input * weights + bias -> output_features
46   Tensor output = at::empty({
47     olen,
48     input_size[1],
49     weight_size[2],
50   }, self.options());
51   output.copy_(bias.expand(output.sizes()));
52   for (const auto k : c10::irange(kw)) {
53     int iShift = std::max(0, static_cast<int>(k - real_pad));
54     int oShift = std::max(0, static_cast<int>(real_pad - k));
55     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
56     int t = std::min(ilen + real_pad - k, olen) - oShift;
57     // Note: gemm assumes column-major matrices
58     // input    is l*m (row-major)
59     // weight   is m*r (row-major)
60     // output   is l*r (row-major)
61     if (t > 0) {
62       auto W = weight[k];
63       auto I = self.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
64       auto O = output.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
65       O.addmm_(I, W);
66     }
67   }
68   return output;
69 }
70 
conv_tbc_backward(const Tensor & dOutput,const Tensor & input,const Tensor & weight,const Tensor & bias,int64_t pad)71 std::tuple<Tensor, Tensor, Tensor> conv_tbc_backward(const Tensor& dOutput, const Tensor& input, const Tensor& weight, const Tensor& bias, int64_t pad) {
72   auto input_size = input.sizes();
73   auto weight_size = weight.sizes();
74 
75   auto ilen = input_size[0];
76   auto batchSize = input_size[1];
77   auto inputPlanes = input_size[2];
78   auto outputPlanes = weight_size[2];
79   auto kw = weight.sizes()[0];
80   auto olen = input_size[0] - kw + 1 + pad * 2;
81   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
82   int real_pad = (olen - ilen + kw - 1) / 2;
83 
84   Tensor dInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
85   for (int k = 0; k < kw; k++) {
86     int iShift = std::max(0, k - real_pad);
87     int oShift = std::max(0, real_pad - k);
88     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
89     int t = std::min(ilen + real_pad - k, olen) - oShift;
90     // dOutput * T(weight) -> dInput
91     if (t > 0) {
92       auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
93       auto dI = dInput.narrow(0, iShift, t).view({t * batchSize, inputPlanes});
94       dI.addmm_(dO, weight[k].t());
95     }
96   }
97 
98   Tensor dWeight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
99   for (int k = 0; k < kw; k++) {
100     int iShift = std::max(0, k - real_pad);
101     int oShift = std::max(0, real_pad - k);
102     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
103     int t = std::min(ilen + real_pad - k, olen) - oShift;
104     // T(input) * dOutput -> dWeight
105     if (t > 0) {
106       auto dW = dWeight[k];
107       auto dO = dOutput.narrow(0, oShift, t).view({t * batchSize, outputPlanes});
108       auto I = input.narrow(0, iShift, t).view({t * batchSize, inputPlanes}).t();
109       dW.addmm_(I, dO);
110     }
111   }
112 
113   Tensor dBias = at::zeros_like(bias, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
114   auto tmp = dOutput.sum(0, false);
115   dBias.copy_(tmp.sum(0));
116 
117   return std::make_tuple(dInput, dWeight, dBias);
118 }
119 
120 } // namespace at::native
121