1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #define EIGEN_USE_THREADS
16
17 #include <stddef.h>
18
19 #include <atomic>
20 #include <cmath>
21 #include <functional>
22 #include <limits>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/device_base.h"
29 #include "tensorflow/core/framework/kernel_def_builder.h"
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_types.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/kernels/gpu_utils.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/core/stringpiece.h"
42 #include "tensorflow/core/lib/gtl/inlined_vector.h"
43 #include "tensorflow/core/lib/hash/hash.h"
44 #include "tensorflow/core/lib/strings/stringprintf.h"
45 #include "tensorflow/core/platform/fingerprint.h"
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/types.h"
48 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
49 #include "tensorflow/core/util/env_var.h"
50 #include "tensorflow/core/util/use_cudnn.h"
51
52 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53 #include "tensorflow/core/platform/stream_executor.h"
54 #include "tensorflow/core/util/stream_executor_util.h"
55 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
56
57 /*
58 * This module implements ops that fuse a multi-layer multi-step RNN/LSTM model
59 * using the underlying Cudnn library.
60 *
61 * Cudnn RNN library exposes an opaque parameter buffer with unknown layout and
62 * format. And it is very likely that if saved, they cannot be used across
63 * different GPUs. So users need to first query the size of the opaque
64 * parameter buffer, and convert it to and from its canonical forms. But each
65 * actual training step is carried out with the parameter buffer.
66 *
67 * Similar to many other ops, the forward op has two flavors: training and
68 * inference. When training is specified, additional data in reserve_space will
69 * be produced for the backward pass. So there is a performance penalty.
70 *
71 * In addition to the actual data and reserve_space, Cudnn also needs more
72 * memory as temporary workspace. The memory management to and from
73 * stream-executor is done through ScratchAllocator. In general,
74 * stream-executor is responsible for creating the memory of proper size. And
75 * TensorFlow is responsible for making sure the memory is alive long enough
76 * and recycles afterwards.
77 *
78 */
79 namespace tensorflow {
80
81 using CPUDevice = Eigen::ThreadPoolDevice;
82
83 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
84
85 using GPUDevice = Eigen::GpuDevice;
86 using se::Stream;
87 using se::StreamExecutor;
88 using se::dnn::RnnDescriptor;
89
90 template <typename Device, typename T, typename Index>
91 class CudnnRNNParamsSizeOp;
92
93 template <typename Device, typename T>
94 class CudnnRNNParamsToCanonical;
95
96 template <typename Device, typename T>
97 class CudnnRNNCanonicalToParams;
98
99 template <typename Device, typename T>
100 class CudnnRNNForwardOp;
101
102 template <typename Device, typename T>
103 class CudnnRNNBackwardOp;
104
105 template <typename Device, typename T>
106 class CudnnRNNForwardOpV2;
107
108 template <typename Device, typename T>
109 class CudnnRNNBackwardOpV2;
110
111 template <typename Device, typename T>
112 class CudnnRNNForwardOpV3;
113
114 template <typename Device, typename T>
115 class CudnnRNNBackwardOpV3;
116
117 enum class TFRNNInputMode {
118 kRNNLinearInput = 0,
119 kRNNSkipInput = 1,
120 kAutoSelect = 9999999
121 };
122
123 namespace {
124 using se::DeviceMemory;
125 using se::DeviceMemoryBase;
126 using se::ScratchAllocator;
127 using se::dnn::AlgorithmConfig;
128 using se::dnn::AlgorithmDesc;
129 using se::dnn::ProfileResult;
130 using se::dnn::RnnDirectionMode;
131 using se::dnn::RnnInputMode;
132 using se::dnn::RnnMode;
133 using se::dnn::RnnSequenceTensorDescriptor;
134 using se::dnn::RnnStateTensorDescriptor;
135 using se::dnn::ToDataType;
136 using se::port::StatusOr;
137
HashList(const std::vector<int> & list)138 uint64 HashList(const std::vector<int>& list) {
139 if (list.empty()) {
140 return 0;
141 }
142 uint64 hash_code = list[0];
143 for (int i = 1; i < list.size(); i++) {
144 hash_code = Hash64Combine(hash_code, list[i]);
145 }
146 return hash_code;
147 }
148
149 // Encapsulate all the shape information that is used in both forward and
150 // backward rnn operations.
151 class CudnnRnnParameters {
152 public:
CudnnRnnParameters(int num_layers,int input_size,int num_units,int max_seq_length,int batch_size,int dir_count,bool has_dropout,bool is_training,RnnMode rnn_mode,TFRNNInputMode rnn_input_mode,DataType dtype)153 CudnnRnnParameters(int num_layers, int input_size, int num_units,
154 int max_seq_length, int batch_size, int dir_count,
155 bool has_dropout, bool is_training, RnnMode rnn_mode,
156 TFRNNInputMode rnn_input_mode, DataType dtype)
157 : num_layers_(num_layers),
158 input_size_(input_size),
159 num_units_(num_units),
160 seq_length_(max_seq_length),
161 batch_size_(batch_size),
162 dir_count_(dir_count),
163 has_dropout_(has_dropout),
164 is_training_(is_training),
165 rnn_mode_(rnn_mode),
166 rnn_input_mode_(rnn_input_mode),
167 dtype_(dtype) {
168 hash_code_ =
169 HashList({num_layers, input_size, num_units, max_seq_length, batch_size,
170 dir_count, static_cast<int>(has_dropout),
171 static_cast<int>(is_training), static_cast<int>(rnn_mode),
172 static_cast<int>(rnn_input_mode), dtype});
173 }
174
operator ==(const CudnnRnnParameters & other) const175 bool operator==(const CudnnRnnParameters& other) const {
176 return this->get_data_as_tuple() == other.get_data_as_tuple();
177 }
178
operator !=(const CudnnRnnParameters & other) const179 bool operator!=(const CudnnRnnParameters& other) const {
180 return !(*this == other);
181 }
hash() const182 uint64 hash() const { return hash_code_; }
183
ToString() const184 string ToString() const {
185 std::vector<string> fields = {
186 std::to_string(num_layers_),
187 std::to_string(input_size_),
188 std::to_string(num_units_),
189 std::to_string(seq_length_),
190 std::to_string(batch_size_),
191 std::to_string(dir_count_),
192 std::to_string(has_dropout_),
193 std::to_string(is_training_),
194 std::to_string(static_cast<int>(rnn_mode_)),
195 std::to_string(static_cast<int>(rnn_input_mode_)),
196 std::to_string(static_cast<int>(dtype_))};
197 return absl::StrJoin(fields, ", ");
198 }
199
200 private:
201 using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
202 RnnMode, TFRNNInputMode, DataType>;
203
get_data_as_tuple() const204 ParameterDataType get_data_as_tuple() const {
205 return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
206 batch_size_, dir_count_, has_dropout_, is_training_,
207 rnn_mode_, rnn_input_mode_, dtype_);
208 }
209
210 const int num_layers_;
211 const int input_size_;
212 const int num_units_;
213 const int seq_length_;
214 const int batch_size_;
215 const int dir_count_;
216 const bool has_dropout_;
217 const bool is_training_;
218 const RnnMode rnn_mode_;
219 const TFRNNInputMode rnn_input_mode_;
220 const DataType dtype_;
221 uint64 hash_code_;
222 };
223
224 struct RnnAutotuneGroup {
nametensorflow::__anon980e2df00111::RnnAutotuneGroup225 static string name() { return "Rnn"; }
226 };
227
228 using AutotuneRnnConfigMap =
229 AutotuneSingleton<RnnAutotuneGroup, CudnnRnnParameters, AlgorithmConfig>;
230
ParseRNNMode(const string & str,RnnMode * rnn_mode)231 Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
232 if (str == "rnn_relu") {
233 *rnn_mode = RnnMode::kRnnRelu;
234 return OkStatus();
235 } else if (str == "rnn_tanh") {
236 *rnn_mode = RnnMode::kRnnTanh;
237 return OkStatus();
238 } else if (str == "lstm") {
239 *rnn_mode = RnnMode::kRnnLstm;
240 return OkStatus();
241 } else if (str == "gru") {
242 *rnn_mode = RnnMode::kRnnGru;
243 return OkStatus();
244 }
245 return errors::InvalidArgument("Invalid RNN mode: ", str);
246 }
247
ParseTFRNNInputMode(const string & str,TFRNNInputMode * rnn_input_mode)248 Status ParseTFRNNInputMode(const string& str, TFRNNInputMode* rnn_input_mode) {
249 if (str == "linear_input") {
250 *rnn_input_mode = TFRNNInputMode::kRNNLinearInput;
251 return OkStatus();
252 } else if (str == "skip_input") {
253 *rnn_input_mode = TFRNNInputMode::kRNNSkipInput;
254 return OkStatus();
255 } else if (str == "auto_select") {
256 *rnn_input_mode = TFRNNInputMode::kAutoSelect;
257 return OkStatus();
258 }
259 return errors::InvalidArgument("Invalid RNN input mode: ", str);
260 }
261
ParseRNNDirectionMode(const string & str,RnnDirectionMode * rnn_dir_mode)262 Status ParseRNNDirectionMode(const string& str,
263 RnnDirectionMode* rnn_dir_mode) {
264 if (str == "unidirectional") {
265 *rnn_dir_mode = RnnDirectionMode::kRnnUnidirectional;
266 return OkStatus();
267 } else if (str == "bidirectional") {
268 *rnn_dir_mode = RnnDirectionMode::kRnnBidirectional;
269 return OkStatus();
270 }
271 return errors::InvalidArgument("Invalid RNN direction mode: ", str);
272 }
273
ToRNNInputMode(TFRNNInputMode tf_input_mode,int num_units,int input_size,RnnInputMode * input_mode)274 Status ToRNNInputMode(TFRNNInputMode tf_input_mode, int num_units,
275 int input_size, RnnInputMode* input_mode) {
276 switch (tf_input_mode) {
277 case TFRNNInputMode::kRNNLinearInput:
278 *input_mode = RnnInputMode::kRnnLinearSkip;
279 break;
280 case TFRNNInputMode::kRNNSkipInput:
281 *input_mode = RnnInputMode::kRnnSkipInput;
282 break;
283 case TFRNNInputMode::kAutoSelect:
284 *input_mode = (input_size == num_units) ? RnnInputMode::kRnnSkipInput
285 : RnnInputMode::kRnnLinearSkip;
286 break;
287 default:
288 return errors::InvalidArgument("Invalid TF input mode: ",
289 static_cast<int>(tf_input_mode));
290 }
291 return OkStatus();
292 }
293
294 // TODO(zhengxq): Merge those into stream_executor_util.h.
295 template <typename T>
AsDeviceMemory(const Tensor * tensor)296 const DeviceMemory<T> AsDeviceMemory(const Tensor* tensor) {
297 return DeviceMemory<T>::MakeFromByteSize(
298 const_cast<T*>(tensor->template flat<T>().data()),
299 tensor->template flat<T>().size() * sizeof(T));
300 }
301
302 template <typename T>
AsDeviceMemory(Tensor * tensor)303 DeviceMemory<T> AsDeviceMemory(Tensor* tensor) {
304 return DeviceMemory<T>::MakeFromByteSize(
305 tensor->template flat<T>().data(),
306 tensor->template flat<T>().size() * sizeof(T));
307 }
308
309 template <typename U, typename T>
CastDeviceMemory(Tensor * tensor)310 DeviceMemory<U> CastDeviceMemory(Tensor* tensor) {
311 return DeviceMemory<U>::MakeFromByteSize(
312 tensor->template flat<T>().data(),
313 tensor->template flat<T>().size() * sizeof(T));
314 }
315
SliceDeviceMemory(const DeviceMemoryBase & device_memory,int64_t offset,int64_t size)316 DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
317 int64_t offset, int64_t size) {
318 const void* base_ptr = device_memory.opaque();
319 void* offset_ptr =
320 const_cast<char*>(reinterpret_cast<const char*>(base_ptr) + offset);
321 CHECK(offset + size <= device_memory.size())
322 << "The slice is not within the region of DeviceMemory.";
323 return DeviceMemoryBase(offset_ptr, size);
324 }
325
FromExecutorStatus(const se::port::Status & s)326 inline Status FromExecutorStatus(const se::port::Status& s) {
327 return s.ok() ? OkStatus()
328 : Status(static_cast<error::Code>(static_cast<int>(s.code())),
329 s.error_message());
330 }
331
332 template <typename T>
FromExecutorStatus(const se::port::StatusOr<T> & s)333 inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
334 return FromExecutorStatus(s.status());
335 }
336
ToExecutorStatus(const Status & s)337 inline se::port::Status ToExecutorStatus(const Status& s) {
338 return s.ok() ? OkStatus()
339 : se::port::Status(static_cast<se::port::error::Code>(
340 static_cast<int>(s.code())),
341 s.error_message());
342 }
343
344 template <typename>
345 struct ToTFDataType;
346
347 template <>
348 struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
349
350 template <>
351 struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
352
353 template <>
354 struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
355
356 template <>
357 struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
358
359 // A helper to allocate temporary scratch memory for Cudnn RNN models. It
360 // takes the ownership of the underlying memory. The expectation is that the
361 // memory should be alive for the span of the Cudnn RNN itself.
362 template <typename T>
363 class CudnnRnnAllocatorInTemp : public ScratchAllocator {
364 public:
365 ~CudnnRnnAllocatorInTemp() override = default;
366
CudnnRnnAllocatorInTemp(OpKernelContext * context)367 explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
368 : context_(context) {}
GetMemoryLimitInBytes()369 int64_t GetMemoryLimitInBytes() override {
370 return std::numeric_limits<int64_t>::max();
371 }
372
AllocateBytes(int64_t byte_size)373 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
374 Tensor temporary_memory;
375 const DataType tf_data_type = ToTFDataType<T>::value;
376 int64_t allocate_count =
377 Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
378 Status allocation_status(context_->allocate_temp(
379 tf_data_type, TensorShape({allocate_count}), &temporary_memory));
380 if (!allocation_status.ok()) {
381 return ToExecutorStatus(allocation_status);
382 }
383 // Hold the reference of the allocated tensors until the end of the
384 // allocator.
385 allocated_tensors_.push_back(temporary_memory);
386 total_byte_size_ += byte_size;
387 return DeviceMemory<uint8>::MakeFromByteSize(
388 temporary_memory.template flat<T>().data(),
389 temporary_memory.template flat<T>().size() * sizeof(T));
390 }
391
TotalByteSize() const392 int64_t TotalByteSize() const { return total_byte_size_; }
393
get_allocated_tensor(int index) const394 Tensor get_allocated_tensor(int index) const {
395 return allocated_tensors_[index];
396 }
397
398 private:
399 int64_t total_byte_size_ = 0;
400 OpKernelContext* context_; // not owned
401 std::vector<Tensor> allocated_tensors_;
402 };
403
404 // A helper to allocate memory for Cudnn RNN models as a kernel output. It is
405 // used by forward pass kernel to feed the output to the backward pass.
406 // The memory is expected to live long enough after the backward pass is
407 // finished.
408 template <typename T>
409 class CudnnRnnAllocatorInOutput : public ScratchAllocator {
410 public:
~CudnnRnnAllocatorInOutput()411 ~CudnnRnnAllocatorInOutput() override {}
CudnnRnnAllocatorInOutput(OpKernelContext * context,int output_index)412 CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
413 : context_(context), output_index_(output_index) {}
GetMemoryLimitInBytes()414 int64_t GetMemoryLimitInBytes() override {
415 return std::numeric_limits<int64_t>::max();
416 }
AllocateBytes(int64_t byte_size)417 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
418 CHECK(total_byte_size_ == 0)
419 << "Reserve space allocator can only be called once";
420 int64_t allocate_count =
421 Eigen::divup(byte_size, static_cast<int64_t>(sizeof(T)));
422
423 Tensor* temporary_memory = nullptr;
424 Status allocation_status(context_->allocate_output(
425 output_index_, TensorShape({allocate_count}), &temporary_memory));
426 if (!allocation_status.ok()) {
427 return ToExecutorStatus(allocation_status);
428 }
429 total_byte_size_ += byte_size;
430 auto memory_uint8 = DeviceMemory<uint8>::MakeFromByteSize(
431 temporary_memory->template flat<T>().data(),
432 temporary_memory->template flat<T>().size() * sizeof(T));
433 return StatusOr<DeviceMemory<uint8>>(memory_uint8);
434 }
TotalByteSize()435 int64_t TotalByteSize() { return total_byte_size_; }
436
437 private:
438 int64_t total_byte_size_ = 0;
439 OpKernelContext* context_; // not owned
440 int output_index_;
441 };
442
443 // A helper to allocate memory for Cudnn RNN models, which is
444 // expected to live between kernel invocations.
445 // This class is not thread-safe.
446 class CudnnRNNSpaceAllocator : public ScratchAllocator {
447 public:
CudnnRNNSpaceAllocator(OpKernelContext * context)448 explicit CudnnRNNSpaceAllocator(OpKernelContext* context)
449 : context_(context) {}
450
~CudnnRNNSpaceAllocator()451 ~CudnnRNNSpaceAllocator() override {}
452
GetMemoryLimitInBytes()453 int64_t GetMemoryLimitInBytes() override {
454 return std::numeric_limits<int64_t>::max();
455 }
456
AllocateBytes(int64_t byte_size)457 StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override {
458 if (total_byte_size_ != 0) {
459 return Status(error::FAILED_PRECONDITION,
460 "Space allocator can only be called once");
461 }
462
463 Status allocation_status =
464 context_->allocate_temp(DT_UINT8, TensorShape({byte_size}), &tensor_);
465 if (!allocation_status.ok()) {
466 return ToExecutorStatus(allocation_status);
467 }
468 total_byte_size_ += byte_size;
469 return AsDeviceMemory<uint8>(&tensor_);
470 }
TotalByteSize()471 int64_t TotalByteSize() { return total_byte_size_; }
472
473 private:
474 int64_t total_byte_size_ = 0;
475 Tensor tensor_;
476 OpKernelContext* context_; // not owned
477 };
478
479 struct CudnnModelTypes {
480 RnnMode rnn_mode;
481 TFRNNInputMode rnn_input_mode;
482 RnnDirectionMode rnn_direction_mode;
HasInputCtensorflow::__anon980e2df00111::CudnnModelTypes483 bool HasInputC() const {
484 // For Cudnn 5.0, only LSTM has input-c. All other models use only
485 // input-h.
486 return rnn_mode == RnnMode::kRnnLstm;
487 }
488
DebugStringtensorflow::__anon980e2df00111::CudnnModelTypes489 string DebugString() const {
490 return strings::Printf(
491 "[rnn_mode, rnn_input_mode, rnn_direction_mode]: %d, %d, %d ",
492 static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode),
493 static_cast<int>(rnn_direction_mode));
494 }
495 };
496
497 // A helper class that collects the shapes to describe a RNN model.
498 struct CudnnRnnModelShapes {
499 int num_layers;
500 int input_size;
501 int num_units;
502 int dir_count;
503 int max_seq_length;
504 int batch_size;
505 int cell_num_units = 0;
506 // If you add new field to this structure, please take care of
507 // updating IsCompatibleWith() below as well as the hash function in
508 // CudnnRnnConfigHasher.
509 TensorShape input_shape;
510 TensorShape output_shape;
511 TensorShape hidden_state_shape;
512 TensorShape cell_state_shape;
513 // At present only fields related to cached RnnDescriptor are concerned.
IsCompatibleWithtensorflow::__anon980e2df00111::CudnnRnnModelShapes514 bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
515 return num_layers == rhs.num_layers && input_size == rhs.input_size &&
516 num_units == rhs.num_units && dir_count == rhs.dir_count &&
517 cell_num_units == rhs.cell_num_units &&
518 max_seq_length == rhs.max_seq_length;
519 }
DebugStringtensorflow::__anon980e2df00111::CudnnRnnModelShapes520 string DebugString() const {
521 return strings::Printf(
522 "[num_layers, input_size, num_units, dir_count, max_seq_length, "
523 "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
524 num_layers, input_size, num_units, dir_count, max_seq_length,
525 batch_size, cell_num_units);
526 }
527 };
528
529 // Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
530 // key.
531 struct CudnnRnnConfigHasher {
operator ()tensorflow::__anon980e2df00111::CudnnRnnConfigHasher532 uint64 operator()(
533 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>&
534 to_hash) const {
535 auto& shapes = to_hash.first;
536 auto& algo_desc = to_hash.second;
537
538 uint64 hash =
539 HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
540 shapes.dir_count, shapes.max_seq_length, shapes.batch_size});
541 if (algo_desc.has_value()) {
542 hash = Hash64Combine(hash, algo_desc->hash());
543 }
544 return hash;
545 }
546 };
547
548 // Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
549 // table key.
550 struct CudnnRnnConfigComparator {
operator ()tensorflow::__anon980e2df00111::CudnnRnnConfigComparator551 bool operator()(
552 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& lhs,
553 const std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>& rhs)
554 const {
555 return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
556 }
557 };
558
559 // Pointers to RNN scratch space for a specific set of shape parameters (used as
560 // a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp).
561 struct RnnScratchSpace {
562 std::unique_ptr<RnnDescriptor> rnn_desc;
563 std::unique_ptr<CudnnRNNSpaceAllocator> dropout_state_allocator;
564 };
565
566 // Extract and checks the forward input tensors, parameters, and shapes from the
567 // OpKernelContext.
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const int num_proj,CudnnRnnModelShapes * model_shapes)568 Status ExtractForwardInput(OpKernelContext* context,
569 const CudnnModelTypes& model_types, bool time_major,
570 const Tensor** input, const Tensor** input_h,
571 const Tensor** input_c, const Tensor** params,
572 const int num_proj,
573 CudnnRnnModelShapes* model_shapes) {
574 TF_RETURN_IF_ERROR(context->input("input", input));
575 TF_RETURN_IF_ERROR(context->input("input_h", input_h));
576 if (model_types.HasInputC()) {
577 TF_RETURN_IF_ERROR(context->input("input_c", input_c));
578 }
579 TF_RETURN_IF_ERROR(context->input("params", params));
580
581 if ((*input)->dims() != 3) {
582 return errors::InvalidArgument("RNN input must be a 3-D vector.");
583 }
584 if (time_major) {
585 model_shapes->max_seq_length = (*input)->dim_size(0);
586 model_shapes->batch_size = (*input)->dim_size(1);
587 } else {
588 model_shapes->max_seq_length = (*input)->dim_size(1);
589 model_shapes->batch_size = (*input)->dim_size(0);
590 }
591 model_shapes->input_size = (*input)->dim_size(2);
592 model_shapes->input_shape = (*input)->shape();
593 model_shapes->dir_count =
594 (model_types.rnn_direction_mode == RnnDirectionMode::kRnnBidirectional)
595 ? 2
596 : 1;
597
598 if ((*input_h)->dims() != 3) {
599 return errors::InvalidArgument("RNN input_h must be a 3-D vector.");
600 }
601 if (time_major) {
602 model_shapes->num_layers =
603 (*input_h)->dim_size(0) / model_shapes->dir_count;
604 } else {
605 model_shapes->num_layers =
606 (*input_h)->dim_size(1) / model_shapes->dir_count;
607 }
608 model_shapes->num_units = (*input_h)->dim_size(2);
609
610 if (time_major) {
611 model_shapes->hidden_state_shape =
612 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
613 model_shapes->batch_size, model_shapes->num_units});
614 } else {
615 model_shapes->hidden_state_shape =
616 TensorShape({model_shapes->batch_size,
617 model_shapes->dir_count * model_shapes->num_layers,
618 model_shapes->num_units});
619 }
620 if ((*input_h)->shape() != model_shapes->hidden_state_shape) {
621 return errors::InvalidArgument(
622 "Invalid input_h shape: ", (*input_h)->shape().DebugString(), " ",
623 model_shapes->hidden_state_shape.DebugString());
624 }
625 if (model_types.HasInputC()) {
626 model_shapes->cell_num_units = (*input_c)->dim_size(2);
627 if (time_major) {
628 model_shapes->cell_state_shape =
629 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
630 model_shapes->batch_size, model_shapes->cell_num_units});
631 } else {
632 model_shapes->cell_state_shape =
633 TensorShape({model_shapes->batch_size,
634 model_shapes->dir_count * model_shapes->num_layers,
635 model_shapes->cell_num_units});
636 }
637 if (num_proj == 0) {
638 if ((*input_h)->shape() != (*input_c)->shape()) {
639 return errors::InvalidArgument(
640 "input_h and input_c must have the same shape w/o projection: ",
641 (*input_h)->shape().DebugString(), " ",
642 (*input_c)->shape().DebugString());
643 }
644 } else {
645 if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
646 num_proj != (*input_h)->dim_size(2) ||
647 (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
648 (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
649 return errors::InvalidArgument(
650 "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
651 (*input_h)->shape().DebugString(), " ",
652 (*input_c)->shape().DebugString());
653 }
654 }
655 } else {
656 // dummy cell_state_shape TODO(kaixih): remove the time_major branch
657 if (time_major) {
658 model_shapes->cell_state_shape =
659 TensorShape({model_shapes->dir_count * model_shapes->num_layers,
660 model_shapes->batch_size, model_shapes->num_units});
661 } else {
662 model_shapes->cell_state_shape =
663 TensorShape({model_shapes->batch_size,
664 model_shapes->dir_count * model_shapes->num_layers,
665 model_shapes->num_units});
666 }
667 model_shapes->cell_num_units = 0;
668 }
669 if (time_major) {
670 model_shapes->output_shape =
671 TensorShape({model_shapes->max_seq_length, model_shapes->batch_size,
672 model_shapes->dir_count * model_shapes->num_units});
673 } else {
674 model_shapes->output_shape =
675 TensorShape({model_shapes->batch_size, model_shapes->max_seq_length,
676 model_shapes->dir_count * model_shapes->num_units});
677 }
678 return OkStatus();
679 }
680
681 // Overloaded function to process the sequence_lengths
ExtractForwardInput(OpKernelContext * context,const CudnnModelTypes & model_types,bool time_major,const Tensor ** input,const Tensor ** input_h,const Tensor ** input_c,const Tensor ** params,const Tensor ** sequence_lengths,const int num_proj,CudnnRnnModelShapes * model_shapes)682 Status ExtractForwardInput(OpKernelContext* context,
683 const CudnnModelTypes& model_types, bool time_major,
684 const Tensor** input, const Tensor** input_h,
685 const Tensor** input_c, const Tensor** params,
686 const Tensor** sequence_lengths, const int num_proj,
687 CudnnRnnModelShapes* model_shapes) {
688 TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
689 return ExtractForwardInput(context, model_types, time_major, input, input_h,
690 input_c, params, num_proj, model_shapes);
691 }
692
693 template <typename T>
CreateForwardAndBackwardIODescriptors(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,std::unique_ptr<RnnSequenceTensorDescriptor> * input_desc,std::unique_ptr<RnnStateTensorDescriptor> * h_state_desc,std::unique_ptr<RnnStateTensorDescriptor> * c_state_desc,std::unique_ptr<RnnSequenceTensorDescriptor> * output_desc,const absl::Span<const int> seq_lengths,bool time_major)694 Status CreateForwardAndBackwardIODescriptors(
695 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
696 std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
697 std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
698 std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
699 std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
700 const absl::Span<const int> seq_lengths, bool time_major) {
701 StreamExecutor* executor = context->op_device_context()->stream()->parent();
702 se::dnn::DataType data_type = ToDataType<T>::value;
703
704 const TensorShape& input_shape = model_shapes.input_shape;
705 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
706 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
707 const TensorShape& output_shape = model_shapes.output_shape;
708
709 DCHECK_EQ(input_shape.dims(), 3);
710 if (seq_lengths.data() != nullptr) {
711 if (time_major) {
712 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
713 input_shape.dim_size(0), input_shape.dim_size(1),
714 input_shape.dim_size(2), seq_lengths, time_major, data_type);
715 TF_RETURN_IF_ERROR(input_desc_s.status());
716 *input_desc = std::move(input_desc_s).value();
717 } else {
718 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
719 input_shape.dim_size(1), input_shape.dim_size(0),
720 input_shape.dim_size(2), seq_lengths, time_major, data_type);
721 TF_RETURN_IF_ERROR(input_desc_s.status());
722 *input_desc = std::move(input_desc_s).value();
723 }
724 } else {
725 auto input_desc_s = executor->createRnnSequenceTensorDescriptor(
726 input_shape.dim_size(0), input_shape.dim_size(1),
727 input_shape.dim_size(2), data_type);
728 TF_RETURN_IF_ERROR(input_desc_s.status());
729 *input_desc = std::move(input_desc_s).value();
730 }
731
732 DCHECK_EQ(hidden_state_shape.dims(), 3);
733 if (time_major) {
734 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
735 hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
736 hidden_state_shape.dim_size(2), data_type);
737 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
738 *h_state_desc = std::move(hidden_state_desc_s).value();
739 } else {
740 auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
741 hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
742 hidden_state_shape.dim_size(2), data_type);
743 TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
744 *h_state_desc = std::move(hidden_state_desc_s).value();
745 }
746
747 DCHECK_EQ(cell_state_shape.dims(), 3);
748 if (time_major) {
749 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
750 cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
751 cell_state_shape.dim_size(2), data_type);
752 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
753 *c_state_desc = std::move(cell_state_desc_s).value();
754 } else {
755 auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
756 cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
757 cell_state_shape.dim_size(2), data_type);
758 TF_RETURN_IF_ERROR(cell_state_desc_s.status());
759 *c_state_desc = std::move(cell_state_desc_s).value();
760 }
761
762 DCHECK_EQ(output_shape.dims(), 3);
763 if (seq_lengths.data() != nullptr) {
764 if (time_major) {
765 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
766 output_shape.dim_size(0), output_shape.dim_size(1),
767 output_shape.dim_size(2), seq_lengths, time_major, data_type);
768 TF_RETURN_IF_ERROR(output_desc_s.status());
769 *output_desc = std::move(output_desc_s).value();
770 } else {
771 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
772 output_shape.dim_size(1), output_shape.dim_size(0),
773 output_shape.dim_size(2), seq_lengths, time_major, data_type);
774 TF_RETURN_IF_ERROR(output_desc_s.status());
775 *output_desc = std::move(output_desc_s).value();
776 }
777 } else {
778 auto output_desc_s = executor->createRnnSequenceTensorDescriptor(
779 output_shape.dim_size(0), output_shape.dim_size(1),
780 output_shape.dim_size(2), data_type);
781 TF_RETURN_IF_ERROR(output_desc_s.status());
782 *output_desc = std::move(output_desc_s).value();
783 }
784
785 return OkStatus();
786 }
787
788 template <typename T>
DoForward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const bool is_training,Tensor * output,Tensor * output_h,Tensor * output_c,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)789 Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
790 const CudnnModelTypes& model_types,
791 const CudnnRnnModelShapes& model_shapes,
792 /* forward inputs */
793 const Tensor* input, const Tensor* input_h,
794 const Tensor* input_c, const Tensor* params,
795 const bool is_training,
796 /* forward outputs, outputs of the function */
797 Tensor* output, Tensor* output_h, Tensor* output_c,
798 const Tensor* sequence_lengths, bool time_major,
799 ScratchAllocator* reserve_space_allocator,
800 ScratchAllocator* workspace_allocator,
801 ProfileResult* output_profile_result) {
802 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
803 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
804 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
805 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
806
807 absl::Span<const int> seq_lengths;
808 if (sequence_lengths != nullptr) {
809 seq_lengths = absl::Span<const int>(
810 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
811 }
812 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
813 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
814 &output_desc, seq_lengths, time_major));
815
816 auto input_data = AsDeviceMemory<T>(input);
817 auto input_h_data = AsDeviceMemory<T>(input_h);
818 DeviceMemory<T> input_c_data;
819 if (model_types.HasInputC()) {
820 input_c_data = AsDeviceMemory<T>(input_c);
821 }
822
823 auto params_data = AsDeviceMemory<T>(params);
824 auto output_data = AsDeviceMemory<T>(output);
825 auto output_h_data = AsDeviceMemory<T>(output_h);
826 DeviceMemory<T> output_c_data;
827 if (model_types.HasInputC()) {
828 output_c_data = AsDeviceMemory<T>(output_c);
829 }
830
831 Stream* stream = context->op_device_context()->stream();
832
833 Tensor seq_lengths_tensor;
834 DeviceMemory<int> seq_lengths_ptr;
835 if (sequence_lengths != nullptr) {
836 TF_RETURN_IF_ERROR(context->allocate_temp(
837 DT_INT32, {static_cast<long>(seq_lengths.size())},
838 &seq_lengths_tensor));
839 seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
840 if (!stream
841 ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
842 seq_lengths.size() * sizeof(int))
843 .ok()) {
844 return errors::InvalidArgument(
845 "Failed to copy memory from host to "
846 "device for sequence_lengths in "
847 "CudnnRNNV3");
848 }
849 }
850
851 bool launch_success =
852 stream
853 ->ThenRnnForward(rnn_desc, *input_desc, input_data, seq_lengths_ptr,
854 *h_state_desc, input_h_data, *c_state_desc,
855 input_c_data, params_data, *output_desc,
856 &output_data, *h_state_desc, &output_h_data,
857 *c_state_desc, &output_c_data, is_training,
858 reserve_space_allocator, workspace_allocator,
859 output_profile_result)
860 .ok();
861 return launch_success
862 ? OkStatus()
863 : errors::Internal(
864 "Failed to call ThenRnnForward with model config: ",
865 model_types.DebugString(), ", ", model_shapes.DebugString());
866 }
867
868 template <typename T>
DoBackward(OpKernelContext * context,const RnnDescriptor & rnn_desc,const CudnnModelTypes & model_types,const CudnnRnnModelShapes & model_shapes,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,const Tensor * output,const Tensor * output_h,const Tensor * output_c,const Tensor * output_backprop,const Tensor * output_h_backprop,const Tensor * output_c_backprop,const Tensor * reserve_space,Tensor * input_backprop,Tensor * input_h_backprop,Tensor * input_c_backprop,Tensor * params_backprop,const Tensor * sequence_lengths,bool time_major,ScratchAllocator * workspace_allocator,ProfileResult * output_profile_result)869 Status DoBackward(
870 OpKernelContext* context, const RnnDescriptor& rnn_desc,
871 const CudnnModelTypes& model_types, const CudnnRnnModelShapes& model_shapes,
872 /* forward inputs */
873 const Tensor* input, const Tensor* input_h, const Tensor* input_c,
874 const Tensor* params,
875 /* forward outputs */
876 const Tensor* output, const Tensor* output_h, const Tensor* output_c,
877 /* backprop inputs */
878 const Tensor* output_backprop, const Tensor* output_h_backprop,
879 const Tensor* output_c_backprop, const Tensor* reserve_space,
880 /* backprop outputs, output of the function */
881 Tensor* input_backprop, Tensor* input_h_backprop, Tensor* input_c_backprop,
882 Tensor* params_backprop, const Tensor* sequence_lengths, bool time_major,
883 ScratchAllocator* workspace_allocator,
884 ProfileResult* output_profile_result) {
885 std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
886 std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
887 std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
888 std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
889
890 absl::Span<const int> seq_lengths;
891 if (sequence_lengths != nullptr) {
892 seq_lengths = absl::Span<const int>(
893 sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
894 }
895 TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
896 context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
897 &output_desc, seq_lengths, time_major));
898
899 auto input_data = AsDeviceMemory<T>(input);
900 auto input_h_data = AsDeviceMemory<T>(input_h);
901 DeviceMemory<T> input_c_data;
902 if (model_types.HasInputC()) {
903 input_c_data = AsDeviceMemory<T>(input_c);
904 }
905 auto params_data = AsDeviceMemory<T>(params);
906 auto output_data = AsDeviceMemory<T>(output);
907 auto output_h_data = AsDeviceMemory<T>(output_h);
908 DeviceMemory<T> output_c_data;
909 if (model_types.HasInputC()) {
910 output_c_data = AsDeviceMemory<T>(output_c);
911 }
912 auto output_backprop_data = AsDeviceMemory<T>(output_backprop);
913 auto output_h_backprop_data = AsDeviceMemory<T>(output_h_backprop);
914 DeviceMemory<T> output_c_backprop_data;
915 if (model_types.HasInputC()) {
916 output_c_backprop_data = AsDeviceMemory<T>(output_c_backprop);
917 }
918 auto input_backprop_data = AsDeviceMemory<T>(input_backprop);
919 auto input_h_backprop_data = AsDeviceMemory<T>(input_h_backprop);
920 DeviceMemory<T> input_c_backprop_data;
921 if (model_types.HasInputC()) {
922 input_c_backprop_data = AsDeviceMemory<T>(input_c_backprop);
923 }
924 auto params_backprop_data = AsDeviceMemory<T>(params_backprop);
925 auto reserve_space_uint8 =
926 CastDeviceMemory<uint8, T>(const_cast<Tensor*>(reserve_space));
927
928 // Creates a memory callback for the workspace. The memory lives to the end
929 // of this kernel calls.
930 Stream* stream = context->op_device_context()->stream();
931
932 Tensor seq_lengths_tensor;
933 DeviceMemory<int> seq_lengths_ptr;
934 if (sequence_lengths != nullptr) {
935 TF_RETURN_IF_ERROR(context->allocate_temp(
936 DT_INT32, {static_cast<long>(seq_lengths.size())},
937 &seq_lengths_tensor));
938 seq_lengths_ptr = AsDeviceMemory<int>(&seq_lengths_tensor);
939 if (!stream
940 ->ThenMemcpy(&seq_lengths_ptr, seq_lengths.data(),
941 seq_lengths.size() * sizeof(int))
942 .ok()) {
943 return errors::InvalidArgument(
944 "Failed to copy memory from host to "
945 "device for sequence_lengths in "
946 "CudnnRNNBackwardOpV3");
947 }
948 }
949
950 bool launch_success =
951 stream
952 ->ThenRnnBackward(
953 rnn_desc, *input_desc, input_data, seq_lengths_ptr, *h_state_desc,
954 input_h_data, *c_state_desc, input_c_data, params_data,
955 *output_desc, output_data, *h_state_desc, output_h_data,
956 *c_state_desc, output_c_data, output_backprop_data,
957 output_h_backprop_data, output_c_backprop_data,
958 &input_backprop_data, &input_h_backprop_data,
959 &input_c_backprop_data, ¶ms_backprop_data,
960 &reserve_space_uint8, workspace_allocator, output_profile_result)
961 .ok();
962 return launch_success
963 ? OkStatus()
964 : errors::Internal(
965 "Failed to call ThenRnnBackward with model config: ",
966 model_types.DebugString(), ", ", model_shapes.DebugString());
967 }
968
969 template <typename T>
RestoreParams(const OpInputList params_input,const std::vector<RnnDescriptor::ParamsRegion> & params,DeviceMemoryBase * data_dst,Stream * stream)970 void RestoreParams(const OpInputList params_input,
971 const std::vector<RnnDescriptor::ParamsRegion>& params,
972 DeviceMemoryBase* data_dst, Stream* stream) {
973 int num_params = params.size();
974 CHECK(params_input.size() == num_params)
975 << "Number of params mismatch. Expected " << params_input.size()
976 << ", got " << num_params;
977 for (int i = 0; i < params.size(); i++) {
978 int64_t size_in_bytes = params[i].size;
979 int64_t size = size_in_bytes / sizeof(T);
980 CHECK(size == params_input[i].NumElements())
981 << "Params size mismatch. Expected " << size << ", got "
982 << params_input[i].NumElements();
983 auto data_src_ptr = StreamExecutorUtil::AsDeviceMemory<T>(params_input[i]);
984 DeviceMemoryBase data_dst_ptr =
985 SliceDeviceMemory(*data_dst, params[i].offset, size_in_bytes);
986 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
987 }
988 }
989
ShouldUsePaddedIO(const Tensor * sequence_lengths,const CudnnRnnModelShapes & model_shapes,bool time_major)990 bool ShouldUsePaddedIO(const Tensor* sequence_lengths,
991 const CudnnRnnModelShapes& model_shapes,
992 bool time_major) {
993 auto seq_array = sequence_lengths->template flat<int>().data();
994 bool all_max_seq_length = true;
995 for (int i = 0; i < model_shapes.batch_size; i++) {
996 if (seq_array[i] != model_shapes.max_seq_length) {
997 all_max_seq_length = false;
998 break;
999 }
1000 }
1001 return !(time_major && all_max_seq_length);
1002 }
1003
1004 } // namespace
1005
1006 // Note: all following kernels depend on a RnnDescriptor instance, which
1007 // according to Cudnn official doc should be kept around and reused across all
1008 // Cudnn kernels in the same model.
1009 // In Tensorflow, we don't pass the reference across different OpKernels,
1010 // rather, recreate it separately in each OpKernel, which does no cause issue:
1011 // CudnnDropoutDescriptor keeps a reference to a memory for
1012 // random number generator state. During recreation, this state is lost.
1013 // However, only forward-pass Cudnn APIs make use of the state.
1014
1015 // A common base class for RNN kernels. It extracts common attributes and
1016 // shape validations.
1017 class CudnnRNNKernelCommon : public OpKernel {
1018 protected:
CudnnRNNKernelCommon(OpKernelConstruction * context)1019 explicit CudnnRNNKernelCommon(OpKernelConstruction* context)
1020 : OpKernel(context) {
1021 OP_REQUIRES_OK(context, context->GetAttr("dropout", &dropout_));
1022 OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
1023 OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
1024 string str;
1025 OP_REQUIRES_OK(context, context->GetAttr("rnn_mode", &str));
1026 OP_REQUIRES_OK(context, ParseRNNMode(str, &model_types_.rnn_mode));
1027 OP_REQUIRES_OK(context, context->GetAttr("input_mode", &str));
1028 OP_REQUIRES_OK(context,
1029 ParseTFRNNInputMode(str, &model_types_.rnn_input_mode));
1030 OP_REQUIRES_OK(context, context->GetAttr("direction", &str));
1031 OP_REQUIRES_OK(
1032 context, ParseRNNDirectionMode(str, &model_types_.rnn_direction_mode));
1033 // Reset CudnnRnnDescriptor and related random number generate states in
1034 // every Compute() call.
1035 OP_REQUIRES_OK(context, ReadBoolFromEnvVar("TF_CUDNN_RESET_RND_GEN_STATE",
1036 false, &reset_rnd_gen_state_));
1037 }
1038
HasInputC() const1039 bool HasInputC() const { return model_types_.HasInputC(); }
rnn_mode() const1040 RnnMode rnn_mode() const { return model_types_.rnn_mode; }
rnn_input_mode() const1041 TFRNNInputMode rnn_input_mode() const { return model_types_.rnn_input_mode; }
rnn_direction_mode() const1042 RnnDirectionMode rnn_direction_mode() const {
1043 return model_types_.rnn_direction_mode;
1044 }
model_types() const1045 const CudnnModelTypes& model_types() const { return model_types_; }
dropout() const1046 float dropout() const { return dropout_; }
seed()1047 uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
ResetRndGenState()1048 bool ResetRndGenState() { return reset_rnd_gen_state_; }
1049
1050 template <typename T>
ExtractCudnnRNNParamsInfo(OpKernelContext * context,int num_proj,std::unique_ptr<RnnDescriptor> * rnn_desc)1051 Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
1052 std::unique_ptr<RnnDescriptor>* rnn_desc) {
1053 const Tensor* num_layers_t = nullptr;
1054 TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
1055 if (!TensorShapeUtils::IsScalar(num_layers_t->shape())) {
1056 return errors::InvalidArgument("num_layers is not a scalar");
1057 }
1058 int num_layers = num_layers_t->scalar<int>()();
1059 const Tensor* num_units_t = nullptr;
1060 TF_RETURN_IF_ERROR(context->input("num_units", &num_units_t));
1061 if (!TensorShapeUtils::IsScalar(num_units_t->shape())) {
1062 return errors::InvalidArgument("num_units is not a scalar");
1063 }
1064 int num_units = num_units_t->scalar<int>()();
1065 const Tensor* input_size_t = nullptr;
1066 TF_RETURN_IF_ERROR(context->input("input_size", &input_size_t));
1067 if (!TensorShapeUtils::IsScalar(input_size_t->shape())) {
1068 return errors::InvalidArgument("input_size is not a scalar");
1069 }
1070 int input_size = input_size_t->scalar<int>()();
1071
1072 int h_num_units = (num_proj == 0 ? num_units : num_proj);
1073 int c_num_units = (num_proj == 0 ? 0 : num_units);
1074
1075 RnnInputMode input_mode;
1076 TF_RETURN_IF_ERROR(
1077 ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
1078
1079 Stream* stream = context->op_device_context()->stream();
1080 // ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
1081 // random number generator, therefore set state_allocator to nullptr.
1082 const AlgorithmConfig algo_config;
1083 auto rnn_desc_s = stream->parent()->createRnnDescriptor(
1084 num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
1085 /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
1086 ToDataType<T>::value, algo_config, dropout(), seed(),
1087 /* state_allocator=*/nullptr, /*use_padded_io=*/false);
1088 if (!rnn_desc_s.ok()) {
1089 return FromExecutorStatus(rnn_desc_s);
1090 }
1091 *rnn_desc = std::move(rnn_desc_s).value();
1092 return OkStatus();
1093 }
1094
1095 template <typename T>
CreateRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,ScratchAllocator * dropout_state_allocator,std::unique_ptr<RnnDescriptor> * rnn_desc,bool use_padded_io)1096 Status CreateRnnDescriptor(OpKernelContext* context,
1097 const CudnnRnnModelShapes& model_shapes,
1098 const RnnInputMode& input_mode,
1099 const AlgorithmConfig& algo_config,
1100 ScratchAllocator* dropout_state_allocator,
1101 std::unique_ptr<RnnDescriptor>* rnn_desc,
1102 bool use_padded_io) {
1103 StreamExecutor* executor = context->op_device_context()->stream()->parent();
1104 se::dnn::DataType data_type = ToDataType<T>::value;
1105 auto rnn_desc_s = executor->createRnnDescriptor(
1106 model_shapes.num_layers, model_shapes.num_units,
1107 model_shapes.input_size, model_shapes.cell_num_units,
1108 model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
1109 data_type, algo_config, dropout(), seed(), dropout_state_allocator,
1110 use_padded_io);
1111 TF_RETURN_IF_ERROR(rnn_desc_s.status());
1112
1113 *rnn_desc = std::move(rnn_desc_s).value();
1114 return OkStatus();
1115 }
1116
1117 using RnnStateCache = gtl::FlatMap<
1118 std::pair<CudnnRnnModelShapes, absl::optional<AlgorithmDesc>>,
1119 RnnScratchSpace, CudnnRnnConfigHasher, CudnnRnnConfigComparator>;
1120 // Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
1121 // should outlive the returned pointer.
1122 template <typename T>
GetCachedRnnDescriptor(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const AlgorithmConfig & algo_config,RnnStateCache * cache,RnnDescriptor ** rnn_desc,bool use_padded_io)1123 Status GetCachedRnnDescriptor(OpKernelContext* context,
1124 const CudnnRnnModelShapes& model_shapes,
1125 const RnnInputMode& input_mode,
1126 const AlgorithmConfig& algo_config,
1127 RnnStateCache* cache, RnnDescriptor** rnn_desc,
1128 bool use_padded_io) {
1129 auto key = std::make_pair(model_shapes, algo_config.algorithm());
1130 RnnScratchSpace& rnn_state = (*cache)[key];
1131 if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
1132 CudnnRNNSpaceAllocator* dropout_state_allocator =
1133 new CudnnRNNSpaceAllocator(context);
1134 rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
1135 Status status = CreateRnnDescriptor<T>(
1136 context, model_shapes, input_mode, algo_config,
1137 dropout_state_allocator, &rnn_state.rnn_desc, use_padded_io);
1138 TF_RETURN_IF_ERROR(status);
1139 }
1140 *rnn_desc = rnn_state.rnn_desc.get();
1141 return OkStatus();
1142 }
1143
1144 private:
1145 int seed_;
1146 int seed2_;
1147 float dropout_;
1148 bool reset_rnd_gen_state_;
1149
1150 CudnnModelTypes model_types_;
1151 };
1152
1153 // A class that returns the size of the opaque parameter buffer. The user should
1154 // use that to create the actual parameter buffer for training. However, it
1155 // should not be used for saving and restoring.
1156 template <typename T, typename Index>
1157 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
1158 public:
CudnnRNNParamsSizeOp(OpKernelConstruction * context)1159 explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
1160 : CudnnRNNKernelCommon(context) {
1161 if (context->HasAttr("num_proj")) {
1162 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1163 } else {
1164 num_proj_ = 0;
1165 }
1166 }
1167
Compute(OpKernelContext * context)1168 void Compute(OpKernelContext* context) override {
1169 std::unique_ptr<RnnDescriptor> rnn_desc;
1170 OP_REQUIRES_OK(context,
1171 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1172 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1173 CHECK(params_size_in_bytes % sizeof(T) == 0)
1174 << "params_size_in_bytes must be multiple of element size";
1175 int64_t params_size = params_size_in_bytes / sizeof(T);
1176
1177 Tensor* output_t = nullptr;
1178 OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
1179 *output_t->template flat<Index>().data() = params_size;
1180 }
1181
1182 private:
1183 int num_proj_;
1184 };
1185
1186 #define REGISTER_GPU(T) \
1187 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsSize") \
1188 .Device(DEVICE_GPU) \
1189 .HostMemory("num_layers") \
1190 .HostMemory("num_units") \
1191 .HostMemory("input_size") \
1192 .HostMemory("params_size") \
1193 .TypeConstraint<T>("T") \
1194 .TypeConstraint<int32>("S"), \
1195 CudnnRNNParamsSizeOp<GPUDevice, T, int32>);
1196
1197 TF_CALL_half(REGISTER_GPU);
1198 TF_CALL_float(REGISTER_GPU);
1199 TF_CALL_double(REGISTER_GPU);
1200 #undef REGISTER_GPU
1201
1202 // Convert weight and bias params from a platform-specific layout to the
1203 // canonical form.
1204 template <typename T>
1205 class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
1206 public:
CudnnRNNParamsToCanonical(OpKernelConstruction * context)1207 explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
1208 : CudnnRNNKernelCommon(context) {
1209 if (context->HasAttr("num_params")) {
1210 OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
1211 } else {
1212 num_params_ = 0;
1213 }
1214 if (context->HasAttr("num_params_weights")) {
1215 OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
1216 &num_params_weights_));
1217 } else {
1218 num_params_weights_ = 0;
1219 }
1220 if (context->HasAttr("num_params_biases")) {
1221 OP_REQUIRES_OK(
1222 context, context->GetAttr("num_params_biases", &num_params_biases_));
1223 } else {
1224 num_params_biases_ = 0;
1225 }
1226 if (context->HasAttr("num_proj")) {
1227 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1228 } else {
1229 num_proj_ = 0;
1230 }
1231 if (num_proj_ == 0) {
1232 num_params_weights_ = num_params_;
1233 num_params_biases_ = num_params_;
1234 }
1235 }
1236
Compute(OpKernelContext * context)1237 void Compute(OpKernelContext* context) override {
1238 const Tensor& input = context->input(3);
1239 auto input_ptr = StreamExecutorUtil::AsDeviceMemory<T>(input);
1240 Stream* stream = context->op_device_context()->stream();
1241
1242 std::unique_ptr<RnnDescriptor> rnn_desc;
1243 OP_REQUIRES_OK(context,
1244 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1245 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1246 CHECK(params_size_in_bytes % sizeof(T) == 0)
1247 << "params_size_in_bytes must be multiple of element size";
1248
1249 const Tensor* num_units_t = nullptr;
1250 OP_REQUIRES_OK(context, context->input("num_units", &num_units_t));
1251 CHECK(TensorShapeUtils::IsScalar(num_units_t->shape()))
1252 << "num_units is not a scalar";
1253 int num_units = num_units_t->scalar<int>()();
1254
1255 const Tensor* input_size_t = nullptr;
1256 OP_REQUIRES_OK(context, context->input("input_size", &input_size_t));
1257 CHECK(TensorShapeUtils::IsScalar(input_size_t->shape()))
1258 << "input_size is not a scalar";
1259 int input_size = input_size_t->scalar<int>()();
1260
1261 const Tensor* num_layers_t = nullptr;
1262 OP_REQUIRES_OK(context, context->input("num_layers", &num_layers_t));
1263 CHECK(TensorShapeUtils::IsScalar(num_layers_t->shape()))
1264 << "num_layers is not a scalar";
1265 int num_layers = num_layers_t->scalar<int>()();
1266 int num_dirs = 1;
1267 if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
1268 num_dirs = 2;
1269 }
1270 const int num_params_weights_per_layer =
1271 num_params_weights_ / num_layers / num_dirs;
1272 // Number of params applied on inputs. The rest are applied on recurrent
1273 // hidden states.
1274 const int num_params_input_state = num_params_weights_per_layer / 2;
1275 OP_REQUIRES(
1276 context, num_params_weights_ % (num_layers * num_dirs) == 0,
1277 errors::InvalidArgument("Number of params (weights) is not a multiple"
1278 "of num_layers * num_dirs."));
1279 OP_REQUIRES(
1280 context, num_params_biases_ % (num_layers * num_dirs) == 0,
1281 errors::InvalidArgument("Number of params (biases) is not a multiple"
1282 "of num_layers * num_dirs."));
1283 if (num_proj_ == 0) {
1284 OP_REQUIRES(
1285 context, num_params_weights_per_layer % 2 == 0,
1286 errors::InvalidArgument("Number of params (weights) per layer is not"
1287 "an even number with no projection."));
1288 } else {
1289 OP_REQUIRES(
1290 context, num_params_weights_per_layer % 2 != 0,
1291 errors::InvalidArgument("Number of params (weights) per layer is not"
1292 "an odl number with projection."));
1293 }
1294
1295 OP_REQUIRES(
1296 context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
1297 errors::InvalidArgument("C Number of params mismatch. Expected ",
1298 num_params_weights_, ", got ",
1299 rnn_desc->ParamsWeightRegions().size()));
1300 int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
1301 int c_num_units = (num_proj_ == 0 ? 0 : num_units);
1302 for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
1303 int64_t size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
1304 int64_t size = size_in_bytes / sizeof(T);
1305 const int layer_idx = i / num_params_weights_per_layer;
1306 const int index_within_layer = i % num_params_weights_per_layer;
1307 int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
1308 // In CuDNN layout, each layer has num_params_weights_per_layer params,
1309 // with the
1310 // first half a.k.a num_params_input_state params applied on the inputs,
1311 // and the second half on the recurrent hidden states.
1312 bool apply_on_input_state = index_within_layer < num_params_input_state;
1313 if (rnn_direction_mode() == RnnDirectionMode::kRnnUnidirectional) {
1314 if (layer_idx == 0 && apply_on_input_state) {
1315 width = input_size;
1316 } else {
1317 width = h_num_units;
1318 }
1319 } else {
1320 if (apply_on_input_state) {
1321 if (layer_idx <= 1) {
1322 // First fwd or bak layer.
1323 width = input_size;
1324 } else {
1325 // Following layers, cell inputs are concatenated outputs of
1326 // its prior layer.
1327 width = 2 * h_num_units;
1328 }
1329 } else {
1330 width = h_num_units;
1331 }
1332 }
1333 CHECK(size == width * height) << "Params size mismatch. Expected "
1334 << width * height << ", got " << size;
1335 Tensor* output = nullptr;
1336 int id_in_layer = i % num_params_weights_per_layer;
1337 if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
1338 std::swap(height, width);
1339 }
1340 OP_REQUIRES_OK(context, context->allocate_output(
1341 i, TensorShape({height, width}), &output));
1342 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1343 input_ptr, rnn_desc->ParamsWeightRegions()[i].offset, size_in_bytes);
1344 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1345 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1346 }
1347
1348 OP_REQUIRES(
1349 context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
1350 errors::InvalidArgument("A Number of params mismatch. Expected ",
1351 num_params_biases_, ", got ",
1352 rnn_desc->ParamsBiasRegions().size()));
1353 for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
1354 int64_t size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
1355 int64_t size = size_in_bytes / sizeof(T);
1356 OP_REQUIRES(context, size == num_units,
1357 errors::InvalidArgument("Params size mismatch. Expected ",
1358 num_units, ", got ", size));
1359
1360 Tensor* output = nullptr;
1361 OP_REQUIRES_OK(context,
1362 context->allocate_output(num_params_weights_ + i,
1363 TensorShape({size}), &output));
1364 DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
1365 input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
1366 auto data_dst_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1367 stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
1368 }
1369 }
1370
1371 private:
1372 int num_params_;
1373 int num_params_weights_;
1374 int num_params_biases_;
1375 int num_proj_;
1376 };
1377
1378 #define REGISTER_GPU(T) \
1379 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonical") \
1380 .Device(DEVICE_GPU) \
1381 .HostMemory("num_layers") \
1382 .HostMemory("num_units") \
1383 .HostMemory("input_size") \
1384 .TypeConstraint<T>("T"), \
1385 CudnnRNNParamsToCanonical<GPUDevice, T>);
1386 TF_CALL_half(REGISTER_GPU);
1387 TF_CALL_float(REGISTER_GPU);
1388 TF_CALL_double(REGISTER_GPU);
1389 #undef REGISTER_GPU
1390
1391 #define REGISTER_GPU(T) \
1392 REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
1393 .Device(DEVICE_GPU) \
1394 .HostMemory("num_layers") \
1395 .HostMemory("num_units") \
1396 .HostMemory("input_size") \
1397 .TypeConstraint<T>("T"), \
1398 CudnnRNNParamsToCanonical<GPUDevice, T>);
1399 TF_CALL_half(REGISTER_GPU);
1400 TF_CALL_float(REGISTER_GPU);
1401 TF_CALL_double(REGISTER_GPU);
1402 #undef REGISTER_GPU
1403
1404 // Convert weight and bias params from the canonical form to a
1405 // platform-specific layout.
1406 template <typename T>
1407 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
1408 public:
CudnnRNNCanonicalToParams(OpKernelConstruction * context)1409 explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
1410 : CudnnRNNKernelCommon(context) {
1411 if (context->HasAttr("num_proj")) {
1412 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1413 } else {
1414 num_proj_ = 0;
1415 }
1416 }
1417
Compute(OpKernelContext * context)1418 void Compute(OpKernelContext* context) override {
1419 std::unique_ptr<RnnDescriptor> rnn_desc;
1420 OP_REQUIRES_OK(context,
1421 ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
1422 int64_t params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
1423 CHECK(params_size_in_bytes % sizeof(T) == 0)
1424 << "params_size_in_bytes must be multiple of element size";
1425 Tensor* output = nullptr;
1426 int params_size = params_size_in_bytes / sizeof(T);
1427 OP_REQUIRES_OK(context,
1428 context->allocate_output(0, {params_size}, &output));
1429 auto output_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*output);
1430 Stream* stream = context->op_device_context()->stream();
1431
1432 OpInputList weights;
1433 OP_REQUIRES_OK(context, context->input_list("weights", &weights));
1434 RestoreParams<T>(weights, rnn_desc->ParamsWeightRegions(), &output_ptr,
1435 stream);
1436
1437 OpInputList biases;
1438 OP_REQUIRES_OK(context, context->input_list("biases", &biases));
1439 RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
1440 stream);
1441 }
1442
1443 private:
1444 int num_proj_;
1445 };
1446
1447 #define REGISTER_GPU(T) \
1448 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParams") \
1449 .Device(DEVICE_GPU) \
1450 .HostMemory("num_layers") \
1451 .HostMemory("num_units") \
1452 .HostMemory("input_size") \
1453 .TypeConstraint<T>("T"), \
1454 CudnnRNNCanonicalToParams<GPUDevice, T>);
1455 TF_CALL_half(REGISTER_GPU);
1456 TF_CALL_float(REGISTER_GPU);
1457 TF_CALL_double(REGISTER_GPU);
1458 #undef REGISTER_GPU
1459
1460 #define REGISTER_GPU(T) \
1461 REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
1462 .Device(DEVICE_GPU) \
1463 .HostMemory("num_layers") \
1464 .HostMemory("num_units") \
1465 .HostMemory("input_size") \
1466 .TypeConstraint<T>("T"), \
1467 CudnnRNNCanonicalToParams<GPUDevice, T>);
1468 TF_CALL_half(REGISTER_GPU);
1469 TF_CALL_float(REGISTER_GPU);
1470 TF_CALL_double(REGISTER_GPU);
1471 #undef REGISTER_GPU
1472
1473 // Run the forward operation of the RNN model.
1474 template <typename T>
1475 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1476 public:
CudnnRNNForwardOp(OpKernelConstruction * context)1477 explicit CudnnRNNForwardOp(OpKernelConstruction* context)
1478 : CudnnRNNKernelCommon(context) {
1479 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
1480
1481 // Read debug env variables.
1482 is_debug_mode_ = DebugCudnnRnn();
1483 debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
1484 debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
1485 }
1486
Compute(OpKernelContext * context)1487 void Compute(OpKernelContext* context) override {
1488 AlgorithmConfig algo_config;
1489 ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
1490 /*time_major=*/true, /*num_proj=*/0);
1491 }
1492
1493 protected:
ComputeAndReturnAlgorithm(OpKernelContext * context,AlgorithmConfig * output_algo_config,bool var_seq_lengths,bool time_major,int num_proj)1494 virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
1495 AlgorithmConfig* output_algo_config,
1496 bool var_seq_lengths, bool time_major,
1497 int num_proj) {
1498 CHECK_NE(output_algo_config, nullptr);
1499
1500 const Tensor* input = nullptr;
1501 const Tensor* input_h = nullptr;
1502 const Tensor* input_c = nullptr;
1503 const Tensor* params = nullptr;
1504 const Tensor* sequence_lengths = nullptr;
1505 CudnnRnnModelShapes model_shapes;
1506 bool use_padded_io = false;
1507 if (var_seq_lengths) {
1508 OP_REQUIRES_OK(context, ExtractForwardInput(
1509 context, model_types(), time_major, &input,
1510 &input_h, &input_c, ¶ms,
1511 &sequence_lengths, num_proj, &model_shapes));
1512 use_padded_io =
1513 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1514 } else {
1515 OP_REQUIRES_OK(context,
1516 ExtractForwardInput(context, model_types(), time_major,
1517 &input, &input_h, &input_c, ¶ms,
1518 num_proj, &model_shapes));
1519 }
1520 RnnInputMode input_mode;
1521 OP_REQUIRES_OK(context,
1522 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1523 model_shapes.input_size, &input_mode));
1524
1525 Tensor* output = nullptr;
1526 Tensor* output_h = nullptr;
1527 Tensor* output_c = nullptr;
1528 OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
1529 &output_h, &output_c));
1530
1531 // Creates a memory callback for the reserve_space. The memory lives in the
1532 // output of this kernel. And it will be fed into the backward pass when
1533 // needed.
1534 CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
1535 // Creates a memory callback for the workspace. The memory lives to the end
1536 // of this kernel calls.
1537 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1538
1539 if (is_debug_mode_) {
1540 AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_,
1541 absl::nullopt);
1542 output_algo_config->set_algorithm(algo_desc);
1543 } else {
1544 OP_REQUIRES_OK(context,
1545 MaybeAutotune(context, model_shapes, input_mode, input,
1546 input_h, input_c, params, output, output_h,
1547 output_c, output_algo_config));
1548 }
1549
1550 Status launch_status;
1551 {
1552 mutex_lock l(mu_);
1553 RnnDescriptor* rnn_desc_ptr = nullptr;
1554 OP_REQUIRES_OK(context,
1555 GetCachedRnnDescriptor<T>(
1556 context, model_shapes, input_mode, *output_algo_config,
1557 &rnn_state_cache_, &rnn_desc_ptr, use_padded_io));
1558 launch_status = DoForward<T>(
1559 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1560 input_c, params, is_training_, output, output_h, output_c,
1561 sequence_lengths, time_major, &reserve_space_allocator,
1562 &workspace_allocator, /*output_profile_result=*/nullptr);
1563 }
1564 OP_REQUIRES_OK(context, launch_status);
1565 }
1566
1567 protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * best_algo_config)1568 virtual Status MaybeAutotune(OpKernelContext* context,
1569 const CudnnRnnModelShapes& model_shapes,
1570 const RnnInputMode& input_mode,
1571 const Tensor* input, const Tensor* input_h,
1572 const Tensor* input_c, const Tensor* params,
1573 Tensor* output, Tensor* output_h,
1574 Tensor* output_c,
1575 AlgorithmConfig* best_algo_config) {
1576 CHECK_NE(best_algo_config, nullptr);
1577 *best_algo_config = AlgorithmConfig();
1578 return OkStatus();
1579 }
1580
is_training() const1581 bool is_training() const { return is_training_; }
1582 bool is_debug_mode_;
1583 bool debug_use_tensor_ops_;
1584 int64_t debug_cudnn_rnn_algo_;
1585
1586 private:
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,Tensor ** output,Tensor ** output_h,Tensor ** output_c)1587 Status AllocateOutputs(OpKernelContext* context,
1588 const CudnnRnnModelShapes& model_shapes,
1589 Tensor** output, Tensor** output_h,
1590 Tensor** output_c) {
1591 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
1592 const TensorShape& output_shape = model_shapes.output_shape;
1593 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
1594
1595 TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
1596 TF_RETURN_IF_ERROR(
1597 context->allocate_output(1, hidden_state_shape, output_h));
1598 if (HasInputC()) {
1599 TF_RETURN_IF_ERROR(
1600 context->allocate_output(2, cell_state_shape, output_c));
1601 } else {
1602 // Only LSTM uses input_c and output_c. So for all other models, we only
1603 // need to create dummy outputs.
1604 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, output_c));
1605 }
1606 if (!is_training_) {
1607 Tensor* dummy_reserve_space = nullptr;
1608 TF_RETURN_IF_ERROR(context->allocate_output(3, {}, &dummy_reserve_space));
1609 }
1610 return OkStatus();
1611 }
1612
1613 mutex mu_;
1614 bool is_training_;
1615 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1616 };
1617
1618 #define REGISTER_GPU(T) \
1619 REGISTER_KERNEL_BUILDER( \
1620 Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
1621 CudnnRNNForwardOp<GPUDevice, T>);
1622
1623 TF_CALL_half(REGISTER_GPU);
1624 TF_CALL_float(REGISTER_GPU);
1625 TF_CALL_double(REGISTER_GPU);
1626 #undef REGISTER_GPU
1627
1628 template <typename T>
1629 class CudnnRNNForwardOpV2<GPUDevice, T>
1630 : public CudnnRNNForwardOp<GPUDevice, T> {
1631 private:
1632 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1633 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1634 using CudnnRNNKernelCommon::dropout;
1635 using CudnnRNNKernelCommon::HasInputC;
1636 using CudnnRNNKernelCommon::model_types;
1637
1638 public:
CudnnRNNForwardOpV2(OpKernelConstruction * context)1639 explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
1640 : CudnnRNNForwardOp<GPUDevice, T>(context) {}
1641
Compute(OpKernelContext * context)1642 void Compute(OpKernelContext* context) override {
1643 AlgorithmConfig best_algo_config;
1644 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1645 context, &best_algo_config, /*var_seq_lengths=*/false,
1646 /*time_major=*/true, /*num_proj=*/0);
1647 if (!context->status().ok()) {
1648 return;
1649 }
1650
1651 Tensor* output_host_reserved = nullptr;
1652 // output_host_reserved stores opaque info used for backprop when running
1653 // in training mode. At present, it includes a serialization of the best
1654 // AlgorithmDesc picked during rnn forward pass autotune.
1655 // int8 algorithm_id
1656 // int8 use_tensor_op
1657 // If autotune is not enabled, the algorithm_id is
1658 // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
1659 // running in inference mode, the output_host_reserved is currently not
1660 // populated.
1661 if (is_training()) {
1662 OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
1663 &output_host_reserved));
1664 auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
1665 output_host_reserved_int8(0) = best_algo_config.algorithm()->algo_id();
1666 output_host_reserved_int8(1) =
1667 best_algo_config.algorithm()->tensor_ops_enabled();
1668 } else {
1669 OP_REQUIRES_OK(context,
1670 context->allocate_output(4, {}, &output_host_reserved));
1671 }
1672 }
1673
1674 protected:
MaybeAutotune(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const RnnInputMode & input_mode,const Tensor * input,const Tensor * input_h,const Tensor * input_c,const Tensor * params,Tensor * output,Tensor * output_h,Tensor * output_c,AlgorithmConfig * algo_config)1675 Status MaybeAutotune(OpKernelContext* context,
1676 const CudnnRnnModelShapes& model_shapes,
1677 const RnnInputMode& input_mode, const Tensor* input,
1678 const Tensor* input_h, const Tensor* input_c,
1679 const Tensor* params, Tensor* output, Tensor* output_h,
1680 Tensor* output_c,
1681 AlgorithmConfig* algo_config) override {
1682 CHECK_NE(algo_config, nullptr);
1683 if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
1684 *algo_config = AlgorithmConfig();
1685 return OkStatus();
1686 }
1687
1688 std::vector<AlgorithmDesc> algorithms;
1689 auto* stream = context->op_device_context()->stream();
1690 CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
1691 if (algorithms.empty()) {
1692 LOG(WARNING) << "No Rnn algorithm found";
1693 return OkStatus();
1694 }
1695
1696 const auto& modeltypes = model_types();
1697 CudnnRnnParameters rnn_params(
1698 model_shapes.num_layers, model_shapes.input_size,
1699 model_shapes.num_units, model_shapes.max_seq_length,
1700 model_shapes.batch_size, model_shapes.dir_count,
1701 /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
1702 modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
1703
1704 if (AutotuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
1705 VLOG(1) << "Using existing best Cudnn RNN algorithm "
1706 << "(algo, tensor_op_enabled) = ("
1707 << algo_config->algorithm()->algo_id() << ", "
1708 << algo_config->algorithm()->tensor_ops_enabled() << ").";
1709 return OkStatus();
1710 }
1711 profiler::ScopedAnnotation trace("cudnn_autotuning");
1712
1713 // Create temp tensors when profiling backprop pass.
1714 auto data_type = input->dtype();
1715 Tensor output_backprop;
1716 Tensor output_h_backprop;
1717 Tensor output_c_backprop;
1718 Tensor input_backprop;
1719 Tensor input_h_backprop;
1720 Tensor input_c_backprop;
1721 Tensor params_backprop;
1722 if (is_training()) {
1723 TF_RETURN_IF_ERROR(context->allocate_temp(
1724 data_type, model_shapes.output_shape, &output_backprop));
1725 TF_RETURN_IF_ERROR(context->allocate_temp(
1726 data_type, model_shapes.hidden_state_shape, &output_h_backprop));
1727
1728 TF_RETURN_IF_ERROR(
1729 context->allocate_temp(data_type, params->shape(), ¶ms_backprop));
1730 TF_RETURN_IF_ERROR(context->allocate_temp(
1731 data_type, model_shapes.input_shape, &input_backprop));
1732 TF_RETURN_IF_ERROR(context->allocate_temp(
1733 data_type, model_shapes.hidden_state_shape, &input_h_backprop));
1734 if (HasInputC()) {
1735 TF_RETURN_IF_ERROR(context->allocate_temp(
1736 data_type, model_shapes.hidden_state_shape, &output_c_backprop));
1737 TF_RETURN_IF_ERROR(context->allocate_temp(
1738 data_type, model_shapes.hidden_state_shape, &input_c_backprop));
1739 }
1740 }
1741 ProfileResult best_result;
1742 for (auto& algo : algorithms) {
1743 VLOG(1) << "Profile Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1744 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ").";
1745 Status status;
1746 ProfileResult final_profile_result;
1747
1748 ProfileResult fwd_profile_result;
1749 ProfileResult bak_profile_result;
1750
1751 // RnnDescriptor is algorithm-dependent, thus not reusable.
1752 std::unique_ptr<RnnDescriptor> rnn_desc;
1753 // Use a temp scratch allocator for the random num generator.
1754 CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
1755 if (!this->template CreateRnnDescriptor<T>(
1756 context, model_shapes, input_mode, AlgorithmConfig(algo),
1757 &dropout_state_allocator, &rnn_desc,
1758 /*use_padded_io=*/false)
1759 .ok()) {
1760 continue;
1761 }
1762
1763 // Again use temp scratch allocator during profiling.
1764 CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
1765 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1766 status = DoForward<T>(context, *rnn_desc, model_types(), model_shapes,
1767 input, input_h, input_c, params, is_training(),
1768 output, output_h, output_c, nullptr, true,
1769 &reserve_space_allocator, &workspace_allocator,
1770 &fwd_profile_result);
1771 if (!status.ok()) {
1772 continue;
1773 }
1774
1775 if (is_training()) {
1776 // Get reserve space from the forward pass.
1777 Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
1778 status = DoBackward<T>(
1779 context, *rnn_desc, model_types(), model_shapes, input, input_h,
1780 input_c, params, output, output_h, output_c, &output_backprop,
1781 &output_h_backprop, &output_c_backprop, &reserve_space,
1782 &input_backprop, &input_h_backprop, &input_c_backprop,
1783 ¶ms_backprop, nullptr, true, &workspace_allocator,
1784 &bak_profile_result);
1785 if (!status.ok()) {
1786 continue;
1787 }
1788 final_profile_result.set_elapsed_time_in_ms(
1789 fwd_profile_result.elapsed_time_in_ms() +
1790 bak_profile_result.elapsed_time_in_ms());
1791 } else {
1792 final_profile_result = fwd_profile_result;
1793 }
1794
1795 auto total_time = final_profile_result.elapsed_time_in_ms();
1796 VLOG(1) << "Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1797 << algo.algo_id() << ", " << algo.tensor_ops_enabled() << ")"
1798 << " run time: " << total_time << " ms.";
1799 if (total_time < best_result.elapsed_time_in_ms()) {
1800 best_result.set_elapsed_time_in_ms(total_time);
1801 best_result.set_algorithm(algo);
1802 }
1803 }
1804
1805 if (!best_result.is_valid()) {
1806 return Status(error::Code::INTERNAL, "No algorithm worked!");
1807 }
1808 algo_config->set_algorithm(best_result.algorithm());
1809 VLOG(1) << "Best Cudnn RNN algorithm (algo, tensor_op_enabled) = ("
1810 << best_result.algorithm().algo_id() << ", "
1811 << best_result.algorithm().tensor_ops_enabled() << ").";
1812 AutotuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
1813 return OkStatus();
1814 }
1815 };
1816
1817 #define REGISTER_GPU(T) \
1818 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
1819 .Device(DEVICE_GPU) \
1820 .HostMemory("host_reserved") \
1821 .TypeConstraint<T>("T"), \
1822 CudnnRNNForwardOpV2<GPUDevice, T>);
1823
1824 TF_CALL_half(REGISTER_GPU);
1825 TF_CALL_float(REGISTER_GPU);
1826 TF_CALL_double(REGISTER_GPU);
1827 #undef REGISTER_GPU
1828
1829 template <typename T>
1830 class CudnnRNNForwardOpV3<GPUDevice, T>
1831 : public CudnnRNNForwardOp<GPUDevice, T> {
1832 private:
1833 using CudnnRNNForwardOp<GPUDevice, T>::is_training;
1834 using CudnnRNNKernelCommon::CreateRnnDescriptor;
1835 using CudnnRNNKernelCommon::dropout;
1836 using CudnnRNNKernelCommon::HasInputC;
1837 using CudnnRNNKernelCommon::model_types;
1838 bool time_major_;
1839
1840 protected:
time_major()1841 bool time_major() { return time_major_; }
1842
1843 public:
CudnnRNNForwardOpV3(OpKernelConstruction * context)1844 explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
1845 : CudnnRNNForwardOp<GPUDevice, T>(context) {
1846 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
1847 if (context->HasAttr("num_proj")) {
1848 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
1849 } else {
1850 num_proj_ = 0;
1851 }
1852 }
1853
Compute(OpKernelContext * context)1854 void Compute(OpKernelContext* context) override {
1855 AlgorithmConfig best_algo_config;
1856 CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
1857 context, &best_algo_config, /*var_seq_lengths=*/true,
1858 /*time_major=*/time_major(), num_proj_);
1859 if (!context->status().ok()) {
1860 return;
1861 }
1862
1863 Tensor* output_host_reserved = nullptr;
1864 // TODO: Current V3 only uses the default standard algorithm to process
1865 // batches with variable sequences and the inputs should be padded.
1866 // Autotune is not supported yet.
1867 OP_REQUIRES_OK(context,
1868 context->allocate_output(4, {}, &output_host_reserved));
1869 }
1870
1871 private:
1872 int num_proj_;
1873 };
1874
1875 #define REGISTER_GPU(T) \
1876 REGISTER_KERNEL_BUILDER(Name("CudnnRNNV3") \
1877 .Device(DEVICE_GPU) \
1878 .HostMemory("sequence_lengths") \
1879 .HostMemory("host_reserved") \
1880 .TypeConstraint<T>("T"), \
1881 CudnnRNNForwardOpV3<GPUDevice, T>);
1882
1883 TF_CALL_half(REGISTER_GPU);
1884 TF_CALL_float(REGISTER_GPU);
1885 TF_CALL_double(REGISTER_GPU);
1886 #undef REGISTER_GPU
1887
1888 // Run the backward operation of the RNN model.
1889 template <typename T>
1890 class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
1891 public:
CudnnRNNBackwardOp(OpKernelConstruction * context)1892 explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
1893 : CudnnRNNKernelCommon(context) {}
1894
Compute(OpKernelContext * context)1895 void Compute(OpKernelContext* context) override {
1896 ComputeImpl(context, false, true, 0);
1897 }
1898
1899 protected:
ComputeImpl(OpKernelContext * context,bool var_seq_lengths,bool time_major,int num_proj)1900 virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
1901 bool time_major, int num_proj) {
1902 const Tensor* input = nullptr;
1903 const Tensor* input_h = nullptr;
1904 const Tensor* input_c = nullptr;
1905 const Tensor* params = nullptr;
1906 const Tensor* sequence_lengths = nullptr;
1907 CudnnRnnModelShapes model_shapes;
1908 bool use_padded_io = false;
1909 if (var_seq_lengths) {
1910 OP_REQUIRES_OK(context, ExtractForwardInput(
1911 context, model_types(), time_major, &input,
1912 &input_h, &input_c, ¶ms,
1913 &sequence_lengths, num_proj, &model_shapes));
1914 use_padded_io =
1915 ShouldUsePaddedIO(sequence_lengths, model_shapes, time_major);
1916 } else {
1917 OP_REQUIRES_OK(context,
1918 ExtractForwardInput(context, model_types(), time_major,
1919 &input, &input_h, &input_c, ¶ms,
1920 num_proj, &model_shapes));
1921 }
1922 RnnInputMode input_mode;
1923 OP_REQUIRES_OK(context,
1924 ToRNNInputMode(rnn_input_mode(), model_shapes.num_units,
1925 model_shapes.input_size, &input_mode));
1926
1927 const Tensor* output = nullptr;
1928 const Tensor* output_h = nullptr;
1929 const Tensor* output_c = nullptr;
1930 const Tensor* output_backprop = nullptr;
1931 const Tensor* output_h_backprop = nullptr;
1932 const Tensor* output_c_backprop = nullptr;
1933 const Tensor* reserve_space = nullptr;
1934 OP_REQUIRES_OK(context,
1935 ExtractBackwardInputs(context, model_shapes, model_types(),
1936 &output, &output_h, &output_c,
1937 &output_backprop, &output_h_backprop,
1938 &output_c_backprop, &reserve_space));
1939
1940 Tensor* input_backprop = nullptr;
1941 Tensor* input_h_backprop = nullptr;
1942 Tensor* input_c_backprop = nullptr;
1943 Tensor* params_backprop = nullptr;
1944 OP_REQUIRES_OK(context,
1945 AllocateOutputs(context, model_shapes, params->shape(),
1946 &input_backprop, &input_h_backprop,
1947 &input_c_backprop, ¶ms_backprop));
1948
1949 // Creates a memory callback for the workspace. The memory lives to the end
1950 // of this kernel calls.
1951 CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
1952 AlgorithmConfig algo_config;
1953 OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
1954 Status launch_status;
1955 {
1956 mutex_lock l(mu_);
1957 RnnDescriptor* rnn_desc_ptr = nullptr;
1958 OP_REQUIRES_OK(
1959 context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
1960 algo_config, &rnn_state_cache_,
1961 &rnn_desc_ptr, use_padded_io));
1962 launch_status = DoBackward<T>(
1963 context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
1964 input_c, params, output, output_h, output_c, output_backprop,
1965 output_h_backprop, output_c_backprop, reserve_space, input_backprop,
1966 input_h_backprop, input_c_backprop, params_backprop, sequence_lengths,
1967 time_major, &workspace_allocator,
1968 /*output_profile_result=*/nullptr);
1969 }
1970 OP_REQUIRES_OK(context, launch_status);
1971 }
1972
1973 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)1974 virtual Status GetAlgorithm(OpKernelContext* context,
1975 AlgorithmConfig* algo_config) {
1976 CHECK_NE(algo_config, nullptr);
1977 *algo_config = AlgorithmConfig();
1978 return OkStatus();
1979 }
1980
1981 private:
1982 mutex mu_;
1983 RnnStateCache rnn_state_cache_ TF_GUARDED_BY(mu_);
1984
ExtractBackwardInputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const CudnnModelTypes & model_types,const Tensor ** output,const Tensor ** output_h,const Tensor ** output_c,const Tensor ** output_backprop,const Tensor ** output_h_backprop,const Tensor ** output_c_backprop,const Tensor ** reserve_space)1985 Status ExtractBackwardInputs(
1986 OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
1987 const CudnnModelTypes& model_types, const Tensor** output,
1988 const Tensor** output_h, const Tensor** output_c,
1989 const Tensor** output_backprop, const Tensor** output_h_backprop,
1990 const Tensor** output_c_backprop, const Tensor** reserve_space) {
1991 TF_RETURN_IF_ERROR(context->input("output", output));
1992 TF_RETURN_IF_ERROR(context->input("output_backprop", output_backprop));
1993 TF_RETURN_IF_ERROR(context->input("output_h", output_h));
1994 TF_RETURN_IF_ERROR(context->input("output_h_backprop", output_h_backprop));
1995 if (model_types.HasInputC()) {
1996 TF_RETURN_IF_ERROR(context->input("output_c", output_c));
1997 TF_RETURN_IF_ERROR(
1998 context->input("output_c_backprop", output_c_backprop));
1999 }
2000 TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
2001 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2002 const TensorShape& output_shape = model_shapes.output_shape;
2003 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2004
2005 if (output_shape != (*output)->shape()) {
2006 return errors::InvalidArgument(
2007 "Invalid output shape: ", (*output)->shape().DebugString(), " ",
2008 output_shape.DebugString());
2009 }
2010 if (hidden_state_shape != (*output_h)->shape()) {
2011 return errors::InvalidArgument(
2012 "Invalid output_h shape: ", (*output_h)->shape().DebugString(), " ",
2013 hidden_state_shape.DebugString());
2014 }
2015
2016 if (output_shape != (*output_backprop)->shape()) {
2017 return errors::InvalidArgument("Invalid output_backprop shape: ",
2018 (*output_backprop)->shape().DebugString(),
2019 " ", output_shape.DebugString());
2020 }
2021 if (hidden_state_shape != (*output_h_backprop)->shape()) {
2022 return errors::InvalidArgument(
2023 "Invalid output_h_backprop shape: ",
2024 (*output_h_backprop)->shape().DebugString(), " ",
2025 hidden_state_shape.DebugString());
2026 }
2027
2028 if (model_types.HasInputC()) {
2029 if (cell_state_shape != (*output_c)->shape()) {
2030 return errors::InvalidArgument(
2031 "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
2032 cell_state_shape.DebugString());
2033 }
2034 if (cell_state_shape != (*output_c_backprop)->shape()) {
2035 return errors::InvalidArgument(
2036 "Invalid output_c_backprop shape: ",
2037 (*output_c_backprop)->shape().DebugString(), " ",
2038 cell_state_shape.DebugString());
2039 }
2040 }
2041 return OkStatus();
2042 }
2043
AllocateOutputs(OpKernelContext * context,const CudnnRnnModelShapes & model_shapes,const TensorShape & params_shape,Tensor ** input_backprop,Tensor ** input_h_backprop,Tensor ** input_c_backprop,Tensor ** params_backprop)2044 Status AllocateOutputs(OpKernelContext* context,
2045 const CudnnRnnModelShapes& model_shapes,
2046 const TensorShape& params_shape,
2047 Tensor** input_backprop, Tensor** input_h_backprop,
2048 Tensor** input_c_backprop, Tensor** params_backprop) {
2049 const TensorShape& input_shape = model_shapes.input_shape;
2050 const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
2051 const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
2052
2053 TF_RETURN_IF_ERROR(
2054 context->allocate_output(0, input_shape, input_backprop));
2055 TF_RETURN_IF_ERROR(
2056 context->allocate_output(1, hidden_state_shape, input_h_backprop));
2057 if (HasInputC()) {
2058 TF_RETURN_IF_ERROR(
2059 context->allocate_output(2, cell_state_shape, input_c_backprop));
2060 } else {
2061 // Only LSTM uses input_c and output_c. So for all other models, we only
2062 // need to create dummy outputs.
2063 TF_RETURN_IF_ERROR(context->allocate_output(2, {}, input_c_backprop));
2064 }
2065 TF_RETURN_IF_ERROR(
2066 context->allocate_output(3, params_shape, params_backprop));
2067 return OkStatus();
2068 }
2069 };
2070
2071 #define REGISTER_GPU(T) \
2072 REGISTER_KERNEL_BUILDER( \
2073 Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
2074 CudnnRNNBackwardOp<GPUDevice, T>);
2075
2076 TF_CALL_half(REGISTER_GPU);
2077 TF_CALL_float(REGISTER_GPU);
2078 TF_CALL_double(REGISTER_GPU);
2079 #undef REGISTER_GPU
2080
2081 template <typename T>
2082 class CudnnRNNBackwardOpV2<GPUDevice, T>
2083 : public CudnnRNNBackwardOp<GPUDevice, T> {
2084 public:
CudnnRNNBackwardOpV2(OpKernelConstruction * context)2085 explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
2086 : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
2087
2088 protected:
GetAlgorithm(OpKernelContext * context,AlgorithmConfig * algo_config)2089 Status GetAlgorithm(OpKernelContext* context,
2090 AlgorithmConfig* algo_config) override {
2091 CHECK_NE(algo_config, nullptr);
2092 const Tensor* host_reserved = nullptr;
2093 TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
2094
2095 auto host_reserved_int8 = host_reserved->vec<int8>();
2096 const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1),
2097 absl::nullopt);
2098 algo_config->set_algorithm(algo_desc);
2099 return OkStatus();
2100 }
2101 };
2102
2103 #define REGISTER_GPU(T) \
2104 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
2105 .Device(DEVICE_GPU) \
2106 .HostMemory("host_reserved") \
2107 .TypeConstraint<T>("T"), \
2108 CudnnRNNBackwardOpV2<GPUDevice, T>);
2109
2110 TF_CALL_half(REGISTER_GPU);
2111 TF_CALL_float(REGISTER_GPU);
2112 TF_CALL_double(REGISTER_GPU);
2113 #undef REGISTER_GPU
2114
2115 template <typename T>
2116 class CudnnRNNBackwardOpV3<GPUDevice, T>
2117 : public CudnnRNNBackwardOp<GPUDevice, T> {
2118 private:
2119 bool time_major_;
2120
2121 protected:
time_major()2122 bool time_major() { return time_major_; }
2123
2124 public:
CudnnRNNBackwardOpV3(OpKernelConstruction * context)2125 explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
2126 : CudnnRNNBackwardOp<GPUDevice, T>(context) {
2127 OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
2128 if (context->HasAttr("num_proj")) {
2129 OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
2130 } else {
2131 num_proj_ = 0;
2132 }
2133 }
2134
Compute(OpKernelContext * context)2135 void Compute(OpKernelContext* context) override {
2136 CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
2137 num_proj_);
2138 }
2139
2140 private:
2141 int num_proj_;
2142 };
2143
2144 #define REGISTER_GPU(T) \
2145 REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV3") \
2146 .Device(DEVICE_GPU) \
2147 .HostMemory("sequence_lengths") \
2148 .HostMemory("host_reserved") \
2149 .TypeConstraint<T>("T"), \
2150 CudnnRNNBackwardOpV3<GPUDevice, T>);
2151
2152 TF_CALL_half(REGISTER_GPU);
2153 TF_CALL_float(REGISTER_GPU);
2154 TF_CALL_double(REGISTER_GPU);
2155 #undef REGISTER_GPU
2156
2157 // TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
2158 // its canonical form.
2159
2160 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
2161
2162 } // namespace tensorflow
2163