xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/RNN.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/RNN.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/core/List.h>
6 #include <ATen/Context.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/mps/MPSDevice.h>
9 #include <ATen/native/quantized/PackedParams.h>
10 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
11 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
12 #include <c10/core/GradMode.h>
13 #include <c10/macros/Macros.h>
14 #include <c10/util/irange.h>
15 #include <torch/custom_class.h>
16 #include <torch/library.h>
17 #include <ATen/Config.h>
18 #if AT_MKLDNN_ENABLED()
19 #include <ATen/native/mkldnn/Utils.h>
20 #endif
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_lstm_mps.h>
27 #include <ATen/ops/_thnn_differentiable_gru_cell_backward_native.h>
28 #include <ATen/ops/_thnn_differentiable_lstm_cell_backward_native.h>
29 #include <ATen/ops/_thnn_fused_gru_cell.h>
30 #include <ATen/ops/_thnn_fused_lstm_cell.h>
31 #include <ATen/ops/_thnn_fused_lstm_cell_backward.h>
32 #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl.h>
33 #include <ATen/ops/_thnn_fused_lstm_cell_backward_native.h>
34 #include <ATen/ops/_use_cudnn_rnn_flatten_weight_native.h>
35 #include <ATen/ops/cat.h>
36 #include <ATen/ops/cudnn_is_acceptable.h>
37 #include <ATen/ops/dropout.h>
38 #include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
39 #include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
40 #include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
41 #include <ATen/ops/gru_cell_native.h>
42 #include <ATen/ops/gru_native.h>
43 #include <ATen/ops/linear.h>
44 #include <ATen/ops/lstm_cell_native.h>
45 #include <ATen/ops/lstm_native.h>
46 #include <ATen/ops/matmul.h>
47 #include <ATen/ops/quantized_gru_cell_native.h>
48 #include <ATen/ops/quantized_lstm_cell_native.h>
49 #include <ATen/ops/quantized_rnn_relu_cell_native.h>
50 #include <ATen/ops/quantized_rnn_tanh_cell_native.h>
51 #include <ATen/ops/relu.h>
52 #include <ATen/ops/rnn_relu_cell_native.h>
53 #include <ATen/ops/rnn_relu_native.h>
54 #include <ATen/ops/rnn_tanh_cell_native.h>
55 #include <ATen/ops/rnn_tanh_native.h>
56 #include <ATen/ops/sigmoid_backward.h>
57 #include <ATen/ops/stack.h>
58 #include <ATen/ops/tanh.h>
59 #include <ATen/ops/tanh_backward.h>
60 #include <ATen/ops/zeros_like.h>
61 #include <ATen/ops/zeros_like_ops.h>
62 #include <utility>
63 #endif
64 
65 int register_linear_params();
66 
67 namespace at::native {
68 
69 namespace {
70 
71 // Check if pytorch is compiled with MIOpen.
use_miopen(const at::Tensor & input,const double dropout_state)72 bool use_miopen(const at::Tensor& input, const double dropout_state) {
73     bool is_miopen_acceptable = ((input.scalar_type() == at::kFloat)|| (input.scalar_type() == at::kHalf)) &&
74                                 (detail::getCUDAHooks().compiledWithMIOpen()) &&
75                                 (input.is_cuda()) &&
76                                 (at::globalContext().userEnabledCuDNN());
77     // MIOpen functions returns miopenStatusBadParm on empty
78     // tensors. Maybe some functions actually support empty tensors, but
79     // native kernels shouldn't be much slower because the output is also
80     // likely empty.
81     if (input.sym_numel() == 0) return false;
82 
83     return is_miopen_acceptable;
84 }
85 
use_mkldnn(const Tensor & input,TensorList params,TensorList hx)86 bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) {
87 #if AT_MKLDNN_ENABLED()
88   if (!at::globalContext().userEnabledMkldnn()) {
89     return false;
90   }
91   auto is_cpu_backend = [&](const TensorList tensors) {
92     bool backend_cpu = true;
93     for (const auto& t : tensors) {
94       if (!(t.options().backend() == at::Backend::CPU)) {
95         backend_cpu = false;
96         break;
97       }
98     }
99     return backend_cpu;
100   };
101   return input.options().backend() == at::Backend::CPU &&
102       is_cpu_backend(params) && is_cpu_backend(hx) &&
103       (input.scalar_type() == kFloat ||
104        (input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) ||
105        (input.scalar_type() == kHalf && !at::GradMode::is_enabled() &&
106         mkldnn_fp16_device_check())) &&
107       input.numel() != 0;
108 #endif
109   return false;
110 }
111 
112 template<typename T>
113 using pair_of = std::pair<T, T>;
114 
115 template<typename T>
116 using tpair_of = std::tuple<T, T>;
117 
118 // Those could have been function pointers, but MSVC chokes on function pointers as template parameters
119 struct tanh_f {
operator ()at::native::__anon694b7cd90111::tanh_f120   Tensor operator()(const Tensor& t) const { return at::tanh(t); }
121 };
122 
123 struct relu_f {
operator ()at::native::__anon694b7cd90111::relu_f124   Tensor operator()(const Tensor& t) const { return at::relu(t); }
125 };
126 
127 struct PackedSequence {
128   PackedSequence() = default;
PackedSequenceat::native::__anon694b7cd90111::PackedSequence129   PackedSequence(Tensor _data, Tensor _batch_sizes)
130     : data(std::move(_data)), batch_sizes(std::move(_batch_sizes)) {}
131 
132   Tensor data;
133   Tensor batch_sizes;
134 };
135 
136 // Simple type for __getstate__/__setstate__ serialization
137 //
138 // Element 0 is a string key to say what kind of CellParam this is. It
139 // should be a valid key into cell_params_deserializers
140 // Element 1 is the Tensors contained within the CellParams instance
141 // Element 2 is the doubles (if any) contained in the CellParams instance
142 // Element 3 is the longs (if any) contained within the CellParams instance
143 using CellParamsSerializationType = std::tuple<
144     std::string,
145     std::vector<at::Tensor>,
146     std::vector<double>,
147     std::vector<int64_t>,
148     std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>>;
149 
150 // Base class so we can polymorphically handle these
151 struct CellParamsBase : torch::CustomClassHolder {
152   virtual Tensor matmul_ih(const Tensor& input) const = 0;
153   virtual Tensor matmul_hh(const Tensor& h) const = 0;
154   // by default doing nothing. CellParams will override this
155   // to define correct behavior for LSTMs with projections.
156   // This function is not pure virtual, because it's useful to
157   // provide this default implementation, so that all cell params
158   // that don't support projections work correctly (e.g. QuantizedCellParams variations)
matmul_hrat::native::__anon694b7cd90111::CellParamsBase159   virtual Tensor matmul_hr(const Tensor& h) const {
160     return h;
161   }
162   virtual Tensor linear_ih(const Tensor& input_ih) const = 0;
163   virtual Tensor linear_hh(const Tensor& input_hh) const = 0;
164 
165   virtual const Tensor& b_ih() const = 0;
166   virtual const Tensor& b_hh() const = 0;
167 
168   virtual CellParamsSerializationType __getstate__() const = 0;
169 };
170 
171 // Pretty much all cells we support take the same set of arguments, but threading those
172 // 4 arguments manually is really annoying. Their lifetime is externally managed, so we only
173 // pass this struct of references around. LSTMs with projections have 5th argument w_hr, for all
174 // other models it's always going to be undefined.
175 struct CellParams : public CellParamsBase {
CellParamsat::native::__anon694b7cd90111::CellParams176   CellParams(
177       const Tensor& _w_ih,
178       const Tensor& _w_hh,
179       const Tensor& _b_ih,
180       const Tensor& _b_hh,
181       const Tensor& _w_hr)
182       : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {};
183 
184   const Tensor& w_ih;
185   const Tensor& w_hh;
186   const Tensor& b_ih_; /* optional */
187   const Tensor& b_hh_; /* optional */
188   const Tensor& w_hr;  /* only defined for LSTMs with projections */
189 
matmul_ihat::native::__anon694b7cd90111::CellParams190   Tensor matmul_ih(const Tensor& input) const override {
191     return at::matmul(input, w_ih.t());
192   }
matmul_hhat::native::__anon694b7cd90111::CellParams193   Tensor matmul_hh(const Tensor& h) const override {
194     return at::matmul(h, w_hh.t());
195   }
matmul_hrat::native::__anon694b7cd90111::CellParams196   Tensor matmul_hr(const Tensor& h) const override {
197     if (w_hr.defined()) {
198       return at::matmul(h, w_hr.t());
199     }
200     return h;
201   }
linear_ihat::native::__anon694b7cd90111::CellParams202   Tensor linear_ih(const Tensor& input) const override {
203     return at::linear(input, w_ih, b_ih_);
204   }
linear_hhat::native::__anon694b7cd90111::CellParams205   Tensor linear_hh(const Tensor& h) const override {
206     return at::linear(h, w_hh, b_hh_);
207   }
b_ihat::native::__anon694b7cd90111::CellParams208   const Tensor& b_ih() const override {
209     return b_ih_;
210   }
b_hhat::native::__anon694b7cd90111::CellParams211   const Tensor& b_hh() const override {
212     return b_hh_;
213   }
__getstate__at::native::__anon694b7cd90111::CellParams214   CellParamsSerializationType __getstate__() const override {
215     TORCH_INTERNAL_ASSERT(false, "Not yet implemented");
216   }
__setstate__at::native::__anon694b7cd90111::CellParams217   static c10::intrusive_ptr<CellParamsBase> __setstate__(
218       const CellParamsSerializationType& state) {
219     TORCH_INTERNAL_ASSERT(false, "Not yet implemented");
220   }
221 };
222 
223 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
224     const at::Tensor& w_ih,
225     const at::Tensor& w_hh,
226     at::Tensor bias_ih,
227     at::Tensor bias_hh);
228 
229 struct QuantizedCellParams : public CellParamsBase {
QuantizedCellParamsat::native::__anon694b7cd90111::QuantizedCellParams230   QuantizedCellParams(
231       Tensor _w_ih,
232       Tensor _w_hh,
233       Tensor _b_ih,
234       Tensor _b_hh,
235       Tensor _packed_ih,
236       Tensor _packed_hh,
237       Tensor _col_offsets_ih,
238       Tensor _col_offsets_hh,
239       Scalar _scale_ih,
240       Scalar _scale_hh,
241       Scalar _zero_point_ih,
242       Scalar _zero_point_hh)
243       : w_ih(std::move(_w_ih)),
244         w_hh(std::move(_w_hh)),
245         b_ih_(std::move(_b_ih)),
246         b_hh_(std::move(_b_hh)),
247         packed_ih(std::move(_packed_ih)),
248         packed_hh(std::move(_packed_hh)),
249         col_offsets_ih(std::move(_col_offsets_ih)),
250         col_offsets_hh(std::move(_col_offsets_hh)),
251         scale_ih(std::move(_scale_ih)),
252         scale_hh(std::move(_scale_hh)),
253         zero_point_ih(std::move(_zero_point_ih)),
254         zero_point_hh(std::move(_zero_point_hh)) {}
255 
256   const Tensor w_ih;
257   const Tensor w_hh;
258   const Tensor b_ih_;
259   const Tensor b_hh_;
260   const Tensor packed_ih;
261   const Tensor packed_hh;
262   const Tensor col_offsets_ih;
263   const Tensor col_offsets_hh;
264   const Scalar scale_ih;
265   const Scalar scale_hh;
266   const Scalar zero_point_ih;
267   const Scalar zero_point_hh;
268 
matmul_ihat::native::__anon694b7cd90111::QuantizedCellParams269   Tensor matmul_ih(const Tensor& input) const override {
270     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
271   }
matmul_hhat::native::__anon694b7cd90111::QuantizedCellParams272   Tensor matmul_hh(const Tensor& h) const override {
273     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
274   }
linear_ihat::native::__anon694b7cd90111::QuantizedCellParams275   Tensor linear_ih(const Tensor& input) const override {
276     return at::fbgemm_linear_int8_weight_fp32_activation(
277         input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_);
278   }
linear_hhat::native::__anon694b7cd90111::QuantizedCellParams279   Tensor linear_hh(const Tensor& h) const override {
280     return at::fbgemm_linear_int8_weight_fp32_activation(
281         h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_);
282   }
b_ihat::native::__anon694b7cd90111::QuantizedCellParams283   const Tensor& b_ih() const override {
284     return b_ih_;
285   }
b_hhat::native::__anon694b7cd90111::QuantizedCellParams286   const Tensor& b_hh() const override {
287     return b_hh_;
288   }
__getstate__at::native::__anon694b7cd90111::QuantizedCellParams289   CellParamsSerializationType __getstate__() const override {
290     std::vector<at::Tensor> tensors_to_serialize = {
291         w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh};
292     std::vector<double> doubles_to_serialize = {scale_ih.toDouble(),
293                                                 scale_hh.toDouble()};
294     std::vector<int64_t> longs_to_serialize = {zero_point_ih.toLong(),
295                                                zero_point_hh.toLong()};
296     return CellParamsSerializationType(
297         "quantized",
298         tensors_to_serialize,
299         doubles_to_serialize,
300         longs_to_serialize,
301         {});
302   }
__setstate__at::native::__anon694b7cd90111::QuantizedCellParams303   static c10::intrusive_ptr<CellParamsBase> __setstate__(
304       CellParamsSerializationType state) {
305     auto [_, tensors, doubles, longs, __] =
306         std::move(state);
307     TORCH_INTERNAL_ASSERT(tensors.size() == 6);
308     TORCH_INTERNAL_ASSERT(doubles.size() == 2);
309     TORCH_INTERNAL_ASSERT(longs.size() == 2);
310 
311     at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]),
312                b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]),
313                col_offsets_ih = std::move(tensors[4]),
314                col_offsets_hh = std::move(tensors[5]);
315     double scale_ih = doubles[0], scale_hh = doubles[1];
316     int64_t zero_point_ih = longs[0], zero_point_hh = longs[1];
317 
318     at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih);
319     at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh);
320 
321     return c10::make_intrusive<QuantizedCellParams>(
322         /*w_ih=*/std::move(qw_ih),
323         /*w_hh=*/std::move(qw_hh),
324         /*b_ih_=*/std::move(b_ih),
325         /*b_hh_=*/std::move(b_hh),
326         /*packed_ih=*/std::move(packed_ih),
327         /*packed_hh=*/std::move(packed_hh),
328         /*col_offsets_ih=*/std::move(col_offsets_ih),
329         /*col_offsets_hh=*/std::move(col_offsets_hh),
330         /*scale_ih=*/scale_ih,
331         /*scale_hh=*/scale_hh,
332         /*zero_point_ih=*/zero_point_ih,
333         /*zero_point_hh=*/zero_point_hh);
334   }
335 };
336 
make_quantized_cell_params(const at::Tensor & w_ih,const at::Tensor & w_hh,at::Tensor b_ih,at::Tensor b_hh)337 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
338     const at::Tensor& w_ih,
339     const at::Tensor& w_hh,
340     at::Tensor b_ih,
341     at::Tensor b_hh) {
342   auto make_vals = [&](const at::Tensor& W) {
343     auto params = at::native::fbgemm_linear_quantize_weight(W);
344     at::Tensor packed_weight =
345         at::native::fbgemm_pack_quantized_matrix(std::get<0>(params));
346     return std::tuple_cat(
347         std::make_tuple(std::move(packed_weight)), std::move(params));
348   };
349 
350   auto [packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih] =
351       make_vals(w_ih);
352   auto [packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh] =
353       make_vals(w_hh);
354 
355   return c10::make_intrusive<QuantizedCellParams>(
356       /*qw_ih=*/std::move(qw_ih),
357       /*qw_hh=*/std::move(qw_hh),
358       /*b_ih=*/std::move(b_ih),
359       /*b_hh=*/std::move(b_hh),
360       /*packed_ih=*/std::move(packed_ih),
361       /*packed_hh=*/std::move(packed_hh),
362       /*col_offsets_ih=*/std::move(col_offsets_ih),
363       /*col_offsets_hh=*/std::move(col_offsets_hh),
364       /*scale_ih=*/scale_ih,
365       /*scale_hh=*/scale_hh,
366       /*zero_point_ih=*/zero_point_ih,
367       /*zero_point_hh=*/zero_point_hh);
368 }
369 
370 // QuantizedCellParams vs. QuantizedCellParamsDynamic
371 //
372 // QuantizedCellParams uses the legacy
373 // fbgemm_linear_int8_weight_fp32_activation API, which requires the explicit
374 // scale and zero point parameters for the weight. QuantizedCellParamsDynamic
375 // uses the new fbgemm_linear_dynamic API, which doesn't require the explicit
376 // scale and zero point parameters. These quantization parameters are
377 // encapsulated in the `PackedLinearWeight` struct in
378 // aten/src/ATen/native/quantized/cpu/fbgemm_utils.h.
379 
380 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_dynamic(
381     c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
382     c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed,
383     at::Tensor bias_ih,
384     at::Tensor bias_hh,
385     bool reduce_range);
386 
387 struct QuantizedCellParamsDynamic : public CellParamsBase {
QuantizedCellParamsDynamicat::native::__anon694b7cd90111::QuantizedCellParamsDynamic388   QuantizedCellParamsDynamic(
389       c10::intrusive_ptr<LinearPackedParamsBase>
390           _packed_w_ih, /* Prepacked Weight Tensor */
391       c10::intrusive_ptr<LinearPackedParamsBase>
392           _packed_w_hh, /* Prepacked Weight Tensor */
393       Tensor _b_ih, /* float Bias Tensor */
394       Tensor _b_hh, /* float Bias Tensor */
395       bool _reduce_range = false /* Use reduced range for activation tensors */)
396       : packed_w_ih(std::move(_packed_w_ih)),
397         packed_w_hh(std::move(_packed_w_hh)),
398         b_ih_(std::move(_b_ih)),
399         b_hh_(std::move(_b_hh)),
400         reduce_range_(_reduce_range) {}
401 
402   c10::intrusive_ptr<LinearPackedParamsBase> packed_w_ih;
403   c10::intrusive_ptr<LinearPackedParamsBase> packed_w_hh;
404   const Tensor b_ih_;
405   const Tensor b_hh_;
406   bool reduce_range_;
407 
matmul_ihat::native::__anon694b7cd90111::QuantizedCellParamsDynamic408   Tensor matmul_ih(const Tensor& input) const override {
409     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
410   }
matmul_hhat::native::__anon694b7cd90111::QuantizedCellParamsDynamic411   Tensor matmul_hh(const Tensor& h) const override {
412     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
413   }
414 
linear_ihat::native::__anon694b7cd90111::QuantizedCellParamsDynamic415   Tensor linear_ih(const Tensor& input_ih) const override {
416     return packed_w_ih->apply_dynamic(input_ih, reduce_range_);
417   }
linear_hhat::native::__anon694b7cd90111::QuantizedCellParamsDynamic418   Tensor linear_hh(const Tensor& input_hh) const override {
419     return packed_w_hh->apply_dynamic(input_hh, reduce_range_);
420   }
421 
b_ihat::native::__anon694b7cd90111::QuantizedCellParamsDynamic422   const Tensor& b_ih() const override {
423     return b_ih_;
424   }
b_hhat::native::__anon694b7cd90111::QuantizedCellParamsDynamic425   const Tensor& b_hh() const override {
426     return b_hh_;
427   }
__getstate__at::native::__anon694b7cd90111::QuantizedCellParamsDynamic428   CellParamsSerializationType __getstate__() const override {
429     std::vector<at::Tensor> tensors_to_serialize{
430         /*b_ih=*/b_ih_,
431         /*b_hh=*/b_hh_,
432     };
433 
434     std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>
435         packed_params_to_serialize{packed_w_ih, packed_w_hh};
436 
437     // reduce_range parameter is serialized along with the int field values.
438     return CellParamsSerializationType(
439         "quantized_dynamic",
440         tensors_to_serialize,
441         {},
442         {reduce_range_},
443         packed_params_to_serialize);
444   }
__setstate__at::native::__anon694b7cd90111::QuantizedCellParamsDynamic445   static c10::intrusive_ptr<CellParamsBase> __setstate__(
446       CellParamsSerializationType state) {
447     auto [_, tensors, __, serialized_ints, packed_params] =
448         std::move(state);
449     TORCH_INTERNAL_ASSERT(tensors.size() == 2);
450     TORCH_INTERNAL_ASSERT(packed_params.size() == 2);
451 
452     bool reduce_range = serialized_ints.empty() ? false : serialized_ints[0];
453     return make_quantized_cell_params_dynamic(
454         /*w_ih_packed=*/std::move(packed_params[0]),
455         /*w_hh_packed=*/std::move(packed_params[1]),
456         /*bias_ih=*/std::move(tensors[0]),
457         /*bias_hh=*/std::move(tensors[1]),
458         /*reduce_range=*/reduce_range);
459   }
460 };
461 
make_quantized_cell_params_dynamic(c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed,at::Tensor bias_ih,at::Tensor bias_hh,bool reduce_range)462 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_dynamic(
463     c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
464     c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed,
465     at::Tensor bias_ih,
466     at::Tensor bias_hh,
467     bool reduce_range) {
468 
469   return c10::make_intrusive<QuantizedCellParamsDynamic>(
470       /*_packed_w_ih=*/std::move(w_ih_packed),
471       /*_packed_w_hh=*/std::move(w_hh_packed),
472       /*_b_ih=*/std::move(bias_ih),
473       /*_b_hh=*/std::move(bias_hh),
474       /*_reduce_range=*/reduce_range);
475 }
476 
477 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_fp16(
478     c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
479     c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed);
480 
481 struct QuantizedCellParamsFP16 : public CellParamsBase {
QuantizedCellParamsFP16at::native::__anon694b7cd90111::QuantizedCellParamsFP16482   QuantizedCellParamsFP16(
483       c10::intrusive_ptr<LinearPackedParamsBase> _packed_ih,
484       c10::intrusive_ptr<LinearPackedParamsBase> _packed_hh)
485       : packed_ih(std::move(_packed_ih)), packed_hh(std::move(_packed_hh)) {}
486 
487   c10::intrusive_ptr<LinearPackedParamsBase> packed_ih;
488   c10::intrusive_ptr<LinearPackedParamsBase> packed_hh;
489   const Tensor b_ih_;
490   const Tensor b_hh_;
491 
matmul_ihat::native::__anon694b7cd90111::QuantizedCellParamsFP16492   Tensor matmul_ih(const Tensor& /* unused */) const override {
493     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
494   }
matmul_hhat::native::__anon694b7cd90111::QuantizedCellParamsFP16495   Tensor matmul_hh(const Tensor& /* unused */) const override {
496     TORCH_CHECK(false, "matmul is not supported with quantized cell params");
497   }
linear_ihat::native::__anon694b7cd90111::QuantizedCellParamsFP16498   Tensor linear_ih(const Tensor& input) const override {
499     return packed_ih->apply_dynamic(input);
500   }
linear_hhat::native::__anon694b7cd90111::QuantizedCellParamsFP16501   Tensor linear_hh(const Tensor& h) const override {
502     return packed_hh->apply_dynamic(h);
503   }
504 
b_ihat::native::__anon694b7cd90111::QuantizedCellParamsFP16505   const Tensor& b_ih() const override {
506     return b_ih_;
507   }
b_hhat::native::__anon694b7cd90111::QuantizedCellParamsFP16508   const Tensor& b_hh() const override {
509     return b_hh_;
510   }
__getstate__at::native::__anon694b7cd90111::QuantizedCellParamsFP16511   CellParamsSerializationType __getstate__() const override {
512     std::vector<c10::intrusive_ptr<LinearPackedParamsBase>>
513         packed_params_to_serialize{packed_ih, packed_hh};
514 
515     return CellParamsSerializationType(
516         "quantized_fp16", {}, {}, {}, packed_params_to_serialize);
517   }
__setstate__at::native::__anon694b7cd90111::QuantizedCellParamsFP16518   static c10::intrusive_ptr<CellParamsBase> __setstate__(
519       CellParamsSerializationType state) {
520     auto packed_params = std::get<4>(std::move(state));
521     TORCH_INTERNAL_ASSERT(packed_params.size() == 2);
522     return make_quantized_cell_params_fp16(
523         /*w_ih_packed=*/std::move(packed_params[0]),
524         /*w_hh_packed=*/std::move(packed_params[1]));
525   }
526 };
527 
make_quantized_cell_params_fp16(c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed)528 c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params_fp16(
529     c10::intrusive_ptr<LinearPackedParamsBase> w_ih_packed,
530     c10::intrusive_ptr<LinearPackedParamsBase> w_hh_packed) {
531   return c10::make_intrusive<QuantizedCellParamsFP16>(
532       std::move(w_ih_packed), std::move(w_hh_packed));
533 }
534 
535 static std::unordered_map<
536     std::string,
537     c10::intrusive_ptr<CellParamsBase> (*)(CellParamsSerializationType)>
538     cell_params_deserializers = {
539         {"quantized", &QuantizedCellParams::__setstate__},
540         {"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__},
541         {"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}};
542 
543 // Stupid wrapper to convert from -> to .
544 struct QRNNCellParamsWrapper {
QRNNCellParamsWrapperat::native::__anon694b7cd90111::QRNNCellParamsWrapper545   QRNNCellParamsWrapper(c10::intrusive_ptr<CellParamsBase> param)
546       : param_(std::move(param)) {}
547 
matmul_ihat::native::__anon694b7cd90111::QRNNCellParamsWrapper548   Tensor matmul_ih(const Tensor& input) const {
549     return param_->matmul_ih(input);
550   }
matmul_hhat::native::__anon694b7cd90111::QRNNCellParamsWrapper551   Tensor matmul_hh(const Tensor& h) const {
552     return param_->matmul_hh(h);
553   }
matmul_hrat::native::__anon694b7cd90111::QRNNCellParamsWrapper554   Tensor matmul_hr(const Tensor& h) const {
555     return param_->matmul_hr(h);
556   }
linear_ihat::native::__anon694b7cd90111::QRNNCellParamsWrapper557   Tensor linear_ih(const Tensor& input) const {
558     return param_->linear_ih(input);
559   }
linear_hhat::native::__anon694b7cd90111::QRNNCellParamsWrapper560   Tensor linear_hh(const Tensor& h) const {
561     return param_->linear_hh(h);
562   }
b_ihat::native::__anon694b7cd90111::QRNNCellParamsWrapper563   const Tensor& b_ih() const {
564     return param_->b_ih();
565   }
b_hhat::native::__anon694b7cd90111::QRNNCellParamsWrapper566   const Tensor& b_hh() const {
567     return param_->b_hh();
568   }
569 
570   c10::intrusive_ptr<CellParamsBase> param_;
571 };
572 
573 // Gathers every two elements of a vector in a vector of pairs
574 template<typename T>
pair_vec(const std::vector<T> & vals)575 static std::vector<pair_of<T>> pair_vec(const std::vector<T>& vals) {
576   TORCH_CHECK(vals.size() % 2 == 0, "Odd number of params or hiddens given to a bidirectional RNN");
577   std::vector<pair_of<T>> result;
578   result.reserve(vals.size() / 2);
579   for (size_t i = 0; i < vals.size(); i += 2) {
580     result.emplace_back(vals[i], vals[i + 1]);
581   }
582   return result;
583 }
584 
585 // Flattens a vector of pairs
586 template<typename T>
unpair_vec(std::vector<pair_of<T>> && vals)587 static std::vector<T> unpair_vec(std::vector<pair_of<T>>&& vals) {
588   std::vector<T> result;
589   result.reserve(vals.size() * 2);
590   for (const auto i : c10::irange(vals.size())) {
591     result.push_back(std::move(vals[i].first));
592     result.push_back(std::move(vals[i].second));
593   }
594   return result;
595 }
596 
597 // Parses a flat list of parameter tensors into a list of CellParams
gather_params(TensorList params,bool has_biases,bool has_projections=false)598 static std::vector<CellParams> gather_params(TensorList params, bool has_biases, bool has_projections = false) {
599   static at::Tensor undefined;
600   std::vector<CellParams> result;
601   if (has_biases) {
602     if (has_projections) {
603       TORCH_CHECK(params.size() % 5 == 0, "got an incorrect number of RNN parameters");
604       for (size_t i = 0; i < params.size(); i += 5) {
605         result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], params[i + 4]);
606       }
607     } else {
608       TORCH_CHECK(params.size() % 4 == 0, "got an incorrect number of RNN parameters");
609       for (size_t i = 0; i < params.size(); i += 4) {
610         result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3], undefined);
611       }
612     }
613   } else {
614     if (has_projections) {
615       TORCH_CHECK(params.size() % 3 == 0, "got an incorrect number of RNN parameters");
616       for (size_t i = 0; i < params.size(); i += 3) {
617         result.emplace_back(params[i], params[i + 1], undefined, undefined, params[i + 2]);
618       }
619     } else {
620       TORCH_CHECK(params.size() % 2 == 0, "got an incorrect number of RNN parameters");
621       for (size_t i = 0; i < params.size(); i += 2) {
622         result.emplace_back(params[i], params[i + 1], undefined, undefined, undefined);
623       }
624     }
625   }
626   return result;
627 }
628 
629 ////////////////////////////////////////////////////////////////////////////////
630 // HIDDEN STATE FUNCTIONS
631 //
632 // Functions implemented below are implemented as templates based on hidden type,
633 // because they need to work both with simple RNNs and GRU (which use a single Tensor),
634 // as well as with LSTM (or possibly more complicated architectures in the future).
635 // Still, there are some operations that need to be performed on the hidden states
636 // alone, and for this purpose we provide an overloaded set of functions below.
637 
hidden_as_output(const Tensor & t)638 Tensor hidden_as_output(const Tensor& t) { return t; }
hidden_as_output(const tpair_of<Tensor> & t)639 Tensor hidden_as_output(const tpair_of<Tensor>& t) { return std::get<0>(t); }
640 
641 template<size_t index>
project(at::ArrayRef<tpair_of<Tensor>> tuples)642 std::vector<Tensor> project(at::ArrayRef<tpair_of<Tensor>> tuples) {
643   std::vector<Tensor> result;
644   result.reserve(tuples.size());
645   for (auto & t : tuples) {
646     result.push_back(std::get<index>(t));
647   }
648   return result;
649 }
650 
hidden_concat(at::ArrayRef<Tensor> hiddens)651 Tensor hidden_concat(at::ArrayRef<Tensor> hiddens) { return at::cat(hiddens, 0); }
hidden_concat(at::ArrayRef<tpair_of<Tensor>> hiddens)652 tpair_of<Tensor> hidden_concat(at::ArrayRef<tpair_of<Tensor>> hiddens) {
653   return std::make_tuple(hidden_concat(project<0>(hiddens)), hidden_concat(project<1>(hiddens)));
654 }
655 
hidden_slice(const Tensor & t,int64_t start,int64_t end)656 Tensor hidden_slice(const Tensor& t, int64_t start, int64_t end) {
657   return t.narrow(0, start, end - start);
658 }
hidden_slice(const tpair_of<Tensor> & t,int64_t start,int64_t end)659 tpair_of<Tensor> hidden_slice(const tpair_of<Tensor>& t, int64_t start, int64_t end) {
660   return std::make_tuple(hidden_slice(std::get<0>(t), start, end),
661                          hidden_slice(std::get<1>(t), start, end));
662 }
663 
664 ////////////////////////////////////////////////////////////////////////////////
665 // CELL IMPLEMENTATIONS
666 //
667 // Cell is a basic component of an RNN, representing a single application of the
668 // recurrent function. You can think of it as a function of signature
669 //
670 // (Tensor input, hidden_type hidden, CellParams) -> hidden_type
671 //
672 // which means that it consumes an input tensor, and updates the previous hidden state.
673 // It's a struct only because functional programming in C++ is a pain, and it's easier
674 // to pass around "vtable pointers" than actual function pointers.
675 
check_rnn_cell_forward_input(const Tensor & input,const c10::SymInt & input_size)676 void check_rnn_cell_forward_input(const Tensor& input, const c10::SymInt& input_size) {
677   TORCH_CHECK(
678     input.sym_size(1) == input_size,
679     "input has inconsistent input_size: got ", input.sym_size(1), " expected ", input_size);
680 }
681 
check_rnn_cell_forward_hidden(const Tensor & input,const Tensor & hx,const c10::SymInt & hidden_size,const c10::SymInt & hidden_label)682 void check_rnn_cell_forward_hidden(const Tensor& input, const Tensor& hx, const c10::SymInt& hidden_size, const c10::SymInt& hidden_label) {
683   TORCH_CHECK(
684     input.sym_size(0) == hx.sym_size(0),
685     "Input batch size ", input.sym_size(0), " doesn't match hidden", hidden_label, " batch size ", hx.sym_size(0));
686 
687   TORCH_CHECK(
688     hx.sym_size(1) == hidden_size,
689     "hidden", hidden_label, " has inconsistent hidden_size: got ", hx.sym_size(1), ", expected ", hidden_size);
690 }
691 
692 template<typename hidden_type_tmpl, typename cell_params_tmpl>
693 struct Cell {
694   using hidden_type = hidden_type_tmpl;
695   using cell_params = cell_params_tmpl;
696 
697   virtual ~Cell() = default; // This is really dumb, but enables projects with
698                              // -Wnon-virtual-dtor to compile...
699 
700   virtual hidden_type operator()(
701       const Tensor& input,
702       const hidden_type& hidden,
703       const cell_params& params,
704       bool pre_compute_input = false) const = 0;
705 };
706 
707 template<typename nonlinearity, typename cell_params>
708 struct SimpleCell : Cell<Tensor, cell_params> {
709   using hidden_type = Tensor;
operator ()at::native::__anon694b7cd90111::SimpleCell710   Tensor operator()(
711       const Tensor& input,
712       const Tensor& hidden,
713       const cell_params& params,
714       bool pre_compute_input = false) const override {
715     return nonlinearity{}(params.linear_hh(hidden).add_(
716         pre_compute_input ? input : params.linear_ih(input)));
717   }
718 };
719 
720 // TODO: can use inplace ops?
721 template <typename cell_params>
722 struct LSTMCell : Cell<std::tuple<Tensor, Tensor>, cell_params> {
723   using hidden_type = std::tuple<Tensor, Tensor>;
724 
operator ()at::native::__anon694b7cd90111::LSTMCell725   hidden_type operator()(
726       const Tensor& input,
727       const hidden_type& hidden,
728       const cell_params& params,
729       bool pre_compute_input = false) const override {
730     const auto& hx = std::get<0>(hidden);
731     const auto& cx = std::get<1>(hidden);
732 
733     if (input.is_cuda() || input.is_privateuseone()) {
734       TORCH_CHECK(!pre_compute_input);
735       auto igates = params.matmul_ih(input);
736       auto hgates = params.matmul_hh(hx);
737       auto result = at::_thnn_fused_lstm_cell(
738           igates, hgates, cx, params.b_ih(), params.b_hh());
739       // applying projections if w_hr is defined
740       auto hy = params.matmul_hr(std::get<0>(result));
741       // Slice off the workspace argument (it's needed only for AD).
742       return std::make_tuple(std::move(hy), std::move(std::get<1>(result)));
743     }
744 
745     const auto gates = params.linear_hh(hx).add_(
746         pre_compute_input ? input : params.linear_ih(input));
747     auto chunked_gates = gates.unsafe_chunk(4, 1);
748     auto ingate = chunked_gates[0].sigmoid_();
749     auto forgetgate = chunked_gates[1].sigmoid_();
750     auto cellgate = chunked_gates[2].tanh_();
751     auto outgate = chunked_gates[3].sigmoid_();
752     auto cy = (forgetgate * cx).add_(ingate * cellgate);
753     auto hy = outgate * cy.tanh();
754     hy = params.matmul_hr(hy);
755     return std::make_tuple(std::move(hy), std::move(cy));
756   }
757 
758 };
759 
760 template <typename cell_params>
761 struct GRUCell : Cell<Tensor, cell_params> {
762   using hidden_type = Tensor;
763 
operator ()at::native::__anon694b7cd90111::GRUCell764   hidden_type operator()(
765       const Tensor& input,
766       const hidden_type& hidden,
767       const cell_params& params,
768       bool pre_compute_input = false) const override {
769     if (input.is_cuda() || input.is_xpu() || input.is_privateuseone()) {
770       TORCH_CHECK(!pre_compute_input);
771       auto igates = params.matmul_ih(input);
772       auto hgates = params.matmul_hh(hidden);
773       auto result = at::_thnn_fused_gru_cell(
774           igates, hgates, hidden, params.b_ih(), params.b_hh());
775       // Slice off the workspace argument (it's needed only for AD).
776       return std::move(std::get<0>(result));
777     }
778     const auto chunked_igates = pre_compute_input
779         ? input.unsafe_chunk(3, 1)
780         : params.linear_ih(input).unsafe_chunk(3, 1);
781     auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
782     const auto reset_gate =
783         chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
784     const auto input_gate =
785         chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
786     const auto new_gate =
787         chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();
788     return (hidden - new_gate).mul_(input_gate).add_(new_gate);
789   }
790 };
791 
792 ////////////////////////////////////////////////////////////////////////////////
793 // LAYER IMPLEMENTATIONS
794 //
795 // Layers are scan-like higher-order functions, which take in cells, and
796 // transform them to functions of signature
797 //
798 // (io_type input, hidden_type hidden, param_type params) -> (io_type, hidden_type)
799 //
800 // which can apply the cell over a sequence of inputs, and produce both a new set
801 // of hidden states, as well as a concatenated output of each step.
802 
803 template<typename output_type, typename hidden_type>
804 struct LayerOutput {
805   output_type outputs;
806   hidden_type final_hidden;
807 };
808 
809 template<typename io_type, typename hidden_type, typename param_type>
810 struct Layer {
811   using output_type = LayerOutput<io_type, hidden_type>;
812 
813   virtual ~Layer() = default; // This is really dumb, but enables projects with
814                               // -Wnon-virtual-dtor to compile...
815   virtual output_type operator()(
816       const io_type& input,
817       const hidden_type& input_hidden,
818       const param_type& params) const = 0;
819 };
820 
821 template<typename hidden_type, typename cell_params>
822 struct FullLayer : Layer<Tensor, hidden_type, cell_params> {
823   using output_type =
824       typename Layer<Tensor, hidden_type, cell_params>::output_type;
825   using unstacked_output_type = LayerOutput<std::vector<Tensor>, hidden_type>;
826 
FullLayerat::native::__anon694b7cd90111::FullLayer827   FullLayer(Cell<hidden_type, cell_params>& cell)
828     : cell_(cell) {};
829 
operator ()at::native::__anon694b7cd90111::FullLayer830   unstacked_output_type operator()(
831       const std::vector<Tensor>& step_inputs,
832       const hidden_type& input_hidden,
833       const cell_params& params,
834       bool pre_compute_input = false) const {
835     std::vector<Tensor> step_outputs;
836     auto hidden = input_hidden;
837     for (const auto& input : step_inputs) {
838       hidden = cell_(input, hidden, params, pre_compute_input);
839       step_outputs.emplace_back(hidden_as_output(hidden));
840     }
841     return {step_outputs, hidden};
842   }
843 
operator ()at::native::__anon694b7cd90111::FullLayer844   output_type operator()(
845       const Tensor& inputs,
846       const hidden_type& input_hidden,
847       const cell_params& params) const override {
848     if (inputs.device().is_cpu()) {
849       const auto inputs_w = params.linear_ih(inputs);
850       auto unstacked_output =
851           (*this)(inputs_w.unbind(0), input_hidden, params, true);
852       TORCH_CHECK(unstacked_output.outputs.size()>0, "Expected sequence length to be larger than 0 in RNN");
853       return {at::stack(unstacked_output.outputs, 0),
854               unstacked_output.final_hidden};
855     }
856     auto unstacked_output = (*this)(inputs.unbind(0), input_hidden, params);
857     TORCH_CHECK(unstacked_output.outputs.size()>0, "Expected sequence length to be larger than 0 in RNN");
858     return {at::stack(unstacked_output.outputs, 0),
859             unstacked_output.final_hidden};
860   }
861 
862   Cell<hidden_type, cell_params>& cell_;
863 };
864 
865 template <typename dir_hidden_type, typename cell_params>
866 struct FullBidirectionalLayer
867     : Layer<Tensor, pair_of<dir_hidden_type>, pair_of<cell_params>> {
868   using hidden_type = pair_of<dir_hidden_type>;
869   using param_type = pair_of<cell_params>;
870   using output_type = typename Layer<Tensor, hidden_type, param_type>::output_type;
871 
FullBidirectionalLayerat::native::__anon694b7cd90111::FullBidirectionalLayer872   FullBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
873     : layer_(cell) {};
874 
operator ()at::native::__anon694b7cd90111::FullBidirectionalLayer875   output_type operator()(
876       const Tensor& input,
877       const hidden_type& input_hidden,
878       const param_type& params) const override {
879     std::vector<Tensor> step_inputs;
880     if (input.device().is_cpu()) {
881       auto input_w = params.first.linear_ih(input);
882       step_inputs = input_w.unbind(0);
883       auto fw_result = layer_(
884           step_inputs, input_hidden.first, params.first, true);
885       TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
886       auto fw_output = at::stack(fw_result.outputs, 0);
887       input_w = params.second.linear_ih(input);
888       step_inputs = input_w.unbind(0);
889       auto rev_step_inputs = reverse(std::move(step_inputs));
890       auto rev_result =
891           layer_(rev_step_inputs, input_hidden.second, params.second, true);
892       std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
893       auto rev_output = at::stack(rev_result.outputs, 0);
894       return {at::cat({fw_output, rev_output}, fw_output.dim() - 1),
895               std::make_pair(fw_result.final_hidden, rev_result.final_hidden)};
896     }
897 
898     step_inputs = input.unbind(0);
899     auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
900     TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
901     auto fw_output = at::stack(fw_result.outputs, 0);
902     auto rev_step_inputs = reverse(std::move(step_inputs));
903     auto rev_result =
904         layer_(rev_step_inputs, input_hidden.second, params.second);
905     std::reverse(rev_result.outputs.begin(), rev_result.outputs.end());
906     auto rev_output = at::stack(rev_result.outputs, 0);
907     return {at::cat({fw_output, rev_output}, fw_output.dim() - 1),
908             std::make_pair(fw_result.final_hidden, rev_result.final_hidden)};
909   }
910 
reverseat::native::__anon694b7cd90111::FullBidirectionalLayer911   std::vector<Tensor> reverse(std::vector<Tensor>&& x) const {
912     std::reverse(x.begin(), x.end());
913     return std::move(x);
914   }
915 
916   FullLayer<dir_hidden_type, cell_params> layer_;
917 };
918 
919 template<typename hidden_type, typename cell_params>
920 struct PackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
921   using output_type =
922       typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
923 
PackedLayerat::native::__anon694b7cd90111::PackedLayer924   PackedLayer(Cell<hidden_type, cell_params>& cell)
925     : cell_(cell) {};
926 
operator ()at::native::__anon694b7cd90111::PackedLayer927   output_type operator()(
928       const PackedSequence& input,
929       const hidden_type& input_hidden,
930       const cell_params& params) const override {
931 
932     std::vector<at::Tensor> step_outputs;
933     std::vector<hidden_type> hiddens;
934     int64_t input_offset = 0;
935     int64_t num_steps = input.batch_sizes.size(0);
936     int64_t* batch_sizes = input.batch_sizes.data_ptr<int64_t>();
937     int64_t last_batch_size = batch_sizes[0];
938 
939     const Tensor* input_ptr = &input.data;
940     bool pre_compute_input = false;
941     Tensor input_w;
942     if (input.data.device().is_cpu()) {
943       input_w = params.linear_ih(input.data);
944       input_ptr = &input_w;
945       pre_compute_input = true;
946     }
947 
948     // Batch sizes is a sequence of decreasing lengths, which are offsets
949     // into a 1D list of inputs. At every step we slice out batch_size elements,
950     // and possibly account for the decrease in the batch size since the last step,
951     // which requires us to slice the hidden state (since some sequences
952     // are completed now). The sliced parts are also saved, because we will need
953     // to return a tensor of final hidden state.
954     auto hidden = input_hidden;
955     for (const auto i : c10::irange(num_steps)) {
956       const int64_t batch_size = batch_sizes[i];
957       auto step_input = input_ptr->narrow(0, input_offset, batch_size);
958       input_offset += batch_size;
959       const int64_t dec = last_batch_size - batch_size;
960       if (dec > 0) {
961         hiddens.emplace_back(
962             hidden_slice(hidden, last_batch_size - dec, last_batch_size));
963         hidden = hidden_slice(hidden, 0, last_batch_size - dec);
964       }
965 
966       last_batch_size = batch_size;
967       hidden = cell_(step_input, hidden, params, pre_compute_input);
968       step_outputs.push_back(hidden_as_output(hidden));
969     }
970     hiddens.emplace_back(hidden);
971     std::reverse(hiddens.begin(), hiddens.end());
972 
973     return {PackedSequence{at::cat(step_outputs, 0), input.batch_sizes},
974             hidden_concat(hiddens)};
975   }
976 
977   Cell<hidden_type, cell_params>& cell_;
978 };
979 
980 template<typename hidden_type, typename cell_params>
981 struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
982   using output_type =
983       typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
984 
ReversedPackedLayerat::native::__anon694b7cd90111::ReversedPackedLayer985   ReversedPackedLayer(Cell<hidden_type, cell_params>& cell)
986     : cell_(cell) {};
987 
operator ()at::native::__anon694b7cd90111::ReversedPackedLayer988   output_type operator()(
989       const PackedSequence& input,
990       const hidden_type& input_hidden,
991       const cell_params& params) const override {
992     std::vector<at::Tensor> step_outputs;
993     int64_t input_offset = input.data.size(0);
994     int64_t num_steps = input.batch_sizes.size(0);
995     int64_t* batch_sizes = input.batch_sizes.data_ptr<int64_t>();
996     int64_t last_batch_size = batch_sizes[num_steps - 1];
997 
998     const Tensor* input_ptr = &input.data;
999     bool pre_compute_input = false;
1000     Tensor input_w;
1001     if (input.data.device().is_cpu()) {
1002       input_w = params.linear_ih(input.data);
1003       input_ptr = &input_w;
1004       pre_compute_input = true;
1005     }
1006 
1007     // Here the situation is similar to that above, except we start out with
1008     // the smallest batch size (and a small set of hidden states we actually use),
1009     // and progressively expand the hidden states, as we move backwards over the
1010     // 1D list of inputs.
1011     auto hidden = hidden_slice(input_hidden, 0, batch_sizes[num_steps - 1]);
1012     for (int64_t i = num_steps - 1; i >= 0; --i) {
1013       const int64_t batch_size = batch_sizes[i];
1014       const int64_t inc = batch_size - last_batch_size;
1015       if (inc > 0) {
1016         hidden = hidden_concat(ArrayRef<hidden_type>{
1017             hidden, hidden_slice(input_hidden, last_batch_size, batch_size)});
1018       }
1019       auto step_input =
1020           input_ptr->narrow(0, input_offset - batch_size, batch_size);
1021       input_offset -= batch_size;
1022       last_batch_size = batch_size;
1023       hidden = cell_(step_input, hidden, params, pre_compute_input);
1024       step_outputs.emplace_back(hidden_as_output(hidden));
1025     }
1026     std::reverse(step_outputs.begin(), step_outputs.end());
1027     return {PackedSequence{at::cat(step_outputs, 0), input.batch_sizes},
1028             hidden};
1029   }
1030 
1031   Cell<hidden_type, cell_params>& cell_;
1032 };
1033 
1034 template <typename dir_hidden_type, typename cell_params>
1035 struct PackedBidirectionalLayer
1036     : Layer<PackedSequence, pair_of<dir_hidden_type>, pair_of<cell_params>> {
1037   using hidden_type = pair_of<dir_hidden_type>;
1038   using param_type = pair_of<cell_params>;
1039   using output_type =
1040       typename Layer<PackedSequence, hidden_type, param_type>::output_type;
1041 
PackedBidirectionalLayerat::native::__anon694b7cd90111::PackedBidirectionalLayer1042   PackedBidirectionalLayer(Cell<dir_hidden_type, cell_params>& cell)
1043     : layer_(cell), rev_layer_(cell) {};
1044 
operator ()at::native::__anon694b7cd90111::PackedBidirectionalLayer1045   output_type operator()(
1046       const PackedSequence& input,
1047       const hidden_type& input_hidden,
1048       const param_type& params) const override {
1049     auto fw_result = layer_(input, input_hidden.first, params.first);
1050     auto rev_result = rev_layer_(input, input_hidden.second, params.second);
1051     PackedSequence output{
1052         at::cat({fw_result.outputs.data, rev_result.outputs.data}, -1),
1053         input.batch_sizes};
1054     return {output,
1055             std::make_pair(fw_result.final_hidden, rev_result.final_hidden)};
1056   }
1057 
1058   PackedLayer<dir_hidden_type, cell_params> layer_;
1059   ReversedPackedLayer<dir_hidden_type, cell_params> rev_layer_;
1060 };
1061 
1062 ////////////////////////////////////////////////////////////////////////////////
1063 // apply_layer_stack
1064 //
1065 // layers are convenient, but in reality we often want to stack them. this little
1066 // helper manages slicing of all inputs and parameters, and repeatedly feeds them
1067 // into the given layer. returns the last layer's outputs, and a vector of final
1068 // hidden states produced at each level.
1069 
dropout(const Tensor & input,double p)1070 Tensor dropout(const Tensor& input, double p) {
1071   return at::dropout(input, p, /*train=*/true);
1072 }
1073 
dropout(const PackedSequence & input,double p)1074 PackedSequence dropout(const PackedSequence& input, double p) {
1075   return {at::dropout(input.data, p, /*train=*/true), input.batch_sizes};
1076 }
1077 
1078 template<typename io_type, typename hidden_type, typename weight_type>
1079 LayerOutput<io_type, std::vector<hidden_type>>
apply_layer_stack(const Layer<io_type,hidden_type,weight_type> & layer,const io_type & input,const std::vector<hidden_type> & hiddens,const std::vector<weight_type> & weights,int64_t num_layers,double dropout_p,bool train)1080 apply_layer_stack(const Layer<io_type, hidden_type, weight_type>& layer, const io_type& input,
1081                   const std::vector<hidden_type>& hiddens, const std::vector<weight_type>& weights,
1082                   int64_t num_layers, double dropout_p, bool train) {
1083   TORCH_CHECK(num_layers == (int64_t)hiddens.size(), "Expected more hidden states in stacked_rnn");
1084   TORCH_CHECK(num_layers == (int64_t)weights.size(), "Expected more weights in stacked_rnn");
1085 
1086   auto layer_input = input;
1087   auto hidden_it = hiddens.begin();
1088   auto weight_it = weights.begin();
1089   std::vector<hidden_type> final_hiddens;
1090   for (const auto l : c10::irange(num_layers)) {
1091     auto layer_output = layer(layer_input, *(hidden_it++), *(weight_it++));
1092     final_hiddens.push_back(layer_output.final_hidden);
1093     layer_input = layer_output.outputs;
1094 
1095     if (dropout_p != 0 && train && l < num_layers - 1) {
1096       layer_input = dropout(layer_input, dropout_p);
1097     }
1098   }
1099 
1100   return {layer_input, final_hiddens};
1101 }
1102 
1103 ////////////////////////////////////////////////////////////////////////////////
1104 // HELPERS SIMPLIFYING DISPATCH TO FUNCTIONS ABOVE
1105 ////////////////////////////////////////////////////////////////////////////////
1106 
1107 template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
_rnn_impl(const io_type & input,const std::vector<cell_params> & params,const std::vector<typename CellType::hidden_type> & hiddens,int64_t num_layers,double dropout_p,bool train,bool bidirectional)1108 LayerOutput<io_type, std::vector<typename CellType::hidden_type>> _rnn_impl(
1109       const io_type& input,
1110       const std::vector<cell_params>& params,
1111       const std::vector<typename CellType::hidden_type>& hiddens,
1112       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
1113   using hidden_type = typename CellType::hidden_type;
1114   CellType cell;
1115   if (bidirectional) {
1116     using BidirLayer = BidirLayerT<hidden_type, cell_params>;
1117     auto bidir_result = apply_layer_stack(BidirLayer{cell}, input, pair_vec(hiddens), pair_vec(params), num_layers, dropout_p, train);
1118     return {bidir_result.outputs, unpair_vec(std::move(bidir_result.final_hidden))};
1119   } else {
1120     return apply_layer_stack(LayerT<hidden_type,cell_params>{cell}, input, hiddens, params, num_layers, dropout_p, train);
1121   }
1122 }
1123 
1124 template<typename CellType, template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
_rnn_impl_with_concat(const io_type & input,const std::vector<cell_params> & params,const std::vector<typename CellType::hidden_type> & hiddens,int64_t num_layers,double dropout_p,bool train,bool bidirectional)1125 std::tuple<io_type, Tensor> _rnn_impl_with_concat(
1126       const io_type& input,
1127       const std::vector<cell_params>& params,
1128       const std::vector<typename CellType::hidden_type>& hiddens,
1129       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
1130   auto result = _rnn_impl<CellType, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
1131   return std::make_tuple(std::move(result.outputs), at::stack(result.final_hidden, 0));
1132 }
1133 
1134 template<template<typename,typename> class LayerT, template<typename,typename> class BidirLayerT, typename cell_params, typename io_type>
_lstm_impl(const io_type & input,const std::vector<cell_params> & params,const Tensor & hx,const Tensor & cx,int64_t num_layers,double dropout_p,bool train,bool bidirectional)1135 std::tuple<io_type, Tensor, Tensor> _lstm_impl(
1136       const io_type& input,
1137       const std::vector<cell_params>& params, const Tensor& hx, const Tensor& cx,
1138       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
1139   // It's much more useful for us to work on lists of pairs of hx and cx for each layer, so we need
1140   // to transpose a pair of those tensors.
1141   auto layer_hx = hx.unbind(0);
1142   auto layer_cx = cx.unbind(0);
1143   int64_t total_layers = layer_hx.size();
1144   std::vector<typename LSTMCell<cell_params>::hidden_type> hiddens;
1145   hiddens.reserve(total_layers);
1146   for (const auto i : c10::irange(total_layers)) {
1147     hiddens.emplace_back(std::move(layer_hx[i]), std::move(layer_cx[i]));
1148   }
1149 
1150   auto result = _rnn_impl<LSTMCell<cell_params>, LayerT, BidirLayerT>(input, params, hiddens, num_layers, dropout_p, train, bidirectional);
1151 
1152   // Now, we need to reverse the transposed we performed above.
1153   std::vector<Tensor> hy, cy;
1154   hy.reserve(total_layers); cy.reserve(total_layers);
1155   for (auto & hidden : result.final_hidden) {
1156     hy.push_back(std::move(std::get<0>(hidden)));
1157     cy.push_back(std::move(std::get<1>(hidden)));
1158   }
1159 
1160   return std::make_tuple(std::move(result.outputs), at::stack(hy, 0), at::stack(cy, 0));
1161 }
1162 
1163 } // anonymous namespace
1164 
_use_cudnn_rnn_flatten_weight()1165 bool _use_cudnn_rnn_flatten_weight() {
1166   return detail::getCUDAHooks().compiledWithCuDNN();
1167 }
1168 
1169 // NB: This a (composite) wrapper for _thnn_fused_lstm_cell_backward_impl.
1170 //     It duplicates the outputs of this function so the non-composite version doesn't have to.
1171 //     The point is so that we avoid triggering TensorImpl use count asserts in debug mode
_thnn_fused_lstm_cell_backward(const std::optional<Tensor> & grad_hy_opt,const std::optional<Tensor> & grad_cy_opt,const Tensor & cx,const Tensor & cy,const Tensor & workspace,bool has_bias)1172 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward( const std::optional<Tensor>& grad_hy_opt, const std::optional<Tensor>& grad_cy_opt,
1173       const Tensor& cx, const Tensor& cy,
1174       const Tensor& workspace, bool has_bias) {
1175   TORCH_INTERNAL_ASSERT(!GradMode::is_enabled());
1176   auto ret = at::_thnn_fused_lstm_cell_backward_impl(grad_hy_opt, grad_cy_opt, cx, cy, workspace, has_bias);
1177   return std::make_tuple(std::get<0>(ret), std::get<0>(ret), std::get<1>(ret), std::get<2>(ret), std::get<2>(ret));
1178 }
1179 
1180 
1181 ////////////////////////////////////////////////////////////////////////////////
1182 // PUBLIC FUNCTIONS
1183 ////////////////////////////////////////////////////////////////////////////////
1184 
1185 #define ONE_HIDDEN_RNN(NAME, CELL)                                          \
1186   DEFINE_DISPATCH(NAME##_cudnn_stub);                                       \
1187   DEFINE_DISPATCH(NAME##_miopen_stub);                                      \
1188   DEFINE_DISPATCH(NAME##_packed_cudnn_stub);                                \
1189   DEFINE_DISPATCH(NAME##_packed_miopen_stub);                               \
1190   REGISTER_NO_CPU_DISPATCH(NAME##_cudnn_stub);                              \
1191   REGISTER_NO_CPU_DISPATCH(NAME##_miopen_stub);                             \
1192   REGISTER_NO_CPU_DISPATCH(NAME##_packed_cudnn_stub);                       \
1193   REGISTER_NO_CPU_DISPATCH(NAME##_packed_miopen_stub);                      \
1194                                                                             \
1195   std::tuple<Tensor, Tensor> NAME(                                          \
1196       const Tensor& _input,                                                 \
1197       const Tensor& hx,                                                     \
1198       TensorList _params,                                                   \
1199       bool has_biases,                                                      \
1200       int64_t num_layers,                                                   \
1201       double dropout_p,                                                     \
1202       bool train,                                                           \
1203       bool bidirectional,                                                   \
1204       bool batch_first) {                                                   \
1205     if (at::cudnn_is_acceptable(_input)) {                                  \
1206       Tensor output, hy;                                                    \
1207       NAME##_cudnn_stub(                                                    \
1208           _input.device().type(),                                           \
1209           output,                                                           \
1210           hy,                                                               \
1211           _input,                                                           \
1212           hx,                                                               \
1213           _params,                                                          \
1214           has_biases,                                                       \
1215           num_layers,                                                       \
1216           dropout_p,                                                        \
1217           train,                                                            \
1218           bidirectional,                                                    \
1219           batch_first);                                                     \
1220       return std::make_tuple(std::move(output), std::move(hy));             \
1221     }                                                                       \
1222     if (use_miopen(_input, dropout_p)) {                                    \
1223       Tensor output, hy;                                                    \
1224       NAME##_miopen_stub(                                                   \
1225           _input.device().type(),                                           \
1226           output,                                                           \
1227           hy,                                                               \
1228           _input,                                                           \
1229           hx,                                                               \
1230           _params,                                                          \
1231           has_biases,                                                       \
1232           num_layers,                                                       \
1233           dropout_p,                                                        \
1234           train,                                                            \
1235           bidirectional,                                                    \
1236           batch_first);                                                     \
1237       return std::make_tuple(std::move(output), std::move(hy));             \
1238     }                                                                       \
1239     check_attributes(_input, _params, hx);                                  \
1240     auto input = batch_first ? _input.transpose(0, 1) : _input;             \
1241     auto params = gather_params(_params, has_biases);                       \
1242     auto results =                                                          \
1243         _rnn_impl_with_concat<CELL, FullLayer, FullBidirectionalLayer>(     \
1244             input,                                                          \
1245             params,                                                         \
1246             hx.unbind(0),                                                   \
1247             num_layers,                                                     \
1248             dropout_p,                                                      \
1249             train,                                                          \
1250             bidirectional);                                                 \
1251     if (batch_first) {                                                      \
1252       std::get<0>(results).transpose_(0, 1);                                \
1253     }                                                                       \
1254     return results;                                                         \
1255   }                                                                         \
1256                                                                             \
1257   std::tuple<Tensor, Tensor> NAME(                                          \
1258       const Tensor& data,                                                   \
1259       const Tensor& batch_sizes,                                            \
1260       const Tensor& hx,                                                     \
1261       TensorList _params,                                                   \
1262       bool has_biases,                                                      \
1263       int64_t num_layers,                                                   \
1264       double dropout_p,                                                     \
1265       bool train,                                                           \
1266       bool bidirectional) {                                                 \
1267     if (at::cudnn_is_acceptable(data)) {                                    \
1268       Tensor output, hy;                                                    \
1269       NAME##_packed_cudnn_stub(                                             \
1270           data.device().type(),                                             \
1271           output,                                                           \
1272           hy,                                                               \
1273           data,                                                             \
1274           batch_sizes,                                                      \
1275           hx,                                                               \
1276           _params,                                                          \
1277           has_biases,                                                       \
1278           num_layers,                                                       \
1279           dropout_p,                                                        \
1280           train,                                                            \
1281           bidirectional);                                                   \
1282       return std::make_tuple(std::move(output), std::move(hy));             \
1283     }                                                                       \
1284     if (use_miopen(data, dropout_p)) {                                      \
1285       Tensor output, hy;                                                    \
1286       NAME##_packed_miopen_stub(                                            \
1287           data.device().type(),                                             \
1288           output,                                                           \
1289           hy,                                                               \
1290           data,                                                             \
1291           batch_sizes,                                                      \
1292           hx,                                                               \
1293           _params,                                                          \
1294           has_biases,                                                       \
1295           num_layers,                                                       \
1296           dropout_p,                                                        \
1297           train,                                                            \
1298           bidirectional);                                                   \
1299       return std::make_tuple(std::move(output), std::move(hy));             \
1300     }                                                                       \
1301     PackedSequence input{data, batch_sizes};                                \
1302     auto params = gather_params(_params, has_biases);                       \
1303     auto result =                                                           \
1304         _rnn_impl_with_concat<CELL, PackedLayer, PackedBidirectionalLayer>( \
1305             input,                                                          \
1306             params,                                                         \
1307             hx.unbind(0),                                                   \
1308             num_layers,                                                     \
1309             dropout_p,                                                      \
1310             train,                                                          \
1311             bidirectional);                                                 \
1312     auto& packed_output = std::get<0>(result);                              \
1313     return std::make_tuple(                                                 \
1314         std::move(packed_output.data), std::move(std::get<1>(result)));     \
1315   }
1316 #define ONE_HIDDEN_QRNN(NAME, CELL)                                         \
1317   static std::tuple<Tensor, Tensor> NAME##_input(                           \
1318       const Tensor& _input,                                                 \
1319       const Tensor& hx,                                                     \
1320       c10::List<c10::intrusive_ptr<CellParamsBase>> _params,                \
1321       bool has_biases,                                                      \
1322       int64_t num_layers,                                                   \
1323       double dropout_p,                                                     \
1324       bool train,                                                           \
1325       bool bidirectional,                                                   \
1326       bool batch_first) {                                                   \
1327     std::vector<QRNNCellParamsWrapper> params;                              \
1328     for (c10::intrusive_ptr<CellParamsBase> x : _params) {                  \
1329       params.emplace_back(std::move(x));                                    \
1330     }                                                                       \
1331     auto input = batch_first ? _input.transpose(0, 1) : _input;             \
1332     auto results =                                                          \
1333         _rnn_impl_with_concat<CELL, FullLayer, FullBidirectionalLayer>(     \
1334             input,                                                          \
1335             params,                                                         \
1336             hx.unbind(0),                                                   \
1337             num_layers,                                                     \
1338             dropout_p,                                                      \
1339             train,                                                          \
1340             bidirectional);                                                 \
1341     if (batch_first) {                                                      \
1342       std::get<0>(results).transpose_(0, 1);                                \
1343     }                                                                       \
1344     return results;                                                         \
1345   }                                                                         \
1346                                                                             \
1347   static std::tuple<Tensor, Tensor> NAME##_data(                            \
1348       const Tensor& data,                                                   \
1349       const Tensor& batch_sizes,                                            \
1350       const Tensor& hx,                                                     \
1351       c10::List<c10::intrusive_ptr<CellParamsBase>> _params,                \
1352       bool has_biases,                                                      \
1353       int64_t num_layers,                                                   \
1354       double dropout_p,                                                     \
1355       bool train,                                                           \
1356       bool bidirectional) {                                                 \
1357     std::vector<QRNNCellParamsWrapper> params;                              \
1358     for (c10::intrusive_ptr<CellParamsBase> x : _params) {                  \
1359       params.emplace_back(std::move(x));                                    \
1360     }                                                                       \
1361     PackedSequence input{data, batch_sizes};                                \
1362     auto result =                                                           \
1363         _rnn_impl_with_concat<CELL, PackedLayer, PackedBidirectionalLayer>( \
1364             input,                                                          \
1365             params,                                                         \
1366             hx.unbind(0),                                                   \
1367             num_layers,                                                     \
1368             dropout_p,                                                      \
1369             train,                                                          \
1370             bidirectional);                                                 \
1371     auto& packed_output = std::get<0>(result);                              \
1372     return std::make_tuple(                                                 \
1373         std::move(packed_output.data), std::move(std::get<1>(result)));     \
1374   }
1375 
ONE_HIDDEN_RNN(gru,GRUCell<CellParams>)1376 ONE_HIDDEN_RNN(gru, GRUCell<CellParams>)
1377 ONE_HIDDEN_QRNN(quantized_gru, GRUCell<QRNNCellParamsWrapper>)
1378 
1379 // BC wrappers for quantized_gru
1380 
1381 static std::tuple<Tensor, Tensor> quantized_gru_input_legacy(
1382     const Tensor& _input,
1383     const Tensor& hx,
1384     c10::List<at::Tensor> _params,
1385     bool has_biases,
1386     int64_t num_layers,
1387     double dropout_p,
1388     bool train,
1389     bool bidirectional,
1390     bool batch_first) {
1391   TORCH_CHECK(
1392       false,
1393       "torch.quantized_gru with List[Tensor] for parameters is "
1394       "no longer supported. Please re-export your model "
1395       "using the newer definitions in torch.jit.quantized");
1396 }
1397 
quantized_gru_data_legacy(const Tensor & data,const Tensor & batch_sizes,const Tensor & hx,c10::List<at::Tensor> _params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)1398 static std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
1399     const Tensor& data,
1400     const Tensor& batch_sizes,
1401     const Tensor& hx,
1402     c10::List<at::Tensor> _params,
1403     bool has_biases,
1404     int64_t num_layers,
1405     double dropout_p,
1406     bool train,
1407     bool bidirectional) {
1408   TORCH_CHECK(
1409       false,
1410       "torch.quantized_gru with List[Tensor] for parameters is "
1411       "no longer supported. Please re-export your model "
1412       "using the newer definitions in torch.jit.quantized");
1413 }
1414 
1415 using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
1416 ONE_HIDDEN_RNN(rnn_tanh, tanf_cell_type)
1417 using relu_cell_type = SimpleCell<relu_f, CellParams>;
1418 ONE_HIDDEN_RNN(rnn_relu, relu_cell_type);
1419 
1420 DEFINE_DISPATCH(lstm_cudnn_stub);
1421 DEFINE_DISPATCH(lstm_packed_cudnn_stub);
1422 DEFINE_DISPATCH(lstm_miopen_stub);
1423 DEFINE_DISPATCH(lstm_packed_miopen_stub);
1424 DEFINE_DISPATCH(lstm_mkldnn_stub);
1425 REGISTER_NO_CPU_DISPATCH(lstm_cudnn_stub);
1426 REGISTER_NO_CPU_DISPATCH(lstm_packed_cudnn_stub);
1427 REGISTER_NO_CPU_DISPATCH(lstm_miopen_stub);
1428 REGISTER_NO_CPU_DISPATCH(lstm_packed_miopen_stub);
1429 
lstm(const Tensor & _input,TensorList hx,TensorList _params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first)1430 std::tuple<Tensor, Tensor, Tensor> lstm(
1431       const Tensor& _input, TensorList hx,
1432       TensorList _params, bool has_biases,
1433       int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
1434   TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
1435   if (at::cudnn_is_acceptable(_input)) {
1436     Tensor output, hy, cy;
1437     lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
1438             num_layers, dropout_p, train, bidirectional, batch_first);
1439     return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1440   }
1441 #ifdef USE_MPS
1442   if (_input.is_mps()) {
1443     std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
1444             num_layers, dropout_p, train, bidirectional, batch_first);
1445     std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
1446     return return_values;
1447   }
1448 #endif
1449   // if cells are of different size, that means projections are used
1450   bool has_projections = (hx[0].sym_size(2) != hx[1].sym_size(2));
1451   if (use_miopen(_input, dropout_p)) {
1452     if (!has_projections) {
1453       Tensor output, hy, cy;
1454       lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
1455                 num_layers, dropout_p, train, bidirectional, batch_first);
1456       return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1457     } else {
1458       TORCH_WARN_ONCE(
1459           "LSTM with projections is not supported with MIOpen. Using default implementation.");
1460     }
1461   }
1462 
1463   if (use_mkldnn(_input, _params, hx)) {
1464     if (!has_projections) {
1465       if (hx[0].unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
1466         TORCH_WARN_ONCE(
1467           "LSTM with symbolic sizes and strides is not supported with oneDNN. Using default implementation.");
1468       } else {
1469         Tensor output, hy, cy;
1470         lstm_mkldnn_stub(_input.device().type(), output, hy, cy,_input, hx, _params, has_biases,
1471             num_layers, dropout_p, train, bidirectional, batch_first);
1472         return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1473       }
1474     } else {
1475       TORCH_WARN_ONCE(
1476           "LSTM with projections is not supported with oneDNN. Using default implementation.");
1477     }
1478   }
1479 
1480   check_attributes(_input, _params, hx);
1481   auto input = batch_first ? _input.transpose(0, 1) : _input;
1482   auto params = gather_params(_params, has_biases, has_projections);
1483   auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
1484       input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
1485   if (batch_first) {
1486     std::get<0>(results) = std::get<0>(results).transpose(0, 1);
1487   }
1488   return results;
1489 }
1490 
lstm(const Tensor & data,const Tensor & batch_sizes,TensorList hx,TensorList _params,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional)1491 std::tuple<Tensor, Tensor, Tensor> lstm(
1492       const Tensor& data, const Tensor& batch_sizes, TensorList hx,
1493       TensorList _params, bool has_biases,
1494       int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
1495   TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
1496   if (at::cudnn_is_acceptable(data)) {
1497     Tensor output, hy, cy;
1498     lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
1499             _params, has_biases, num_layers, dropout_p, train, bidirectional);
1500     return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1501   }
1502   // if cells are of different size, that means projections are used
1503   bool has_projections = (hx[0].size(2) != hx[1].size(2));
1504   if (use_miopen(data, dropout_p)) {
1505     if (!has_projections) {
1506       Tensor output, hy, cy;
1507       lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
1508               _params, has_biases, num_layers, dropout_p, train, bidirectional);
1509       return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
1510     } else {
1511       TORCH_WARN_ONCE(
1512           "LSTM with projections is not supported with MIOpen. Using default implementation.");
1513     }
1514   }
1515 
1516   PackedSequence input { data, batch_sizes };
1517   auto params = gather_params(_params, has_biases, has_projections);
1518   auto result = _lstm_impl<PackedLayer, PackedBidirectionalLayer>(
1519       input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional);
1520   auto & packed_output = std::get<0>(result);
1521   return std::make_tuple(std::move(packed_output.data),
1522                          std::move(std::get<1>(result)),
1523                          std::move(std::get<2>(result)));
1524 }
1525 
lstm_cell(const Tensor & input,TensorList hx,const Tensor & w_ih,const Tensor & w_hh,const std::optional<Tensor> & b_ih_opt,const std::optional<Tensor> & b_hh_opt)1526 std::tuple<Tensor, Tensor> lstm_cell(
1527     const Tensor& input, TensorList hx,
1528     const Tensor& w_ih, const Tensor& w_hh, const std::optional<Tensor>& b_ih_opt, const std::optional<Tensor>& b_hh_opt) {
1529   // See [Note: hacky wrapper removal for optional tensor]
1530   c10::MaybeOwned<Tensor> b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt);
1531   const Tensor& b_ih = *b_ih_maybe_owned;
1532   const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();});
1533 
1534   TORCH_CHECK(hx.size() == 2, "lstm_cell expects two hidden states");
1535   check_rnn_cell_forward_input(input, w_ih.sym_size(1));
1536   auto hidden_size = w_hh.sym_size(1);
1537   check_rnn_cell_forward_hidden(input, hx[0], hidden_size, 0);
1538   check_rnn_cell_forward_hidden(input, hx[1], std::move(hidden_size), 1);
1539   static at::Tensor undefined;
1540   return LSTMCell<CellParams>{}(input, std::make_tuple(hx[0], hx[1]), CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
1541 }
1542 
1543 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
_thnn_differentiable_lstm_cell_backward(const std::optional<Tensor> & grad_hy_opt,const std::optional<Tensor> & grad_cy_opt,const Tensor & input_gates,const Tensor & hidden_gates,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt,const Tensor & cx,const Tensor & cy)1544 _thnn_differentiable_lstm_cell_backward( const std::optional<Tensor>& grad_hy_opt, const std::optional<Tensor>& grad_cy_opt,
1545     const Tensor& input_gates,
1546     const Tensor& hidden_gates, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt,
1547     const Tensor& cx,
1548     const Tensor& cy) {
1549   // See [Note: hacky wrapper removal for optional tensor]
1550   c10::MaybeOwned<Tensor> grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt);
1551   const Tensor& grad_hy = *grad_hy_maybe_owned;
1552   const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
1553   const Tensor& input_bias = c10::value_or_else(input_bias_opt, [] {return Tensor();});
1554   const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
1555 
1556   if (!grad_hy.defined() && !grad_cy.defined()) {
1557     return std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>();
1558   }
1559   Tensor gates = input_gates + hidden_gates;
1560   if (input_bias.defined()) {
1561     gates = gates + input_bias;
1562   }
1563   if (hidden_bias.defined()) {
1564     gates = gates + hidden_bias;
1565   }
1566   auto chunked_gates = gates.unsafe_chunk(4, 1);
1567   Tensor i = chunked_gates[0].sigmoid();
1568   Tensor f = chunked_gates[1].sigmoid();
1569   Tensor c = chunked_gates[2].tanh();
1570   Tensor o = chunked_gates[3].sigmoid();
1571 
1572   Tensor gcx = cy.tanh();
1573   Tensor gog;
1574   TORCH_INTERNAL_ASSERT((grad_hy.defined() || grad_cy.defined()),"either gradient with respect to hy or cy should be defined");
1575   if (grad_hy.defined()) {
1576     gog = grad_hy * gcx;
1577     gog = at::sigmoid_backward(gog, o);
1578     gcx = at::tanh_backward(grad_hy * o, gcx);
1579     if (grad_cy.defined()) {
1580       gcx = gcx + grad_cy;
1581     }
1582   } else if (grad_cy.defined()) {
1583     gog = at::zeros_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1584     gcx = grad_cy;
1585   }
1586   Tensor gig = gcx * c;
1587   Tensor gfg = gcx * cx;
1588   Tensor gcg = gcx * i;
1589   gcx = gcx * f;
1590   gig = at::sigmoid_backward(gig, i);
1591   gfg = at::sigmoid_backward(gfg, f);
1592   gcg = at::tanh_backward(gcg, c);
1593   Tensor grad_gates = at::cat({std::move(gig), std::move(gfg), std::move(gcg), std::move(gog)}, 1);
1594   Tensor grad_bias = input_bias.defined() ? grad_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
1595   return std::make_tuple(grad_gates, grad_gates, std::move(gcx), grad_bias, grad_bias);
1596 }
1597 
_thnn_differentiable_gru_cell_backward(const Tensor & grad_hy,const Tensor & input_gates,const Tensor & hidden_gates,const Tensor & hx,const std::optional<Tensor> & input_bias_opt,const std::optional<Tensor> & hidden_bias_opt)1598 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_differentiable_gru_cell_backward(
1599     const Tensor& grad_hy,
1600     const Tensor& input_gates,
1601     const Tensor& hidden_gates,
1602     const Tensor& hx, const std::optional<Tensor>& input_bias_opt, const std::optional<Tensor>& hidden_bias_opt){
1603   // See [Note: hacky wrapper removal for optional tensor]
1604   c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt);
1605   const Tensor& input_bias = *input_bias_maybe_owned;
1606   const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();});
1607 
1608   Tensor in_g = input_gates;
1609   Tensor h_g = hidden_gates;
1610   if (input_bias.defined()){
1611     in_g = in_g+input_bias;
1612   }
1613   if (hidden_bias.defined()){
1614     h_g = h_g + hidden_bias;
1615   }
1616   auto chunked_input_gates = in_g.unsafe_chunk(3, 1);
1617   Tensor ir = chunked_input_gates[0];
1618   Tensor ii = chunked_input_gates[1];
1619   Tensor in = chunked_input_gates[2];
1620   auto chunked_hidden_gates = h_g.unsafe_chunk(3, 1);
1621   Tensor hr = chunked_hidden_gates[0];
1622   Tensor hi = chunked_hidden_gates[1];
1623   Tensor hn = chunked_hidden_gates[2];
1624   Tensor rg = (ir + hr).sigmoid();
1625   Tensor ig = (ii + hi).sigmoid();
1626   Tensor grad_hx = grad_hy * ig;
1627   Tensor ng = (in+rg*hn).tanh();
1628   Tensor gig = at::sigmoid_backward(grad_hy * (hx - ng), ig);
1629   Tensor gin = at::tanh_backward(grad_hy * (1 - ig), ng);
1630   Tensor ghn = gin * rg;
1631   Tensor grg = at::sigmoid_backward(gin * hn, rg);
1632   Tensor grad_input_gates = at::cat({grg,gig,std::move(gin)}, 1);
1633   Tensor grad_hidden_gates = at::cat({std::move(grg),std::move(gig),std::move(ghn)}, 1);
1634   Tensor grad_input_bias = input_bias.defined() ? grad_input_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
1635   Tensor grad_hidden_bias = input_bias.defined() ? grad_hidden_gates.sum(0, /*keepdim=*/false) : at::Tensor{};
1636   return std::make_tuple(std::move(grad_input_gates), std::move(grad_hidden_gates),
1637                          std::move(grad_hx), std::move(grad_input_bias), std::move(grad_hidden_bias));
1638 }
1639 
gru_cell(const Tensor & input,const Tensor & hx,const Tensor & w_ih,const Tensor & w_hh,const std::optional<Tensor> & b_ih_opt,const std::optional<Tensor> & b_hh_opt)1640 Tensor gru_cell(
1641     const Tensor& input, const Tensor& hx,
1642     const Tensor& w_ih, const Tensor& w_hh, const std::optional<Tensor>& b_ih_opt, const std::optional<Tensor>& b_hh_opt) {
1643   // See [Note: hacky wrapper removal for optional tensor]
1644   c10::MaybeOwned<Tensor> b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt);
1645   const Tensor& b_ih = *b_ih_maybe_owned;
1646   const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();});
1647 
1648   check_rnn_cell_forward_input(input, w_ih.size(1));
1649   check_rnn_cell_forward_hidden(input, hx, w_hh.size(1), 0);
1650   static at::Tensor undefined;
1651   return GRUCell<CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
1652 }
1653 
rnn_tanh_cell(const Tensor & input,const Tensor & hx,const Tensor & w_ih,const Tensor & w_hh,const std::optional<Tensor> & b_ih_opt,const std::optional<Tensor> & b_hh_opt)1654 Tensor rnn_tanh_cell(
1655     const Tensor& input, const Tensor& hx,
1656     const Tensor& w_ih, const Tensor& w_hh, const std::optional<Tensor>& b_ih_opt, const std::optional<Tensor>& b_hh_opt) {
1657   // See [Note: hacky wrapper removal for optional tensor]
1658   c10::MaybeOwned<Tensor> b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt);
1659   const Tensor& b_ih = *b_ih_maybe_owned;
1660   const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();});
1661 
1662   static at::Tensor undefined;
1663   check_rnn_cell_forward_input(input, w_ih.size(1));
1664   check_rnn_cell_forward_hidden(input, hx, w_hh.size(1), 0);
1665   return SimpleCell<tanh_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
1666 }
1667 
rnn_relu_cell(const Tensor & input,const Tensor & hx,const Tensor & w_ih,const Tensor & w_hh,const std::optional<Tensor> & b_ih_opt,const std::optional<Tensor> & b_hh_opt)1668 Tensor rnn_relu_cell(
1669     const Tensor& input, const Tensor& hx,
1670     const Tensor& w_ih, const Tensor& w_hh, const std::optional<Tensor>& b_ih_opt, const std::optional<Tensor>& b_hh_opt) {
1671   // See [Note: hacky wrapper removal for optional tensor]
1672   c10::MaybeOwned<Tensor> b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt);
1673   const Tensor& b_ih = *b_ih_maybe_owned;
1674   const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();});
1675 
1676   static at::Tensor undefined;
1677   check_rnn_cell_forward_input(input, w_ih.size(1));
1678   check_rnn_cell_forward_hidden(input, hx, w_hh.size(1), 0);
1679   return SimpleCell<relu_f, CellParams>{}(input, hx, CellParams{w_ih, w_hh, b_ih, b_hh, undefined});
1680 }
1681 
1682 // Quantized implementations
1683 //
1684 // These implementations use FBGEMM to do the i2h and h2h linear layers with
1685 // an int8 or float16 quantized weight. This is advantageous in small-batch-size
1686 // scenarios where runtime is dominated by memory fetches of the weight matrix.
1687 
quantized_lstm_input(const Tensor & _input,c10::List<at::Tensor> hx_,c10::List<c10::intrusive_ptr<CellParamsBase>> _params_,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first,std::optional<ScalarType> dtype,bool use_dynamic)1688 static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input(
1689     const Tensor& _input,
1690     c10::List<at::Tensor> hx_,
1691     c10::List<c10::intrusive_ptr<CellParamsBase>> _params_,
1692     bool has_biases,
1693     int64_t num_layers,
1694     double dropout_p,
1695     bool train,
1696     bool bidirectional,
1697     bool batch_first,
1698     std::optional<ScalarType> dtype,
1699     bool use_dynamic) {
1700   auto hx = hx_.vec();
1701   std::vector<QRNNCellParamsWrapper> params;
1702   params.reserve(_params_.size());
1703   for (const auto& param : _params_) {
1704     params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
1705   }
1706   TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
1707   TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");
1708   auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;
1709   auto input = batch_first ? _input.transpose(0, 1) : _input;
1710   TORCH_CHECK(has_biases, "quantized LSTM requires biases");
1711   TORCH_CHECK(
1712       result_dtype == at::kChar || result_dtype == at::kQInt8 ||
1713           result_dtype == at::kHalf,
1714       "dtype is not supported");
1715 
1716   auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(
1717         input, params, hx[0], hx[1], num_layers,
1718         dropout_p, train, bidirectional);
1719 
1720   if (batch_first) {
1721     std::get<0>(results) = std::get<0>(results).transpose(0, 1);
1722   }
1723   return results;
1724 }
1725 
1726 // BC wrappers for quantized_lstm
1727 
quantized_lstm_input_legacy(const Tensor & _input,c10::List<at::Tensor> hx_,c10::List<at::Tensor> _params_,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,bool batch_first,std::optional<ScalarType> dtype,bool use_dynamic)1728 static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input_legacy(
1729     const Tensor& _input,
1730     c10::List<at::Tensor> hx_,
1731     c10::List<at::Tensor> _params_,
1732     bool has_biases,
1733     int64_t num_layers,
1734     double dropout_p,
1735     bool train,
1736     bool bidirectional,
1737     bool batch_first,
1738     std::optional<ScalarType> dtype,
1739     bool use_dynamic) {
1740   TORCH_CHECK(
1741       false,
1742       "torch.quantized_lstm with List[Tensor] for parameters is "
1743       "no longer supported. Please re-export your model "
1744       "using the newer definitions in torch.jit.quantized");
1745 }
1746 
quantized_lstm_data(const Tensor & data,const Tensor & batch_sizes,const c10::List<at::Tensor> & hx_,const c10::List<c10::intrusive_ptr<CellParamsBase>> & _params_,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,std::optional<ScalarType> dtype,bool use_dynamic)1747 static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
1748     const Tensor& data,
1749     const Tensor& batch_sizes,
1750     const c10::List<at::Tensor>& hx_,
1751     const c10::List<c10::intrusive_ptr<CellParamsBase>>& _params_,
1752     bool has_biases,
1753     int64_t num_layers,
1754     double dropout_p,
1755     bool train,
1756     bool bidirectional,
1757     std::optional<ScalarType> dtype,
1758     bool use_dynamic) {
1759   auto hx = hx_.vec();
1760   std::vector<QRNNCellParamsWrapper> params;
1761   params.reserve(_params_.size());
1762   for (const auto& param : _params_) {
1763     params.emplace_back(static_cast<c10::intrusive_ptr<CellParamsBase>>(param));
1764   }
1765   TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
1766   TORCH_CHECK(hx[0].size(2) == hx[1].size(2), "quantized LSTM with projections is not supported");
1767 
1768   PackedSequence input { data, batch_sizes };
1769   auto results = _lstm_impl<PackedLayer, PackedBidirectionalLayer>(
1770         input, params, hx[0], hx[1], num_layers,
1771         dropout_p, train, bidirectional);
1772   auto & packed_output = std::get<0>(results);
1773   return std::make_tuple(std::move(packed_output.data),
1774                          std::move(std::get<1>(results)),
1775                          std::move(std::get<2>(results)));
1776 }
1777 
quantized_lstm_data_legacy(const Tensor & data,const Tensor & batch_sizes,c10::List<at::Tensor> hx_,c10::List<at::Tensor> _params_,bool has_biases,int64_t num_layers,double dropout_p,bool train,bool bidirectional,std::optional<ScalarType> dtype,bool use_dynamic)1778 static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
1779     const Tensor& data,
1780     const Tensor& batch_sizes,
1781     c10::List<at::Tensor> hx_,
1782     c10::List<at::Tensor> _params_,
1783     bool has_biases,
1784     int64_t num_layers,
1785     double dropout_p,
1786     bool train,
1787     bool bidirectional,
1788     std::optional<ScalarType> dtype,
1789     bool use_dynamic) {
1790   TORCH_CHECK(
1791       false,
1792       "torch.quantized_lstm with List[Tensor] for parameters is "
1793       "no longer supported. Please re-export your model "
1794       "using the newer definitions in torch.jit.quantized");
1795 }
1796 
1797 #define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
1798 return_type name( \
1799     const Tensor& input, \
1800     hx_type hx, \
1801     const Tensor& w_ih, \
1802     const Tensor& w_hh, \
1803     const Tensor& b_ih, \
1804     const Tensor& b_hh, \
1805     const Tensor& packed_ih, \
1806     const Tensor& packed_hh, \
1807     const Tensor& col_offsets_ih, \
1808     const Tensor& col_offsets_hh, \
1809     const Scalar& scale_ih, \
1810     const Scalar& scale_hh, \
1811     const Scalar& zero_point_ih, \
1812     const Scalar& zero_point_hh) { \
1813   QuantizedCellParams params( \
1814       w_ih, \
1815       w_hh, \
1816       b_ih, \
1817       b_hh, \
1818       packed_ih, \
1819       packed_hh, \
1820       col_offsets_ih, \
1821       col_offsets_hh, \
1822       scale_ih, \
1823       scale_hh, \
1824       zero_point_ih, \
1825       zero_point_hh); \
1826   return cell_type{}( \
1827       input, prepare_hx_fn(hx), params); \
1828 }
1829 // Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels
1830 // QNNPACK does not reduce range for activations
1831 #define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \
1832 return_type name( \
1833     const Tensor& input, \
1834     hx_type hx, \
1835     c10::intrusive_ptr<LinearPackedParamsBase> _packed_w_ih, \
1836     c10::intrusive_ptr<LinearPackedParamsBase> _packed_w_hh, \
1837     const Tensor& b_ih, \
1838     const Tensor& b_hh \
1839  ) { \
1840   QuantizedCellParamsDynamic params( \
1841       _packed_w_ih, \
1842       _packed_w_hh, \
1843       b_ih, \
1844       b_hh,\
1845       true); \
1846   return cell_type{}( \
1847       input, prepare_hx_fn(hx), params); \
1848 }
1849 
1850 // Quantized LSTM cell
1851 using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
1852 using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
prepare_quantized_lstm_hx(TensorList hx)1853 static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
1854   return std::make_tuple(hx[0], hx[1]);
1855 }
1856 
1857 // Quantized LSTM cell
1858 using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;
1859 
1860 DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
1861 
1862 static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
1863 
1864 // Helpers for simpler cells
1865 using simple_hx_type = const Tensor&;
prepare_quantized_hx(simple_hx_type hx)1866 static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
1867   return hx;
1868 }
1869 
1870 // Quantized GRU cell
1871 using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
1872 using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;
1873 
1874 DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
1875 
1876 static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
1877 
1878 // Quantized RNN w/ ReLU cell
1879 using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
1880 DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
1881 using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
1882 static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
1883 
1884 // Quantized RNN w/ tanh cell
1885 using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
1886 DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
1887 using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
1888 static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
1889 
1890 namespace {
1891 
1892 static C10_UNUSED auto ensure_linear_params_registered = register_linear_params();
1893 
1894 static auto cell_params_base_registry =
1895     torch::selective_class_<CellParamsBase>("rnn", TORCH_SELECTIVE_CLASS("CellParamsBase"))
1896         .def_pickle(
1897             [](const c10::intrusive_ptr<CellParamsBase>& self)
__anon694b7cd90d02(const c10::intrusive_ptr<CellParamsBase>& self) 1898                 -> CellParamsSerializationType { return self->__getstate__(); },
1899             [](CellParamsSerializationType state)
__anon694b7cd90e02(CellParamsSerializationType state) 1900                 -> c10::intrusive_ptr<CellParamsBase> {
1901               std::string type = std::get<0>(state);
1902               TORCH_INTERNAL_ASSERT(cell_params_deserializers.count(type));
1903               return cell_params_deserializers[type](std::move(state));
1904             });
1905 
TORCH_LIBRARY_FRAGMENT(aten,m)1906 TORCH_LIBRARY_FRAGMENT(aten, m) {
1907   m.def(
1908       TORCH_SELECTIVE_SCHEMA("aten::quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"));
1909   m.def(
1910       TORCH_SELECTIVE_SCHEMA("aten::quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"));
1911   m.def(
1912       TORCH_SELECTIVE_SCHEMA("aten::quantized_lstm.input_legacy(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"));
1913   m.def(
1914       TORCH_SELECTIVE_SCHEMA("aten::quantized_lstm.data_legacy(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"));
1915   m.def(
1916       TORCH_SELECTIVE_SCHEMA("aten::quantized_gru.input(Tensor input, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)"));
1917   m.def(
1918       TORCH_SELECTIVE_SCHEMA("aten::quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)"));
1919   m.def(
1920       TORCH_SELECTIVE_SCHEMA("aten::quantized_gru.input_legacy(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)"));
1921   m.def(
1922       TORCH_SELECTIVE_SCHEMA("aten::quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)"));
1923 }
1924 
TORCH_LIBRARY_FRAGMENT(quantized,m)1925 TORCH_LIBRARY_FRAGMENT(quantized, m) {
1926   m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"));
1927   m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
1928   m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
1929   m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"));
1930   m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
1931   m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
1932   m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_tanh_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
1933 }
1934 
TORCH_LIBRARY_IMPL(aten,CPU,m)1935 TORCH_LIBRARY_IMPL(aten, CPU, m) {
1936   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_lstm.input"), TORCH_FN(quantized_lstm_input));
1937   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_lstm.data"), TORCH_FN(quantized_lstm_data));
1938   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_lstm.input_legacy"), TORCH_FN(quantized_lstm_input_legacy));
1939   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_lstm.data_legacy"), TORCH_FN(quantized_lstm_data_legacy));
1940   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_gru.input"), TORCH_FN(quantized_gru_input));
1941   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_gru.data"), TORCH_FN(quantized_gru_data));
1942   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_gru.input_legacy"), TORCH_FN(quantized_gru_input_legacy));
1943   m.impl(TORCH_SELECTIVE_NAME("aten::quantized_gru.data_legacy"), TORCH_FN(quantized_gru_data_legacy));
1944 }
1945 
TORCH_LIBRARY_IMPL(quantized,CPU,m)1946 TORCH_LIBRARY_IMPL(quantized, CPU, m) {
1947   m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic));
1948   m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params));
1949   m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic));
1950   m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic));
1951   m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic));
1952   m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_tanh_cell_dynamic"), TORCH_FN(quantized_rnn_tanh_cell_dynamic));
1953 }
1954 
TORCH_LIBRARY_IMPL(quantized,CatchAll,m)1955 TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
1956   m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_fp16"), TORCH_FN(make_quantized_cell_params_fp16));
1957 }
1958 
1959 } // namespace
1960 }  // namespace at::native
1961