xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cudnn_rnn_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16 
17 #include <stddef.h>
18 
19 #include <atomic>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/gpu_utils.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/stringpiece.h"
42 #include "tensorflow/core/lib/gtl/inlined_vector.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/fingerprint.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
49 #include "tensorflow/core/util/env_var.h"
50 #include "tensorflow/core/util/use_cudnn.h"
51 
52 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53 #include "tensorflow/core/platform/stream_executor.h"
54 #include "tensorflow/core/util/stream_executor_util.h"
55 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
56 
57 /*
58  * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
59  * using the underlying Cudnn library.
60  *
61  * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
62  * format. And it is very likely that if saved, they cannot be used across
63  * different GPUs. So users need to first query the size of the opaque
64  * parameter buffer, and convert it to and from its canonical forms. But each
65  * actual training step is carried out with the parameter buffer.
66  *
67  * Similar to many other ops, the forward op has two flavors: training and
68  * inference. When training is specified, additional data in reserve_space will
69  * be produced for the backward pass. So there is a performance penalty.
70  *
71  * In addition to the actual data and reserve_space, Cudnn also needs more
72  * memory as temporary workspace. The memory management to and from
73  * stream-executor is done through ScratchAllocator. In general,
74  * stream-executor is responsible for creating the memory of proper size. And
75  * TensorFlow is responsible for making sure the memory is alive long enough
76  * and recycles afterwards.
77  *
78  */
79 namespace tensorflow {
80 
81 using CPUDevice = Eigen::ThreadPoolDevice;
82 
83 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
84 
85 using GPUDevice = Eigen::GpuDevice;
86 using se::Stream;
87 using se::StreamExecutor;
88 using se::dnn::RnnDescriptor;
89 
90 template <typename Device, typename T, typename Index>
91 class CudnnRNNParamsSizeOp;
92 
93 template <typename Device, typename T>
94 class CudnnRNNParamsToCanonical;
95 
96 template <typename Device, typename T>
97 class CudnnRNNCanonicalToParams;
98 
99 template <typename Device, typename T>
100 class CudnnRNNForwardOp;
101 
102 template <typename Device, typename T>
103 class CudnnRNNBackwardOp;
104 
105 template <typename Device, typename T>
106 class CudnnRNNForwardOpV2;
107 
108 template <typename Device, typename T>
109 class CudnnRNNBackwardOpV2;
110 
111 template <typename Device, typename T>
112 class CudnnRNNForwardOpV3;
113 
114 template <typename Device, typename T>
115 class CudnnRNNBackwardOpV3;
116 
117 enum class TFRNNInputMode {
118   kRNNLinearInput = 0,
119   kRNNSkipInput = 1,
120   kAutoSelect = 9999999
121 };
122 
123 namespace {
124 using se::DeviceMemory;
125 using se::DeviceMemoryBase;
126 using se::ScratchAllocator;
127 using se::dnn::AlgorithmConfig;
128 using se::dnn::AlgorithmDesc;
129 using se::dnn::ProfileResult;
130 using se::dnn::RnnDirectionMode;
131 using se::dnn::RnnInputMode;
132 using se::dnn::RnnMode;
133 using se::dnn::RnnSequenceTensorDescriptor;
134 using se::dnn::RnnStateTensorDescriptor;
135 using se::dnn::ToDataType;
136 using se::port::StatusOr;
137 
HashList(const std::vector<int> & list)138 uint64 HashList(const std::vector<int>& list) {
139   if (list.empty()) {
140     return 0;
141   }
142   uint64 hash_code = list[0];
143   for (int i = 1; i < list.size(); i++) {
144     hash_code = Hash64Combine(hash_code, list[i]);
145   }
146   return hash_code;
147 }
148 
149 // Encapsulate all the shape information that is used in both forward and
150 // backward rnn operations.
151 class CudnnRnnParameters {
152  public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)153   CudnnRnnParameters(int num_layers, int input_size, int num_units,
154                      int max_seq_length, int batch_size, int dir_count,
155                      bool has_dropout, bool is_training, RnnMode rnn_mode,
156                      TFRNNInputMode rnn_input_mode, DataType dtype)
157       : num_layers_(num_layers),
158         input_size_(input_size),
159         num_units_(num_units),
160         seq_length_(max_seq_length),
161         batch_size_(batch_size),
162         dir_count_(dir_count),
163         has_dropout_(has_dropout),
164         is_training_(is_training),
165         rnn_mode_(rnn_mode),
166         rnn_input_mode_(rnn_input_mode),
167         dtype_(dtype) {
168     hash_code_ =
169         HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
170                   dir_count, static_cast<int>(has_dropout),
171                   static_cast<int>(is_training), static_cast<int>(rnn_mode),
172                   static_cast<int>(rnn_input_mode), dtype});
173   }
174 
operator ==(const CudnnRnnParameters & other) const175   bool operator==(const CudnnRnnParameters& other) const {
176     return this->get_data_as_tuple() == other.get_data_as_tuple();
177   }
178 
operator !=(const CudnnRnnParameters & other) const179   bool operator!=(const CudnnRnnParameters& other) const {
180     return !(*this == other);
181   }
hash() const182   uint64 hash() const { return hash_code_; }
183 
ToString() const184   string ToString() const {
185     std::vector<string> fields = {
186         std::to_string(num_layers_),
187         std::to_string(input_size_),
188         std::to_string(num_units_),
189         std::to_string(seq_length_),
190         std::to_string(batch_size_),
191         std::to_string(dir_count_),
192         std::to_string(has_dropout_),
193         std::to_string(is_training_),
194         std::to_string(static_cast<int>(rnn_mode_)),
195         std::to_string(static_cast<int>(rnn_input_mode_)),
196         std::to_string(static_cast<int>(dtype_))};
197     return absl::StrJoin(fields, ", ");
198   }
199 
200  private:
201   using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
202                                        RnnMode, TFRNNInputMode, DataType>;
203 
get_data_as_tuple() const204   ParameterDataType get_data_as_tuple() const {
205     return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
206                            batch_size_, dir_count_, has_dropout_, is_training_,
207                            rnn_mode_, rnn_input_mode_, dtype_);
208   }
209 
210   const int num_layers_;
211   const int input_size_;
212   const int num_units_;
213   const int seq_length_;
214   const int batch_size_;
215   const int dir_count_;
216   const bool has_dropout_;
217   const bool is_training_;
218   const RnnMode rnn_mode_;
219   const TFRNNInputMode rnn_input_mode_;
220   const DataType dtype_;
221   uint64 hash_code_;
222 };
223 
224 struct RnnAutotuneGroup {
nametensorflow::__anon980e2df00111::RnnAutotuneGroup225   static string name() { return "Rnn"; }
226 };
227 
228 using AutotuneRnnConfigMap =
229     AutotuneSingleton<RnnAutotuneGroup, CudnnRnnParameters, AlgorithmConfig>;
230 
ParseRNNMode(const string & str,RnnMode * rnn_mode)231 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
232   if (str == "rnn_relu") {
233     *rnn_mode = RnnMode::kRnnRelu;
234     return OkStatus();
235   } else if (str == "rnn_tanh") {
236     *rnn_mode = RnnMode::kRnnTanh;
237     return OkStatus();
238   } else if (str == "lstm") {
239     *rnn_mode = RnnMode::kRnnLstm;
240     return OkStatus();
241   } else if (str == "gru") {
242     *rnn_mode = RnnMode::kRnnGru;
243     return OkStatus();
244   }
245   return errors::InvalidArgument("Invalid RNN mode: ", str);
246 }
247 
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)248 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
249   if (str == "linear_input") {
250     *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
251     return OkStatus();
252   } else if (str == "skip_input") {
253     *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
254     return OkStatus();
255   } else if (str == "auto_select") {
256     *rnn_input_mode = TFRNNInputMode::kAutoSelect;
257     return OkStatus();
258   }
259   return errors::InvalidArgument("Invalid RNN input mode: ", str);
260 }
261 
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)262 Status ParseRNNDirectionMode(const string& str,
263                              RnnDirectionMode* rnn_dir_mode) {
264   if (str == "unidirectional") {
265     *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
266     return OkStatus();
267   } else if (str == "bidirectional") {
268     *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
269     return OkStatus();
270   }
271   return errors::InvalidArgument("Invalid RNN direction mode: ", str);
272 }
273 
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)274 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
275                       int input_size, RnnInputMode* input_mode) {
276   switch (tf_input_mode) {
277     case TFRNNInputMode::kRNNLinearInput:
278       *input_mode = RnnInputMode::kRnnLinearSkip;
279       break;
280     case TFRNNInputMode::kRNNSkipInput:
281       *input_mode = RnnInputMode::kRnnSkipInput;
282       break;
283     case TFRNNInputMode::kAutoSelect:
284       *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
285                                               : RnnInputMode::kRnnLinearSkip;
286       break;
287     default:
288       return errors::InvalidArgument("Invalid TF input mode: ",
289                                      static_cast<int>(tf_input_mode));
290   }
291   return OkStatus();
292 }
293 
294 // TODO(zhengxq): Merge those into stream_executor_util.h.
295 template <typename T>
AsDeviceMemory(const Tensor * tensor)296 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
297   return DeviceMemory<T>::MakeFromByteSize(
298       const_cast<T*>(tensor->template flat<T>().data()),
299       tensor->template flat<T>().size() * sizeof(T));
300 }
301 
302 template <typename T>
AsDeviceMemory(Tensor * tensor)303 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
304   return DeviceMemory<T>::MakeFromByteSize(
305       tensor->template flat<T>().data(),
306       tensor->template flat<T>().size() * sizeof(T));
307 }
308 
309 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)310 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
311   return DeviceMemory<U>::MakeFromByteSize(
312       tensor->template flat<T>().data(),
313       tensor->template flat<T>().size() * sizeof(T));
314 }
315 
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64_t offset,int64_t size)316 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
317                                    int64_t offset, int64_t size) {
318   const void* base_ptr = device_memory.opaque();
319   void* offset_ptr =
320       const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
321   CHECK(offset + size <= device_memory.size())
322       << "The slice is not within the region of DeviceMemory.";
323   return DeviceMemoryBase(offset_ptr, size);
324 }
325 
FromExecutorStatus(const se::port::Status & s)326 inline Status FromExecutorStatus(const se::port::Status& s) {
327   return s.ok() ? OkStatus()
328                 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
329                          s.error_message());
330 }
331 
332 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)333 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
334   return FromExecutorStatus(s.status());
335 }
336 
ToExecutorStatus(const Status & s)337 inline se::port::Status ToExecutorStatus(const Status& s) {
338   return s.ok() ? OkStatus()
339                 : se::port::Status(static_cast<se::port::error::Code>(
340                                        static_cast<int>(s.code())),
341                                    s.error_message());
342 }
343 
344 template <typename>
345 struct ToTFDataType;
346 
347 template <>
348 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
349 
350 template <>
351 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
352 
353 template <>
354 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
355 
356 template <>
357 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
358 
359 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
360 // takes the ownership of the underlying memory. The expectation is that the
361 // memory should be alive for the span of the Cudnn RNN itself.
362 template <typename T>
363 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
364  public:
365   ~CudnnRnnAllocatorInTemp() override = default;
366 
CudnnRnnAllocatorInTemp(OpKernelContext * context)367   explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
368       : context_(context) {}
GetMemoryLimitInBytes()369   int64_t GetMemoryLimitInBytes() override {
370     return std::numeric_limits<int64_t>::max();
371   }
372 
AllocateBytes(int64_t byte_size)373   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
374     Tensor temporary_memory;
375     const DataType tf_data_type = ToTFDataType<T>::value;
376     int64_t allocate_count =
377         Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
378     Status allocation_status(context_->allocate_temp(
379         tf_data_type, TensorShape({allocate_count}), &temporary_memory));
380     if (!allocation_status.ok()) {
381       return ToExecutorStatus(allocation_status);
382     }
383     // Hold the reference of the allocated tensors until the end of the
384     // allocator.
385     allocated_tensors_.push_back(temporary_memory);
386     total_byte_size_ += byte_size;
387     return DeviceMemory<uint8>::MakeFromByteSize(
388         temporary_memory.template flat<T>().data(),
389         temporary_memory.template flat<T>().size() * sizeof(T));
390   }
391 
TotalByteSize() const392   int64_t TotalByteSize() const { return total_byte_size_; }
393 
get_allocated_tensor(int index) const394   Tensor get_allocated_tensor(int index) const {
395     return allocated_tensors_[index];
396   }
397 
398  private:
399   int64_t total_byte_size_ = 0;
400   OpKernelContext* context_;  // not owned
401   std::vector<Tensor> allocated_tensors_;
402 };
403 
404 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
405 // used by forward pass kernel to feed the output to the backward pass.
406 // The memory is expected to live long enough after the backward pass is
407 // finished.
408 template <typename T>
409 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
410  public:
~CudnnRnnAllocatorInOutput()411   ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)412   CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
413       : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()414   int64_t GetMemoryLimitInBytes() override {
415     return std::numeric_limits<int64_t>::max();
416   }
AllocateBytes(int64_t byte_size)417   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
418     CHECK(total_byte_size_ == 0)
419         << "Reserve space allocator can only be called once";
420     int64_t allocate_count =
421         Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
422 
423     Tensor* temporary_memory = nullptr;
424     Status allocation_status(context_->allocate_output(
425         output_index_, TensorShape({allocate_count}), &temporary_memory));
426     if (!allocation_status.ok()) {
427       return ToExecutorStatus(allocation_status);
428     }
429     total_byte_size_ += byte_size;
430     auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
431         temporary_memory->template flat<T>().data(),
432         temporary_memory->template flat<T>().size() * sizeof(T));
433     return StatusOr<DeviceMemory<uint8>>(memory_uint8);
434   }
TotalByteSize()435   int64_t TotalByteSize() { return total_byte_size_; }
436 
437  private:
438   int64_t total_byte_size_ = 0;
439   OpKernelContext* context_;  // not owned
440   int output_index_;
441 };
442 
443 // A helper to allocate memory for Cudnn RNN models, which is
444 // expected to live between kernel invocations.
445 // This class is not thread-safe.
446 class CudnnRNNSpaceAllocator : public ScratchAllocator {
447  public:
CudnnRNNSpaceAllocator(OpKernelContext * context)448   explicit CudnnRNNSpaceAllocator(OpKernelContext* context)
449       : context_(context) {}
450 
~CudnnRNNSpaceAllocator()451   ~CudnnRNNSpaceAllocator() override {}
452 
GetMemoryLimitInBytes()453   int64_t GetMemoryLimitInBytes() override {
454     return std::numeric_limits<int64_t>::max();
455   }
456 
AllocateBytes(int64_t byte_size)457   StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
458     if (total_byte_size_ != 0) {
459       return Status(error::FAILED_PRECONDITION,
460                     "Space allocator can only be called once");
461     }
462 
463     Status allocation_status =
464         context_->allocate_temp(DT_UINT8, TensorShape({byte_size}), &tensor_);
465     if (!allocation_status.ok()) {
466       return ToExecutorStatus(allocation_status);
467     }
468     total_byte_size_ += byte_size;
469     return AsDeviceMemory<uint8>(&tensor_);
470   }
TotalByteSize()471   int64_t TotalByteSize() { return total_byte_size_; }
472 
473  private:
474   int64_t total_byte_size_ = 0;
475   Tensor tensor_;
476   OpKernelContext* context_;  // not owned
477 };
478 
479 struct CudnnModelTypes {
480   RnnMode rnn_mode;
481   TFRNNInputMode rnn_input_mode;
482   RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anon980e2df00111::CudnnModelTypes483   bool HasInputC() const {
484     // For Cudnn 5.0, only LSTM has input-c. All other models use only
485     // input-h.
486     return rnn_mode == RnnMode::kRnnLstm;
487   }
488 
DebugStringtensorflow::__anon980e2df00111::CudnnModelTypes489   string DebugString() const {
490     return strings::Printf(
491         "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
492         static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
493         static_cast<int>(rnn_direction_mode));
494   }
495 };
496 
497 // A helper class that collects the shapes to describe a RNN model.
498 struct CudnnRnnModelShapes {
499   int num_layers;
500   int input_size;
501   int num_units;
502   int dir_count;
503   int max_seq_length;
504   int batch_size;
505   int cell_num_units = 0;
506   // If you add new field to this structure, please take care of
507   // updating IsCompatibleWith() below as well as the hash function in
508   // CudnnRnnConfigHasher.
509   TensorShape input_shape;
510   TensorShape output_shape;
511   TensorShape hidden_state_shape;
512   TensorShape cell_state_shape;
513   // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anon980e2df00111::CudnnRnnModelShapes514   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
515     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
516            num_units == rhs.num_units && dir_count == rhs.dir_count &&
517            cell_num_units == rhs.cell_num_units &&
518            max_seq_length == rhs.max_seq_length;
519   }
DebugStringtensorflow::__anon980e2df00111::CudnnRnnModelShapes520   string DebugString() const {
521     return strings::Printf(
522         "[num_layers, input_size, num_units, dir_count, max_seq_length, "
523         "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
524         num_layers, input_size, num_units, dir_count, max_seq_length,
525         batch_size, cell_num_units);
526   }
527 };
528 
529 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
530 // key.
531 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anon980e2df00111::CudnnRnnConfigHasher532   uint64 operator()(
533       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
534           to_hash) const {
535     auto& shapes = to_hash.first;
536     auto& algo_desc = to_hash.second;
537 
538     uint64 hash =
539         HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
540                   shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
541     if (algo_desc.has_value()) {
542       hash = Hash64Combine(hash, algo_desc->hash());
543     }
544     return hash;
545   }
546 };
547 
548 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
549 // table key.
550 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anon980e2df00111::CudnnRnnConfigComparator551   bool operator()(
552       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
553       const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
554       const {
555     return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
556   }
557 };
558 
559 // Pointers to RNN scratch space for a specific set of shape parameters (used as
560 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
561 struct RnnScratchSpace {
562   std::unique_ptr<RnnDescriptor> rnn_desc;
563   std::unique_ptr<CudnnRNNSpaceAllocator> dropout_state_allocator;
564 };
565 
566 // Extract and checks the forward input tensors, parameters, and shapes from the
567 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)568 Status ExtractForwardInput(OpKernelContext* context,
569                            const CudnnModelTypes& model_types, bool time_major,
570                            const Tensor** input, const Tensor** input_h,
571                            const Tensor** input_c, const Tensor** params,
572                            const int num_proj,
573                            CudnnRnnModelShapes* model_shapes) {
574   TF_RETURN_IF_ERROR(context->input("input", input));
575   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
576   if (model_types.HasInputC()) {
577     TF_RETURN_IF_ERROR(context->input("input_c", input_c));
578   }
579   TF_RETURN_IF_ERROR(context->input("params", params));
580 
581   if ((*input)->dims() != 3) {
582     return errors::InvalidArgument("RNN input must be a 3-D vector.");
583   }
584   if (time_major) {
585     model_shapes->max_seq_length = (*input)->dim_size(0);
586     model_shapes->batch_size = (*input)->dim_size(1);
587   } else {
588     model_shapes->max_seq_length = (*input)->dim_size(1);
589     model_shapes->batch_size = (*input)->dim_size(0);
590   }
591   model_shapes->input_size = (*input)->dim_size(2);
592   model_shapes->input_shape = (*input)->shape();
593   model_shapes->dir_count =
594       (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
595           ? 2
596           : 1;
597 
598   if ((*input_h)->dims() != 3) {
599     return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
600   }
601   if (time_major) {
602     model_shapes->num_layers =
603         (*input_h)->dim_size(0) / model_shapes->dir_count;
604   } else {
605     model_shapes->num_layers =
606         (*input_h)->dim_size(1) / model_shapes->dir_count;
607   }
608   model_shapes->num_units = (*input_h)->dim_size(2);
609 
610   if (time_major) {
611     model_shapes->hidden_state_shape =
612         TensorShape({model_shapes->dir_count * model_shapes->num_layers,
613                      model_shapes->batch_size, model_shapes->num_units});
614   } else {
615     model_shapes->hidden_state_shape =
616         TensorShape({model_shapes->batch_size,
617                      model_shapes->dir_count * model_shapes->num_layers,
618                      model_shapes->num_units});
619   }
620   if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
621     return errors::InvalidArgument(
622         "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
623         model_shapes->hidden_state_shape.DebugString());
624   }
625   if (model_types.HasInputC()) {
626     model_shapes->cell_num_units = (*input_c)->dim_size(2);
627     if (time_major) {
628       model_shapes->cell_state_shape =
629           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
630                        model_shapes->batch_size, model_shapes->cell_num_units});
631     } else {
632       model_shapes->cell_state_shape =
633           TensorShape({model_shapes->batch_size,
634                        model_shapes->dir_count * model_shapes->num_layers,
635                        model_shapes->cell_num_units});
636     }
637     if (num_proj == 0) {
638       if ((*input_h)->shape() != (*input_c)->shape()) {
639         return errors::InvalidArgument(
640             "input_h and input_c must have the same shape w/o projection: ",
641             (*input_h)->shape().DebugString(), " ",
642             (*input_c)->shape().DebugString());
643       }
644     } else {
645       if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
646           num_proj != (*input_h)->dim_size(2) ||
647           (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
648           (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
649         return errors::InvalidArgument(
650             "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
651             (*input_h)->shape().DebugString(), " ",
652             (*input_c)->shape().DebugString());
653       }
654     }
655   } else {
656     // dummy cell_state_shape TODO(kaixih): remove the time_major branch
657     if (time_major) {
658       model_shapes->cell_state_shape =
659           TensorShape({model_shapes->dir_count * model_shapes->num_layers,
660                        model_shapes->batch_size, model_shapes->num_units});
661     } else {
662       model_shapes->cell_state_shape =
663           TensorShape({model_shapes->batch_size,
664                        model_shapes->dir_count * model_shapes->num_layers,
665                        model_shapes->num_units});
666     }
667     model_shapes->cell_num_units = 0;
668   }
669   if (time_major) {
670     model_shapes->output_shape =
671         TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
672                      model_shapes->dir_count * model_shapes->num_units});
673   } else {
674     model_shapes->output_shape =
675         TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
676                      model_shapes->dir_count * model_shapes->num_units});
677   }
678   return OkStatus();
679 }
680 
681 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)682 Status ExtractForwardInput(OpKernelContext* context,
683                            const CudnnModelTypes& model_types, bool time_major,
684                            const Tensor** input, const Tensor** input_h,
685                            const Tensor** input_c, const Tensor** params,
686                            const Tensor** sequence_lengths, const int num_proj,
687                            CudnnRnnModelShapes* model_shapes) {
688   TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
689   return ExtractForwardInput(context, model_types, time_major, input, input_h,
690                              input_c, params, num_proj, model_shapes);
691 }
692 
693 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> seq_lengths,bool time_major)694 Status CreateForwardAndBackwardIODescriptors(
695     OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
696     std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
697     std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
698     std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
699     std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
700     const absl::Span<const int> seq_lengths, bool time_major) {
701   StreamExecutor* executor = context->op_device_context()->stream()->parent();
702   se::dnn::DataType data_type = ToDataType<T>::value;
703 
704   const TensorShape& input_shape = model_shapes.input_shape;
705   const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
706   const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
707   const TensorShape& output_shape = model_shapes.output_shape;
708 
709   DCHECK_EQ(input_shape.dims(), 3);
710   if (seq_lengths.data() != nullptr) {
711     if (time_major) {
712       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
713           input_shape.dim_size(0), input_shape.dim_size(1),
714           input_shape.dim_size(2), seq_lengths, time_major, data_type);
715       TF_RETURN_IF_ERROR(input_desc_s.status());
716       *input_desc = std::move(input_desc_s).value();
717     } else {
718       auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
719           input_shape.dim_size(1), input_shape.dim_size(0),
720           input_shape.dim_size(2), seq_lengths, time_major, data_type);
721       TF_RETURN_IF_ERROR(input_desc_s.status());
722       *input_desc = std::move(input_desc_s).value();
723     }
724   } else {
725     auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
726         input_shape.dim_size(0), input_shape.dim_size(1),
727         input_shape.dim_size(2), data_type);
728     TF_RETURN_IF_ERROR(input_desc_s.status());
729     *input_desc = std::move(input_desc_s).value();
730   }
731 
732   DCHECK_EQ(hidden_state_shape.dims(), 3);
733   if (time_major) {
734     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
735         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
736         hidden_state_shape.dim_size(2), data_type);
737     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
738     *h_state_desc = std::move(hidden_state_desc_s).value();
739   } else {
740     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
741         hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
742         hidden_state_shape.dim_size(2), data_type);
743     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
744     *h_state_desc = std::move(hidden_state_desc_s).value();
745   }
746 
747   DCHECK_EQ(cell_state_shape.dims(), 3);
748   if (time_major) {
749     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
750         cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
751         cell_state_shape.dim_size(2), data_type);
752     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
753     *c_state_desc = std::move(cell_state_desc_s).value();
754   } else {
755     auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
756         cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
757         cell_state_shape.dim_size(2), data_type);
758     TF_RETURN_IF_ERROR(cell_state_desc_s.status());
759     *c_state_desc = std::move(cell_state_desc_s).value();
760   }
761 
762   DCHECK_EQ(output_shape.dims(), 3);
763   if (seq_lengths.data() != nullptr) {
764     if (time_major) {
765       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
766           output_shape.dim_size(0), output_shape.dim_size(1),
767           output_shape.dim_size(2), seq_lengths, time_major, data_type);
768       TF_RETURN_IF_ERROR(output_desc_s.status());
769       *output_desc = std::move(output_desc_s).value();
770     } else {
771       auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
772           output_shape.dim_size(1), output_shape.dim_size(0),
773           output_shape.dim_size(2), seq_lengths, time_major, data_type);
774       TF_RETURN_IF_ERROR(output_desc_s.status());
775       *output_desc = std::move(output_desc_s).value();
776     }
777   } else {
778     auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
779         output_shape.dim_size(0), output_shape.dim_size(1),
780         output_shape.dim_size(2), data_type);
781     TF_RETURN_IF_ERROR(output_desc_s.status());
782     *output_desc = std::move(output_desc_s).value();
783   }
784 
785   return OkStatus();
786 }
787 
788 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)789 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
790                  const CudnnModelTypes& model_types,
791                  const CudnnRnnModelShapes& model_shapes,
792                  /* forward inputs */
793                  const Tensor* input, const Tensor* input_h,
794                  const Tensor* input_c, const Tensor* params,
795                  const bool is_training,
796                  /* forward outputs, outputs of the function */
797                  Tensor* output, Tensor* output_h, Tensor* output_c,
798                  const Tensor* sequence_lengths, bool time_major,
799                  ScratchAllocator* reserve_space_allocator,
800                  ScratchAllocator* workspace_allocator,
801                  ProfileResult* output_profile_result) {
802   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
803   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
804   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
805   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
806 
807   absl::Span<const int> seq_lengths;
808   if (sequence_lengths != nullptr) {
809     seq_lengths = absl::Span<const int>(
810         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
811   }
812   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
813       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
814       &output_desc, seq_lengths, time_major));
815 
816   auto input_data = AsDeviceMemory<T>(input);
817   auto input_h_data = AsDeviceMemory<T>(input_h);
818   DeviceMemory<T> input_c_data;
819   if (model_types.HasInputC()) {
820     input_c_data = AsDeviceMemory<T>(input_c);
821   }
822 
823   auto params_data = AsDeviceMemory<T>(params);
824   auto output_data = AsDeviceMemory<T>(output);
825   auto output_h_data = AsDeviceMemory<T>(output_h);
826   DeviceMemory<T> output_c_data;
827   if (model_types.HasInputC()) {
828     output_c_data = AsDeviceMemory<T>(output_c);
829   }
830 
831   Stream* stream = context->op_device_context()->stream();
832 
833   Tensor seq_lengths_tensor;
834   DeviceMemory<int> seq_lengths_ptr;
835   if (sequence_lengths != nullptr) {
836     TF_RETURN_IF_ERROR(context->allocate_temp(
837         DT_INT32, {static_cast<long>(seq_lengths.size())},
838         &seq_lengths_tensor));
839     seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
840     if (!stream
841              ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
842                           seq_lengths.size() * sizeof(int))
843              .ok()) {
844       return errors::InvalidArgument(
845           "Failed to copy memory from host to "
846           "device for sequence_lengths in "
847           "CudnnRNNV3");
848     }
849   }
850 
851   bool launch_success =
852       stream
853           ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr,
854                            *h_state_desc, input_h_data, *c_state_desc,
855                            input_c_data, params_data, *output_desc,
856                            &output_data, *h_state_desc, &output_h_data,
857                            *c_state_desc, &output_c_data, is_training,
858                            reserve_space_allocator, workspace_allocator,
859                            output_profile_result)
860           .ok();
861   return launch_success
862              ? OkStatus()
863              : errors::Internal(
864                    "Failed to call ThenRnnForward with model config: ",
865                    model_types.DebugString(), ", ", model_shapes.DebugString());
866 }
867 
868 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)869 Status DoBackward(
870     OpKernelContext* context, const RnnDescriptor& rnn_desc,
871     const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
872     /* forward inputs */
873     const Tensor* input, const Tensor* input_h, const Tensor* input_c,
874     const Tensor* params,
875     /* forward outputs */
876     const Tensor* output, const Tensor* output_h, const Tensor* output_c,
877     /* backprop inputs */
878     const Tensor* output_backprop, const Tensor* output_h_backprop,
879     const Tensor* output_c_backprop, const Tensor* reserve_space,
880     /* backprop outputs, output of the function */
881     Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
882     Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
883     ScratchAllocator* workspace_allocator,
884     ProfileResult* output_profile_result) {
885   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
886   std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
887   std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
888   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
889 
890   absl::Span<const int> seq_lengths;
891   if (sequence_lengths != nullptr) {
892     seq_lengths = absl::Span<const int>(
893         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
894   }
895   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
896       context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
897       &output_desc, seq_lengths, time_major));
898 
899   auto input_data = AsDeviceMemory<T>(input);
900   auto input_h_data = AsDeviceMemory<T>(input_h);
901   DeviceMemory<T> input_c_data;
902   if (model_types.HasInputC()) {
903     input_c_data = AsDeviceMemory<T>(input_c);
904   }
905   auto params_data = AsDeviceMemory<T>(params);
906   auto output_data = AsDeviceMemory<T>(output);
907   auto output_h_data = AsDeviceMemory<T>(output_h);
908   DeviceMemory<T> output_c_data;
909   if (model_types.HasInputC()) {
910     output_c_data = AsDeviceMemory<T>(output_c);
911   }
912   auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
913   auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
914   DeviceMemory<T> output_c_backprop_data;
915   if (model_types.HasInputC()) {
916     output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
917   }
918   auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
919   auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
920   DeviceMemory<T> input_c_backprop_data;
921   if (model_types.HasInputC()) {
922     input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
923   }
924   auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
925   auto reserve_space_uint8 =
926       CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
927 
928   // Creates a memory callback for the workspace. The memory lives to the end
929   // of this kernel calls.
930   Stream* stream = context->op_device_context()->stream();
931 
932   Tensor seq_lengths_tensor;
933   DeviceMemory<int> seq_lengths_ptr;
934   if (sequence_lengths != nullptr) {
935     TF_RETURN_IF_ERROR(context->allocate_temp(
936         DT_INT32, {static_cast<long>(seq_lengths.size())},
937         &seq_lengths_tensor));
938     seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
939     if (!stream
940              ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
941                           seq_lengths.size() * sizeof(int))
942              .ok()) {
943       return errors::InvalidArgument(
944           "Failed to copy memory from host to "
945           "device for sequence_lengths in "
946           "CudnnRNNBackwardOpV3");
947     }
948   }
949 
950   bool launch_success =
951       stream
952           ->ThenRnnBackward(
953               rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc,
954               input_h_data, *c_state_desc, input_c_data, params_data,
955               *output_desc, output_data, *h_state_desc, output_h_data,
956               *c_state_desc, output_c_data, output_backprop_data,
957               output_h_backprop_data, output_c_backprop_data,
958               &input_backprop_data, &input_h_backprop_data,
959               &input_c_backprop_data, &params_backprop_data,
960               &reserve_space_uint8, workspace_allocator, output_profile_result)
961           .ok();
962   return launch_success
963              ? OkStatus()
964              : errors::Internal(
965                    "Failed to call ThenRnnBackward with model config: ",
966                    model_types.DebugString(), ", ", model_shapes.DebugString());
967 }
968 
969 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)970 void RestoreParams(const OpInputList params_input,
971                    const std::vector<RnnDescriptor::ParamsRegion>& params,
972                    DeviceMemoryBase* data_dst, Stream* stream) {
973   int num_params = params.size();
974   CHECK(params_input.size() == num_params)
975       << "Number of params mismatch. Expected " << params_input.size()
976       << ", got " << num_params;
977   for (int i = 0; i < params.size(); i++) {
978     int64_t size_in_bytes = params[i].size;
979     int64_t size = size_in_bytes / sizeof(T);
980     CHECK(size == params_input[i].NumElements())
981         << "Params size mismatch. Expected " << size << ", got "
982         << params_input[i].NumElements();
983     auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
984     DeviceMemoryBase data_dst_ptr =
985         SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
986     stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
987   }
988 }
989 
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)990 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
991                        const CudnnRnnModelShapes& model_shapes,
992                        bool time_major) {
993   auto seq_array = sequence_lengths->template flat<int>().data();
994   bool all_max_seq_length = true;
995   for (int i = 0; i < model_shapes.batch_size; i++) {
996     if (seq_array[i] != model_shapes.max_seq_length) {
997       all_max_seq_length = false;
998       break;
999     }
1000   }
1001   return !(time_major && all_max_seq_length);
1002 }
1003 
1004 }  // namespace
1005 
1006 // Note: all following kernels depend on a RnnDescriptor instance, which
1007 // according to Cudnn official doc should be kept around and reused across all
1008 // Cudnn kernels in the same model.
1009 // In Tensorflow, we don't pass the reference across different OpKernels,
1010 // rather, recreate it separately in each OpKernel, which does no cause issue:
1011 // CudnnDropoutDescriptor keeps a reference to a memory for
1012 // random number generator state. During recreation, this state is lost.
1013 // However, only forward-pass Cudnn APIs make use of the state.
1014 
1015 // A common base class for RNN kernels. It extracts common attributes and
1016 // shape validations.
1017 class CudnnRNNKernelCommon : public OpKernel {
1018  protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)1019   explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
1020       : OpKernel(context) {
1021     OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
1022     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
1023     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
1024     string str;
1025     OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
1026     OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
1027     OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
1028     OP_REQUIRES_OK(context,
1029                    ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
1030     OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
1031     OP_REQUIRES_OK(
1032         context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
1033     // Reset CudnnRnnDescriptor and related random number generate states in
1034     // every Compute() call.
1035     OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
1036                                                false, &reset_rnd_gen_state_));
1037   }
1038 
HasInputC() const1039   bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const1040   RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const1041   TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1042   RnnDirectionMode rnn_direction_mode() const {
1043     return model_types_.rnn_direction_mode;
1044   }
model_types() const1045   const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1046   float dropout() const { return dropout_; }
seed()1047   uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1048   bool ResetRndGenState() { return reset_rnd_gen_state_; }
1049 
1050   template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1051   Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1052                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
1053     const Tensor* num_layers_t = nullptr;
1054     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1055     if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1056       return errors::InvalidArgument("num_layers is not a scalar");
1057     }
1058     int num_layers = num_layers_t->scalar<int>()();
1059     const Tensor* num_units_t = nullptr;
1060     TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1061     if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1062       return errors::InvalidArgument("num_units is not a scalar");
1063     }
1064     int num_units = num_units_t->scalar<int>()();
1065     const Tensor* input_size_t = nullptr;
1066     TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1067     if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1068       return errors::InvalidArgument("input_size is not a scalar");
1069     }
1070     int input_size = input_size_t->scalar<int>()();
1071 
1072     int h_num_units = (num_proj == 0 ? num_units : num_proj);
1073     int c_num_units = (num_proj == 0 ? 0 : num_units);
1074 
1075     RnnInputMode input_mode;
1076     TF_RETURN_IF_ERROR(
1077         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1078 
1079     Stream* stream = context->op_device_context()->stream();
1080     // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1081     // random number generator, therefore set state_allocator to nullptr.
1082     const AlgorithmConfig algo_config;
1083     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1084         num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1085         /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1086         ToDataType<T>::value, algo_config, dropout(), seed(),
1087         /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1088     if (!rnn_desc_s.ok()) {
1089       return FromExecutorStatus(rnn_desc_s);
1090     }
1091     *rnn_desc = std::move(rnn_desc_s).value();
1092     return OkStatus();
1093   }
1094 
1095   template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1096   Status CreateRnnDescriptor(OpKernelContext* context,
1097                              const CudnnRnnModelShapes& model_shapes,
1098                              const RnnInputMode& input_mode,
1099                              const AlgorithmConfig& algo_config,
1100                              ScratchAllocator* dropout_state_allocator,
1101                              std::unique_ptr<RnnDescriptor>* rnn_desc,
1102                              bool use_padded_io) {
1103     StreamExecutor* executor = context->op_device_context()->stream()->parent();
1104     se::dnn::DataType data_type = ToDataType<T>::value;
1105     auto rnn_desc_s = executor->createRnnDescriptor(
1106         model_shapes.num_layers, model_shapes.num_units,
1107         model_shapes.input_size, model_shapes.cell_num_units,
1108         model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1109         data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1110         use_padded_io);
1111     TF_RETURN_IF_ERROR(rnn_desc_s.status());
1112 
1113     *rnn_desc = std::move(rnn_desc_s).value();
1114     return OkStatus();
1115   }
1116 
1117   using RnnStateCache = gtl::FlatMap<
1118       std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1119       RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1120   // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1121   // should outlive the returned pointer.
1122   template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1123   Status GetCachedRnnDescriptor(OpKernelContext* context,
1124                                 const CudnnRnnModelShapes& model_shapes,
1125                                 const RnnInputMode& input_mode,
1126                                 const AlgorithmConfig& algo_config,
1127                                 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1128                                 bool use_padded_io) {
1129     auto key = std::make_pair(model_shapes, algo_config.algorithm());
1130     RnnScratchSpace& rnn_state = (*cache)[key];
1131     if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1132       CudnnRNNSpaceAllocator* dropout_state_allocator =
1133           new CudnnRNNSpaceAllocator(context);
1134       rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1135       Status status = CreateRnnDescriptor<T>(
1136           context, model_shapes, input_mode, algo_config,
1137           dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1138       TF_RETURN_IF_ERROR(status);
1139     }
1140     *rnn_desc = rnn_state.rnn_desc.get();
1141     return OkStatus();
1142   }
1143 
1144  private:
1145   int seed_;
1146   int seed2_;
1147   float dropout_;
1148   bool reset_rnd_gen_state_;
1149 
1150   CudnnModelTypes model_types_;
1151 };
1152 
1153 // A class that returns the size of the opaque parameter buffer. The user should
1154 // use that to create the actual parameter buffer for training. However, it
1155 // should not be used for saving and restoring.
1156 template <typename T, typename Index>
1157 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1158  public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1159   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1160       : CudnnRNNKernelCommon(context) {
1161     if (context->HasAttr("num_proj")) {
1162       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1163     } else {
1164       num_proj_ = 0;
1165     }
1166   }
1167 
Compute(OpKernelContext * context)1168   void Compute(OpKernelContext* context) override {
1169     std::unique_ptr<RnnDescriptor> rnn_desc;
1170     OP_REQUIRES_OK(context,
1171                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1172     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1173     CHECK(params_size_in_bytes % sizeof(T) == 0)
1174         << "params_size_in_bytes must be multiple of element size";
1175     int64_t params_size = params_size_in_bytes / sizeof(T);
1176 
1177     Tensor* output_t = nullptr;
1178     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1179     *output_t->template flat<Index>().data() = params_size;
1180   }
1181 
1182  private:
1183   int num_proj_;
1184 };
1185 
1186 #define REGISTER_GPU(T)                                    \
1187   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize")       \
1188                               .Device(DEVICE_GPU)          \
1189                               .HostMemory("num_layers")    \
1190                               .HostMemory("num_units")     \
1191                               .HostMemory("input_size")    \
1192                               .HostMemory("params_size")   \
1193                               .TypeConstraint<T>("T")      \
1194                               .TypeConstraint<int32>("S"), \
1195                           CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1196 
1197 TF_CALL_half(REGISTER_GPU);
1198 TF_CALL_float(REGISTER_GPU);
1199 TF_CALL_double(REGISTER_GPU);
1200 #undef REGISTER_GPU
1201 
1202 // Convert weight and bias params from a platform-specific layout to the
1203 // canonical form.
1204 template <typename T>
1205 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1206  public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1207   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1208       : CudnnRNNKernelCommon(context) {
1209     if (context->HasAttr("num_params")) {
1210       OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1211     } else {
1212       num_params_ = 0;
1213     }
1214     if (context->HasAttr("num_params_weights")) {
1215       OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1216                                                &num_params_weights_));
1217     } else {
1218       num_params_weights_ = 0;
1219     }
1220     if (context->HasAttr("num_params_biases")) {
1221       OP_REQUIRES_OK(
1222           context, context->GetAttr("num_params_biases", &num_params_biases_));
1223     } else {
1224       num_params_biases_ = 0;
1225     }
1226     if (context->HasAttr("num_proj")) {
1227       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1228     } else {
1229       num_proj_ = 0;
1230     }
1231     if (num_proj_ == 0) {
1232       num_params_weights_ = num_params_;
1233       num_params_biases_ = num_params_;
1234     }
1235   }
1236 
Compute(OpKernelContext * context)1237   void Compute(OpKernelContext* context) override {
1238     const Tensor& input = context->input(3);
1239     auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1240     Stream* stream = context->op_device_context()->stream();
1241 
1242     std::unique_ptr<RnnDescriptor> rnn_desc;
1243     OP_REQUIRES_OK(context,
1244                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1245     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1246     CHECK(params_size_in_bytes % sizeof(T) == 0)
1247         << "params_size_in_bytes must be multiple of element size";
1248 
1249     const Tensor* num_units_t = nullptr;
1250     OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1251     CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1252         << "num_units is not a scalar";
1253     int num_units = num_units_t->scalar<int>()();
1254 
1255     const Tensor* input_size_t = nullptr;
1256     OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1257     CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1258         << "input_size is not a scalar";
1259     int input_size = input_size_t->scalar<int>()();
1260 
1261     const Tensor* num_layers_t = nullptr;
1262     OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1263     CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1264         << "num_layers is not a scalar";
1265     int num_layers = num_layers_t->scalar<int>()();
1266     int num_dirs = 1;
1267     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1268       num_dirs = 2;
1269     }
1270     const int num_params_weights_per_layer =
1271         num_params_weights_ / num_layers / num_dirs;
1272     // Number of params applied on inputs. The rest are applied on recurrent
1273     // hidden states.
1274     const int num_params_input_state = num_params_weights_per_layer / 2;
1275     OP_REQUIRES(
1276         context, num_params_weights_ % (num_layers * num_dirs) == 0,
1277         errors::InvalidArgument("Number of params (weights) is not a multiple"
1278                                 "of num_layers * num_dirs."));
1279     OP_REQUIRES(
1280         context, num_params_biases_ % (num_layers * num_dirs) == 0,
1281         errors::InvalidArgument("Number of params (biases) is not a multiple"
1282                                 "of num_layers * num_dirs."));
1283     if (num_proj_ == 0) {
1284       OP_REQUIRES(
1285           context, num_params_weights_per_layer % 2 == 0,
1286           errors::InvalidArgument("Number of params (weights) per layer is not"
1287                                   "an even number with no projection."));
1288     } else {
1289       OP_REQUIRES(
1290           context, num_params_weights_per_layer % 2 != 0,
1291           errors::InvalidArgument("Number of params (weights) per layer is not"
1292                                   "an odl number with projection."));
1293     }
1294 
1295     OP_REQUIRES(
1296         context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1297         errors::InvalidArgument("C Number of params mismatch. Expected ",
1298                                 num_params_weights_, ", got ",
1299                                 rnn_desc->ParamsWeightRegions().size()));
1300     int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1301     int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1302     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1303       int64_t size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1304       int64_t size = size_in_bytes / sizeof(T);
1305       const int layer_idx = i / num_params_weights_per_layer;
1306       const int index_within_layer = i % num_params_weights_per_layer;
1307       int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1308       // In CuDNN layout, each layer has num_params_weights_per_layer params,
1309       // with the
1310       // first half a.k.a num_params_input_state params applied on the inputs,
1311       // and the second half on the recurrent hidden states.
1312       bool apply_on_input_state = index_within_layer < num_params_input_state;
1313       if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1314         if (layer_idx == 0 && apply_on_input_state) {
1315           width = input_size;
1316         } else {
1317           width = h_num_units;
1318         }
1319       } else {
1320         if (apply_on_input_state) {
1321           if (layer_idx <= 1) {
1322             // First fwd or bak layer.
1323             width = input_size;
1324           } else {
1325             // Following layers, cell inputs are concatenated outputs of
1326             // its prior layer.
1327             width = 2 * h_num_units;
1328           }
1329         } else {
1330           width = h_num_units;
1331         }
1332       }
1333       CHECK(size == width * height) << "Params size mismatch. Expected "
1334                                     << width * height << ", got " << size;
1335       Tensor* output = nullptr;
1336       int id_in_layer = i % num_params_weights_per_layer;
1337       if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1338         std::swap(height, width);
1339       }
1340       OP_REQUIRES_OK(context, context->allocate_output(
1341                                   i, TensorShape({height, width}), &output));
1342       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1343           input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1344       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1345       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1346     }
1347 
1348     OP_REQUIRES(
1349         context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1350         errors::InvalidArgument("A Number of params mismatch. Expected ",
1351                                 num_params_biases_, ", got ",
1352                                 rnn_desc->ParamsBiasRegions().size()));
1353     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1354       int64_t size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1355       int64_t size = size_in_bytes / sizeof(T);
1356       OP_REQUIRES(context, size == num_units,
1357                   errors::InvalidArgument("Params size mismatch. Expected ",
1358                                           num_units, ", got ", size));
1359 
1360       Tensor* output = nullptr;
1361       OP_REQUIRES_OK(context,
1362                      context->allocate_output(num_params_weights_ + i,
1363                                               TensorShape({size}), &output));
1364       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1365           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1366       auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1367       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1368     }
1369   }
1370 
1371  private:
1372   int num_params_;
1373   int num_params_weights_;
1374   int num_params_biases_;
1375   int num_proj_;
1376 };
1377 
1378 #define REGISTER_GPU(T)                                     \
1379   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1380                               .Device(DEVICE_GPU)           \
1381                               .HostMemory("num_layers")     \
1382                               .HostMemory("num_units")      \
1383                               .HostMemory("input_size")     \
1384                               .TypeConstraint<T>("T"),      \
1385                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1386 TF_CALL_half(REGISTER_GPU);
1387 TF_CALL_float(REGISTER_GPU);
1388 TF_CALL_double(REGISTER_GPU);
1389 #undef REGISTER_GPU
1390 
1391 #define REGISTER_GPU(T)                                       \
1392   REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1393                               .Device(DEVICE_GPU)             \
1394                               .HostMemory("num_layers")       \
1395                               .HostMemory("num_units")        \
1396                               .HostMemory("input_size")       \
1397                               .TypeConstraint<T>("T"),        \
1398                           CudnnRNNParamsToCanonical<GPUDevice, T>);
1399 TF_CALL_half(REGISTER_GPU);
1400 TF_CALL_float(REGISTER_GPU);
1401 TF_CALL_double(REGISTER_GPU);
1402 #undef REGISTER_GPU
1403 
1404 // Convert weight and bias params from the canonical form to a
1405 // platform-specific layout.
1406 template <typename T>
1407 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1408  public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1409   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1410       : CudnnRNNKernelCommon(context) {
1411     if (context->HasAttr("num_proj")) {
1412       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1413     } else {
1414       num_proj_ = 0;
1415     }
1416   }
1417 
Compute(OpKernelContext * context)1418   void Compute(OpKernelContext* context) override {
1419     std::unique_ptr<RnnDescriptor> rnn_desc;
1420     OP_REQUIRES_OK(context,
1421                    ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1422     int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1423     CHECK(params_size_in_bytes % sizeof(T) == 0)
1424         << "params_size_in_bytes must be multiple of element size";
1425     Tensor* output = nullptr;
1426     int params_size = params_size_in_bytes / sizeof(T);
1427     OP_REQUIRES_OK(context,
1428                    context->allocate_output(0, {params_size}, &output));
1429     auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1430     Stream* stream = context->op_device_context()->stream();
1431 
1432     OpInputList weights;
1433     OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1434     RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1435                      stream);
1436 
1437     OpInputList biases;
1438     OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1439     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1440                      stream);
1441   }
1442 
1443  private:
1444   int num_proj_;
1445 };
1446 
1447 #define REGISTER_GPU(T)                                     \
1448   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1449                               .Device(DEVICE_GPU)           \
1450                               .HostMemory("num_layers")     \
1451                               .HostMemory("num_units")      \
1452                               .HostMemory("input_size")     \
1453                               .TypeConstraint<T>("T"),      \
1454                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1455 TF_CALL_half(REGISTER_GPU);
1456 TF_CALL_float(REGISTER_GPU);
1457 TF_CALL_double(REGISTER_GPU);
1458 #undef REGISTER_GPU
1459 
1460 #define REGISTER_GPU(T)                                       \
1461   REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1462                               .Device(DEVICE_GPU)             \
1463                               .HostMemory("num_layers")       \
1464                               .HostMemory("num_units")        \
1465                               .HostMemory("input_size")       \
1466                               .TypeConstraint<T>("T"),        \
1467                           CudnnRNNCanonicalToParams<GPUDevice, T>);
1468 TF_CALL_half(REGISTER_GPU);
1469 TF_CALL_float(REGISTER_GPU);
1470 TF_CALL_double(REGISTER_GPU);
1471 #undef REGISTER_GPU
1472 
1473 // Run the forward operation of the RNN model.
1474 template <typename T>
1475 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1476  public:
CudnnRNNForwardOp(OpKernelConstruction * context)1477   explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1478       : CudnnRNNKernelCommon(context) {
1479     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1480 
1481     // Read debug env variables.
1482     is_debug_mode_ = DebugCudnnRnn();
1483     debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1484     debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1485   }
1486 
Compute(OpKernelContext * context)1487   void Compute(OpKernelContext* context) override {
1488     AlgorithmConfig algo_config;
1489     ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1490                               /*time_major=*/true, /*num_proj=*/0);
1491   }
1492 
1493  protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1494   virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1495                                          AlgorithmConfig* output_algo_config,
1496                                          bool var_seq_lengths, bool time_major,
1497                                          int num_proj) {
1498     CHECK_NE(output_algo_config, nullptr);
1499 
1500     const Tensor* input = nullptr;
1501     const Tensor* input_h = nullptr;
1502     const Tensor* input_c = nullptr;
1503     const Tensor* params = nullptr;
1504     const Tensor* sequence_lengths = nullptr;
1505     CudnnRnnModelShapes model_shapes;
1506     bool use_padded_io = false;
1507     if (var_seq_lengths) {
1508       OP_REQUIRES_OK(context, ExtractForwardInput(
1509                                   context, model_types(), time_major, &input,
1510                                   &input_h, &input_c, &params,
1511                                   &sequence_lengths, num_proj, &model_shapes));
1512       use_padded_io =
1513           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1514     } else {
1515       OP_REQUIRES_OK(context,
1516                      ExtractForwardInput(context, model_types(), time_major,
1517                                          &input, &input_h, &input_c, &params,
1518                                          num_proj, &model_shapes));
1519     }
1520     RnnInputMode input_mode;
1521     OP_REQUIRES_OK(context,
1522                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1523                                   model_shapes.input_size, &input_mode));
1524 
1525     Tensor* output = nullptr;
1526     Tensor* output_h = nullptr;
1527     Tensor* output_c = nullptr;
1528     OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1529                                             &output_h, &output_c));
1530 
1531     // Creates a memory callback for the reserve_space. The memory lives in the
1532     // output of this kernel. And it will be fed into the backward pass when
1533     // needed.
1534     CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1535     // Creates a memory callback for the workspace. The memory lives to the end
1536     // of this kernel calls.
1537     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1538 
1539     if (is_debug_mode_) {
1540       AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_,
1541                               absl::nullopt);
1542       output_algo_config->set_algorithm(algo_desc);
1543     } else {
1544       OP_REQUIRES_OK(context,
1545                      MaybeAutotune(context, model_shapes, input_mode, input,
1546                                    input_h, input_c, params, output, output_h,
1547                                    output_c, output_algo_config));
1548     }
1549 
1550     Status launch_status;
1551     {
1552       mutex_lock l(mu_);
1553       RnnDescriptor* rnn_desc_ptr = nullptr;
1554       OP_REQUIRES_OK(context,
1555                      GetCachedRnnDescriptor<T>(
1556                          context, model_shapes, input_mode, *output_algo_config,
1557                          &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1558       launch_status = DoForward<T>(
1559           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1560           input_c, params, is_training_, output, output_h, output_c,
1561           sequence_lengths, time_major, &reserve_space_allocator,
1562           &workspace_allocator, /*output_profile_result=*/nullptr);
1563     }
1564     OP_REQUIRES_OK(context, launch_status);
1565   }
1566 
1567  protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1568   virtual Status MaybeAutotune(OpKernelContext* context,
1569                                const CudnnRnnModelShapes& model_shapes,
1570                                const RnnInputMode& input_mode,
1571                                const Tensor* input, const Tensor* input_h,
1572                                const Tensor* input_c, const Tensor* params,
1573                                Tensor* output, Tensor* output_h,
1574                                Tensor* output_c,
1575                                AlgorithmConfig* best_algo_config) {
1576     CHECK_NE(best_algo_config, nullptr);
1577     *best_algo_config = AlgorithmConfig();
1578     return OkStatus();
1579   }
1580 
is_training() const1581   bool is_training() const { return is_training_; }
1582   bool is_debug_mode_;
1583   bool debug_use_tensor_ops_;
1584   int64_t debug_cudnn_rnn_algo_;
1585 
1586  private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1587   Status AllocateOutputs(OpKernelContext* context,
1588                          const CudnnRnnModelShapes& model_shapes,
1589                          Tensor** output, Tensor** output_h,
1590                          Tensor** output_c) {
1591     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1592     const TensorShape& output_shape = model_shapes.output_shape;
1593     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1594 
1595     TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1596     TF_RETURN_IF_ERROR(
1597         context->allocate_output(1, hidden_state_shape, output_h));
1598     if (HasInputC()) {
1599       TF_RETURN_IF_ERROR(
1600           context->allocate_output(2, cell_state_shape, output_c));
1601     } else {
1602       // Only LSTM uses input_c and output_c. So for all other models, we only
1603       // need to create dummy outputs.
1604       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1605     }
1606     if (!is_training_) {
1607       Tensor* dummy_reserve_space = nullptr;
1608       TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1609     }
1610     return OkStatus();
1611   }
1612 
1613   mutex mu_;
1614   bool is_training_;
1615   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1616 };
1617 
1618 #define REGISTER_GPU(T)                                           \
1619   REGISTER_KERNEL_BUILDER(                                        \
1620       Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1621       CudnnRNNForwardOp<GPUDevice, T>);
1622 
1623 TF_CALL_half(REGISTER_GPU);
1624 TF_CALL_float(REGISTER_GPU);
1625 TF_CALL_double(REGISTER_GPU);
1626 #undef REGISTER_GPU
1627 
1628 template <typename T>
1629 class CudnnRNNForwardOpV2<GPUDevice, T>
1630     : public CudnnRNNForwardOp<GPUDevice, T> {
1631  private:
1632   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1633   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1634   using CudnnRNNKernelCommon::dropout;
1635   using CudnnRNNKernelCommon::HasInputC;
1636   using CudnnRNNKernelCommon::model_types;
1637 
1638  public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1639   explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1640       : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1641 
Compute(OpKernelContext * context)1642   void Compute(OpKernelContext* context) override {
1643     AlgorithmConfig best_algo_config;
1644     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1645         context, &best_algo_config, /*var_seq_lengths=*/false,
1646         /*time_major=*/true, /*num_proj=*/0);
1647     if (!context->status().ok()) {
1648       return;
1649     }
1650 
1651     Tensor* output_host_reserved = nullptr;
1652     // output_host_reserved stores opaque info used for backprop when running
1653     // in training mode. At present, it includes a serialization of the best
1654     // AlgorithmDesc picked during rnn forward pass autotune.
1655     // int8 algorithm_id
1656     // int8 use_tensor_op
1657     // If autotune is not enabled, the algorithm_id is
1658     // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1659     // running in inference mode, the output_host_reserved is currently not
1660     // populated.
1661     if (is_training()) {
1662       OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1663                                                        &output_host_reserved));
1664       auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1665       output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1666       output_host_reserved_int8(1) =
1667           best_algo_config.algorithm()->tensor_ops_enabled();
1668     } else {
1669       OP_REQUIRES_OK(context,
1670                      context->allocate_output(4, {}, &output_host_reserved));
1671     }
1672   }
1673 
1674  protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1675   Status MaybeAutotune(OpKernelContext* context,
1676                        const CudnnRnnModelShapes& model_shapes,
1677                        const RnnInputMode& input_mode, const Tensor* input,
1678                        const Tensor* input_h, const Tensor* input_c,
1679                        const Tensor* params, Tensor* output, Tensor* output_h,
1680                        Tensor* output_c,
1681                        AlgorithmConfig* algo_config) override {
1682     CHECK_NE(algo_config, nullptr);
1683     if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1684       *algo_config = AlgorithmConfig();
1685       return OkStatus();
1686     }
1687 
1688     std::vector<AlgorithmDesc> algorithms;
1689     auto* stream = context->op_device_context()->stream();
1690     CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1691     if (algorithms.empty()) {
1692       LOG(WARNING) << "No Rnn algorithm found";
1693       return OkStatus();
1694     }
1695 
1696     const auto& modeltypes = model_types();
1697     CudnnRnnParameters rnn_params(
1698         model_shapes.num_layers, model_shapes.input_size,
1699         model_shapes.num_units, model_shapes.max_seq_length,
1700         model_shapes.batch_size, model_shapes.dir_count,
1701         /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1702         modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1703 
1704     if (AutotuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1705       VLOG(1) << "Using existing best Cudnn RNN algorithm "
1706               << "(algo, tensor_op_enabled) = ("
1707               << algo_config->algorithm()->algo_id() << ", "
1708               << algo_config->algorithm()->tensor_ops_enabled() << ").";
1709       return OkStatus();
1710     }
1711     profiler::ScopedAnnotation trace("cudnn_autotuning");
1712 
1713     // Create temp tensors when profiling backprop pass.
1714     auto data_type = input->dtype();
1715     Tensor output_backprop;
1716     Tensor output_h_backprop;
1717     Tensor output_c_backprop;
1718     Tensor input_backprop;
1719     Tensor input_h_backprop;
1720     Tensor input_c_backprop;
1721     Tensor params_backprop;
1722     if (is_training()) {
1723       TF_RETURN_IF_ERROR(context->allocate_temp(
1724           data_type, model_shapes.output_shape, &output_backprop));
1725       TF_RETURN_IF_ERROR(context->allocate_temp(
1726           data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1727 
1728       TF_RETURN_IF_ERROR(
1729           context->allocate_temp(data_type, params->shape(), &params_backprop));
1730       TF_RETURN_IF_ERROR(context->allocate_temp(
1731           data_type, model_shapes.input_shape, &input_backprop));
1732       TF_RETURN_IF_ERROR(context->allocate_temp(
1733           data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1734       if (HasInputC()) {
1735         TF_RETURN_IF_ERROR(context->allocate_temp(
1736             data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1737         TF_RETURN_IF_ERROR(context->allocate_temp(
1738             data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1739       }
1740     }
1741     ProfileResult best_result;
1742     for (auto& algo : algorithms) {
1743       VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1744               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1745       Status status;
1746       ProfileResult final_profile_result;
1747 
1748       ProfileResult fwd_profile_result;
1749       ProfileResult bak_profile_result;
1750 
1751       // RnnDescriptor is algorithm-dependent, thus not reusable.
1752       std::unique_ptr<RnnDescriptor> rnn_desc;
1753       // Use a temp scratch allocator for the random num generator.
1754       CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1755       if (!this->template CreateRnnDescriptor<T>(
1756                    context, model_shapes, input_mode, AlgorithmConfig(algo),
1757                    &dropout_state_allocator, &rnn_desc,
1758                    /*use_padded_io=*/false)
1759                .ok()) {
1760         continue;
1761       }
1762 
1763       // Again use temp scratch allocator during profiling.
1764       CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1765       CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1766       status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1767                             input, input_h, input_c, params, is_training(),
1768                             output, output_h, output_c, nullptr, true,
1769                             &reserve_space_allocator, &workspace_allocator,
1770                             &fwd_profile_result);
1771       if (!status.ok()) {
1772         continue;
1773       }
1774 
1775       if (is_training()) {
1776         // Get reserve space from the forward pass.
1777         Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1778         status = DoBackward<T>(
1779             context, *rnn_desc, model_types(), model_shapes, input, input_h,
1780             input_c, params, output, output_h, output_c, &output_backprop,
1781             &output_h_backprop, &output_c_backprop, &reserve_space,
1782             &input_backprop, &input_h_backprop, &input_c_backprop,
1783             &params_backprop, nullptr, true, &workspace_allocator,
1784             &bak_profile_result);
1785         if (!status.ok()) {
1786           continue;
1787         }
1788         final_profile_result.set_elapsed_time_in_ms(
1789             fwd_profile_result.elapsed_time_in_ms() +
1790             bak_profile_result.elapsed_time_in_ms());
1791       } else {
1792         final_profile_result = fwd_profile_result;
1793       }
1794 
1795       auto total_time = final_profile_result.elapsed_time_in_ms();
1796       VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1797               << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1798               << " run time: " << total_time << " ms.";
1799       if (total_time < best_result.elapsed_time_in_ms()) {
1800         best_result.set_elapsed_time_in_ms(total_time);
1801         best_result.set_algorithm(algo);
1802       }
1803     }
1804 
1805     if (!best_result.is_valid()) {
1806       return Status(error::Code::INTERNAL, "No algorithm worked!");
1807     }
1808     algo_config->set_algorithm(best_result.algorithm());
1809     VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) =  ("
1810             << best_result.algorithm().algo_id() << ", "
1811             << best_result.algorithm().tensor_ops_enabled() << ").";
1812     AutotuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1813     return OkStatus();
1814   }
1815 };
1816 
1817 #define REGISTER_GPU(T)                                    \
1818   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2")               \
1819                               .Device(DEVICE_GPU)          \
1820                               .HostMemory("host_reserved") \
1821                               .TypeConstraint<T>("T"),     \
1822                           CudnnRNNForwardOpV2<GPUDevice, T>);
1823 
1824 TF_CALL_half(REGISTER_GPU);
1825 TF_CALL_float(REGISTER_GPU);
1826 TF_CALL_double(REGISTER_GPU);
1827 #undef REGISTER_GPU
1828 
1829 template <typename T>
1830 class CudnnRNNForwardOpV3<GPUDevice, T>
1831     : public CudnnRNNForwardOp<GPUDevice, T> {
1832  private:
1833   using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1834   using CudnnRNNKernelCommon::CreateRnnDescriptor;
1835   using CudnnRNNKernelCommon::dropout;
1836   using CudnnRNNKernelCommon::HasInputC;
1837   using CudnnRNNKernelCommon::model_types;
1838   bool time_major_;
1839 
1840  protected:
time_major()1841   bool time_major() { return time_major_; }
1842 
1843  public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1844   explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1845       : CudnnRNNForwardOp<GPUDevice, T>(context) {
1846     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1847     if (context->HasAttr("num_proj")) {
1848       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1849     } else {
1850       num_proj_ = 0;
1851     }
1852   }
1853 
Compute(OpKernelContext * context)1854   void Compute(OpKernelContext* context) override {
1855     AlgorithmConfig best_algo_config;
1856     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1857         context, &best_algo_config, /*var_seq_lengths=*/true,
1858         /*time_major=*/time_major(), num_proj_);
1859     if (!context->status().ok()) {
1860       return;
1861     }
1862 
1863     Tensor* output_host_reserved = nullptr;
1864     // TODO: Current V3 only uses the default standard algorithm to process
1865     // batches with variable sequences and the inputs should be padded.
1866     // Autotune is not supported yet.
1867     OP_REQUIRES_OK(context,
1868                    context->allocate_output(4, {}, &output_host_reserved));
1869   }
1870 
1871  private:
1872   int num_proj_;
1873 };
1874 
1875 #define REGISTER_GPU(T)                                       \
1876   REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3")                  \
1877                               .Device(DEVICE_GPU)             \
1878                               .HostMemory("sequence_lengths") \
1879                               .HostMemory("host_reserved")    \
1880                               .TypeConstraint<T>("T"),        \
1881                           CudnnRNNForwardOpV3<GPUDevice, T>);
1882 
1883 TF_CALL_half(REGISTER_GPU);
1884 TF_CALL_float(REGISTER_GPU);
1885 TF_CALL_double(REGISTER_GPU);
1886 #undef REGISTER_GPU
1887 
1888 // Run the backward operation of the RNN model.
1889 template <typename T>
1890 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1891  public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1892   explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1893       : CudnnRNNKernelCommon(context) {}
1894 
Compute(OpKernelContext * context)1895   void Compute(OpKernelContext* context) override {
1896     ComputeImpl(context, false, true, 0);
1897   }
1898 
1899  protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1900   virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1901                            bool time_major, int num_proj) {
1902     const Tensor* input = nullptr;
1903     const Tensor* input_h = nullptr;
1904     const Tensor* input_c = nullptr;
1905     const Tensor* params = nullptr;
1906     const Tensor* sequence_lengths = nullptr;
1907     CudnnRnnModelShapes model_shapes;
1908     bool use_padded_io = false;
1909     if (var_seq_lengths) {
1910       OP_REQUIRES_OK(context, ExtractForwardInput(
1911                                   context, model_types(), time_major, &input,
1912                                   &input_h, &input_c, &params,
1913                                   &sequence_lengths, num_proj, &model_shapes));
1914       use_padded_io =
1915           ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1916     } else {
1917       OP_REQUIRES_OK(context,
1918                      ExtractForwardInput(context, model_types(), time_major,
1919                                          &input, &input_h, &input_c, &params,
1920                                          num_proj, &model_shapes));
1921     }
1922     RnnInputMode input_mode;
1923     OP_REQUIRES_OK(context,
1924                    ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1925                                   model_shapes.input_size, &input_mode));
1926 
1927     const Tensor* output = nullptr;
1928     const Tensor* output_h = nullptr;
1929     const Tensor* output_c = nullptr;
1930     const Tensor* output_backprop = nullptr;
1931     const Tensor* output_h_backprop = nullptr;
1932     const Tensor* output_c_backprop = nullptr;
1933     const Tensor* reserve_space = nullptr;
1934     OP_REQUIRES_OK(context,
1935                    ExtractBackwardInputs(context, model_shapes, model_types(),
1936                                          &output, &output_h, &output_c,
1937                                          &output_backprop, &output_h_backprop,
1938                                          &output_c_backprop, &reserve_space));
1939 
1940     Tensor* input_backprop = nullptr;
1941     Tensor* input_h_backprop = nullptr;
1942     Tensor* input_c_backprop = nullptr;
1943     Tensor* params_backprop = nullptr;
1944     OP_REQUIRES_OK(context,
1945                    AllocateOutputs(context, model_shapes, params->shape(),
1946                                    &input_backprop, &input_h_backprop,
1947                                    &input_c_backprop, &params_backprop));
1948 
1949     // Creates a memory callback for the workspace. The memory lives to the end
1950     // of this kernel calls.
1951     CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1952     AlgorithmConfig algo_config;
1953     OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1954     Status launch_status;
1955     {
1956       mutex_lock l(mu_);
1957       RnnDescriptor* rnn_desc_ptr = nullptr;
1958       OP_REQUIRES_OK(
1959           context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1960                                              algo_config, &rnn_state_cache_,
1961                                              &rnn_desc_ptr, use_padded_io));
1962       launch_status = DoBackward<T>(
1963           context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1964           input_c, params, output, output_h, output_c, output_backprop,
1965           output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1966           input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1967           time_major, &workspace_allocator,
1968           /*output_profile_result=*/nullptr);
1969     }
1970     OP_REQUIRES_OK(context, launch_status);
1971   }
1972 
1973  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1974   virtual Status GetAlgorithm(OpKernelContext* context,
1975                               AlgorithmConfig* algo_config) {
1976     CHECK_NE(algo_config, nullptr);
1977     *algo_config = AlgorithmConfig();
1978     return OkStatus();
1979   }
1980 
1981  private:
1982   mutex mu_;
1983   RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1984 
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1985   Status ExtractBackwardInputs(
1986       OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1987       const CudnnModelTypes& model_types, const Tensor** output,
1988       const Tensor** output_h, const Tensor** output_c,
1989       const Tensor** output_backprop, const Tensor** output_h_backprop,
1990       const Tensor** output_c_backprop, const Tensor** reserve_space) {
1991     TF_RETURN_IF_ERROR(context->input("output", output));
1992     TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1993     TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1994     TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1995     if (model_types.HasInputC()) {
1996       TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1997       TF_RETURN_IF_ERROR(
1998           context->input("output_c_backprop", output_c_backprop));
1999     }
2000     TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
2001     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2002     const TensorShape& output_shape = model_shapes.output_shape;
2003     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2004 
2005     if (output_shape != (*output)->shape()) {
2006       return errors::InvalidArgument(
2007           "Invalid output shape: ", (*output)->shape().DebugString(), " ",
2008           output_shape.DebugString());
2009     }
2010     if (hidden_state_shape != (*output_h)->shape()) {
2011       return errors::InvalidArgument(
2012           "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
2013           hidden_state_shape.DebugString());
2014     }
2015 
2016     if (output_shape != (*output_backprop)->shape()) {
2017       return errors::InvalidArgument("Invalid output_backprop shape: ",
2018                                      (*output_backprop)->shape().DebugString(),
2019                                      " ", output_shape.DebugString());
2020     }
2021     if (hidden_state_shape != (*output_h_backprop)->shape()) {
2022       return errors::InvalidArgument(
2023           "Invalid output_h_backprop shape: ",
2024           (*output_h_backprop)->shape().DebugString(), " ",
2025           hidden_state_shape.DebugString());
2026     }
2027 
2028     if (model_types.HasInputC()) {
2029       if (cell_state_shape != (*output_c)->shape()) {
2030         return errors::InvalidArgument(
2031             "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
2032             cell_state_shape.DebugString());
2033       }
2034       if (cell_state_shape != (*output_c_backprop)->shape()) {
2035         return errors::InvalidArgument(
2036             "Invalid output_c_backprop shape: ",
2037             (*output_c_backprop)->shape().DebugString(), " ",
2038             cell_state_shape.DebugString());
2039       }
2040     }
2041     return OkStatus();
2042   }
2043 
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2044   Status AllocateOutputs(OpKernelContext* context,
2045                          const CudnnRnnModelShapes& model_shapes,
2046                          const TensorShape& params_shape,
2047                          Tensor** input_backprop, Tensor** input_h_backprop,
2048                          Tensor** input_c_backprop, Tensor** params_backprop) {
2049     const TensorShape& input_shape = model_shapes.input_shape;
2050     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2051     const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2052 
2053     TF_RETURN_IF_ERROR(
2054         context->allocate_output(0, input_shape, input_backprop));
2055     TF_RETURN_IF_ERROR(
2056         context->allocate_output(1, hidden_state_shape, input_h_backprop));
2057     if (HasInputC()) {
2058       TF_RETURN_IF_ERROR(
2059           context->allocate_output(2, cell_state_shape, input_c_backprop));
2060     } else {
2061       // Only LSTM uses input_c and output_c. So for all other models, we only
2062       // need to create dummy outputs.
2063       TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2064     }
2065     TF_RETURN_IF_ERROR(
2066         context->allocate_output(3, params_shape, params_backprop));
2067     return OkStatus();
2068   }
2069 };
2070 
2071 #define REGISTER_GPU(T)                                                   \
2072   REGISTER_KERNEL_BUILDER(                                                \
2073       Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2074       CudnnRNNBackwardOp<GPUDevice, T>);
2075 
2076 TF_CALL_half(REGISTER_GPU);
2077 TF_CALL_float(REGISTER_GPU);
2078 TF_CALL_double(REGISTER_GPU);
2079 #undef REGISTER_GPU
2080 
2081 template <typename T>
2082 class CudnnRNNBackwardOpV2<GPUDevice, T>
2083     : public CudnnRNNBackwardOp<GPUDevice, T> {
2084  public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2085   explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2086       : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2087 
2088  protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2089   Status GetAlgorithm(OpKernelContext* context,
2090                       AlgorithmConfig* algo_config) override {
2091     CHECK_NE(algo_config, nullptr);
2092     const Tensor* host_reserved = nullptr;
2093     TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2094 
2095     auto host_reserved_int8 = host_reserved->vec<int8>();
2096     const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1),
2097                                   absl::nullopt);
2098     algo_config->set_algorithm(algo_desc);
2099     return OkStatus();
2100   }
2101 };
2102 
2103 #define REGISTER_GPU(T)                                    \
2104   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2")       \
2105                               .Device(DEVICE_GPU)          \
2106                               .HostMemory("host_reserved") \
2107                               .TypeConstraint<T>("T"),     \
2108                           CudnnRNNBackwardOpV2<GPUDevice, T>);
2109 
2110 TF_CALL_half(REGISTER_GPU);
2111 TF_CALL_float(REGISTER_GPU);
2112 TF_CALL_double(REGISTER_GPU);
2113 #undef REGISTER_GPU
2114 
2115 template <typename T>
2116 class CudnnRNNBackwardOpV3<GPUDevice, T>
2117     : public CudnnRNNBackwardOp<GPUDevice, T> {
2118  private:
2119   bool time_major_;
2120 
2121  protected:
time_major()2122   bool time_major() { return time_major_; }
2123 
2124  public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2125   explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2126       : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2127     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2128     if (context->HasAttr("num_proj")) {
2129       OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2130     } else {
2131       num_proj_ = 0;
2132     }
2133   }
2134 
Compute(OpKernelContext * context)2135   void Compute(OpKernelContext* context) override {
2136     CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2137                                                   num_proj_);
2138   }
2139 
2140  private:
2141   int num_proj_;
2142 };
2143 
2144 #define REGISTER_GPU(T)                                       \
2145   REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3")          \
2146                               .Device(DEVICE_GPU)             \
2147                               .HostMemory("sequence_lengths") \
2148                               .HostMemory("host_reserved")    \
2149                               .TypeConstraint<T>("T"),        \
2150                           CudnnRNNBackwardOpV3<GPUDevice, T>);
2151 
2152 TF_CALL_half(REGISTER_GPU);
2153 TF_CALL_float(REGISTER_GPU);
2154 TF_CALL_double(REGISTER_GPU);
2155 #undef REGISTER_GPU
2156 
2157 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2158 // its canonical form.
2159 
2160 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2161 
2162 }  // namespace tensorflow
2163