xref: /aosp_15_r20/external/tensorflow/tensorflow/stream_executor/rocm/rocm_dnn.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "absl/types/span.h"
26 #include "third_party/eigen3/Eigen/Core"
27 #include "rocm/include/miopen/miopen.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/util/env_var.h"
30 #include "tensorflow/stream_executor/dnn.h"
31 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
32 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
33 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
34 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
35 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
36 #include "tensorflow/stream_executor/lib/env.h"
37 #include "tensorflow/stream_executor/lib/error.h"
38 #include "tensorflow/stream_executor/lib/initialize.h"
39 #include "tensorflow/stream_executor/lib/threadpool.h"
40 #include "tensorflow/stream_executor/platform/dso_loader.h"
41 #include "tensorflow/stream_executor/platform/logging.h"
42 #include "tensorflow/stream_executor/plugin_registry.h"
43 #include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
44 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
45 #include "tensorflow/stream_executor/scratch_allocator.h"
46 #include "tensorflow/stream_executor/stream.h"
47 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
48 
49 namespace {
50 
51 // Converts (via narrowing) a type T value to a type U, and checks that the
52 // value has no value change due to the conversion.
53 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)54 NarrowT CheckedNarrowing(const WideT& wide) {
55   NarrowT narrow = wide;
56   CHECK_EQ(narrow, wide)
57       << "checked narrowing failed; values not equal post-conversion";
58   return narrow;
59 }
60 
61 const int kConvDebugVlogLevel = 3;
62 
63 }  // namespace
64 
65 namespace stream_executor {
66 
67 using dnn::AlgorithmDesc;
68 using dnn::BatchDescriptor;
69 using dnn::ConvolutionDescriptor;
70 using dnn::FilterDescriptor;
71 using dnn::NormalizeDescriptor;
72 using dnn::PoolingDescriptor;
73 
74 namespace gpu {
75 
76 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kMIOpenPlugin);
77 
ToString(miopenStatus_t status)78 string ToString(miopenStatus_t status) {
79   switch (status) {
80     case miopenStatusSuccess:
81       return "miopenStatusSuccess";
82     case miopenStatusNotInitialized:
83       return "miopenStatusNotInitialized";
84     case miopenStatusAllocFailed:
85       return "miopenStatusAllocFailed";
86     case miopenStatusBadParm:
87       return "miopenStatusBadParm";
88     case miopenStatusInternalError:
89       return "miopenStatusInternalError";
90     case miopenStatusInvalidValue:
91       return "miopenStatusInvalidValue";
92     case miopenStatusNotImplemented:
93       return "miopenStatusNotImplemented";
94     case miopenStatusUnknownError:
95       return "miopenStatusUnknownError";
96     default:
97       return absl::StrCat("<unknown miopen status: ", static_cast<int>(status),
98                           ">");
99   }
100 }
101 
ToString(miopenConvFwdAlgorithm_t algorithm)102 string ToString(miopenConvFwdAlgorithm_t algorithm) {
103   string s;
104   switch (algorithm) {
105     case miopenConvolutionFwdAlgoGEMM:
106       s = "GEMM";
107       break;
108     case miopenConvolutionFwdAlgoDirect:
109       s = "Direct";
110       break;
111     case miopenConvolutionFwdAlgoFFT:
112       s = "FFT";
113       break;
114     case miopenConvolutionFwdAlgoWinograd:
115       s = "Winograd";
116       break;
117     case miopenConvolutionFwdAlgoImplicitGEMM:
118       s = "Implicit GEMM";
119       break;
120   }
121   return s;
122 }
123 
ToString(miopenConvBwdWeightsAlgorithm_t algorithm)124 string ToString(miopenConvBwdWeightsAlgorithm_t algorithm) {
125   string s;
126   switch (algorithm) {
127     case miopenConvolutionBwdWeightsAlgoGEMM:
128       s = "GEMM";
129       break;
130     case miopenConvolutionBwdWeightsAlgoDirect:
131       s = "Direct";
132       break;
133     case miopenConvolutionBwdWeightsAlgoWinograd:
134       s = "Winograd";
135       break;
136     case miopenConvolutionBwdWeightsAlgoImplicitGEMM:
137       s = "Implicit GEMM";
138       break;
139   }
140   return s;
141 }
142 
ToString(miopenConvBwdDataAlgorithm_t algorithm)143 string ToString(miopenConvBwdDataAlgorithm_t algorithm) {
144   string s;
145   switch (algorithm) {
146     case miopenConvolutionBwdDataAlgoGEMM:
147       s = "GEMM";
148       break;
149     case miopenConvolutionBwdDataAlgoDirect:
150       s = "Direct";
151       break;
152     case miopenConvolutionBwdDataAlgoFFT:
153       s = "FFT";
154       break;
155     case miopenConvolutionBwdDataAlgoWinograd:
156       s = "Winograd";
157       break;
158     case miopenTransposeBwdDataAlgoGEMM:
159       s = "Transpose GEMM";
160       break;
161     case miopenConvolutionBwdDataAlgoImplicitGEMM:
162       s = "Implicit GEMM";
163       break;
164   }
165   return s;
166 }
167 
ToString(miopenConvAlgorithm_t algorithm)168 string ToString(miopenConvAlgorithm_t algorithm) {
169   string s;
170   switch (algorithm) {
171     case miopenConvolutionAlgoGEMM:
172       s = "GEMM";
173       break;
174     case miopenConvolutionAlgoDirect:
175       s = "Direct";
176       break;
177     case miopenConvolutionAlgoFFT:
178       s = "FFT";
179       break;
180     case miopenConvolutionAlgoWinograd:
181       s = "Winograd";
182       break;
183     case miopenConvolutionAlgoImplicitGEMM:
184       s = "Implicit GEMM";
185       break;
186   }
187   return s;
188 }
189 
190 // RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
191 //
192 // See MIOpenAccess::GetHandle() for details.
193 class MIOpenHandle {
194  public:
195   // Takes ownership of the executor context and the lock to access MIOpen
196   // using handle.
MIOpenHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,miopenHandle_t handle)197   MIOpenHandle(gpu::ScopedActivateExecutorContext context,
198                std::unique_ptr<absl::MutexLock> lock, miopenHandle_t handle)
199       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
200 
201   // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
202   // a copy.
handle() const203   miopenHandle_t handle() const { return handle_; }
204 
205  private:
206   gpu::ScopedActivateExecutorContext context_;
207   std::unique_ptr<absl::MutexLock> lock_;
208   miopenHandle_t handle_;  // Not owned.
209 };
210 
211 namespace wrap {
212 
213 #ifdef PLATFORM_GOOGLE
214 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)      \
215   struct WrapperShim__##__name {                 \
216     template <typename... Args>                  \
217     miopenStatus_t operator()(Args... args) {    \
218       miopenStatus_t retval = ::__name(args...); \
219       return retval;                             \
220     }                                            \
221   } __name;
222 
223 #else
224 
225 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)                               \
226   struct DynLoadShim__##__name {                                          \
227     static const char* kName;                                             \
228     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
229     static void* GetDsoHandle() {                                         \
230       auto s = internal::CachedDsoLoader::GetMiopenDsoHandle();           \
231       return s.ValueOrDie();                                              \
232     }                                                                     \
233     static FuncPtrT LoadOrDie() {                                         \
234       void* f;                                                            \
235       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
236                                                           kName, &f);     \
237       CHECK(s.ok()) << "could not find " << kName                         \
238                     << " in miopen DSO; dlerror: " << s.error_message();  \
239       return reinterpret_cast<FuncPtrT>(f);                               \
240     }                                                                     \
241     static FuncPtrT DynLoad() {                                           \
242       static FuncPtrT f = LoadOrDie();                                    \
243       return f;                                                           \
244     }                                                                     \
245     template <typename... Args>                                           \
246     miopenStatus_t operator()(Args... args) {                             \
247       return DynLoad()(args...);                                          \
248     }                                                                     \
249   } __name;                                                               \
250   const char* DynLoadShim__##__name::kName = #__name;
251 
252 #endif
253 
254 // clang-format off
255 #define MIOPEN_DNN_ROUTINE_EACH(__macro)                             \
256   __macro(miopenBatchNormalizationBackward)                          \
257   __macro(miopenBatchNormalizationForwardInference)                  \
258   __macro(miopenBatchNormalizationForwardTraining)                   \
259   __macro(miopenGetConvolutionForwardOutputDim)                      \
260   __macro(miopenGetConvolutionNdForwardOutputDim)                    \
261   __macro(miopenFindConvolutionForwardAlgorithm)                     \
262   __macro(miopenCreateTensorDescriptor)                              \
263   __macro(miopenDestroyTensorDescriptor)                             \
264   __macro(miopenSetNdPoolingDescriptor)                              \
265   __macro(miopenSetPoolingIndexType)                                 \
266   __macro(miopenSetLRNDescriptor)                                    \
267   __macro(miopenLRNGetWorkSpaceSize)                                 \
268   __macro(miopenCreateConvolutionDescriptor)                         \
269   __macro(miopenCreatePoolingDescriptor)                             \
270   __macro(miopenDestroyPoolingDescriptor)                            \
271   __macro(miopenCreateLRNDescriptor)                                 \
272   __macro(miopenDestroyLRNDescriptor)                                \
273   __macro(miopenDestroyConvolutionDescriptor)                        \
274   __macro(miopenCreateWithStream)                                    \
275   __macro(miopenDestroy)                                             \
276   __macro(miopenSetStream)                                           \
277   __macro(miopenSetAllocator)                                        \
278   __macro(miopenActivationForward)                                   \
279   __macro(miopenConvolutionForward)                                  \
280   __macro(miopenConvolutionBackwardBias)                             \
281   __macro(miopenConvolutionForwardGetWorkSpaceSize)                  \
282   __macro(miopenInitConvolutionDescriptor)                           \
283   __macro(miopenInitConvolutionNdDescriptor)                         \
284   __macro(miopenGetConvolutionDescriptor)                            \
285   __macro(miopenGetConvolutionNdDescriptor)                          \
286   __macro(miopenSetConvolutionGroupCount)                            \
287   __macro(miopenSet4dTensorDescriptor)                               \
288   __macro(miopenGetTensorDescriptor)                                 \
289   __macro(miopenSetTensorDescriptor)                                 \
290   __macro(miopenGetTensorDescriptorSize)                             \
291   __macro(miopenPoolingForward)                                      \
292   __macro(miopenPoolingGetWorkSpaceSizeV2)                           \
293   __macro(miopenPoolingBackward)                                     \
294   __macro(miopenLRNForward)                                          \
295   __macro(miopenLRNBackward)                                         \
296   __macro(miopenOpTensor)                                            \
297   __macro(miopenConvolutionBackwardData)                             \
298   __macro(miopenConvolutionBackwardWeights)                          \
299   __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)          \
300   __macro(miopenFindConvolutionBackwardDataAlgorithm)                \
301   __macro(miopenFindConvolutionBackwardWeightsAlgorithm)             \
302   __macro(miopenConvolutionBackwardDataGetWorkSpaceSize)             \
303   __macro(miopenCreateRNNDescriptor)                                 \
304   __macro(miopenSetRNNDescriptor)                                    \
305   __macro(miopenDestroyRNNDescriptor)                                \
306   __macro(miopenGetRNNParamsSize)                                    \
307   __macro(miopenGetRNNLayerParam)                                    \
308   __macro(miopenGetRNNLayerBias)                                     \
309   __macro(miopenGetRNNWorkspaceSize)                                 \
310   __macro(miopenGetRNNTrainingReserveSize)                           \
311   __macro(miopenRNNForwardInference)                                 \
312   __macro(miopenRNNForwardTraining)                                  \
313   __macro(miopenRNNBackwardData)                                     \
314   __macro(miopenRNNBackwardWeights)                                  \
315   __macro(miopenGetRNNLayerParamOffset)                              \
316   __macro(miopenGetRNNLayerParamSize)                                \
317   __macro(miopenGetRNNLayerBiasOffset)                               \
318   __macro(miopenGetRNNLayerBiasSize)                                 \
319   __macro(miopenGetRNNParamsDescriptor)                              \
320   __macro(miopenCreateActivationDescriptor)                          \
321   __macro(miopenSetActivationDescriptor)                             \
322   __macro(miopenGetActivationDescriptor)                             \
323   __macro(miopenDestroyActivationDescriptor)                         \
324   __macro(miopenCreateFusionPlan)                                    \
325   __macro(miopenCreateOpConvForward)                                 \
326   __macro(miopenCreateOpBiasForward)                                 \
327   __macro(miopenCreateOpActivationForward)                           \
328   __macro(miopenCreateOpActivationBackward)                          \
329   __macro(miopenCreateOpBatchNormInference)                          \
330   __macro(miopenCreateOpBatchNormForward)                            \
331   __macro(miopenCreateOpBatchNormBackward)                           \
332   __macro(miopenCompileFusionPlan)                                   \
333   __macro(miopenFusionPlanGetOp)                                     \
334   __macro(miopenCreateOperatorArgs)                                  \
335   __macro(miopenSetOpArgsConvForward)                                \
336   __macro(miopenSetOpArgsBiasForward)                                \
337   __macro(miopenSetOpArgsActivForward)                               \
338   __macro(miopenSetOpArgsActivBackward)                              \
339   __macro(miopenSetOpArgsBatchNormInference)                         \
340   __macro(miopenSetOpArgsBatchNormForward)                           \
341   __macro(miopenSetOpArgsBatchNormBackward)                          \
342   __macro(miopenExecuteFusionPlan)                                   \
343   __macro(miopenDestroyOperatorArgs)                                 \
344   __macro(miopenDestroyFusionPlan)                                   \
345   __macro(miopenConvolutionForwardGetSolutionCount)                  \
346   __macro(miopenConvolutionForwardGetSolution)                       \
347   __macro(miopenConvolutionForwardGetSolutionWorkspaceSize)          \
348   __macro(miopenConvolutionForwardCompileSolution)                   \
349   __macro(miopenConvolutionForwardImmediate)                         \
350   __macro(miopenConvolutionBackwardDataGetSolutionCount)             \
351   __macro(miopenConvolutionBackwardDataGetSolution)                  \
352   __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize)     \
353   __macro(miopenConvolutionBackwardDataCompileSolution)              \
354   __macro(miopenConvolutionBackwardDataImmediate)                    \
355   __macro(miopenConvolutionBackwardWeightsGetSolutionCount)          \
356   __macro(miopenConvolutionBackwardWeightsGetSolution)               \
357   __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize)  \
358   __macro(miopenConvolutionBackwardWeightsCompileSolution)           \
359   __macro(miopenConvolutionBackwardWeightsImmediate)                 \
360   __macro(miopenCreateCTCLossDescriptor)                             \
361   __macro(miopenSetCTCLossDescriptor)                                \
362   __macro(miopenGetCTCLossWorkspaceSize)                             \
363   __macro(miopenCTCLoss)                                             \
364   __macro(miopenDestroyCTCLossDescriptor)
365 // clang-format on
366 
367 MIOPEN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_MIOPEN_WRAP)
368 
369 #undef MIOPEN_DNN_ROUTINE_EACH
370 
371 }  // namespace wrap
372 
373 namespace {
374 
375 // These routines should ideally be provided as an MIOpen API.
376 // They are called for *every* _ROCMmFusedOp*::Compute call, and they need to be
377 // efficient! Instead of calculating the hash value by quering the MIOpen Get*
378 // APIs for the descriptor components, it would be a lot more efficient if,
379 // MIOpen calculated the hash value when creating the descriptor, stored it on
380 // the descriptor datastructure, and provided an API routine to query it.
381 
382 const int kMaxMIOpenTensorSize = 5;
383 
GetHashValue(miopenTensorDescriptor_t tensor_desc)384 uint64_t GetHashValue(miopenTensorDescriptor_t tensor_desc) {
385   miopenDataType_t datatype = miopenFloat;
386   int dims[kMaxMIOpenTensorSize] = {0};
387   int strides[kMaxMIOpenTensorSize] = {0};
388   wrap::miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides);
389 
390   uint64_t hash_value = tensorflow::hash<int>()(datatype);
391   for (int dim : dims)
392     hash_value =
393         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(dim));
394   for (int stride : strides)
395     hash_value =
396         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(stride));
397 
398   return hash_value;
399 }
400 
GetHashValue(miopenConvolutionDescriptor_t conv_desc)401 uint64_t GetHashValue(miopenConvolutionDescriptor_t conv_desc) {
402   miopenConvolutionMode_t c_mode = miopenConvolution;
403   int nd = 0;
404   wrap::miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr,
405                                          nullptr, &c_mode);
406 
407   std::vector<int> stride(nd);
408   std::vector<int> pad(nd);
409   std::vector<int> dilation(nd);
410 
411   wrap::miopenGetConvolutionNdDescriptor(
412       conv_desc, nd, &nd, pad.data(), stride.data(), dilation.data(), &c_mode);
413 
414   uint64_t hash_value = tensorflow::hash<int>()(c_mode);
415   auto hash64Combine = [&hash_value](int element) {
416     tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(element));
417   };
418   std::for_each(pad.begin(), pad.end(), hash64Combine);
419   std::for_each(stride.begin(), stride.end(), hash64Combine);
420   std::for_each(dilation.begin(), dilation.end(), hash64Combine);
421 
422   return hash_value;
423 }
424 
425 // Class to implement a cache of compiled fusion plans
426 class CachedFusionPlans {
427  public:
428   // Check if we already have a fusion_plan corresponding to the given hash
429   // value.
430   // If we do, then
431   //   return true (+ the cached fusion plan via given pointer)
432   // Else
433   //   create a new fusion plan descriptor,
434   //   associate it with the given hash value in the cache
435   //   return false (+ newly created fusion plan via given pointer)
FindOrCreate(uint64_t hash,miopenFusionPlanDescriptor_t * fusion_plan,miopenFusionDirection_t fusion_direction,miopenTensorDescriptor_t input_descriptor)436   static bool FindOrCreate(uint64_t hash,
437                            miopenFusionPlanDescriptor_t* fusion_plan,
438                            miopenFusionDirection_t fusion_direction,
439                            miopenTensorDescriptor_t input_descriptor) {
440     absl::MutexLock lock{&cached_plans_mutex};
441 
442     bool found_cached_plan = false;
443 
444     auto it = cached_plans.find(hash);
445     if (it != cached_plans.end()) {
446       *fusion_plan = it->second;
447       found_cached_plan = true;
448     } else {
449       auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction,
450                                                  input_descriptor);
451       if (status != miopenStatusSuccess) {
452         LOG(FATAL) << "call to miopenCreateFusionPlan failed: "
453                    << ToString(status);
454       } else {
455         cached_plans[hash] = *fusion_plan;
456       }
457     }
458 
459     return found_cached_plan;
460   }
461 
462   // Need to figure out the right place to call this routine
Clear()463   static void Clear() {
464     absl::MutexLock lock{&cached_plans_mutex};
465 
466     for (auto it : cached_plans) {
467       auto status = wrap::miopenDestroyFusionPlan(it.second);
468       if (status != miopenStatusSuccess) {
469         LOG(FATAL) << "call to miopenDestroyFusionPlan failed: "
470                    << ToString(status);
471       }
472     }
473 
474     cached_plans.clear();
475 
476     unsupported_plans.clear();
477   }
478 
479   // Is the Fusion plan corresponding to this hash unsupported
IsUnsupportedFusionPlan(uint64_t hash)480   static bool IsUnsupportedFusionPlan(uint64_t hash) {
481     absl::MutexLock lock{&cached_plans_mutex};
482     return unsupported_plans.count(hash) > 0;
483   }
484 
485   // Mark the given hash value as corresponding to an unsupported fusion plan
MarkFusionPlanUnsupported(uint64_t hash)486   static void MarkFusionPlanUnsupported(uint64_t hash) {
487     absl::MutexLock lock{&cached_plans_mutex};
488     unsupported_plans.insert(hash);
489   }
490 
491  private:
492   // Mutex to guard access to all data within this class
493   static absl::Mutex cached_plans_mutex;
494 
495   // Map of hash-value to MIOpen Fusion plan descriptors
496   // Need to be able share this across more than one stream and hence static
497   static std::map<uint64_t, miopenFusionPlanDescriptor_t> cached_plans;
498 
499   // Set of hash-values that correspond to MIOpen Fusion plans that will fail
500   // compile and hence are not supported.
501   static std::set<uint64_t> unsupported_plans;
502 };
503 
504 absl::Mutex CachedFusionPlans::cached_plans_mutex;
505 std::map<uint64_t, miopenFusionPlanDescriptor_t>
506     CachedFusionPlans::cached_plans;
507 std::set<uint64_t> CachedFusionPlans::unsupported_plans;
508 
GetProfileResultFromConvSolution(miopenConvSolution_t solution)509 dnn::ProfileResult GetProfileResultFromConvSolution(
510     miopenConvSolution_t solution) {
511   dnn::ProfileResult profile_result;
512   profile_result.set_algorithm(
513       {solution.solution_id, false, solution.workspace_size});
514   profile_result.set_elapsed_time_in_ms(solution.time);
515   profile_result.set_scratch_size(solution.workspace_size);
516   return profile_result;
517 }
518 
GetProfileResultFromConvAlgoPerf(dnn::ConvolutionKind kind,miopenConvAlgoPerf_t algorithm)519 dnn::ProfileResult GetProfileResultFromConvAlgoPerf(
520     dnn::ConvolutionKind kind, miopenConvAlgoPerf_t algorithm) {
521   int64_t algo_id;
522   switch (kind) {
523     case dnn::ConvolutionKind::FORWARD:
524       algo_id = algorithm.fwd_algo;
525       break;
526     case dnn::ConvolutionKind::BACKWARD_DATA:
527       algo_id = algorithm.bwd_data_algo;
528       break;
529     case dnn::ConvolutionKind::BACKWARD_FILTER:
530       algo_id = algorithm.bwd_weights_algo;
531       break;
532     default:
533       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
534       break;
535   }
536 
537   dnn::ProfileResult profile_result;
538   profile_result.set_algorithm({algo_id, false, algorithm.memory});
539   profile_result.set_elapsed_time_in_ms(algorithm.time);
540   profile_result.set_scratch_size(algorithm.memory);
541   return profile_result;
542 }
543 }  // namespace
544 
545 // Wraps a MIOpen handle and provides access to it through miopenHandle_t
546 // instances, which also locks a mutex, acquires the ROCm context, and sets
547 // the stream that MIOpen should use to enqueue any work.
548 //
549 // Note: MIOpenSupport::miopen_ should be the only instantiation of this class.
550 class MIOpenAccess {
551  public:
552   // Takes ownership of the handle.
MIOpenAccess(miopenHandle_t handle)553   explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
554 
~MIOpenAccess()555   ~MIOpenAccess() {
556     absl::MutexLock lock(&mutex_);
557     wrap::miopenDestroy(handle_);
558   }
559 
560   // Creates a MIOpenHandle instance for stream.
561   //
562   // MIOpen API calls using the same handle instance need to be serialized
563   // across threads. This is guaranteed by MIOpenHandle instances locking the
564   // mutex owned by this class.
565   //
566   // Most MIOpen APIs taking a handle perform work on a HIP stream. The
567   // MIOpenHandle instance acquires the executor's ROCm context and sets MIOpen
568   // to use the provided stream.
569   //
570   // The stream argument may be null, which translates to the null stream.
571   // The null stream synchronizes with all other streams and it is
572   // therefore a bad idea (performance wise) to call any MIOpen APIs that
573   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)574   MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
575     auto lock = std::make_unique<absl::MutexLock>(&mutex_);
576     mutex_.AssertHeld();
577     gpu::ScopedActivateExecutorContext context(executor);
578     hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
579     auto status = wrap::miopenSetStream(handle_, hip_stream);
580     CHECK_EQ(status, miopenStatusSuccess) << "Failed to set MIOpen stream.";
581     return MIOpenHandle(std::move(context), std::move(lock), handle_);
582   }
583 
584  private:
585   // Guards the enqueueing of MIOpen operations via the handle_ below.
586   absl::Mutex mutex_;
587 
588   // MIOpen library handle.
589   miopenHandle_t handle_ ABSL_GUARDED_BY(mutex_);  // Owned.
590 };
591 
MIOpenSupport(GpuExecutor * parent)592 MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
593   // by default, the Get*Algorithm API will return the list of all applicable
594   // algorithms
595   return_best_algo_only_ = false;
596   // but if the env var TF_ROCM_RETURN_BEST_ALGO_ONLY is set, only the best
597   // (i.e. most efficient) algorithm will be returned
598   tensorflow::ReadBoolFromEnvVar("TF_ROCM_RETURN_BEST_ALGO_ONLY", false,
599                                  &return_best_algo_only_);
600 
601   // by default, use Find Mode APIs for convolution
602   use_immediate_mode_ = false;
603   // swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
604   tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
605                                  &use_immediate_mode_);
606 
607   bool enable_pooling_cache = false;
608   tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
609                                  &enable_pooling_cache);
610   if (enable_pooling_cache) m_pooling_cache_allowed = true;
611 }
612 
Init()613 port::Status MIOpenSupport::Init() {
614   ScopedActivateExecutorContext context(parent_);
615   miopenHandle_t miopen_handle = nullptr;
616   auto status = wrap::miopenCreateWithStream(
617       reinterpret_cast<miopenHandle_t*>(&miopen_handle), (hipStream_t)(0));
618   if (status == miopenStatusSuccess) {
619     miopen_.reset(new MIOpenAccess(miopen_handle));
620     return port::Status::OK();
621   }
622 
623   CHECK_EQ(miopen_handle, nullptr);
624   LOG(ERROR) << "could not create miopen handle: " << ToString(status);
625   if (status == miopenStatusNotInitialized) {
626     auto result = rocm::Diagnostician::FindKernelDriverVersion();
627     if (!result.ok()) {
628       LOG(ERROR) << "error retrieving driver version: "
629                  << rocm::DriverVersionStatusToString(result);
630     } else {
631       const auto& version = result.ValueOrDie();
632       LOG(INFO) << "possibly insufficient driver version: "
633                 << rocm::DriverVersionToString(version);
634     }
635   }
636 
637   return port::Status{port::error::INTERNAL,
638                       absl::StrCat("miopen library could not create a handle: ",
639                                    ToString(status))};
640 }
641 
642 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()643 MIOpenSupport::GetVersion() {
644   // ROCM TODO: retrieve MIOpen version with its API
645   return perftools::gputools::dnn::VersionInfo(1, 3, 0);
646 }
647 
648 // Turns a BatchDescriptor structure into a miopen tensor handle within a scope.
649 class ScopedTensorDescriptor {
650  public:
ScopedTensorDescriptor(const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)651   ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
652                          miopenDataType_t elem_type)
653       : handle_(nullptr) {
654     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
655     if (status != miopenStatusSuccess) {
656       LOG(FATAL) << "could not create miopen tensor descriptor: "
657                  << ToString(status);
658     }
659 
660     switch (batch_descriptor.layout()) {
661       case dnn::DataLayout::kBatchYXDepth:
662       case dnn::DataLayout::kBatchDepthYX: {
663         const int nd = batch_descriptor.ndims() + 2;
664 
665         // MIOpen requires the strides and dims to be ordered as BDYX.
666         std::vector<int64_t> strides64 =
667             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
668         std::vector<int64_t> dims64 =
669             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
670 
671         // MIOpen requires arrays of ints.
672         std::vector<int> strides(nd);
673         std::vector<int> dims(nd);
674         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
675                        &CheckedNarrowing<int64_t, int>);
676         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
677                        &CheckedNarrowing<int64_t, int>);
678         status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
679                                                  dims.data(), strides.data());
680 
681         if (status != miopenStatusSuccess) {
682           LOG(FATAL) << "could not convert BatchDescriptor "
683                      << batch_descriptor.ToString()
684                      << " to miopen tensor descriptor: " << ToString(status);
685         }
686       } break;
687       default:
688         LOG(FATAL) << "Unsupported tensor format "
689                    << DataLayoutString(batch_descriptor.layout());
690         break;
691     }
692   }
693 
~ScopedTensorDescriptor()694   ~ScopedTensorDescriptor() {
695     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
696     if (status != miopenStatusSuccess) {
697       LOG(ERROR) << "could not destroy miopen tensor descriptor: "
698                  << ToString(status);
699     }
700   }
701 
handle() const702   miopenTensorDescriptor_t handle() const { return handle_; }
703 
704  private:
705   miopenTensorDescriptor_t handle_;  // Owned.
706 
707   SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
708 };
709 
710 // Turns a FilterDescriptor structure into a miopen filter handle within a
711 // scope.
712 class ScopedFilterDescriptor {
713  public:
ScopedFilterDescriptor(const FilterDescriptor & filter_descriptor,miopenDataType_t elem_type)714   ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
715                          miopenDataType_t elem_type)
716       : handle_(nullptr) {
717     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
718     if (status != miopenStatusSuccess) {
719       LOG(FATAL) << "could not create miopen filter descriptor: "
720                  << ToString(status);
721     }
722 
723     // We need to pass two vectors to the miopenSetTensorDescriptor routine
724     // "dims" (length == number of dims, elem value == dimension size)
725     // "strides" (length == number of dims, elem value == stride size)
726     //
727     // Irrespective of the actual filter layout, the indexing of both those
728     // vectors must be the following (coz that is what MIOpen expects)
729     // dims[0] = strides[0] = N or output
730     // dims[1] = strides[1] = C or input
731     // dims[2] = strides[2] = H or spatial dim 0
732     // dims[3] = strides[3] = W or spatial dim 1
733     //
734     // assume you have a tensor with dimensions
735     // batch descriptor name    filter descriptor name    value
736     //   N (batch size)            O (output features)    256
737     //   C (channels)              I (input features)       3
738     //   H (height)                H (height)               7
739     //   W (width)                 W (width)                5
740     //
741     // The content of "dims" will be the same irrespective of layout
742     // layout (NCHW or NHWC), and MIOpen expects it should be
743     //                           NCHW layout   NHWC layout
744     // dims[0] = size of N dim =    256           256
745     // dims[1] = size of C dim =      3             3
746     // dims[2] = size of H dim =      7             7
747     // dims[3] = size of W dim =      5             5
748     //
749     // The content of "strides" will be different based on layout
750     //                                  NCHW layout   NHWC layout
751     //  strides[0] = stride of N dim =     7x5x3       7x5x3
752     //  strides[1] = stride of C dim =     7x5         1
753     //  strides[2] = stride of H dim =     5           5x3
754     //  strides[3] = stride of W dim =     1           3
755 
756     switch (filter_descriptor.layout()) {
757       case dnn::FilterLayout::kOutputYXInput:
758       case dnn::FilterLayout::kOutputInputYX: {
759         const int nd = filter_descriptor.ndims() + 2;
760 
761         // MIOpen requires the strides and dims to be ordered as BDYX.
762         std::vector<int64_t> strides64 =
763             filter_descriptor.full_strides(dnn::FilterLayout::kOutputInputYX);
764         std::vector<int64_t> dims64 =
765             filter_descriptor.full_dims(dnn::FilterLayout::kOutputInputYX);
766 
767         // MIOpen requires arrays of ints.
768         std::vector<int> strides;
769         std::vector<int> dims;
770         absl::c_transform(strides64, std::back_inserter(strides),
771                           &CheckedNarrowing<int64_t, int>);
772         absl::c_transform(dims64, std::back_inserter(dims),
773                           &CheckedNarrowing<int64_t, int>);
774         status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
775                                                  dims.data(), strides.data());
776 
777         if (status != miopenStatusSuccess) {
778           LOG(FATAL) << "could not convert FilterDescriptor "
779                      << filter_descriptor.ToString()
780                      << " to miopen tensor descriptor: " << ToString(status);
781         }
782       } break;
783       default:
784         LOG(FATAL) << "Unsupported tensor format "
785                    << FilterLayoutString(filter_descriptor.layout());
786         break;
787     }
788   }
789 
~ScopedFilterDescriptor()790   ~ScopedFilterDescriptor() {
791     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
792     if (status != miopenStatusSuccess) {
793       LOG(ERROR) << "could not destroy miopen filter descriptor: "
794                  << ToString(status);
795     }
796   }
797 
handle() const798   miopenTensorDescriptor_t handle() const { return handle_; }
799 
800  private:
801   // miopen filter descriptor this object creates. Owned.
802   miopenTensorDescriptor_t handle_;
803 
804   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
805 };
806 
807 // Turns a ConvolutionDescriptor structure into a miopen convolution handle
808 // within a scope.
809 class ScopedConvolutionDescriptor {
810  public:
ScopedConvolutionDescriptor(const ConvolutionDescriptor & convolution_descriptor,miopenDataType_t data_type)811   ScopedConvolutionDescriptor(
812       const ConvolutionDescriptor& convolution_descriptor,
813       miopenDataType_t data_type)
814       : handle_(nullptr) {
815     auto status = wrap::miopenCreateConvolutionDescriptor(&handle_);
816     if (status != miopenStatusSuccess) {
817       LOG(FATAL) << "could not create miopen convolution descriptor: "
818                  << ToString(status);
819     }
820     const auto& strides64 = convolution_descriptor.strides();
821     const auto& padding64 = convolution_descriptor.padding();
822     if (convolution_descriptor.pad_alignment() ==
823         dnn::PadAlignment::kTensorFlowPadding) {
824       LOG(ERROR) << "TensorFlow padding alignment is not supported.";
825     }
826 
827     // MIOpen requires arrays of ints.
828     std::vector<int> strides(convolution_descriptor.ndims());
829     std::vector<int> padding(convolution_descriptor.ndims());
830     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
831                    &CheckedNarrowing<int64_t, int>);
832     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
833                    &CheckedNarrowing<int64_t, int>);
834 
835     std::vector<int> upscale(convolution_descriptor.ndims());
836     const auto& dilations64 = convolution_descriptor.dilations();
837     std::transform(dilations64.cbegin(), dilations64.cend(), upscale.begin(),
838                    &CheckedNarrowing<int64_t, int>);
839 
840     status = wrap::miopenInitConvolutionNdDescriptor(
841         handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
842         upscale.data(), miopenConvolution);
843     if (status != miopenStatusSuccess) {
844       LOG(FATAL) << "could not set miopen convolution descriptor: "
845                  << ToString(status);
846     }
847 
848     VLOG(2) << "Requesting grouped convolution: "
849             << convolution_descriptor.group_count();
850     status = wrap::miopenSetConvolutionGroupCount(
851         handle_, convolution_descriptor.group_count());
852     if (status != miopenStatusSuccess) {
853       LOG(FATAL) << "could not set miopen convolution group count: "
854                  << ToString(status);
855     }
856   }
~ScopedConvolutionDescriptor()857   ~ScopedConvolutionDescriptor() {
858     auto status = wrap::miopenDestroyConvolutionDescriptor(handle_);
859     if (status != miopenStatusSuccess) {
860       LOG(ERROR) << "could not destroy miopen convolution descriptor: "
861                  << ToString(status);
862     }
863   }
864 
handle() const865   miopenConvolutionDescriptor_t handle() const { return handle_; }
866 
867  private:
868   miopenConvolutionDescriptor_t handle_;  // Owned.
869 
870   SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
871 };
872 
873 // Turns a PoolingDescriptor structure into a miopen pooling descriptor handle
874 // within a scope.
875 class ScopedPoolingDescriptor {
876  public:
ScopedPoolingDescriptor(const PoolingDescriptor & pooling_descriptor)877   ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
878       : handle_(nullptr) {
879     auto status = wrap::miopenCreatePoolingDescriptor(&handle_);
880     if (status != miopenStatusSuccess) {
881       LOG(FATAL) << "could not create miopen pooling descriptor: "
882                  << ToString(status);
883     }
884 
885     absl::Span<const int64_t> strides64 = pooling_descriptor.strides();
886     absl::Span<const int64_t> padding64 = pooling_descriptor.padding();
887     absl::Span<const int64_t> shape64 = pooling_descriptor.window();
888 
889     const int nd = pooling_descriptor.ndims();
890     std::vector<int> shape(nd);
891     std::vector<int> padding(nd);
892     std::vector<int> strides(nd);
893     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
894                    &CheckedNarrowing<int64_t, int>);
895     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
896                    &CheckedNarrowing<int64_t, int>);
897     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
898                    &CheckedNarrowing<int64_t, int>);
899 
900     status = wrap::miopenSetNdPoolingDescriptor(
901         handle_,
902         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
903              ? miopenPoolingMax
904              : miopenPoolingAverage),
905         nd, shape.data(), padding.data(), strides.data());
906 
907     // Note: The index type has to be uint32 type for now because MIOpen
908     // API assumes all input indexes to be the same type. Since a tensor
909     // descriptor can only use int32 type, the index type here need to be
910     // aligned with the tensor index type of the (input) tensor descritptor
911     status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
912 
913     if (status != miopenStatusSuccess) {
914       LOG(FATAL) << "could not set miopen pooling descriptor: "
915                  << ToString(status);
916     }
917   }
~ScopedPoolingDescriptor()918   ~ScopedPoolingDescriptor() {
919     auto status = wrap::miopenDestroyPoolingDescriptor(handle_);
920     if (status != miopenStatusSuccess) {
921       LOG(ERROR) << "could not destroy miopen pooling descriptor: "
922                  << ToString(status);
923     }
924   }
925 
handle() const926   miopenPoolingDescriptor_t handle() const { return handle_; }
927 
928  private:
929   miopenPoolingDescriptor_t handle_;  // Owned.
930 
931   SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
932 };
933 
934 // Turns a NormalizeDescriptor structure into a miopen LRN descriptor handle.
935 class ScopedNormalizeDescriptor {
936  public:
ScopedNormalizeDescriptor(const NormalizeDescriptor & normalize_descriptor)937   ScopedNormalizeDescriptor(const NormalizeDescriptor& normalize_descriptor)
938       : handle_(nullptr) {
939     auto status = wrap::miopenCreateLRNDescriptor(&handle_);
940     if (status != miopenStatusSuccess) {
941       LOG(FATAL) << "could not create miopen LRN descriptor: "
942                  << ToString(status);
943     }
944 
945     // The range specifies that the indices in the closed range
946     // [i - range, i + range] should be included in the normalization for index
947     // i. The lrnN value is the total number of elements in the range, so
948     // lrnN = 2*range + 1.
949     unsigned lrn_N = 2 * normalize_descriptor.range() + 1;
950 
951     // Note that SE defines the normalization operation as
952     //
953     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
954     //
955     // but MIOpen defines it as
956     //
957     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
958     //
959     // i.e. there is a factor of n difference between the meaning of the alphas
960     // in the two contexts. The MIOpen alpha is n times the SE alpha.
961     double lrn_alpha = lrn_N * normalize_descriptor.alpha();
962 
963     double lrn_beta = normalize_descriptor.beta();
964     double lrn_k = normalize_descriptor.bias();
965     status = wrap::miopenSetLRNDescriptor(handle_, miopenLRNCrossChannel, lrn_N,
966                                           lrn_alpha, lrn_beta, lrn_k);
967     if (status != miopenStatusSuccess) {
968       LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status);
969     }
970   }
971 
~ScopedNormalizeDescriptor()972   ~ScopedNormalizeDescriptor() {
973     auto status = wrap::miopenDestroyLRNDescriptor(handle_);
974     if (status != miopenStatusSuccess) {
975       LOG(ERROR) << "could not destroy miopen LRN descriptor: "
976                  << ToString(status);
977     }
978   }
979 
handle() const980   miopenLRNDescriptor_t handle() const { return handle_; }
981 
982  private:
983   miopenLRNDescriptor_t handle_;  // Owned.
984 
985   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
986 };
987 
988 // Turns a activation mode into a miopen activation mode descriptor with a scope
989 // around it
990 class ScopedActivationDescriptor {
991  public:
ScopedActivationDescriptor(dnn::ActivationMode activation_mode)992   ScopedActivationDescriptor(dnn::ActivationMode activation_mode)
993       : handle_(nullptr),
994         miopen_activation_mode_(miopenActivationPASTHRU),
995         alpha_(0.0),
996         beta_(0.0),
997         gamma_(0.0) {
998     auto status = wrap::miopenCreateActivationDescriptor(&handle_);
999     if (status != miopenStatusSuccess) {
1000       LOG(FATAL) << "call to miopenCreateActivationDescriptor failed: "
1001                  << ToString(status);
1002     } else {
1003       switch (activation_mode) {
1004         case dnn::ActivationMode::kNone:
1005           miopen_activation_mode_ = miopenActivationPASTHRU;
1006           break;
1007 
1008         case dnn::ActivationMode::kSigmoid:
1009           miopen_activation_mode_ = miopenActivationLOGISTIC;
1010           break;
1011 
1012         case dnn::ActivationMode::kRelu:
1013           miopen_activation_mode_ = miopenActivationRELU;
1014           break;
1015 
1016         case dnn::ActivationMode::kRelu6:
1017           miopen_activation_mode_ = miopenActivationRELU;
1018           alpha_ = 6.0;
1019           break;
1020 
1021         case dnn::ActivationMode::kTanh:
1022           miopen_activation_mode_ = miopenActivationTANH;
1023           break;
1024 
1025         default:
1026           LOG(FATAL) << "Activation mode ("
1027                      << dnn::ActivationModeString(activation_mode)
1028                      << ") not yet implemented";
1029           break;
1030       }
1031 
1032       status = wrap::miopenSetActivationDescriptor(
1033           handle_, miopen_activation_mode_, alpha_, beta_, gamma_);
1034       if (status != miopenStatusSuccess) {
1035         LOG(FATAL) << "call to miopenSetActivationDescriptor failed: "
1036                    << ToString(status);
1037       }
1038     }
1039   }
1040 
~ScopedActivationDescriptor()1041   ~ScopedActivationDescriptor() {
1042     auto status = wrap::miopenDestroyActivationDescriptor(handle_);
1043     if (status != miopenStatusSuccess) {
1044       LOG(FATAL) << "call to miopenDestroyActivationDescriptor failed: "
1045                  << ToString(status);
1046     }
1047   }
1048 
handle() const1049   miopenActivationDescriptor_t handle() const { return handle_; }
1050 
GetHashValue()1051   uint64_t GetHashValue() {
1052     uint64_t hash_value = tensorflow::hash<int>()(miopen_activation_mode_);
1053     hash_value = tensorflow::Hash64Combine(hash_value,
1054                                            tensorflow::hash<double>()(alpha_));
1055     hash_value = tensorflow::Hash64Combine(hash_value,
1056                                            tensorflow::hash<double>()(beta_));
1057     hash_value = tensorflow::Hash64Combine(hash_value,
1058                                            tensorflow::hash<double>()(gamma_));
1059 
1060     return hash_value;
1061   }
1062 
1063  private:
1064   miopenActivationDescriptor_t handle_;  // Owned.
1065 
1066   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
1067 
1068  public:
1069   // caching these values here to avoid calling miopenGetActivationDescriptor
1070   // to do the same. miopenGetActivationDescriptor gets called twice during each
1071   // call to execute a fusion plan (that involves the activation op)...once call
1072   // during calculating hashvalue for the fusion op, and another before calling
1073   // SetOpArgs for the activation op
1074   miopenActivationMode_t miopen_activation_mode_;
1075   double alpha_;
1076   double beta_;
1077   double gamma_;
1078 };
1079 
1080 // base class for all fusion plan implementations to derive from
1081 class ScopedFusionPlanBase {
1082  public:
ScopedFusionPlanBase(miopenHandle_t miopen_handle,const miopenFusionDirection_t fuse_direction,const miopenTensorDescriptor_t input_descriptor)1083   ScopedFusionPlanBase(miopenHandle_t miopen_handle,
1084                        const miopenFusionDirection_t fuse_direction,
1085                        const miopenTensorDescriptor_t input_descriptor)
1086       : miopen_handle_(miopen_handle),
1087         fusion_plan_(nullptr),
1088         fusion_args_(nullptr),
1089         fusion_plan_compiled_(false) {
1090     auto status = wrap::miopenCreateOperatorArgs(&fusion_args_);
1091     if (status != miopenStatusSuccess) {
1092       LOG(FATAL) << "call to miopenCreateOperatorArgs failed: "
1093                  << ToString(status);
1094     }
1095   }
1096 
~ScopedFusionPlanBase()1097   virtual ~ScopedFusionPlanBase() {
1098     auto status = wrap::miopenDestroyOperatorArgs(fusion_args_);
1099     if (status != miopenStatusSuccess) {
1100       LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: "
1101                  << ToString(status);
1102     }
1103   }
1104 
Execute(miopenTensorDescriptor_t input_descriptor,const void * input_data,miopenTensorDescriptor_t output_descriptor,void * output_data)1105   miopenStatus_t Execute(miopenTensorDescriptor_t input_descriptor,
1106                          const void* input_data,
1107                          miopenTensorDescriptor_t output_descriptor,
1108                          void* output_data) {
1109     auto status = wrap::miopenExecuteFusionPlan(
1110         miopen_handle_, fusion_plan_, input_descriptor, input_data,
1111         output_descriptor, output_data, fusion_args_);
1112     if (status != miopenStatusSuccess) {
1113       LOG(FATAL) << "call to miopenExecuteFusionPlan failed: "
1114                  << ToString(status);
1115     }
1116 
1117     return status;
1118   }
1119 
CompilationSucceeded()1120   bool CompilationSucceeded() { return fusion_plan_compiled_; }
1121 
1122  protected:
SetConvolutionArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1123   miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha,
1124                                     const float* beta, const void* data) {
1125     miopenFusionOpDescriptor_t conv_op;
1126     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op);
1127     if (status != miopenStatusSuccess) {
1128       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1129                  << ToString(status);
1130     }
1131 
1132     status = wrap::miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha,
1133                                               beta, data);
1134     if (status != miopenStatusSuccess) {
1135       LOG(FATAL) << "call to miopenSetOpArgsConvForward failed: "
1136                  << ToString(status);
1137     }
1138     return status;
1139   }
1140 
SetBiasArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1141   miopenStatus_t SetBiasArgs(const int op_idx, const float* alpha,
1142                              const float* beta, const void* data) {
1143     miopenFusionOpDescriptor_t bias_op;
1144     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op);
1145     if (status != miopenStatusSuccess) {
1146       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1147                  << ToString(status);
1148     }
1149 
1150     status = wrap::miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha,
1151                                               beta, data);
1152     if (status != miopenStatusSuccess) {
1153       LOG(FATAL) << "call to miopenSetOpArgsBiasForward failed: "
1154                  << ToString(status);
1155     }
1156     return status;
1157   }
1158 
SetBatchNormInferenceArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1159   miopenStatus_t SetBatchNormInferenceArgs(const int op_idx, const float* alpha,
1160                                            const float* beta, const void* scale,
1161                                            const void* offset, const void* mean,
1162                                            const void* variance,
1163                                            double epsilon) {
1164     miopenFusionOpDescriptor_t batchnorm_op;
1165     auto status =
1166         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1167     if (status != miopenStatusSuccess) {
1168       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1169                  << ToString(status);
1170     }
1171 
1172     status = wrap::miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op,
1173                                                      alpha, beta, scale, offset,
1174                                                      mean, variance, epsilon);
1175     if (status != miopenStatusSuccess) {
1176       LOG(FATAL) << "call to miopenSetOpArgsBatchNormInference failed: "
1177                  << ToString(status);
1178     }
1179     return status;
1180   }
1181 
SetBatchNormForwardArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,void * running_mean,void * running_variance,void * saved_mean,void * saved_inv_variance,double epsilon,double exponential_average_factor)1182   miopenStatus_t SetBatchNormForwardArgs(
1183       const int op_idx, const float* alpha, const float* beta,
1184       const void* scale, const void* offset, void* running_mean,
1185       void* running_variance, void* saved_mean, void* saved_inv_variance,
1186       double epsilon, double exponential_average_factor) {
1187     miopenFusionOpDescriptor_t batchnorm_op;
1188     auto status =
1189         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1190     if (status != miopenStatusSuccess) {
1191       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1192                  << ToString(status);
1193     }
1194 
1195     status = wrap::miopenSetOpArgsBatchNormForward(
1196         fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean,
1197         saved_inv_variance, running_mean, running_variance, epsilon,
1198         exponential_average_factor);
1199     if (status != miopenStatusSuccess) {
1200       LOG(FATAL) << "call to miopenSetOpArgsBatchNormForward failed: "
1201                  << ToString(status);
1202     }
1203     return status;
1204   }
1205 
SetBatchNormBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * x,const void * scale,const void * offset,void * scale_grad,void * offset_grad,const void * saved_mean,const void * saved_inv_variance)1206   miopenStatus_t SetBatchNormBackwardArgs(const int op_idx, const float* alpha,
1207                                           const float* beta, const void* x,
1208                                           const void* scale, const void* offset,
1209                                           void* scale_grad, void* offset_grad,
1210                                           const void* saved_mean,
1211                                           const void* saved_inv_variance) {
1212     miopenFusionOpDescriptor_t batchnorm_op;
1213     auto status =
1214         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1215     if (status != miopenStatusSuccess) {
1216       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1217                  << ToString(status);
1218     }
1219 
1220     status = wrap::miopenSetOpArgsBatchNormBackward(
1221         fusion_args_, batchnorm_op, alpha, beta, x, scale, offset, scale_grad,
1222         offset_grad, saved_mean, saved_inv_variance);
1223     if (status != miopenStatusSuccess) {
1224       LOG(FATAL) << "call to miopenSetOpArgsBatchNormBackward failed: "
1225                  << ToString(status);
1226     }
1227     return status;
1228   }
1229 
SetActivationForwardArgs(const int op_idx,const float * alpha,const float * beta,double activ_alpha,double activ_beta,double activ_gamma)1230   miopenStatus_t SetActivationForwardArgs(const int op_idx, const float* alpha,
1231                                           const float* beta, double activ_alpha,
1232                                           double activ_beta,
1233                                           double activ_gamma) {
1234     miopenFusionOpDescriptor_t actv_op;
1235     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1236     if (status != miopenStatusSuccess) {
1237       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1238                  << ToString(status);
1239     }
1240 
1241     status =
1242         wrap::miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta,
1243                                           activ_alpha, activ_beta, activ_gamma);
1244     if (status != miopenStatusSuccess) {
1245       LOG(FATAL) << "call to miopenSetOpArgsActivForward failed: "
1246                  << ToString(status);
1247     }
1248     return status;
1249   }
1250 
SetActivationBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * y,double activ_alpha,double activ_beta,double activ_gamma)1251   miopenStatus_t SetActivationBackwardArgs(const int op_idx, const float* alpha,
1252                                            const float* beta, const void* y,
1253                                            double activ_alpha,
1254                                            double activ_beta,
1255                                            double activ_gamma) {
1256     miopenFusionOpDescriptor_t actv_op;
1257     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1258     if (status != miopenStatusSuccess) {
1259       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1260                  << ToString(status);
1261     }
1262 
1263     status = wrap::miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha,
1264                                                 beta, y, nullptr, activ_alpha,
1265                                                 activ_beta, activ_gamma);
1266     if (status != miopenStatusSuccess) {
1267       LOG(FATAL) << "call to miopenSetOpArgsActivBackward failed: "
1268                  << ToString(status);
1269     }
1270     return status;
1271   }
1272 
1273   miopenHandle_t miopen_handle_;
1274   miopenFusionPlanDescriptor_t fusion_plan_;
1275   miopenOperatorArgs_t fusion_args_;  // Owned.
1276   bool fusion_plan_compiled_;
1277 
1278   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBase);
1279 };
1280 
1281 // class to represent the Convolution+Bias+Activation fusion plan
1282 class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase {
1283  public:
ScopedFusionPlanConvolutionBiasActivation(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1284   ScopedFusionPlanConvolutionBiasActivation(
1285       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1286       miopenTensorDescriptor_t filter_descriptor,
1287       miopenConvolutionDescriptor_t conv_descriptor,
1288       miopenTensorDescriptor_t bias_descriptor,
1289       ScopedActivationDescriptor& activation_descriptor)
1290       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1291                              input_descriptor) {
1292     uint64_t hash = GetFusionOpHashValue(
1293         miopen_handle, input_descriptor, filter_descriptor, conv_descriptor,
1294         bias_descriptor, activation_descriptor);
1295 
1296     bool is_compiled = CachedFusionPlans::FindOrCreate(
1297         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1298     if (!is_compiled) {
1299       miopenFusionOpDescriptor_t conv_op;
1300       auto status = wrap::miopenCreateOpConvForward(
1301           fusion_plan_, &conv_op, conv_descriptor, filter_descriptor);
1302       if (status != miopenStatusSuccess) {
1303         LOG(FATAL) << "call to miopenCreateOpConvForward failed: "
1304                    << ToString(status);
1305       }
1306 
1307       miopenFusionOpDescriptor_t bias_op;
1308       status = wrap::miopenCreateOpBiasForward(fusion_plan_, &bias_op,
1309                                                bias_descriptor);
1310       if (status != miopenStatusSuccess) {
1311         LOG(FATAL) << "call to miopenCreateOpBiasForward failed: "
1312                    << ToString(status);
1313       }
1314 
1315       miopenFusionOpDescriptor_t actv_op;
1316       status = wrap::miopenCreateOpActivationForward(
1317           fusion_plan_, &actv_op,
1318           activation_descriptor.miopen_activation_mode_);
1319       if (status != miopenStatusSuccess) {
1320         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1321                    << ToString(status);
1322       }
1323 
1324       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1325       if (status != miopenStatusSuccess) {
1326         VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: "
1327                 << ToString(status);
1328 
1329         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1330       } else {
1331         VLOG(2) << "Fusion Plan compile succedded (CBA) ";
1332         fusion_plan_compiled_ = true;
1333       }
1334     } else {
1335       // fusion plan was already compiled...check whether it failed to compile
1336       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1337     }
1338   }
1339 
SetConvolutionArgs(const void * filter_data)1340   miopenStatus_t SetConvolutionArgs(const void* filter_data) {
1341     float alpha = 1.0;
1342     float beta = 0.0;
1343     return ScopedFusionPlanBase::SetConvolutionArgs(k_conv_op_idx, &alpha,
1344                                                     &beta, filter_data);
1345   }
1346 
SetBiasArgs(const void * bias_data)1347   miopenStatus_t SetBiasArgs(const void* bias_data) {
1348     float alpha = 1.0;
1349     float beta = 0.0;
1350     return ScopedFusionPlanBase::SetBiasArgs(k_bias_op_idx, &alpha, &beta,
1351                                              bias_data);
1352   }
1353 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1354   miopenStatus_t SetActivationForwardArgs(
1355       ScopedActivationDescriptor& activation_descriptor) {
1356     float alpha = 1.0;
1357     float beta = 0.0;
1358 
1359     return ScopedFusionPlanBase::SetActivationForwardArgs(
1360         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1361         activation_descriptor.beta_, activation_descriptor.gamma_);
1362   }
1363 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1364   uint64_t GetFusionOpHashValue(
1365       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1366       miopenTensorDescriptor_t filter_descriptor,
1367       miopenConvolutionDescriptor_t conv_descriptor,
1368       miopenTensorDescriptor_t bias_descriptor,
1369       ScopedActivationDescriptor& activation_descriptor) {
1370     uint64_t hash_value = tensorflow::Hash64("ConvolutionBiasActivation");
1371 
1372     hash_value = tensorflow::Hash64Combine(
1373         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1374 
1375     hash_value =
1376         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1377     hash_value =
1378         tensorflow::Hash64Combine(hash_value, GetHashValue(filter_descriptor));
1379     hash_value =
1380         tensorflow::Hash64Combine(hash_value, GetHashValue(conv_descriptor));
1381     hash_value =
1382         tensorflow::Hash64Combine(hash_value, GetHashValue(bias_descriptor));
1383     hash_value = tensorflow::Hash64Combine(
1384         hash_value, activation_descriptor.GetHashValue());
1385     return hash_value;
1386   }
1387 
1388  private:
1389   const int k_conv_op_idx = 0;
1390   const int k_bias_op_idx = 1;
1391   const int k_actv_op_idx = 2;
1392 
1393   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanConvolutionBiasActivation);
1394 };
1395 
1396 // class to represent the BatchNorm+Activation (inference) fusion plan
1397 class ScopedFusionPlanBatchNormActivationInference
1398     : public ScopedFusionPlanBase {
1399  public:
ScopedFusionPlanBatchNormActivationInference(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1400   ScopedFusionPlanBatchNormActivationInference(
1401       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1402       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1403       ScopedActivationDescriptor& activation_descriptor)
1404       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1405                              input_descriptor) {
1406     uint64_t hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1407                                          scale_offset_mean_variance_descriptor,
1408                                          activation_descriptor);
1409 
1410     bool is_compiled = CachedFusionPlans::FindOrCreate(
1411         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1412 
1413     if (!is_compiled) {
1414       miopenFusionOpDescriptor_t batchnorm_op;
1415       auto status = wrap::miopenCreateOpBatchNormInference(
1416           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1417           scale_offset_mean_variance_descriptor);
1418 
1419       if (status != miopenStatusSuccess) {
1420         LOG(FATAL) << "call to miopenCreateOpBatchNormInference failed: "
1421                    << ToString(status);
1422       }
1423 
1424       miopenFusionOpDescriptor_t actv_op;
1425       status = wrap::miopenCreateOpActivationForward(
1426           fusion_plan_, &actv_op,
1427           activation_descriptor.miopen_activation_mode_);
1428       if (status != miopenStatusSuccess) {
1429         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1430                    << ToString(status);
1431       }
1432 
1433       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1434       if (status != miopenStatusSuccess) {
1435         VLOG(2) << "call to miopenCompileFusionPlan (BnA inference) failed: "
1436                 << ToString(status);
1437 
1438         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1439       } else {
1440         VLOG(2) << "Fusion Plan compile succedded (BnA inference) ";
1441         fusion_plan_compiled_ = true;
1442       }
1443     } else {
1444       // fusion plan was already compiled...check whether it failed to compile
1445       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1446     }
1447   }
1448 
SetBatchNormInferenceArgs(const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1449   miopenStatus_t SetBatchNormInferenceArgs(const void* scale,
1450                                            const void* offset, const void* mean,
1451                                            const void* variance,
1452                                            double epsilon) {
1453     float alpha = 1.0;
1454     float beta = 0.0;
1455     return ScopedFusionPlanBase::SetBatchNormInferenceArgs(
1456         k_batchnorm_op_idx, &alpha, &beta, scale, offset, mean, variance,
1457         epsilon);
1458   }
1459 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1460   miopenStatus_t SetActivationForwardArgs(
1461       ScopedActivationDescriptor& activation_descriptor) {
1462     float alpha = 1.0;
1463     float beta = 0.0;
1464 
1465     return ScopedFusionPlanBase::SetActivationForwardArgs(
1466         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1467         activation_descriptor.beta_, activation_descriptor.gamma_);
1468   }
1469 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1470   uint64_t GetFusionOpHashValue(
1471       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1472       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1473       ScopedActivationDescriptor& activation_descriptor) {
1474     uint64_t hash_value = tensorflow::Hash64("BatchNormActivationInference");
1475 
1476     hash_value = tensorflow::Hash64Combine(
1477         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1478 
1479     hash_value =
1480         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1481 
1482     hash_value = tensorflow::Hash64Combine(
1483         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1484 
1485     hash_value = tensorflow::Hash64Combine(
1486         hash_value, activation_descriptor.GetHashValue());
1487     return hash_value;
1488   }
1489 
1490  private:
1491   const int k_batchnorm_op_idx = 0;
1492   const int k_actv_op_idx = 1;
1493 
1494   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationInference);
1495 };
1496 
1497 // class to represent the BatchNorm+Activation (training-forward) fusion plan
1498 class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase {
1499  public:
ScopedFusionPlanBatchNormActivationForward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1500   ScopedFusionPlanBatchNormActivationForward(
1501       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1502       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1503       ScopedActivationDescriptor& activation_descriptor)
1504       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1505                              input_descriptor) {
1506     uint64_t hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1507                                          scale_offset_mean_variance_descriptor,
1508                                          activation_descriptor);
1509 
1510     bool is_compiled = CachedFusionPlans::FindOrCreate(
1511         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1512 
1513     if (!is_compiled) {
1514       miopenFusionOpDescriptor_t batchnorm_op;
1515       auto status = wrap::miopenCreateOpBatchNormForward(
1516           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1517           true /* runningMeanVariance */);
1518 
1519       if (status != miopenStatusSuccess) {
1520         LOG(FATAL) << "call to miopenCreateOpBatchNormForward failed: "
1521                    << ToString(status);
1522       }
1523 
1524       miopenFusionOpDescriptor_t actv_op;
1525       status = wrap::miopenCreateOpActivationForward(
1526           fusion_plan_, &actv_op,
1527           activation_descriptor.miopen_activation_mode_);
1528       if (status != miopenStatusSuccess) {
1529         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1530                    << ToString(status);
1531       }
1532 
1533       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1534       if (status != miopenStatusSuccess) {
1535         VLOG(2) << "call to miopenCompileFusionPlan (BnA forward) failed: "
1536                 << ToString(status);
1537 
1538         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1539       } else {
1540         VLOG(2) << "Fusion Plan compile succedded (BnA forward) ";
1541         fusion_plan_compiled_ = true;
1542       }
1543     } else {
1544       // fusion plan was already compiled...check whether it failed to compile
1545       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1546     }
1547   }
1548 
SetBatchNormForwardArgs(const void * scale,const void * offset,void * batch_mean,void * batch_var,void * saved_mean,void * saved_var,double epsilon)1549   miopenStatus_t SetBatchNormForwardArgs(const void* scale, const void* offset,
1550                                          void* batch_mean, void* batch_var,
1551                                          void* saved_mean, void* saved_var,
1552                                          double epsilon) {
1553     float alpha = 1.0;
1554     float beta = 0.0;
1555     return ScopedFusionPlanBase::SetBatchNormForwardArgs(
1556         k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var,
1557         saved_mean, saved_var, epsilon, /*exponential_average_factor=*/1.0);
1558   }
1559 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1560   miopenStatus_t SetActivationForwardArgs(
1561       ScopedActivationDescriptor& activation_descriptor) {
1562     float alpha = 1.0;
1563     float beta = 0.0;
1564 
1565     return ScopedFusionPlanBase::SetActivationForwardArgs(
1566         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1567         activation_descriptor.beta_, activation_descriptor.gamma_);
1568   }
1569 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1570   uint64_t GetFusionOpHashValue(
1571       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1572       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1573       ScopedActivationDescriptor& activation_descriptor) {
1574     uint64_t hash_value = tensorflow::Hash64("BatchNormActivationForward");
1575 
1576     hash_value = tensorflow::Hash64Combine(
1577         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1578 
1579     hash_value =
1580         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1581 
1582     hash_value = tensorflow::Hash64Combine(
1583         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1584 
1585     hash_value = tensorflow::Hash64Combine(
1586         hash_value, activation_descriptor.GetHashValue());
1587     return hash_value;
1588   }
1589 
1590  private:
1591   const int k_batchnorm_op_idx = 0;
1592   const int k_actv_op_idx = 1;
1593 
1594   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationForward);
1595 };
1596 
1597 // class to represent the BatchNorm+Activation (training-backward) fusion plan
1598 class ScopedFusionPlanBatchNormActivationBackward
1599     : public ScopedFusionPlanBase {
1600  public:
ScopedFusionPlanBatchNormActivationBackward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1601   ScopedFusionPlanBatchNormActivationBackward(
1602       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1603       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1604       ScopedActivationDescriptor& activation_descriptor)
1605       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1606                              input_descriptor) {
1607     uint64_t hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1608                                          scale_offset_mean_variance_descriptor,
1609                                          activation_descriptor);
1610 
1611     bool is_compiled = CachedFusionPlans::FindOrCreate(
1612         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1613 
1614     if (!is_compiled) {
1615       miopenFusionOpDescriptor_t batchnorm_op;
1616       auto status = wrap::miopenCreateOpBatchNormBackward(
1617           fusion_plan_, &batchnorm_op, miopenBNSpatial);
1618 
1619       if (status != miopenStatusSuccess) {
1620         LOG(FATAL) << "call to miopenCreateOpBatchNormBackward failed: "
1621                    << ToString(status);
1622       }
1623 
1624       miopenFusionOpDescriptor_t actv_op;
1625       status = wrap::miopenCreateOpActivationBackward(
1626           fusion_plan_, &actv_op,
1627           activation_descriptor.miopen_activation_mode_);
1628       if (status != miopenStatusSuccess) {
1629         LOG(FATAL) << "call to miopenCreateOpActivationBackward failed: "
1630                    << ToString(status);
1631       }
1632 
1633       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1634       if (status != miopenStatusSuccess) {
1635         VLOG(2) << "call to miopenCompileFusionPlan (BnA backward) failed: "
1636                 << ToString(status);
1637 
1638         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1639       } else {
1640         VLOG(2) << "Fusion Plan compile succedded (BnA backward) ";
1641         fusion_plan_compiled_ = true;
1642       }
1643     } else {
1644       // fusion plan was already compiled...check whether it failed to compile
1645       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1646     }
1647   }
1648 
SetBatchNormBackwardArgs(const void * x,const void * scale,const void * offset,const void * saved_mean,const void * saved_var,void * scale_grad,void * offset_grad)1649   miopenStatus_t SetBatchNormBackwardArgs(const void* x, const void* scale,
1650                                           const void* offset,
1651                                           const void* saved_mean,
1652                                           const void* saved_var,
1653                                           void* scale_grad, void* offset_grad) {
1654     float alpha = 1.0;
1655     float beta = 0.0;
1656 
1657     return ScopedFusionPlanBase::SetBatchNormBackwardArgs(
1658         k_batchnorm_op_idx, &alpha, &beta, x, scale, offset, scale_grad,
1659         offset_grad, saved_mean, saved_var);
1660   }
1661 
SetActivationBackwardArgs(ScopedActivationDescriptor & activation_descriptor,const void * y)1662   miopenStatus_t SetActivationBackwardArgs(
1663       ScopedActivationDescriptor& activation_descriptor, const void* y) {
1664     float alpha = 1.0;
1665     float beta = 0.0;
1666 
1667     return ScopedFusionPlanBase::SetActivationBackwardArgs(
1668         k_actv_op_idx, &alpha, &beta, y, activation_descriptor.alpha_,
1669         activation_descriptor.beta_, activation_descriptor.gamma_);
1670   }
1671 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1672   uint64_t GetFusionOpHashValue(
1673       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1674       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1675       ScopedActivationDescriptor& activation_descriptor) {
1676     uint64_t hash_value = tensorflow::Hash64("BatchNormActivationBackward");
1677 
1678     hash_value = tensorflow::Hash64Combine(
1679         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1680 
1681     hash_value =
1682         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1683 
1684     hash_value = tensorflow::Hash64Combine(
1685         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1686 
1687     hash_value = tensorflow::Hash64Combine(
1688         hash_value, activation_descriptor.GetHashValue());
1689     return hash_value;
1690   }
1691 
1692  private:
1693   const int k_batchnorm_op_idx = 0;
1694   const int k_actv_op_idx = 1;
1695 
1696   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationBackward);
1697 };
1698 
1699 namespace {
ToMIOpenDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1700 miopenDataType_t ToMIOpenDataType(
1701     dnn::DataType data_type,
1702     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1703   switch (data_type) {
1704     case dnn::DataType::kFloat:
1705       return miopenFloat;
1706     case dnn::DataType::kHalf:
1707       return miopenHalf;
1708     case dnn::DataType::kDouble:
1709     default:
1710       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1711   }
1712 }
1713 
ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode)1714 miopenRNNInputMode_t ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode) {
1715   switch (input_mode) {
1716     case dnn::RnnInputMode::kRnnLinearSkip:
1717       return miopenRNNlinear;
1718     case dnn::RnnInputMode::kRnnSkipInput:
1719       return miopenRNNskip;
1720     default:
1721       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1722   }
1723 }
1724 
ToMIOpenRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1725 miopenRNNDirectionMode_t ToMIOpenRnnDirectionMode(
1726     dnn::RnnDirectionMode direction_mode) {
1727   switch (direction_mode) {
1728     case dnn::RnnDirectionMode::kRnnUnidirectional:
1729       return miopenRNNunidirection;
1730     case dnn::RnnDirectionMode::kRnnBidirectional:
1731       return miopenRNNbidirection;
1732     default:
1733       LOG(FATAL) << "Invalid RNN direction mode: "
1734                  << static_cast<int>(direction_mode);
1735   }
1736 }
1737 
ToMIOpenRnnMode(dnn::RnnMode rnn_mode)1738 miopenRNNMode_t ToMIOpenRnnMode(dnn::RnnMode rnn_mode) {
1739   switch (rnn_mode) {
1740     case dnn::RnnMode::kRnnRelu:
1741       return miopenRNNRELU;
1742     case dnn::RnnMode::kRnnTanh:
1743       return miopenRNNTANH;
1744     case dnn::RnnMode::kRnnLstm:
1745       return miopenLSTM;
1746     case dnn::RnnMode::kRnnGru:
1747       return miopenGRU;
1748     default:
1749       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1750   }
1751 }
1752 
1753 template <typename Base>
1754 class MixinBase : public Base {};
1755 template <>
1756 class MixinBase<void> {};
1757 
1758 }  // namespace
1759 
1760 #define RETURN_IF_MIOPEN_ERROR(STATUS, ...)                              \
1761   if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) {               \
1762     string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
1763     SetFailure(port::Status(port::error::UNKNOWN, error_msg));           \
1764     LOG(ERROR) << error_msg;                                             \
1765     return;                                                              \
1766   }
1767 
1768 template <typename Base>
1769 class MIOpenDescriptorCommon : public MixinBase<Base> {
1770  public:
ok() const1771   bool ok() const { return status_.ok(); }
Status() const1772   port::Status Status() const { return status_; }
1773 
1774  protected:
SetFailure(const port::Status & status)1775   void SetFailure(const port::Status& status) { status_.Update(status); }
1776   port::Status status_;
1777 };
1778 
1779 class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon<void> {
1780  public:
1781   typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
1782   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1783   MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,
1784                             const MIOpenRnnDescriptor& rnn_desc);
~MIOpenRnnParamsDescriptor()1785   ~MIOpenRnnParamsDescriptor() {
1786     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1787     RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy RNN tensor descriptor");
1788   }
handle() const1789   miopenTensorDescriptor_t handle() const {
1790     if (!ok()) return nullptr;
1791     return handle_;
1792   }
params_size_in_bytes() const1793   int64_t params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1794   ParamsRegions params_weights() const {
1795     if (!ok()) return ParamsRegions();
1796     return weights_;
1797   }
params_biases() const1798   ParamsRegions params_biases() const {
1799     if (!ok()) return ParamsRegions();
1800     return biases_;
1801   }
1802 
1803  private:
1804   int GetRegionCountPerLayer() const;
1805   miopenTensorDescriptor_t handle_;
1806   const MIOpenRnnDescriptor* rnn_desc_;
1807   int64_t params_size_in_bytes_;
1808   ParamsRegions weights_;
1809   ParamsRegions biases_;
1810   port::Status status_;
1811   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnParamsDescriptor);
1812 };
1813 
1814 class MIOpenRnnDescriptor : public MIOpenDescriptorCommon<dnn::RnnDescriptor> {
1815  public:
MIOpenRnnDescriptor(miopenHandle_t miopen_handle,int num_layers,int hidden_size,int input_size,miopenRNNInputMode_t input_mode,miopenRNNDirectionMode_t direction_mode,miopenRNNMode_t rnn_mode,miopenDataType_t data_type,float dropout,uint64_t seed,ScratchAllocator * state_allocator)1816   MIOpenRnnDescriptor(miopenHandle_t miopen_handle, int num_layers,
1817                       int hidden_size, int input_size,
1818                       miopenRNNInputMode_t input_mode,
1819                       miopenRNNDirectionMode_t direction_mode,
1820                       miopenRNNMode_t rnn_mode, miopenDataType_t data_type,
1821                       float dropout, uint64_t seed,
1822                       ScratchAllocator* state_allocator)
1823       : rnn_desc_(nullptr),
1824         num_layers_(num_layers),
1825         hidden_size_(hidden_size),
1826         input_size_(input_size),
1827         input_mode_(input_mode),
1828         direction_mode_(direction_mode),
1829         rnn_mode_(rnn_mode),
1830         data_type_(data_type) {
1831     // Create the RNN handle
1832     auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_);
1833     RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor");
1834     status = wrap::miopenSetRNNDescriptor(
1835         rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
1836         num_layers /*numLayers*/, input_mode /*inputMode*/,
1837         direction_mode /*direction*/, rnn_mode /*mode*/,
1838         miopenRNNwithBias /*biasMode*/, miopenRNNdefault /*algo*/,
1839         data_type /*dataType*/);
1840     RETURN_IF_MIOPEN_ERROR(status, "Unable to update RNN descriptor");
1841     // Create the params handle.
1842     miopen_params_desc_.reset(
1843         new MIOpenRnnParamsDescriptor(miopen_handle, *this));
1844     if (!miopen_params_desc_->ok()) {
1845       SetFailure(miopen_params_desc_->Status());
1846       return;
1847     }
1848   }
~MIOpenRnnDescriptor()1849   ~MIOpenRnnDescriptor() override {
1850     if (rnn_desc_) {
1851       auto status = wrap::miopenDestroyRNNDescriptor(rnn_desc_);
1852       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN descriptor");
1853     }
1854   }
handle() const1855   miopenRNNDescriptor_t handle() const {
1856     if (!ok()) return nullptr;
1857     return rnn_desc_;
1858   }
num_layers() const1859   int num_layers() const { return num_layers_; }
hidden_size() const1860   int hidden_size() const { return hidden_size_; }
input_size() const1861   int input_size() const { return input_size_; }
input_mode() const1862   miopenRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1863   miopenRNNDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1864   miopenRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1865   miopenDataType_t data_type() const { return data_type_; }
ParamsSizeInBytes() const1866   int64_t ParamsSizeInBytes() const override {
1867     return miopen_params_desc_->params_size_in_bytes();
1868   }
params_handle() const1869   miopenTensorDescriptor_t params_handle() const {
1870     if (!miopen_params_desc_) return nullptr;
1871     return miopen_params_desc_->handle();
1872   }
ParamsWeightRegions() const1873   ParamsRegions ParamsWeightRegions() const override {
1874     if (!ok()) return ParamsRegions();
1875     return miopen_params_desc_->params_weights();
1876   }
ParamsBiasRegions() const1877   ParamsRegions ParamsBiasRegions() const override {
1878     if (!ok()) return ParamsRegions();
1879     return miopen_params_desc_->params_biases();
1880   }
1881 
1882  private:
1883   miopenRNNDescriptor_t rnn_desc_;
1884   int num_layers_;
1885   int hidden_size_;
1886   int input_size_;
1887   miopenRNNInputMode_t input_mode_;
1888   miopenRNNDirectionMode_t direction_mode_;
1889   miopenRNNMode_t rnn_mode_;
1890   miopenDataType_t data_type_;
1891   port::Status status_;
1892   // no dropout in MIOpen.
1893   // std::unique_ptr<miopenDropoutDescriptor> miopen_dropout_desc_;
1894   std::unique_ptr<MIOpenRnnParamsDescriptor> miopen_params_desc_;
1895   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnDescriptor);
1896 };
1897 
1898 // Get ID of the internal parameter tensor.
1899 //
GetRegionCountPerLayer() const1900 int MIOpenRnnParamsDescriptor::GetRegionCountPerLayer() const {
1901   auto rnn_mode = rnn_desc_->rnn_mode();
1902   switch (rnn_mode) {
1903     case miopenRNNRELU:
1904     case miopenRNNTANH:
1905       return 2;
1906     case miopenLSTM:
1907       return 8;
1908     case miopenGRU:
1909       return 6;
1910     default:
1911       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1912   }
1913 }
1914 
1915 class MIOpenRnnSequenceTensorDescriptor
1916     : public MIOpenDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
1917  public:
MIOpenRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,miopenDataType_t data_type)1918   MIOpenRnnSequenceTensorDescriptor(int seq_length, int batch_size,
1919                                     int data_size, miopenDataType_t data_type)
1920       : seq_length_(seq_length),
1921         batch_size_(batch_size),
1922         data_size_(data_size),
1923         data_type_(data_type) {
1924     miopenTensorDescriptor_t handle = nullptr;
1925     if (seq_length <= 0) {
1926       string error_msg =
1927           absl::StrCat("sequence length must be positive: ", seq_length);
1928       LOG(ERROR) << error_msg;
1929       SetFailure(port::Status(port::error::UNKNOWN, error_msg));
1930       return;
1931     }
1932     auto status = wrap::miopenCreateTensorDescriptor(&handle);
1933     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1934     std::array<int, 2> dims = {{batch_size, data_size}};
1935     status = wrap::miopenSetTensorDescriptor(
1936         handle /*tensorDesc*/, data_type /*dataType*/, 2 /*nbDims*/,
1937         dims.data() /*dimA*/, nullptr /*strideA*/);
1938     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1939     // Replicate handle across the number of steps.
1940     handles_.assign(seq_length, handle);
1941   }
1942 
~MIOpenRnnSequenceTensorDescriptor()1943   ~MIOpenRnnSequenceTensorDescriptor() override {
1944     // Only the first one needs to be destroyed. All others are the same.
1945     auto status = wrap::miopenDestroyTensorDescriptor(handles_[0]);
1946     RETURN_IF_MIOPEN_ERROR(status,
1947                            "Failed to destroy sequence tensor descriptor");
1948   }
1949 
handles() const1950   const miopenTensorDescriptor_t* handles() const {
1951     if (!ok()) return nullptr;
1952     CHECK(!handles_.empty()) << "handles cannot be empty";
1953     return handles_.data();
1954   }
1955 
seq_length() const1956   int seq_length() const { return seq_length_; }
batch_size() const1957   int batch_size() const { return batch_size_; }
data_size() const1958   int data_size() const { return data_size_; }
1959 
1960  private:
1961   int seq_length_;
1962   int batch_size_;
1963   int data_size_;
1964   miopenDataType_t data_type_;
1965   std::vector<miopenTensorDescriptor_t> handles_;
1966   port::Status status_;
1967   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnSequenceTensorDescriptor);
1968 };
1969 
1970 class MIOpenRnnStateTensorDescriptor
1971     : public MIOpenDescriptorCommon<dnn::RnnStateTensorDescriptor> {
1972  public:
MIOpenRnnStateTensorDescriptor(int num_layers,int batch_size,int data_size,miopenDataType_t data_type)1973   MIOpenRnnStateTensorDescriptor(int num_layers, int batch_size, int data_size,
1974                                  miopenDataType_t data_type)
1975       : handle_(nullptr),
1976         num_layers_(num_layers),
1977         batch_size_(batch_size),
1978         data_size_(data_size),
1979         data_type_(data_type) {
1980     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
1981     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1982     std::array<int, 3> dims = {{num_layers, batch_size, data_size}};
1983     status = wrap::miopenSetTensorDescriptor(
1984         handle_ /*tensorDesc*/, data_type /*dataType*/, 3 /*nbDims*/,
1985         dims.data() /*dimA*/, nullptr /*strideA*/);
1986     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1987   }
1988 
~MIOpenRnnStateTensorDescriptor()1989   ~MIOpenRnnStateTensorDescriptor() override {
1990     if (!handle_) {
1991       auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1992       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN state tensor");
1993     }
1994   }
1995 
handle() const1996   miopenTensorDescriptor_t handle() const {
1997     if (!ok()) return nullptr;
1998     return handle_;
1999   }
num_layers() const2000   int num_layers() const { return num_layers_; }
batch_size() const2001   int batch_size() const { return batch_size_; }
data_size() const2002   int data_size() const { return data_size_; }
2003 
2004  private:
2005   miopenTensorDescriptor_t handle_;
2006   int num_layers_;
2007   int batch_size_;
2008   int data_size_;
2009   port::Status status_;
2010   miopenDataType_t data_type_;
2011   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnStateTensorDescriptor);
2012 };
2013 
2014 namespace {
2015 
2016 struct RnnModelDims {
2017   int num_layers = 0;
2018   int batch_size = 0;
2019   int seq_length = 0;
2020   int hidden_size = 0;
2021   int input_size = 0;
2022   int dir_count = 0;
2023 };
2024 
2025 template <class T>
ExtractAndCheckRnnForward(const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,RnnModelDims * model_dims)2026 bool ExtractAndCheckRnnForward(
2027     const MIOpenRnnDescriptor& rnn_desc,
2028     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2029     const DeviceMemory<T>& input_data,
2030     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2031     const DeviceMemory<T>& input_h_data,
2032     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2033     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2034     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2035     const DeviceMemory<T>& output_data,
2036     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2037     const DeviceMemory<T>& output_h_data,
2038     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2039     const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
2040   // extract model parameters
2041   model_dims->num_layers = rnn_desc.num_layers();
2042   model_dims->batch_size = input_desc.batch_size();
2043   model_dims->seq_length = input_desc.seq_length();
2044   model_dims->hidden_size = rnn_desc.hidden_size();
2045   model_dims->input_size = input_desc.data_size();
2046   model_dims->dir_count =
2047       (rnn_desc.direction_mode() == miopenRNNbidirection) ? 2 : 1;
2048 
2049   // check parameters
2050   if (!(input_h_desc.num_layers() ==
2051             model_dims->num_layers * model_dims->dir_count &&
2052         input_h_desc.batch_size() == model_dims->batch_size &&
2053         input_h_desc.data_size() == model_dims->hidden_size)) {
2054     LOG(ERROR) << "Invalid input_h shape";
2055     return false;
2056   }
2057   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
2058         input_h_desc.batch_size() == input_c_desc.batch_size() &&
2059         input_h_desc.data_size() == input_c_desc.data_size())) {
2060     LOG(ERROR) << "Invalid input_c shape";
2061     return false;
2062   }
2063   if (!(output_desc.seq_length() == model_dims->seq_length &&
2064         output_desc.batch_size() == model_dims->batch_size &&
2065         output_desc.data_size() ==
2066             model_dims->hidden_size * model_dims->dir_count)) {
2067     LOG(ERROR) << "Invalid output shape";
2068     return false;
2069   }
2070   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
2071         input_h_desc.batch_size() == output_h_desc.batch_size() &&
2072         input_h_desc.data_size() == output_h_desc.data_size())) {
2073     LOG(ERROR) << "Invalid output_h shape";
2074     return false;
2075   }
2076   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
2077         input_h_desc.batch_size() == output_c_desc.batch_size() &&
2078         input_h_desc.data_size() == output_c_desc.data_size())) {
2079     LOG(ERROR) << "Invalid output_h shape";
2080     return false;
2081   }
2082 
2083   return true;
2084 }
2085 
CheckRNNParameterSize(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc)2086 bool CheckRNNParameterSize(
2087     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc,
2088     const MIOpenRnnSequenceTensorDescriptor& input_desc) {
2089   size_t params_size_in_bytes = 0;
2090   auto status = wrap::miopenGetRNNParamsSize(
2091       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2092       input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
2093       rnn_desc.data_type() /*dataType*/);
2094   if (status != miopenStatusSuccess) {
2095     LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
2096     return false;
2097   }
2098   return static_cast<int64_t>(params_size_in_bytes) ==
2099          rnn_desc.ParamsSizeInBytes();
2100 }
2101 
CreateRnnWorkspace(Stream * stream,miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * workspace)2102 bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle,
2103                         const MIOpenRnnDescriptor& rnn_desc,
2104                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
2105                         ScratchAllocator* workspace_allocator,
2106                         DeviceMemory<uint8>* workspace) {
2107   // Query the workspace size.
2108   size_t workspace_size_in_bytes = 0;
2109   auto status = wrap::miopenGetRNNWorkspaceSize(
2110       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2111       input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
2112       &workspace_size_in_bytes /*sizeInBytes*/);
2113   if (status != miopenStatusSuccess) {
2114     LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
2115     return false;
2116   }
2117   // Allocate the workspace.
2118   if (workspace_size_in_bytes > 0) {
2119     auto allocated =
2120         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
2121     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
2122       LOG(ERROR) << "Failed to allocate RNN workspace";
2123 
2124       return false;
2125     }
2126     stream->ThenMemZero(workspace, workspace_size_in_bytes);
2127   } else {
2128     *workspace = DeviceMemory<uint8>();
2129   }
2130   return true;
2131 }
2132 
2133 }  // namespace
2134 
2135 template <class T>
DoRnnForwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)2136 bool MIOpenSupport::DoRnnForwardImpl(
2137     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2138     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2139     const DeviceMemory<T>& input_data,
2140     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2141     const DeviceMemory<T>& input_h_data,
2142     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2143     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2144     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2145     DeviceMemory<T>* output_data,
2146     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2147     DeviceMemory<T>* output_h_data,
2148     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2149     DeviceMemory<T>* output_c_data, bool is_training,
2150     ScratchAllocator* reserve_space_allocator,
2151     ScratchAllocator* workspace_allocator) {
2152   // extract model parameters
2153   RnnModelDims model_dims;
2154   bool res = ExtractAndCheckRnnForward(
2155       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2156       input_c_desc, input_c_data, params, output_desc, *output_data,
2157       output_h_desc, *output_h_data, output_c_desc, *output_c_data,
2158       &model_dims);
2159   if (!res) {
2160     LOG(ERROR) << "Invalid parameters for RNN Model";
2161     return false;
2162   }
2163 
2164   auto miopen = miopen_->GetHandle(parent_, stream);
2165 
2166   // check params size
2167 
2168   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2169     LOG(ERROR) << "Invalid parameters";
2170     return false;
2171   }
2172 
2173   // create the workspace
2174   DeviceMemory<uint8> workspace;
2175   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2176                           workspace_allocator, &workspace)) {
2177     LOG(ERROR) << "Unable to create rnn workspace";
2178 
2179     return false;
2180   }
2181 
2182   // query the reserve space size
2183   // allocate the reserve space
2184   DeviceMemory<uint8> reserve_space;
2185   if (is_training) {
2186     size_t reserve_space_size_in_bytes = 0;
2187     auto status = wrap::miopenGetRNNTrainingReserveSize(
2188         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2189         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2190         &reserve_space_size_in_bytes /*sizeInBytes*/);
2191     if (status != miopenStatusSuccess) {
2192       LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
2193       return false;
2194     }
2195 
2196     if (reserve_space_size_in_bytes > 0) {
2197       auto allocated =
2198           reserve_space_allocator->AllocateBytes(reserve_space_size_in_bytes);
2199       if (!allocated.ok() ||
2200           (reserve_space = allocated.ValueOrDie()) == nullptr) {
2201         LOG(ERROR) << "Fail to allocate RNN reserve space";
2202         return false;
2203       }
2204       stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes);
2205     }
2206   }
2207 
2208   // make the forward call
2209   if (!is_training) {
2210     auto status = wrap::miopenRNNForwardInference(
2211         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2212         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2213         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2214         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2215         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2216         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2217         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2218         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2219         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2220         workspace.size() /*workSpaceSizeInBytes*/);
2221 
2222     if (status != miopenStatusSuccess) {
2223       LOG(ERROR) << "Failed to call miopenRNNForwardInference: "
2224                  << ToString(status);
2225       return false;
2226     }
2227   } else {
2228     auto status = wrap::miopenRNNForwardTraining(
2229         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2230         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2231         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2232         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2233         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2234         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2235         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2236         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2237         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2238         workspace.size() /*workSpaceSizeInBytes*/,
2239         reserve_space.opaque() /*reserveSpace*/,
2240         reserve_space.size() /*reserveSpaceSizeInBytes*/);
2241     if (status != miopenStatusSuccess) {
2242       LOG(ERROR) << "Failed to call miopenRNNForwardTraining"
2243                  << ToString(status);
2244       return false;
2245     }
2246   }
2247   return true;
2248 }
2249 
2250 template <class T>
DoRnnBackwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)2251 bool MIOpenSupport::DoRnnBackwardImpl(
2252     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2253     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2254     const DeviceMemory<T>& input_data,
2255     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2256     const DeviceMemory<T>& input_h_data,
2257     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2258     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2259     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2260     const DeviceMemory<T>& output_data,
2261     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2262     const DeviceMemory<T>& output_h_data,
2263     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2264     const DeviceMemory<T>& output_c_data,
2265     const DeviceMemory<T>& output_backprop_data,
2266     const DeviceMemory<T>& output_h_backprop_data,
2267     const DeviceMemory<T>& output_c_backprop_data,
2268     DeviceMemory<T>* input_backprop_data,
2269     DeviceMemory<T>* input_h_backprop_data,
2270     DeviceMemory<T>* input_c_backprop_data,
2271     DeviceMemory<T>* params_backprop_data,
2272     DeviceMemory<uint8>* reserve_space_data,
2273     ScratchAllocator* workspace_allocator) {
2274   // extract model parameters
2275   RnnModelDims model_dims;
2276   bool res = ExtractAndCheckRnnForward(
2277       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2278       input_c_desc, input_c_data, params, output_desc, output_data,
2279       output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
2280   if (!res) {
2281     LOG(ERROR) << "Invalid parameters for RNN Model";
2282     return false;
2283   }
2284 
2285   auto miopen = miopen_->GetHandle(parent_, stream);
2286 
2287   // check params size
2288 
2289   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2290     LOG(ERROR) << "Invalid parameters";
2291     return false;
2292   }
2293 
2294   // create the workspace
2295   DeviceMemory<uint8> workspace;
2296   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2297                           workspace_allocator, &workspace)) {
2298     LOG(ERROR) << "Unable to create rnn workspace";
2299     return false;
2300   }
2301 
2302   // workaround for missing initialization support in MIOpen.
2303   // TODO: remove this when MIOpen is ready.
2304   auto type_size = std::is_same<T, Eigen::half>::value ? 2 : sizeof(T);
2305   auto size_data = input_desc.seq_length() * input_desc.batch_size() *
2306                    input_desc.data_size();
2307   if ((size_data > 0) && (input_backprop_data->opaque() != nullptr))
2308     stream->ThenMemZero(input_backprop_data, size_data * type_size);
2309 
2310   size_data = input_h_desc.num_layers() * input_h_desc.batch_size() *
2311               input_h_desc.data_size();
2312   if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr))
2313     stream->ThenMemZero(input_h_backprop_data, size_data * type_size);
2314 
2315   size_data = input_c_desc.num_layers() * input_c_desc.batch_size() *
2316               input_c_desc.data_size();
2317   if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr))
2318     stream->ThenMemZero(input_c_backprop_data, size_data * type_size);
2319 
2320   // make the backward data call
2321   auto status = wrap::miopenRNNBackwardData(
2322       miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2323       model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
2324       output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
2325       output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
2326       output_h_backprop_data.opaque() /*dhy*/,
2327       output_c_desc.handle() /*dcyDesc*/,
2328       output_c_backprop_data.opaque() /*dcy*/,
2329       rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
2330       input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
2331       input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
2332       input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
2333       input_h_desc.handle() /*dhxDesc*/,
2334       input_h_backprop_data->opaque() /*dhx*/,
2335       input_c_desc.handle() /*dcxDesc*/,
2336       input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
2337       workspace.size() /*workSpaceSizeInBytes*/,
2338       reserve_space_data->opaque() /*reserveSpace*/,
2339       reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2340   if (status != miopenStatusSuccess) {
2341     LOG(ERROR) << "Failed to call miopenRNNBackwardData: " << ToString(status);
2342     return false;
2343   }
2344 
2345   if (params_backprop_data != nullptr) {
2346     // Clear the dw to zeros.
2347     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2348     // make the backward weight call
2349     status = wrap::miopenRNNBackwardWeights(
2350         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2351         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2352         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2353         input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
2354         output_data.opaque() /*y*/, rnn_desc.params_handle() /*dwDesc*/,
2355         params_backprop_data->opaque() /*dw*/, workspace.opaque() /*workspace*/,
2356         workspace.size() /*workSpaceSizeInBytes*/,
2357         reserve_space_data->opaque() /*reserveSpace*/,
2358         reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2359     if (status != miopenStatusSuccess) {
2360       LOG(ERROR) << "Failed to call miopenRNNBackwardWeights: "
2361                  << ToString(status);
2362       return false;
2363     }
2364   }
2365 
2366   return true;
2367 }
2368 
MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc)2369 MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
2370     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc)
2371     : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
2372   miopenTensorDescriptor_t input_desc = nullptr;
2373   {
2374     // Query the params size.
2375     auto status = wrap::miopenCreateTensorDescriptor(&input_desc);
2376     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create tensor descriptor");
2377     std::array<int, 2> dims = {{1, rnn_desc.input_size()}};
2378     status = wrap::miopenSetTensorDescriptor(
2379         input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
2380         2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/);
2381     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to set tensor descriptor");
2382 
2383     size_t params_size = 0;
2384     status = wrap::miopenGetRNNParamsSize(
2385         miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2386         input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
2387         rnn_desc.data_type() /*dataType*/);
2388     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to get RNN parameter size");
2389     params_size_in_bytes_ = static_cast<int64_t>(params_size);
2390   }
2391 
2392   {
2393     // Create the params descriptor.
2394     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
2395     RETURN_IF_MIOPEN_ERROR(status,
2396                            "MIOpen fails to create RNN params descriptor");
2397     status = wrap::miopenGetRNNParamsDescriptor(miopen_handle,
2398                                                 rnn_desc.handle(), input_desc,
2399                                                 handle_, rnn_desc.data_type());
2400     RETURN_IF_MIOPEN_ERROR(status,
2401                            "MIOpen fails to update RNN filter descriptor");
2402   }
2403   {
2404     // Release the dummy input tensor descriptor.
2405     auto status = wrap::miopenDestroyTensorDescriptor(input_desc);
2406     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to destroy tensor descriptor");
2407   }
2408 }
2409 
2410 class MIOpenCTCLossDescriptor {
2411  public:
MIOpenCTCLossDescriptor(miopenDataType_t data_type)2412   explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
2413     auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
2414     if (status != miopenStatusSuccess) {
2415       LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
2416                  << ToString(status);
2417     }
2418 
2419     bool apply_softmax_layer = true;
2420     status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
2421                                               apply_softmax_layer);
2422     if (status != miopenStatusSuccess) {
2423       LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
2424                  << ToString(status);
2425     }
2426   }
2427 
~MIOpenCTCLossDescriptor()2428   ~MIOpenCTCLossDescriptor() {
2429     auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
2430     if (status != miopenStatusSuccess) {
2431       LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2432                  << ToString(status);
2433     }
2434   }
2435 
handle() const2436   miopenCTCLossDescriptor_t handle() const { return handle_; }
2437 
2438  private:
2439   miopenCTCLossDescriptor_t handle_;  // Owned
2440 
2441   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
2442 };
2443 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2444 port::Status MIOpenSupport::DoPrepareForCtcLoss(
2445     Stream* stream, dnn::DataType element_type,
2446     const dnn::RnnStateTensorDescriptor& probs_desc,
2447     const dnn::RnnStateTensorDescriptor& grads_desc,
2448     absl::Span<const int> labels_data,
2449     absl::Span<const int> labels_lengths_data,
2450     absl::Span<const int> input_lengths_data,
2451     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2452     int* ctc_loss_algo_id) {
2453   auto miopen = miopen_->GetHandle(parent_, stream);
2454 
2455   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2456 
2457   // Query the workspace size.
2458   size_t workspace_size_in_bytes = 0;
2459 
2460   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2461       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2462 
2463   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2464       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2465 
2466   auto status = wrap::miopenGetCTCLossWorkspaceSize(
2467       miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
2468       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2469       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
2470       &workspace_size_in_bytes);
2471 
2472   if (status != miopenStatusSuccess) {
2473     LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2474                << ToString(status);
2475     return port::InternalError(
2476         "Failed to determine scratch memory size for MIOpen CTC Loss");
2477   }
2478 
2479   *scratch_memory = DeviceMemory<uint8>();
2480 
2481   // Allocate the workspace.
2482   if (workspace_size_in_bytes != 0) {
2483     if (scratch_allocator == nullptr) {
2484       return port::InternalError(
2485           absl::StrCat("An allocator must be specified when scratch memory is "
2486                        "needed"));
2487     }
2488     auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
2489     if (scratch_or.ok()) {
2490       *scratch_memory = scratch_or.ValueOrDie();
2491     } else {
2492       LOG(ERROR)
2493           << "Failed to allocate scratch memory - "
2494           << scratch_or.status().error_message() << "\n"
2495           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2496              "larger number (e.g. 8192) to increase the max memory limit.\n"
2497           << "\tIncreasing the max memory limit might help resolve this "
2498              "error";
2499       return port::InternalError(absl::StrCat(
2500           "Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
2501           workspace_size_in_bytes));
2502     }
2503   }
2504 
2505   return port::Status::OK();
2506 }
2507 
DoCtcLossImpl(Stream * stream,const MIOpenRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const MIOpenRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const MIOpenCTCLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2508 port::Status MIOpenSupport::DoCtcLossImpl(
2509     Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
2510     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2511     absl::Span<const int> labels_lengths_data,
2512     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2513     const MIOpenRnnStateTensorDescriptor& grads_desc,
2514     DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
2515     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2516   auto miopen = miopen_->GetHandle(parent_, stream);
2517 
2518   int kNumTimestamps = probs_desc.num_layers();
2519   int kBatchSize = probs_desc.batch_size();
2520   int kNumLabels = probs_desc.data_size();
2521   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2522   (void)total_size;
2523 
2524   auto status = wrap::miopenCTCLoss(
2525       miopen.handle(), probs_desc.handle(), probs_data.opaque(),
2526       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2527       costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
2528       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
2529       scratch_memory.opaque(), scratch_memory.size());
2530   if (status != miopenStatusSuccess) {
2531     LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
2532     return port::InternalError("Failure during MIOpen CTC Loss");
2533   }
2534 
2535   return port::Status::OK();
2536 }
2537 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2538 port::Status MIOpenSupport::DoCtcLoss(
2539     Stream* stream, dnn::DataType element_type,
2540     const dnn::RnnStateTensorDescriptor& probs_desc,
2541     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2542     absl::Span<const int> labels_lengths_data,
2543     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2544     const dnn::RnnStateTensorDescriptor& grads_desc,
2545     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
2546     int ctc_loss_algo_id) {
2547   // Current MIOPen CTC Loss only supports the float datatype
2548   if (element_type != dnn::DataType::kFloat) {
2549     return port::Status(port::error::INVALID_ARGUMENT,
2550                         "MIOpenCTCLossDescriptor is supported only when the "
2551                         "DataType is float");
2552   }
2553 
2554   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2555 
2556   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2557       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2558 
2559   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2560       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2561 
2562   return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
2563                        labels_lengths_data, input_lengths_data, costs_data,
2564                        miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
2565                        scratch_memory, ctc_loss_algo_id);
2566 }
2567 
2568 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)2569 MIOpenSupport::createRnnDescriptor(
2570     int num_layers, int hidden_size, int input_size, int cell_size,
2571     int batch_size, dnn::RnnInputMode input_mode,
2572     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2573     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2574     float dropout, uint64_t seed, ScratchAllocator* state_allocator,
2575     bool use_padded_io) {
2576   // ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is
2577   // not supported by MIOpen now.
2578   if (use_padded_io) {
2579     return port::Status(port::error::INVALID_ARGUMENT,
2580                         "ROCm MIOpen only supports packed input output.");
2581   }
2582 
2583   bool use_projection = cell_size != 0 && hidden_size < cell_size;
2584   if (use_projection) {
2585     return port::Status(
2586         port::error::INVALID_ARGUMENT,
2587         "ROCm MIOpen does not support RNN ProjectionLayers yet.");
2588   }
2589 
2590   auto miopen = miopen_->GetHandle(parent_, nullptr);
2591   std::unique_ptr<MIOpenRnnDescriptor> rnn_desc(new MIOpenRnnDescriptor(
2592       miopen.handle(), num_layers, hidden_size, input_size,
2593       ToMIOpenRnnInputMode(input_mode),
2594       ToMIOpenRnnDirectionMode(direction_mode), ToMIOpenRnnMode(rnn_mode),
2595       ToMIOpenDataType(data_type), dropout, seed, state_allocator));
2596   if (!rnn_desc->ok()) {
2597     return rnn_desc->Status();
2598   }
2599   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
2600       std::move(rnn_desc));
2601 }
2602 
2603 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)2604 MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
2605                                                  int data_size,
2606                                                  dnn::DataType data_type) {
2607   std::unique_ptr<MIOpenRnnSequenceTensorDescriptor> seq_desc(
2608       new MIOpenRnnSequenceTensorDescriptor(seq_length, batch_size, data_size,
2609                                             ToMIOpenDataType(data_type)));
2610   if (!seq_desc->ok()) {
2611     return seq_desc->Status();
2612   }
2613   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
2614       std::move(seq_desc));
2615 }
2616 
2617 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2618 MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2619                                               int data_size,
2620                                               dnn::DataType data_type) {
2621   std::unique_ptr<MIOpenRnnStateTensorDescriptor> state_desc(
2622       new MIOpenRnnStateTensorDescriptor(num_layer, batch_size, data_size,
2623                                          ToMIOpenDataType(data_type)));
2624   if (!state_desc->ok()) {
2625     return state_desc->Status();
2626   }
2627   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
2628       std::move(state_desc));
2629 }
2630 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2631 bool MIOpenSupport::DoRnnForward(
2632     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2633     const dnn::RnnSequenceTensorDescriptor& input_desc,
2634     const DeviceMemory<Eigen::half>& input_data,
2635     const DeviceMemory<int>& seq_lengths_data,
2636     const dnn::RnnStateTensorDescriptor& input_h_desc,
2637     const DeviceMemory<Eigen::half>& input_h_data,
2638     const dnn::RnnStateTensorDescriptor& input_c_desc,
2639     const DeviceMemory<Eigen::half>& input_c_data,
2640     const DeviceMemory<Eigen::half>& params,
2641     const dnn::RnnSequenceTensorDescriptor& output_desc,
2642     DeviceMemory<Eigen::half>* output_data,
2643     const dnn::RnnStateTensorDescriptor& output_h_desc,
2644     DeviceMemory<Eigen::half>* output_h_data,
2645     const dnn::RnnStateTensorDescriptor& output_c_desc,
2646     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2647     ScratchAllocator* reserve_space_allocator,
2648     ScratchAllocator* workspace_allocator,
2649     dnn::ProfileResult* output_profile_result) {
2650   // ROCM TODO: output_profile_result is ignore for now
2651 
2652   const MIOpenRnnDescriptor& miopen_rnn_desc =
2653       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2654   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2655       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2656   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2657       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2658   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2659       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2660   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2661       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2662   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2663       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2664   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2665       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2666 
2667   return DoRnnForwardImpl<Eigen::half>(
2668       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2669       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2670       params, miopen_output_desc, output_data, miopen_output_h_desc,
2671       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2672       reserve_space_allocator, workspace_allocator);
2673 }
2674 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2675 bool MIOpenSupport::DoRnnForward(
2676     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2677     const dnn::RnnSequenceTensorDescriptor& input_desc,
2678     const DeviceMemory<float>& input_data,
2679     const DeviceMemory<int>& seq_lengths_data,
2680     const dnn::RnnStateTensorDescriptor& input_h_desc,
2681     const DeviceMemory<float>& input_h_data,
2682     const dnn::RnnStateTensorDescriptor& input_c_desc,
2683     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2684     const dnn::RnnSequenceTensorDescriptor& output_desc,
2685     DeviceMemory<float>* output_data,
2686     const dnn::RnnStateTensorDescriptor& output_h_desc,
2687     DeviceMemory<float>* output_h_data,
2688     const dnn::RnnStateTensorDescriptor& output_c_desc,
2689     DeviceMemory<float>* output_c_data, bool is_training,
2690     ScratchAllocator* reserve_space_allocator,
2691     ScratchAllocator* workspace_allocator,
2692     dnn::ProfileResult* output_profile_result) {
2693   // ROCM TODO: output_profile_result is ignore for now
2694 
2695   const MIOpenRnnDescriptor& miopen_rnn_desc =
2696       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2697   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2698       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2699   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2700       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2701   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2702       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2703   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2704       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2705   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2706       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2707   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2708       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2709 
2710   return DoRnnForwardImpl<float>(
2711       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2712       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2713       params, miopen_output_desc, output_data, miopen_output_h_desc,
2714       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2715       reserve_space_allocator, workspace_allocator);
2716 }
2717 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2718 bool MIOpenSupport::DoRnnForward(
2719     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2720     const dnn::RnnSequenceTensorDescriptor& input_desc,
2721     const DeviceMemory<double>& input_data,
2722     const DeviceMemory<int>& seq_lengths_data,
2723     const dnn::RnnStateTensorDescriptor& input_h_desc,
2724     const DeviceMemory<double>& input_h_data,
2725     const dnn::RnnStateTensorDescriptor& input_c_desc,
2726     const DeviceMemory<double>& input_c_data,
2727     const DeviceMemory<double>& params,
2728     const dnn::RnnSequenceTensorDescriptor& output_desc,
2729     DeviceMemory<double>* output_data,
2730     const dnn::RnnStateTensorDescriptor& output_h_desc,
2731     DeviceMemory<double>* output_h_data,
2732     const dnn::RnnStateTensorDescriptor& output_c_desc,
2733     DeviceMemory<double>* output_c_data, bool is_training,
2734     ScratchAllocator* reserve_space_allocator,
2735     ScratchAllocator* workspace_allocator,
2736     dnn::ProfileResult* output_profile_result) {
2737   LOG(ERROR) << "miopen does not support double type RNN fwd yet";
2738   return false;
2739 }
2740 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2741 bool MIOpenSupport::DoRnnBackward(
2742     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2743     const dnn::RnnSequenceTensorDescriptor& input_desc,
2744     const DeviceMemory<Eigen::half>& input_data,
2745     const DeviceMemory<int>& seq_lengths_data,
2746     const dnn::RnnStateTensorDescriptor& input_h_desc,
2747     const DeviceMemory<Eigen::half>& input_h_data,
2748     const dnn::RnnStateTensorDescriptor& input_c_desc,
2749     const DeviceMemory<Eigen::half>& input_c_data,
2750     const DeviceMemory<Eigen::half>& params,
2751     const dnn::RnnSequenceTensorDescriptor& output_desc,
2752     const DeviceMemory<Eigen::half>& output_data,
2753     const dnn::RnnStateTensorDescriptor& output_h_desc,
2754     const DeviceMemory<Eigen::half>& output_h_data,
2755     const dnn::RnnStateTensorDescriptor& output_c_desc,
2756     const DeviceMemory<Eigen::half>& output_c_data,
2757     const DeviceMemory<Eigen::half>& output_backprop_data,
2758     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2759     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2760     DeviceMemory<Eigen::half>* input_backprop_data,
2761     DeviceMemory<Eigen::half>* input_h_backprop_data,
2762     DeviceMemory<Eigen::half>* input_c_backprop_data,
2763     DeviceMemory<Eigen::half>* params_backprop_data,
2764     DeviceMemory<uint8>* reserve_space_data,
2765     ScratchAllocator* workspace_allocator,
2766     dnn::ProfileResult* output_profile_result) {
2767   // ROCM TODO: output_profile_result is ignore for now
2768 
2769   const MIOpenRnnDescriptor& miopen_rnn_desc =
2770       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2771   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2772       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2773   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2774       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2775   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2776       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2777   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2778       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2779   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2780       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2781   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2782       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2783 
2784   return DoRnnBackwardImpl<Eigen::half>(
2785       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2786       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2787       params, miopen_output_desc, output_data, miopen_output_h_desc,
2788       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2789       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2790       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2791       reserve_space_data, workspace_allocator);
2792 }
2793 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2794 bool MIOpenSupport::DoRnnBackward(
2795     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2796     const dnn::RnnSequenceTensorDescriptor& input_desc,
2797     const DeviceMemory<float>& input_data,
2798     const DeviceMemory<int>& seq_lengths_data,
2799     const dnn::RnnStateTensorDescriptor& input_h_desc,
2800     const DeviceMemory<float>& input_h_data,
2801     const dnn::RnnStateTensorDescriptor& input_c_desc,
2802     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2803     const dnn::RnnSequenceTensorDescriptor& output_desc,
2804     const DeviceMemory<float>& output_data,
2805     const dnn::RnnStateTensorDescriptor& output_h_desc,
2806     const DeviceMemory<float>& output_h_data,
2807     const dnn::RnnStateTensorDescriptor& output_c_desc,
2808     const DeviceMemory<float>& output_c_data,
2809     const DeviceMemory<float>& output_backprop_data,
2810     const DeviceMemory<float>& output_h_backprop_data,
2811     const DeviceMemory<float>& output_c_backprop_data,
2812     DeviceMemory<float>* input_backprop_data,
2813     DeviceMemory<float>* input_h_backprop_data,
2814     DeviceMemory<float>* input_c_backprop_data,
2815     DeviceMemory<float>* params_backprop_data,
2816     DeviceMemory<uint8>* reserve_space_data,
2817     ScratchAllocator* workspace_allocator,
2818     dnn::ProfileResult* output_profile_result) {
2819   // ROCM TODO: output_profile_result is ignore for now
2820 
2821   const MIOpenRnnDescriptor& miopen_rnn_desc =
2822       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2823   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2824       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2825   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2826       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2827   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2828       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2829   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2830       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2831   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2832       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2833   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2834       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2835 
2836   return DoRnnBackwardImpl<float>(
2837       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2838       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2839       params, miopen_output_desc, output_data, miopen_output_h_desc,
2840       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2841       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2842       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2843       reserve_space_data, workspace_allocator);
2844 }
2845 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2846 bool MIOpenSupport::DoRnnBackward(
2847     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2848     const dnn::RnnSequenceTensorDescriptor& input_desc,
2849     const DeviceMemory<double>& input_data,
2850     const DeviceMemory<int>& seq_lengths_data,
2851     const dnn::RnnStateTensorDescriptor& input_h_desc,
2852     const DeviceMemory<double>& input_h_data,
2853     const dnn::RnnStateTensorDescriptor& input_c_desc,
2854     const DeviceMemory<double>& input_c_data,
2855     const DeviceMemory<double>& params,
2856     const dnn::RnnSequenceTensorDescriptor& output_desc,
2857     const DeviceMemory<double>& output_data,
2858     const dnn::RnnStateTensorDescriptor& output_h_desc,
2859     const DeviceMemory<double>& output_h_data,
2860     const dnn::RnnStateTensorDescriptor& output_c_desc,
2861     const DeviceMemory<double>& output_c_data,
2862     const DeviceMemory<double>& output_backprop_data,
2863     const DeviceMemory<double>& output_h_backprop_data,
2864     const DeviceMemory<double>& output_c_backprop_data,
2865     DeviceMemory<double>* input_backprop_data,
2866     DeviceMemory<double>* input_h_backprop_data,
2867     DeviceMemory<double>* input_c_backprop_data,
2868     DeviceMemory<double>* params_backprop_data,
2869     DeviceMemory<uint8>* reserve_space_data,
2870     ScratchAllocator* workspace_allocator,
2871     dnn::ProfileResult* output_profile_result) {
2872   LOG(ERROR) << "miopen does not support half type RNN bwd yet";
2873   return false;
2874 }
2875 
2876 // This is the context required to use the TF scratch allocator:
2877 struct MIOpenAllocatorContext {
MIOpenAllocatorContextstream_executor::gpu::MIOpenAllocatorContext2878   MIOpenAllocatorContext(ScratchAllocator* scratch_allocator, Stream* stream)
2879       : scratch_allocator_(scratch_allocator), stream_(stream) {}
2880 
2881   ScratchAllocator* scratch_allocator_;
2882   Stream* stream_;
2883 };
2884 
MIOpenAllocatorCallback(void * ctx,size_t size_in_bytes)2885 void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
2886   auto* mac = static_cast<MIOpenAllocatorContext*>(ctx);
2887   auto allocated = mac->scratch_allocator_->AllocateBytes(size_in_bytes);
2888 
2889   DeviceMemory<uint8> scratch;
2890   if (allocated.ok()) {
2891     scratch = allocated.ValueOrDie();
2892     return scratch.opaque();
2893   } else {
2894     return nullptr;
2895   }
2896 }
2897 
MIOpenDeallocatorCallback(void * ctx,void * mem)2898 void MIOpenDeallocatorCallback(void* ctx, void* mem) {
2899   // Don't need deallocator since the TensorFlow heap will automatically
2900   // reclaim the memory
2901 }
2902 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2903 port::Status MIOpenSupport::DoPrepareForConvolution(
2904     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2905     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2906     const dnn::FilterDescriptor& filter_descriptor,
2907     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2908     DeviceMemoryBase output_data,
2909     const dnn::ConvolutionDescriptor& convolution_descriptor,
2910     const dnn::AlgorithmConfig& algorithm_config,
2911     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2912     DeviceMemory<uint8>* scratch_memory) {
2913   std::optional<dnn::AlgorithmDesc> input_algo_desc =
2914       algorithm_config.algorithm();
2915 
2916   assert(input_algo_desc.has_value());
2917 
2918   // An algorithm has been specified.
2919   *algorithm_desc = *input_algo_desc;
2920 
2921   assert(algorithm_config.scratch_size().has_value());
2922 
2923   size_t scratch_memory_size = *(algorithm_config.scratch_size());
2924 
2925   // allocate scratch memory
2926   if (scratch_memory_size != 0) {
2927     if (scratch_allocator == nullptr) {
2928       return port::InternalError(
2929           absl::StrCat("An allocator must be specified when scratch memory is "
2930                        "needed"));
2931     }
2932     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
2933     if (allocated.ok()) {
2934       *scratch_memory = allocated.ValueOrDie();
2935     } else {
2936       LOG(ERROR)
2937           << "Failed to allocate scratch memory - "
2938           << allocated.status().error_message() << "\n"
2939           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2940              "larger number (e.g. 8192) to increase the max memory limit.\n"
2941           << "\tIncreasing the max memory limit might help resolve this "
2942              "error";
2943       return port::InternalError(absl::StrCat(
2944           "Failed to allocate scratch memory of size: ", scratch_memory_size));
2945     }
2946   }
2947 
2948   return port::Status::OK();
2949 }
2950 
2951 class RocmConvRunner : public dnn::ConvRunner {
2952  public:
RocmConvRunner(GpuExecutor * parent,MIOpenAccess * miopen,int64_t algo_id,size_t workspace_size,dnn::ConvolutionKind kind,dnn::DataType input_type,bool use_immediate_mode,BatchDescriptor input_descriptor,BatchDescriptor output_descriptor,FilterDescriptor filter_descriptor,ConvolutionDescriptor conv_descriptor)2953   RocmConvRunner(GpuExecutor* parent, MIOpenAccess* miopen, int64_t algo_id,
2954                  size_t workspace_size, dnn::ConvolutionKind kind,
2955                  dnn::DataType input_type, bool use_immediate_mode,
2956                  BatchDescriptor input_descriptor,
2957                  BatchDescriptor output_descriptor,
2958                  FilterDescriptor filter_descriptor,
2959                  ConvolutionDescriptor conv_descriptor)
2960       : parent_(parent),
2961         miopen_(miopen),
2962         algo_id_(algo_id),
2963         workspace_size_(workspace_size),
2964         kind_(kind),
2965         use_immediate_mode_(use_immediate_mode),
2966         input_desc_{input_descriptor, ToMIOpenDataType(input_type)},
2967         output_desc_{output_descriptor, ToMIOpenDataType(input_type)},
2968         filter_desc_{filter_descriptor, ToMIOpenDataType(input_type)},
2969         conv_desc_{conv_descriptor, ToMIOpenDataType(input_type)} {}
2970 
ToString() const2971   std::string ToString() const override {
2972     return dnn::AlgorithmDesc{algo_id_, false, workspace_size_}.ToString();
2973   }
2974 
GetWorkspaceSize() const2975   size_t GetWorkspaceSize() const override { return workspace_size_; }
2976 
ToAlgorithmDesc() const2977   port::StatusOr<AlgorithmDesc> ToAlgorithmDesc() const override {
2978     return {{algo_id_, false, workspace_size_}};
2979   }
2980 
operator ()(Stream * stream,dnn::ProfileResult * profile_result,DeviceMemoryBase scratch_memory,DeviceMemoryBase input_data,DeviceMemoryBase filter_data,DeviceMemoryBase output_data) const2981   port::Status operator()(Stream* stream, dnn::ProfileResult* profile_result,
2982                           DeviceMemoryBase scratch_memory,
2983                           DeviceMemoryBase input_data,
2984                           DeviceMemoryBase filter_data,
2985                           DeviceMemoryBase output_data) const override {
2986     auto miopen = miopen_->GetHandle(parent_, stream);
2987     // Alpha is the scaling factor for input.
2988     float alpha = 1.0;
2989     // Beta is the scaling factor for output.
2990     float beta = 0.0;
2991 
2992     const bool is_profiling = profile_result != nullptr;
2993 
2994     std::unique_ptr<GpuTimer> timer;
2995     if (is_profiling) {
2996       timer.reset(new GpuTimer(parent_));
2997       if (!timer->Init()) {
2998         return port::Status(port::error::INTERNAL, "Failed to init timer");
2999       }
3000       // The start and stop of the timer should be as close to the MIOpen call
3001       // as possible. It is still possible for other threads to issue workload
3002       // on to this stream. So it could take multiple profiling measurements.
3003       if (!timer->Start(AsGpuStream(stream))) {
3004         timer->Destroy();
3005         return port::Status(port::error::INTERNAL, "Failed to start timer");
3006       }
3007     }
3008 
3009     miopenStatus_t status = miopenStatusSuccess;
3010     switch (kind_) {
3011       case dnn::ConvolutionKind::FORWARD: {
3012         if (use_immediate_mode_) {
3013           status = wrap::miopenConvolutionForwardImmediate(
3014               miopen.handle(), filter_desc_.handle(), filter_data.opaque(),
3015               input_desc_.handle(), input_data.opaque(), conv_desc_.handle(),
3016               output_desc_.handle(), output_data.opaque(),
3017               scratch_memory.opaque(), scratch_memory.size(),
3018               static_cast<uint64_t>(algo_id_));
3019         } else {
3020           status = wrap::miopenConvolutionForward(
3021               miopen.handle(), &alpha, input_desc_.handle(),
3022               input_data.opaque(), filter_desc_.handle(), filter_data.opaque(),
3023               conv_desc_.handle(),
3024               static_cast<miopenConvFwdAlgorithm_t>(algo_id_), &beta,
3025               output_desc_.handle(), output_data.opaque(),
3026               scratch_memory.opaque(), scratch_memory.size());
3027         }
3028 
3029         break;
3030       }
3031       case dnn::ConvolutionKind::BACKWARD_DATA: {
3032         if (use_immediate_mode_) {
3033           status = wrap::miopenConvolutionBackwardDataImmediate(
3034               miopen.handle(), output_desc_.handle(), output_data.opaque(),
3035               filter_desc_.handle(), filter_data.opaque(), conv_desc_.handle(),
3036               input_desc_.handle(), input_data.opaque(),
3037               scratch_memory.opaque(), scratch_memory.size(),
3038               static_cast<uint64_t>(algo_id_));
3039         } else {
3040           status = wrap::miopenConvolutionBackwardData(
3041               miopen.handle(), &alpha, output_desc_.handle(),
3042               output_data.opaque(), filter_desc_.handle(), filter_data.opaque(),
3043               conv_desc_.handle(),
3044               static_cast<miopenConvBwdDataAlgorithm_t>(algo_id_), &beta,
3045               input_desc_.handle(), input_data.opaque(),
3046               scratch_memory.opaque(), scratch_memory.size());
3047         }
3048         break;
3049       }
3050       case dnn::ConvolutionKind::BACKWARD_FILTER: {
3051         if (use_immediate_mode_) {
3052           status = wrap::miopenConvolutionBackwardWeightsImmediate(
3053               miopen.handle(), output_desc_.handle(), output_data.opaque(),
3054               input_desc_.handle(), input_data.opaque(), conv_desc_.handle(),
3055               filter_desc_.handle(), filter_data.opaque(),
3056               scratch_memory.opaque(), scratch_memory.size(),
3057               static_cast<uint64_t>(algo_id_));
3058         } else {
3059           status = wrap::miopenConvolutionBackwardWeights(
3060               miopen.handle(), &alpha, output_desc_.handle(),
3061               output_data.opaque(), input_desc_.handle(), input_data.opaque(),
3062               conv_desc_.handle(),
3063               static_cast<miopenConvBwdWeightsAlgorithm_t>(algo_id_), &beta,
3064               filter_desc_.handle(), filter_data.opaque(),
3065               scratch_memory.opaque(), scratch_memory.size());
3066         }
3067         break;
3068       }
3069       default:
3070         return port::InternalError(absl::StrCat("Unexpected convolution kind ",
3071                                                 static_cast<int>(kind_)));
3072     }
3073 
3074     if (is_profiling) {
3075       if (!timer->Stop(AsGpuStream(stream))) {
3076         timer->Destroy();
3077         return port::Status(port::error::INTERNAL, "Failed to stop timer");
3078       }
3079       if (status == miopenStatusSuccess) {
3080         dnn::AlgorithmDesc algotype(algo_id_, false);
3081         profile_result->set_algorithm(algotype);
3082         profile_result->set_elapsed_time_in_ms(timer->GetElapsedMilliseconds());
3083         profile_result->set_scratch_size(scratch_memory.size());
3084       }
3085       timer->Destroy();
3086     }
3087 
3088     if (status != miopenStatusSuccess) {
3089       return port::InternalError(
3090           absl::StrCat("Failed to enqueue convolution on stream: ",
3091                        ::stream_executor::gpu::ToString(status)));
3092     }
3093 
3094     return port::Status::OK();
3095   }
3096 
3097  private:
3098   GpuExecutor* parent_;
3099   MIOpenAccess* miopen_;
3100   int64_t algo_id_;
3101   size_t workspace_size_;
3102   dnn::ConvolutionKind kind_;
3103   bool use_immediate_mode_;
3104 
3105   ScopedTensorDescriptor input_desc_;
3106   ScopedTensorDescriptor output_desc_;
3107   ScopedFilterDescriptor filter_desc_;
3108   ScopedConvolutionDescriptor conv_desc_;
3109 };
3110 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)3111 port::Status MIOpenSupport::DoConvolve(
3112     dnn::ConvolutionKind kind, dnn::DataType element_type,
3113     dnn::DataType output_type, Stream* stream,
3114     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3115     const dnn::FilterDescriptor& filter_descriptor,
3116     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3117     DeviceMemoryBase output_data,
3118     const dnn::ConvolutionDescriptor& convolution_descriptor,
3119     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
3120     dnn::ProfileResult* output_profile_result) {
3121   TF_ASSIGN_OR_RETURN(
3122       auto runner,
3123       ConvolveRunnerFromDesc(stream, algorithm_desc, kind, element_type,
3124                              output_type, input_descriptor, filter_descriptor,
3125                              output_descriptor, convolution_descriptor));
3126 
3127   return (*runner)(stream, output_profile_result, scratch_memory, input_data,
3128                    filter_data, output_data);
3129 }
3130 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3131 bool MIOpenSupport::GetConvolveAlgorithms(
3132     // ROCM TODO: refactor cc_major / cc_minor
3133     CudaComputeCapability cuda_compute_capability,
3134     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3135   out_algorithms->assign({
3136       // clang-format off
3137       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoGEMM, false),
3138       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoDirect, false),
3139       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoFFT, false),
3140       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoWinograd, false),
3141       // clang-format on
3142   });
3143   return true;
3144 }
3145 
GetConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,ScratchAllocator * scratch_allocator,std::vector<std::unique_ptr<const dnn::ConvRunner>> * out_runners)3146 port::Status MIOpenSupport::GetConvolveRunners(
3147     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
3148     dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
3149     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3150     const dnn::FilterDescriptor& filter_descriptor,
3151     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3152     DeviceMemoryBase output_data,
3153     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
3154     ScratchAllocator* scratch_allocator,
3155     std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_runners) {
3156   if (input_type != output_type) {
3157     return port::UnimplementedError(
3158         absl::StrFormat("MIOpen backend does not support different input and "
3159                         "output types: %d != %d",
3160                         input_type, output_type));
3161   }
3162 
3163   std::vector<dnn::ProfileResult> profile_results;
3164   if (!GetMIOpenConvolveAlgorithms(
3165           kind, input_type, stream, input_descriptor, input_data,
3166           filter_descriptor, filter_data, output_descriptor, output_data,
3167           convolution_descriptor, scratch_allocator, &profile_results)) {
3168     return port::Status(
3169         port::error::UNKNOWN,
3170         "GetConvolveRunners: GetMIOpenConvolveAlgorithms failed");
3171   }
3172 
3173   for (const auto& profile_result : profile_results) {
3174     TF_ASSIGN_OR_RETURN(
3175         auto runner, ConvolveRunnerFromDesc(
3176                          stream, profile_result.algorithm(), kind, input_type,
3177                          output_type, input_descriptor, filter_descriptor,
3178                          output_descriptor, convolution_descriptor));
3179     out_runners->push_back(std::move(runner));
3180   }
3181 
3182   return port::Status::OK();
3183 }
3184 
3185 port::StatusOr<std::unique_ptr<const dnn::ConvRunner>>
ConvolveRunnerFromDesc(Stream * stream,const dnn::AlgorithmDesc & algorithm_desc,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor)3186 MIOpenSupport::ConvolveRunnerFromDesc(
3187     Stream* stream, const dnn::AlgorithmDesc& algorithm_desc,
3188     dnn::ConvolutionKind kind, dnn::DataType input_type,
3189     dnn::DataType output_type, const dnn::BatchDescriptor& input_descriptor,
3190     const dnn::FilterDescriptor& filter_descriptor,
3191     const dnn::BatchDescriptor& output_descriptor,
3192     const dnn::ConvolutionDescriptor& convolution_descriptor) {
3193   if (input_type != output_type) {
3194     return port::UnimplementedError(
3195         absl::StrFormat("MIOpen backend does not support different input and "
3196                         "output types: %d != %d",
3197                         input_type, output_type));
3198   }
3199 
3200   auto workspace_size = algorithm_desc.workspace_size();
3201   if (!workspace_size) {
3202     return port::InvalidArgumentError(
3203         "MIOpenSupport::ConvolveRunnerFromDesc requires "
3204         "AlgorithmProto.workspace_size, but it was missing.");
3205   }
3206   return {std::make_unique<RocmConvRunner>(
3207       parent_, miopen_.get(), algorithm_desc.algo_id(), *workspace_size, kind,
3208       input_type, use_immediate_mode_, input_descriptor, output_descriptor,
3209       filter_descriptor, convolution_descriptor)};
3210 }
3211 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3212 bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
3213     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3214     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3215     const dnn::FilterDescriptor& filter_descriptor,
3216     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3217     DeviceMemoryBase output_data,
3218     const dnn::ConvolutionDescriptor& convolution_descriptor,
3219     ScratchAllocator* scratch_allocator,
3220     std::vector<dnn::ProfileResult>* out_algorithms) {
3221   return use_immediate_mode_
3222              ? GetMIOpenConvolveAlgorithmsImmediateMode(
3223                    kind, element_type, stream, input_descriptor, input_data,
3224                    filter_descriptor, filter_data, output_descriptor,
3225                    output_data, convolution_descriptor, scratch_allocator,
3226                    out_algorithms)
3227              : GetMIOpenConvolveAlgorithmsFindMode(
3228                    kind, element_type, stream, input_descriptor, input_data,
3229                    filter_descriptor, filter_data, output_descriptor,
3230                    output_data, convolution_descriptor, scratch_allocator,
3231                    out_algorithms);
3232 }
3233 
GetMIOpenConvolveAlgorithmsImmediateMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3234 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode(
3235     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3236     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3237     const dnn::FilterDescriptor& filter_descriptor,
3238     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3239     DeviceMemoryBase output_data,
3240     const dnn::ConvolutionDescriptor& convolution_descriptor,
3241     ScratchAllocator* scratch_allocator,
3242     std::vector<dnn::ProfileResult>* out_algorithms) {
3243   auto miopen = miopen_->GetHandle(parent_, stream);
3244 
3245   ScopedTensorDescriptor input_nd{input_descriptor,
3246                                   ToMIOpenDataType(element_type)};
3247   ScopedTensorDescriptor output_nd{output_descriptor,
3248                                    ToMIOpenDataType(element_type)};
3249   ScopedFilterDescriptor filter{filter_descriptor,
3250                                 ToMIOpenDataType(element_type)};
3251   ScopedConvolutionDescriptor conv{convolution_descriptor,
3252                                    ToMIOpenDataType(element_type)};
3253 
3254   // First determine the number of algorityhms available
3255   size_t maxSolutionCount = 0;
3256 
3257   switch (kind) {
3258     case dnn::ConvolutionKind::FORWARD: {
3259       auto status = wrap::miopenConvolutionForwardGetSolutionCount(
3260           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3261           output_nd.handle(), &maxSolutionCount);
3262       if (status != miopenStatusSuccess) {
3263         LOG(FATAL)
3264             << "call to miopenConvolutionForwardGetSolutionCount failed: "
3265             << ToString(status);
3266         return false;
3267       }
3268       break;
3269     }
3270     case dnn::ConvolutionKind::BACKWARD_DATA: {
3271       auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount(
3272           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3273           input_nd.handle(), &maxSolutionCount);
3274       if (status != miopenStatusSuccess) {
3275         LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
3276                       "failed: "
3277                    << ToString(status);
3278         return false;
3279       }
3280       break;
3281     }
3282     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3283       auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount(
3284           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3285           filter.handle(), &maxSolutionCount);
3286       if (status != miopenStatusSuccess) {
3287         LOG(FATAL)
3288             << "call to miopenConvolutionBackwardWeightsGetSolutionCount "
3289                "failed: "
3290             << ToString(status);
3291         return false;
3292       }
3293       break;
3294     }
3295     default: {
3296       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3297       return false;
3298       break;
3299     }
3300   }
3301 
3302   VLOG(kConvDebugVlogLevel)
3303       << "Number of conv solutions max: " << maxSolutionCount;
3304 
3305   if (return_best_algo_only_) {
3306     VLOG(kConvDebugVlogLevel) << "TF_ROCM_RETURN_BEST_ALGO_ONLY is set, "
3307                               << "setting maxSolutionCount to 1";
3308     maxSolutionCount = 1;
3309   }
3310 
3311   size_t solutionCount = 0;
3312   std::unique_ptr<miopenConvSolution_t[]> solutions(
3313       new miopenConvSolution_t[maxSolutionCount]);
3314 
3315   switch (kind) {
3316     case dnn::ConvolutionKind::FORWARD: {
3317       auto status = wrap::miopenConvolutionForwardGetSolution(
3318           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3319           output_nd.handle(), maxSolutionCount, &solutionCount,
3320           solutions.get());
3321 
3322       if (status != miopenStatusSuccess) {
3323         LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: "
3324                    << ToString(status);
3325         return false;
3326       }
3327 
3328       VLOG(kConvDebugVlogLevel)
3329           << "Number of conv solutions actual: " << solutionCount;
3330 
3331       for (size_t i = 0; i < solutionCount; i++) {
3332         miopenConvSolution_t solution = solutions[i];
3333 
3334         VLOG(kConvDebugVlogLevel)
3335             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3336             << ", " << solution.workspace_size << ", " << solution.solution_id
3337             << ", " << ToString(solution.algorithm);
3338 
3339         status = wrap::miopenConvolutionForwardCompileSolution(
3340             miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3341             output_nd.handle(), solution.solution_id);
3342 
3343         if (status != miopenStatusSuccess) {
3344           LOG(FATAL)
3345               << "call to miopenConvolutionForwardCompileSolution failed: "
3346               << ToString(status);
3347           return false;
3348         }
3349 
3350         out_algorithms->emplace_back(
3351             GetProfileResultFromConvSolution(solution));
3352       }
3353       break;
3354     }
3355 
3356     case dnn::ConvolutionKind::BACKWARD_DATA: {
3357       auto status = wrap::miopenConvolutionBackwardDataGetSolution(
3358           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3359           input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get());
3360       if (status != miopenStatusSuccess) {
3361         LOG(FATAL)
3362             << "call to miopenConvolutionBackwardDataGetSolution failed: "
3363             << ToString(status);
3364         return false;
3365       }
3366 
3367       VLOG(kConvDebugVlogLevel)
3368           << "Number of conv solutions actual: " << solutionCount;
3369 
3370       for (size_t i = 0; i < solutionCount; i++) {
3371         miopenConvSolution_t solution = solutions[i];
3372 
3373         VLOG(kConvDebugVlogLevel)
3374             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3375             << ", " << solution.workspace_size << ", " << solution.solution_id
3376             << ", " << ToString(solution.algorithm);
3377 
3378         status = wrap::miopenConvolutionBackwardDataCompileSolution(
3379             miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3380             input_nd.handle(), solution.solution_id);
3381 
3382         if (status != miopenStatusSuccess) {
3383           LOG(FATAL) << " call to miopenConvolutionBackwardDataCompileSolution "
3384                         "failed: "
3385                      << ToString(status);
3386           return false;
3387         }
3388 
3389         out_algorithms->emplace_back(
3390             GetProfileResultFromConvSolution(solution));
3391       }
3392       break;
3393     }
3394     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3395       auto status = wrap::miopenConvolutionBackwardWeightsGetSolution(
3396           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3397           filter.handle(), maxSolutionCount, &solutionCount, solutions.get());
3398       if (status != miopenStatusSuccess) {
3399         LOG(FATAL)
3400             << "call to miopenConvolutionBackwardWeightsGetSolution failed: "
3401             << ToString(status);
3402         return false;
3403       }
3404 
3405       VLOG(kConvDebugVlogLevel)
3406           << "Number of conv solutions actual: " << solutionCount;
3407 
3408       for (size_t i = 0; i < solutionCount; i++) {
3409         miopenConvSolution_t solution = solutions[i];
3410 
3411         VLOG(kConvDebugVlogLevel)
3412             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3413             << ", " << solution.workspace_size << ", " << solution.solution_id
3414             << ", " << ToString(solution.algorithm);
3415 
3416         status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
3417             miopen.handle(), output_nd.handle(), input_nd.handle(),
3418             conv.handle(), filter.handle(), solution.solution_id);
3419 
3420         if (status != miopenStatusSuccess) {
3421           LOG(FATAL)
3422               << "call to miopenConvolutionBackwardWeightsCompileSolution "
3423                  "failed: "
3424               << ToString(status);
3425           return false;
3426         }
3427 
3428         out_algorithms->emplace_back(
3429             GetProfileResultFromConvSolution(solution));
3430       }
3431       break;
3432     }
3433     default: {
3434       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3435       return false;
3436       break;
3437     }
3438   }
3439 
3440   return true;
3441 }
3442 
GetMIOpenConvolveAlgorithmsFindMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3443 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode(
3444     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3445     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3446     const dnn::FilterDescriptor& filter_descriptor,
3447     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3448     DeviceMemoryBase output_data,
3449     const dnn::ConvolutionDescriptor& convolution_descriptor,
3450     ScratchAllocator* scratch_allocator,
3451     std::vector<dnn::ProfileResult>* out_algorithms) {
3452   auto miopen = miopen_->GetHandle(parent_, stream);
3453 
3454   ScopedTensorDescriptor input_nd{input_descriptor,
3455                                   ToMIOpenDataType(element_type)};
3456   ScopedTensorDescriptor output_nd{output_descriptor,
3457                                    ToMIOpenDataType(element_type)};
3458   ScopedFilterDescriptor filter{filter_descriptor,
3459                                 ToMIOpenDataType(element_type)};
3460   ScopedConvolutionDescriptor conv{convolution_descriptor,
3461                                    ToMIOpenDataType(element_type)};
3462 
3463   // Determine the workspace memory size that will need by the call to Find
3464   size_t scratch_memory_size = 0;
3465   switch (kind) {
3466     case dnn::ConvolutionKind::FORWARD: {
3467       auto status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
3468           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3469           output_nd.handle(), &scratch_memory_size);
3470       if (status != miopenStatusSuccess) {
3471         LOG(FATAL)
3472             << "call to miopenConvolutionForwardGetWorkspaceSize failed: "
3473             << ToString(status);
3474         return false;
3475       }
3476       break;
3477     }
3478     case dnn::ConvolutionKind::BACKWARD_DATA: {
3479       auto status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
3480           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3481           input_nd.handle(), &scratch_memory_size);
3482       if (status != miopenStatusSuccess) {
3483         LOG(FATAL)
3484             << "call to miopenConvolutionBackwardDataGetWorkspaceSize failed: "
3485             << ToString(status);
3486         return false;
3487       }
3488       break;
3489     }
3490     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3491       auto status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
3492           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3493           filter.handle(), &scratch_memory_size);
3494       if (status != miopenStatusSuccess) {
3495         LOG(FATAL)
3496             << "call to miopenConvolutionBackwardWeightsGetWorkspaceSize "
3497                "failed: "
3498             << ToString(status);
3499         return false;
3500       }
3501       break;
3502     }
3503     default: {
3504       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3505       return false;
3506       break;
3507     }
3508   }
3509 
3510   // allocate scratch memory
3511   DeviceMemory<uint8> scratch_memory;
3512   if (scratch_memory_size != 0) {
3513     if (scratch_allocator == nullptr) {
3514       LOG(FATAL)
3515           << "An allocator must be specified when scratch memory is needed";
3516       return false;
3517     }
3518     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
3519     if (allocated.ok()) {
3520       scratch_memory = allocated.ValueOrDie();
3521     } else {
3522       LOG(FATAL)
3523           << "Failed to allocate scratch memory - "
3524           << allocated.status().error_message() << "\n"
3525           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
3526              "larger number (e.g. 8192) to increase the max memory limit.\n"
3527           << "\tIncreasing the max memory limit might help resolve this "
3528              "error";
3529       return false;
3530     }
3531   }
3532 
3533   // Only get the best algorithm for Find Mode
3534   size_t requestedAlgorithmCount = 1;
3535 
3536   VLOG(kConvDebugVlogLevel)
3537       << "Number of conv algortihms to request: " << requestedAlgorithmCount;
3538 
3539   miopenConvAlgoPerf_t returnedAlgorithm;
3540 
3541   int returnedAlgorithmCount = 0;
3542   bool exhaustiveSearch = false;
3543 
3544   switch (kind) {
3545     case dnn::ConvolutionKind::FORWARD: {
3546       auto status = wrap::miopenFindConvolutionForwardAlgorithm(
3547           miopen.handle(), input_nd.handle(), input_data.opaque(),
3548           filter.handle(), filter_data.opaque(), conv.handle(),
3549           output_nd.handle(), output_data.opaque(), requestedAlgorithmCount,
3550           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3551           scratch_memory_size, exhaustiveSearch);
3552       if (status != miopenStatusSuccess) {
3553         LOG(FATAL) << "call to miopenFindConvolutionForwardAlgorithm failed: "
3554                    << ToString(status);
3555         return false;
3556       }
3557       break;
3558     }
3559     case dnn::ConvolutionKind::BACKWARD_DATA: {
3560       auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
3561           miopen.handle(), output_nd.handle(), output_data.opaque(),
3562           filter.handle(), filter_data.opaque(), conv.handle(),
3563           input_nd.handle(), input_data.opaque(), requestedAlgorithmCount,
3564           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3565           scratch_memory_size, exhaustiveSearch);
3566       if (status != miopenStatusSuccess) {
3567         LOG(FATAL)
3568             << "call to miopenFindConvolutionBackwardDataAlgorithm failed: "
3569             << ToString(status);
3570         return false;
3571       }
3572       break;
3573     }
3574     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3575       auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
3576           miopen.handle(), output_nd.handle(), output_data.opaque(),
3577           input_nd.handle(), input_data.opaque(), conv.handle(),
3578           filter.handle(), filter_data.opaque(), requestedAlgorithmCount,
3579           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3580           scratch_memory_size, exhaustiveSearch);
3581       if (status != miopenStatusSuccess) {
3582         LOG(FATAL) << "call to miopenConvolutionBackwardWeightsAlgorithm "
3583                       "failed: "
3584                    << ToString(status);
3585         return false;
3586       }
3587       break;
3588     }
3589     default: {
3590       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3591       return false;
3592       break;
3593     }
3594   }
3595 
3596   out_algorithms->emplace_back(
3597       GetProfileResultFromConvAlgoPerf(kind, returnedAlgorithm));
3598 
3599   return true;
3600 }
3601 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3602 bool MIOpenSupport::GetRnnAlgorithms(
3603     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3604   // ROCM TODO: implement this with proper MIOpen API
3605   return true;
3606 }
3607 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3608 bool MIOpenSupport::GetConvolveBackwardDataAlgorithms(
3609     // ROCM TODO: refactor cc_major / cc_minor
3610     CudaComputeCapability cuda_compute_capability,
3611     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3612   out_algorithms->assign({
3613       // clang-format off
3614       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoGEMM, false),
3615       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoDirect, false),
3616       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoFFT, false),
3617       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoWinograd, false),
3618       // clang-format on
3619   });
3620   return true;
3621 }
3622 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3623 bool MIOpenSupport::GetConvolveBackwardFilterAlgorithms(
3624     // ROCM TODO: refactor cc_major / cc_minor
3625     CudaComputeCapability cuda_compute_capability,
3626     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3627   out_algorithms->assign({
3628       // clang-format off
3629       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoGEMM, false),
3630       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoDirect, false),
3631       // clang-format on
3632   });
3633   return true;
3634 }
3635 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3636 bool MIOpenSupport::DoBatchNormalizationForward(
3637     Stream* stream, const DeviceMemory<Eigen::half>& x,
3638     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3639     const DeviceMemory<float>& estimated_mean,
3640     const DeviceMemory<float>& estimated_variance,
3641     const DeviceMemory<Eigen::half>& side_input,
3642     const dnn::BatchDescriptor& x_desc,
3643     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3644     const double exponential_average_factor,
3645     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3646     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3647     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3648     bool is_training, ScratchAllocator* reserve_space_allocator,
3649     ScratchAllocator* workspace_allocator) {
3650   return DoBatchNormalizationForwardImpl<Eigen::half, float>(
3651       stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3652       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3653       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3654       batch_var, saved_mean, saved_inv_var, is_training);
3655 }
3656 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3657 bool MIOpenSupport::DoBatchNormalizationForward(
3658     Stream* stream, const DeviceMemory<float>& x,
3659     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3660     const DeviceMemory<float>& estimated_mean,
3661     const DeviceMemory<float>& estimated_variance,
3662     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3663     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3664     const double exponential_average_factor,
3665     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3666     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3667     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3668     bool is_training, ScratchAllocator* reserve_space_allocator,
3669     ScratchAllocator* workspace_allocator) {
3670   return DoBatchNormalizationForwardImpl<float, float>(
3671       stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
3672       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3673       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3674       batch_var, saved_mean, saved_inv_var, is_training);
3675 }
3676 
3677 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training)3678 bool MIOpenSupport::DoBatchNormalizationForwardImpl(
3679     Stream* stream, dnn::DataType input_data_type,
3680     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3681     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3682     const DeviceMemory<U>& estimated_mean,
3683     const DeviceMemory<U>& estimated_variance,
3684     const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
3685     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3686     const double exponential_average_factor,
3687     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3688     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3689     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3690     bool is_training) {
3691   auto miopen = miopen_->GetHandle(parent_, stream);
3692 
3693   ScopedTensorDescriptor x_descriptor{x_desc,
3694                                       ToMIOpenDataType(input_data_type)};
3695   ScopedTensorDescriptor scale_offset_descriptor{
3696       scale_offset_desc, ToMIOpenDataType(scale_data_type)};
3697   miopenBatchNormMode_t mode = miopenBNSpatial;
3698   float one = 1.0;
3699   float zero = 0.0;
3700 
3701   auto status = miopenStatusInvalidValue;
3702   if (is_training) {
3703     status = wrap::miopenBatchNormalizationForwardTraining(
3704         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3705         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3706         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3707         exponential_average_factor, batch_mean->opaque(), batch_var->opaque(),
3708         epsilon, saved_mean->opaque(), saved_inv_var->opaque());
3709   } else {
3710     const void* maybe_inv_var = estimated_variance.opaque();
3711     status = wrap::miopenBatchNormalizationForwardInference(
3712         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3713         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3714         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3715         const_cast<void*>(estimated_mean.opaque()),
3716         const_cast<void*>(maybe_inv_var), epsilon);
3717   }
3718   if (status != miopenStatusSuccess) {
3719     LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
3720                << ToString(status);
3721     return false;
3722   }
3723   return true;
3724 }
3725 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3726 bool MIOpenSupport::DoBatchNormalizationBackward(
3727     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3728     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3729     const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
3730     const DeviceMemory<float>& inv_var, const DeviceMemory<Eigen::half>& y,
3731     const dnn::BatchDescriptor& x_desc,
3732     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3733     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* x_backprop,
3734     DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
3735     DeviceMemory<Eigen::half>* side_input_backprop,
3736     DeviceMemory<uint8>* reserve_space_data,
3737     ScratchAllocator* workspace_allocator) {
3738   return DoBatchNormalizationBackwardImpl<Eigen::half, float>(
3739       stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, inv_var,
3740       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3741       offset_backprop);
3742 }
3743 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & variance,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3744 bool MIOpenSupport::DoBatchNormalizationBackward(
3745     Stream* stream, const DeviceMemory<float>& y_backprop,
3746     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3747     const DeviceMemory<float>& offset, const DeviceMemory<float>& mean,
3748     const DeviceMemory<float>& variance, const DeviceMemory<float>& y,
3749     const dnn::BatchDescriptor& x_desc,
3750     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3751     dnn::ActivationMode activation_mode, DeviceMemory<float>* x_backprop,
3752     DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
3753     DeviceMemory<float>* side_input_backprop,
3754     DeviceMemory<uint8>* reserve_space_data,
3755     ScratchAllocator* workspace_allocator) {
3756   return DoBatchNormalizationBackwardImpl<float, float>(
3757       stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, variance,
3758       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3759       offset_backprop);
3760 }
3761 
3762 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int miopen_input_type,int miopen_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3763 bool MIOpenSupport::DoBatchNormalizationBackwardImpl(
3764     Stream* stream, int miopen_input_type, int miopen_scale_type,
3765     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3766     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3767     const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
3768     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3769     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3770     DeviceMemory<U>* offset_backprop) {
3771   auto miopen = miopen_->GetHandle(parent_, stream);
3772   ScopedTensorDescriptor x_descriptor{
3773       x_desc, static_cast<miopenDataType_t>(miopen_input_type)};
3774   ScopedTensorDescriptor scale_offset_descriptor{
3775       scale_offset_desc, static_cast<miopenDataType_t>(miopen_scale_type)};
3776   miopenBatchNormMode_t mode = miopenBNSpatial;
3777   float one = 1.0;
3778   float zero = 0.0;
3779 
3780   auto status = wrap::miopenBatchNormalizationBackward(
3781       miopen.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3782       x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3783       x_descriptor.handle(), x_backprop->opaque(),
3784       scale_offset_descriptor.handle(), scale.opaque(),
3785       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3786       mean.opaque(), variance.opaque());
3787   if (status != miopenStatusSuccess) {
3788     LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
3789                << ToString(status);
3790     return false;
3791   }
3792   return true;
3793 }
3794 
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3795 port::Status MIOpenSupport::DoFusedConvolve(
3796     Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
3797     dnn::DataType bias_type, dnn::DataType output_type,
3798     const dnn::BatchDescriptor& conv_input_descriptor,
3799     DeviceMemoryBase conv_input_data, double conv_input_scale,
3800     const dnn::FilterDescriptor& filter_descriptor,
3801     DeviceMemoryBase filter_data,
3802     const dnn::ConvolutionDescriptor& convolution_descriptor,
3803     DeviceMemoryBase side_input_data, double side_input_scale,
3804     const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
3805     dnn::ActivationMode activation_mode,
3806     const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
3807     ScratchAllocator* scratch_allocator,
3808     const dnn::AlgorithmConfig& algorithm_config,
3809     dnn::ProfileResult* output_profile_result) {
3810   return port::UnimplementedError("fused convolve not implemented yet");
3811 }
3812 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3813 bool MIOpenSupport::DoTransformTensor(Stream* stream,
3814                                       const dnn::BatchDescriptor& input_desc,
3815                                       dnn::DataType input_type,
3816                                       const DeviceMemoryBase& input_data,
3817                                       const dnn::BatchDescriptor& output_desc,
3818                                       dnn::DataType output_type, float scale,
3819                                       DeviceMemoryBase* output_data) {
3820   // ROCM TODO implement this operation
3821   LOG(ERROR) << "transform tensor not implemented yet";
3822   return false;
3823 }
3824 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3825 bool MIOpenSupport::DoMatMul(Stream* stream,
3826                              const DeviceMemory<float>& input_data,
3827                              const DeviceMemory<float>& weights,
3828                              const dnn::BatchDescriptor& input_dimensions,
3829                              const dnn::BatchDescriptor& output_dimensions,
3830                              DeviceMemory<float>* output_data) {
3831   if (input_dimensions.count() != output_dimensions.count()) {
3832     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3833     return false;
3834   }
3835 
3836   // We do not permute the input or output, instead we just
3837   // reinterpret the layout. We are working with row-major matrices
3838   // and the rows of the input and output correspond to batch, so
3839   // batch has to be outermost in both the input and output.
3840   //
3841   // By adding transposes to the BLAS gemm call we could perhaps make
3842   // the kYXDepthBatch layout work as well, but there has been no need
3843   // for that so far.
3844   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3845       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3846     LOG(ERROR) << "Unsupported MatMul input layout.";
3847     return false;
3848   }
3849   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3850       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3851     LOG(ERROR) << "Unsupported MatMul output layout.";
3852     return false;
3853   }
3854 
3855   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3856     // This is a fast path that also supports the kBatchYXDepth layout.
3857 
3858     // The matrices here are in row-major format while BLAS expects
3859     // column-major, i.e. our matrices are transposed as far as BLAS
3860     // is concerned. So we need to compute output^T =
3861     // input^T*weights^T. There is no parameter for transposing the
3862     // output in BLAS gemm, but instead we can transpose both sides of
3863     // the equality to see that this is equivalent to
3864     // output=weights*input. So we only need to swap the order of
3865     // weights and input in the matrix product to correct for the
3866     // row-major versus column-major difference.
3867     const int64_t m = output_dimensions.NodesAcrossFeatureMaps();
3868     const int64_t n = input_dimensions.count();
3869     const int64_t k = input_dimensions.NodesAcrossFeatureMaps();
3870     if (!stream
3871              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
3872                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
3873                             input_data, k, output_data, m,
3874                             blas::kDefaultComputePrecision)
3875              .ok()) {
3876       return false;
3877     }
3878   } else {
3879     // This is a slower and more complex path that supports output
3880     // width() * height() > 1, though it only supports the
3881     // kBatchYXDepth layout. Does support kBatchDepthYX if output
3882     // feature_map_count() == 1, as then there is no difference
3883     // between the two layouts.
3884     //
3885     // The operation here is the same as above, except that we have to
3886     // do the matrix multiplication for each (y,x) output coordinate
3887     // separately. We then interpret weights as containing K = width()
3888     // * height() different matrices, which we all multiply onto the
3889     // matrix from input_data, yielding K matrix products. We then
3890     // combine these together into one matrix by concatenating all the
3891     // first rows of these matrices, then all the seconds rows and so
3892     // on. We can do this with a batched matrix multiplication, where
3893     // the result is written to a different submatrix of the output
3894     // for each matrix multiplication.
3895     //
3896     // The reason that we only support the kBatchYXDepth output layout
3897     // is that we have to do something in the depth for each (y,x)
3898     // coordinate. The kBatchYXDepth layout has the depth information
3899     // for each point (y,x) in contiguous memory while the
3900     // kBatchDepthYX layout does not.
3901     //
3902     // TODO(broune): Consider a special case for when output depth ==
3903     // 1, as then possibly this could all be done as one matrix
3904     // multiplication instead of a batched one, which should be
3905     // faster. Another possibility would be to add a weights layout
3906     // parameter and then support kBatchDepthYX for a different
3907     // weights layout.
3908     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3909         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3910           output_dimensions.feature_map_count() == 1)) {
3911       LOG(ERROR) << "Unsupported MatMul output layout.";
3912       return false;
3913     }
3914 
3915     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3916     const float beta = 0.0f;   // Ignore the original values in output_data.
3917     const uint64_t m = output_dimensions.feature_map_count();
3918     const uint64_t n = input_dimensions.count();
3919     const uint64_t k = input_dimensions.NodesAcrossFeatureMaps();
3920     const int lda = m;
3921     const int ldb = k;
3922     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3923     const int batch_count = output_dimensions.NodesPerFeatureMap();
3924 
3925     std::vector<DeviceMemory<float>> a(batch_count);
3926     std::vector<DeviceMemory<float>> b(batch_count);
3927     std::vector<DeviceMemory<float>> c(batch_count);
3928     for (int i = 0; i < batch_count; ++i) {
3929       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3930                                  output_dimensions.feature_map_count();
3931       a[i] = DeviceMemory<float>::MakeFromByteSize(
3932           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3933               weights_offset,
3934           weights.ElementCount() - weights_offset);
3935 
3936       b[i] = input_data;
3937 
3938       const int output_offset = i * output_dimensions.feature_map_count();
3939       c[i] = DeviceMemory<float>::MakeFromByteSize(
3940           const_cast<float*>(
3941               reinterpret_cast<const float*>(output_data->opaque())) +
3942               output_offset,
3943           output_data->ElementCount() - output_offset);
3944     }
3945     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3946       std::vector<DeviceMemory<float>*> ptrs;
3947       ptrs.reserve(v.size());
3948       for (auto& mem : v) {
3949         ptrs.push_back(&mem);
3950       }
3951       return ptrs;
3952     };
3953 
3954     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3955                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
3956                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3957                                 ldc, batch_count);
3958   }
3959 
3960   return stream->ok();
3961 }
3962 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3963 bool MIOpenSupport::DoBiasAdd(Stream* stream,
3964                               const DeviceMemory<float>& input_data,
3965                               const DeviceMemory<float>& biases,
3966                               const dnn::BatchDescriptor& dimensions,
3967                               DeviceMemory<float>* output_data) {
3968   ScopedTensorDescriptor input_descriptor{dimensions, miopenFloat};
3969 
3970   BatchDescriptor bias_dimensions;
3971   bias_dimensions.set_count(1)
3972       .set_feature_map_count(dimensions.feature_map_count())
3973       .set_height(1)
3974       .set_width(1)
3975       .set_layout(dnn::DataLayout::kBatchYXDepth);
3976   ScopedTensorDescriptor bias_descriptor{bias_dimensions, miopenFloat};
3977 
3978   if (input_data.opaque() != output_data->opaque()) {
3979     stream->ThenMemcpy(output_data, input_data,
3980                        dimensions.ElementCount() * sizeof(float));
3981     if (!stream->ok()) {
3982       LOG(ERROR)
3983           << "stream " << stream
3984           << " could not enqueue a tensor copy as part of bias addition.";
3985       return false;
3986     }
3987   }
3988 
3989   auto miopen = miopen_->GetHandle(parent_, stream);
3990 
3991   const float alpha1 = 1.0f;
3992   const float alpha2 = 0.0f;
3993   const float beta = 1.0f;
3994 
3995   auto status = wrap::miopenOpTensor(
3996       miopen.handle(), miopenTensorOpAdd, &alpha1, bias_descriptor.handle(),
3997       biases.opaque(), &alpha2, bias_descriptor.handle(), biases.opaque(),
3998       &beta, input_descriptor.handle(), output_data->opaque());
3999 
4000   if (status != miopenStatusSuccess) {
4001     LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
4002     return false;
4003   }
4004 
4005   return true;
4006 }
4007 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)4008 bool MIOpenSupport::DoActivate(Stream* stream,
4009                                dnn::ActivationMode activation_mode,
4010                                const dnn::BatchDescriptor& dimensions,
4011                                const DeviceMemory<float>& input_data,
4012                                DeviceMemory<float>* output_data,
4013                                uint64_t options) {
4014   LOG(ERROR) << "miopen does not support activation yet";
4015   return false;
4016 }
4017 
DoPoolForward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,ScratchAllocator * workspace_allocator)4018 port::Status MIOpenSupport::DoPoolForward(
4019     dnn::DataType element_type, Stream* stream,
4020     const dnn::PoolingDescriptor& pooling_dimensions,
4021     const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
4022     const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
4023     ScratchAllocator* workspace_allocator) {
4024   if (element_type == dnn::DataType::kDouble) {
4025     return port::Status(port::error::INVALID_ARGUMENT,
4026                         "MIOpen does not support pooling for double type yet");
4027   }
4028 
4029   auto miopen = miopen_->GetHandle(parent_, stream);
4030   // Alpha is the scaling factor for input.
4031   float alpha = 1.0;
4032   // Beta is the scaling factor for output.
4033   float beta = 0.0;
4034 
4035   auto miopen_dtype =
4036       element_type == dnn::DataType::kFloat ? miopenFloat : miopenHalf;
4037 
4038   ScopedTensorDescriptor src_desc{input_dimensions, miopen_dtype};
4039   ScopedTensorDescriptor dest_desc{output_dimensions, miopen_dtype};
4040   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4041 
4042   bool do_backward = false;
4043   uint8* workspace = nullptr;
4044   size_t workspace_size = 0;
4045   std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
4046   if (m_pooling_cache_enabled && element_type == dnn::DataType::kFloat) {
4047     do_backward = true;
4048     auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4049         pooling_desc.handle(), dest_desc.handle(), &workspace_size);
4050     if (status != miopenStatusSuccess) {
4051       return port::InternalError(absl::StrCat(
4052           "Failed to obtain workspace size for backward pooling on stream: ",
4053           ToString(status)));
4054     }
4055     if (workspace_size != 0) {
4056       PoolingWorkspaceDescriptor* pdesc = 0;
4057       bool cache_hit =
4058           m_pooling_cache_allowed &&
4059           m_pooling_cache.find(input_data.opaque(), input_dimensions,
4060                                output_dimensions, pooling_dimensions,
4061                                miopenFloat, pdesc);
4062       if (cache_hit) {
4063         // reusing the same buffer
4064         workspace = reinterpret_cast<uint8*>(
4065             pdesc->workspace->mutable_device_memory()->opaque());
4066       } else {
4067         wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size).value();
4068         workspace = reinterpret_cast<uint8*>(
4069             wsp_mem->mutable_device_memory()->opaque());
4070         m_pooling_cache.insert(input_data.opaque(), input_dimensions,
4071                                output_dimensions, pooling_dimensions,
4072                                miopenFloat, wsp_mem, workspace_size,
4073                                AsGpuStreamValue(stream));
4074       }
4075     }
4076   }
4077 
4078   auto status = wrap::miopenPoolingForward(
4079       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4080       input_data.opaque(), &beta, dest_desc.handle(), output_data.opaque(),
4081       do_backward, workspace, workspace_size);
4082   if (status != miopenStatusSuccess) {
4083     return port::InternalError(absl::StrCat(
4084         "Failed to enqueue forward pooling on stream: ", ToString(status)));
4085   }
4086   return port::Status::OK();
4087 }
4088 
IsSame(const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type)4089 bool PoolingWorkspaceDescriptor::IsSame(
4090     const dnn::BatchDescriptor& input_dimensions,
4091     const dnn::BatchDescriptor& output_dimensions,
4092     const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
4093   return dtype == _type &&
4094          input_dims ==
4095              input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
4096          output_dims ==
4097              output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
4098          op.mode() == pooling_dimensions.mode() &&
4099          op.window() == pooling_dimensions.window() &&
4100          op.padding() == pooling_dimensions.padding() &&
4101          op.strides() == pooling_dimensions.strides();
4102 }
4103 
find(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,PoolingWorkspaceDescriptor * & pdesc)4104 bool PoolingWorkspaceCache::find(
4105     const void* p, const dnn::BatchDescriptor& input_dimensions,
4106     const dnn::BatchDescriptor& output_dimensions,
4107     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
4108     PoolingWorkspaceDescriptor*& pdesc) {
4109   pdesc = 0;
4110   auto it = cache.find(p);
4111   if (it == cache.end()) {
4112     return false;
4113   }
4114   if (!it->second.IsSame(input_dimensions, output_dimensions,
4115                          pooling_dimensions, _type)) {
4116     return false;
4117   }
4118   pdesc = &it->second;
4119   return true;
4120 }
4121 
insert(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,std::unique_ptr<TemporaryDeviceMemory<uint8>> & workspace,size_t wsp_size,hipStream_t hip_stream)4122 void PoolingWorkspaceCache::insert(
4123     const void* p, const dnn::BatchDescriptor& input_dimensions,
4124     const dnn::BatchDescriptor& output_dimensions,
4125     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
4126     std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
4127     hipStream_t hip_stream) {
4128   PoolingWorkspaceDescriptor* desc = 0;
4129   auto it = cache.find(p);
4130   if (it != cache.end()) {
4131     // replacing an entry with the same pointer but different attributes
4132     // (if everything matches, the caller is expected to reuse the entry)
4133     desc = &it->second;
4134     hipStreamSynchronize(hip_stream);
4135     memory_used -= desc->workspace_size;
4136   } else {
4137     cache[p] = PoolingWorkspaceDescriptor();
4138     desc = &cache[p];
4139   }
4140   desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4141   desc->output_dims =
4142       output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4143   desc->op = pooling_dimensions;
4144   desc->dtype = _type;
4145   desc->timestamp = timestamp;
4146   timestamp++;
4147   desc->workspace = std::move(workspace);
4148   desc->workspace_size = wsp_size;
4149   memory_used += wsp_size;
4150   trim(hip_stream);
4151 }
4152 
trim(hipStream_t hip_stream)4153 void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
4154   if (memory_used < memory_budget && cache.size() < trim_size) return;
4155   bool must_sync = true;
4156   while (true) {
4157     int new_size = cache.size() - (cache.size() >> 2);
4158     std::vector<const void*> old_entries;
4159     for (auto& x : cache)
4160       if (x.second.timestamp + new_size < timestamp)
4161         old_entries.push_back(x.first);
4162     if (old_entries.empty()) break;
4163     if (must_sync) hipStreamSynchronize(hip_stream);
4164     must_sync = true;
4165     for (auto x : old_entries) {
4166       memory_used -= cache[x].workspace_size;
4167       cache.erase(x);
4168     }
4169     if (memory_used < memory_budget || cache.size() < 10) break;
4170   }
4171 }
4172 
DoPoolBackward(dnn::DataType element_type,Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,DeviceMemoryBase input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemoryBase output_data,DeviceMemoryBase input_diff_data,DeviceMemoryBase output_diff_data,ScratchAllocator * workspace_allocator)4173 port::Status MIOpenSupport::DoPoolBackward(
4174     dnn::DataType element_type, Stream* stream,
4175     const dnn::PoolingDescriptor& pooling_dimensions,
4176     const dnn::BatchDescriptor& input_dimensions, DeviceMemoryBase input_data,
4177     const dnn::BatchDescriptor& output_dimensions, DeviceMemoryBase output_data,
4178     DeviceMemoryBase input_diff_data, DeviceMemoryBase output_diff_data,
4179     ScratchAllocator* workspace_allocator) {
4180   if (element_type == dnn::DataType::kDouble) {
4181     return port::Status(port::error::INVALID_ARGUMENT,
4182                         "MIOpen does not support pooling for double type yet");
4183   }
4184 
4185   auto miopen = miopen_->GetHandle(parent_, stream);
4186   if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
4187   // Alpha is the scaling factor for input.
4188   float alpha = 1.0;
4189   // Beta is the scaling factor for output.
4190   float beta = 0.0;
4191 
4192   auto miopen_dtype =
4193       element_type == dnn::DataType::kFloat ? miopenFloat : miopenHalf;
4194 
4195   ScopedTensorDescriptor src_desc{input_dimensions, miopen_dtype};
4196   ScopedTensorDescriptor dest_desc{output_dimensions, miopen_dtype};
4197   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4198 
4199   uint8* workspace_ptr = 0;
4200   DeviceMemory<uint8> workspace;
4201   PoolingWorkspaceDescriptor* pdesc = 0;
4202 
4203   size_t workspace_size_in_bytes = 0;
4204   auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4205       pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
4206   if (status != miopenStatusSuccess) {
4207     return port::InternalError(absl::StrCat(
4208         "Failed to obtain workspace size for backward pooling on stream: ",
4209         ToString(status)));
4210   }
4211 
4212   // Allocate the workspace.
4213   if (workspace_size_in_bytes > 0) {
4214     bool cache_hit = m_pooling_cache_allowed &&
4215                      m_pooling_cache.find(input_data.opaque(), input_dimensions,
4216                                           output_dimensions, pooling_dimensions,
4217                                           miopen_dtype, pdesc);
4218     if (cache_hit) {
4219       assert(pdesc != 0);
4220       workspace_ptr = reinterpret_cast<uint8*>(
4221           pdesc->workspace->mutable_device_memory()->opaque());
4222       VLOG(1) << "Pooling cache hit";
4223     } else {
4224       VLOG(1) << "Pooling cache miss";
4225       assert(workspace_allocator);
4226       auto allocated =
4227           workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4228       if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4229         return port::InternalError(
4230             "Failed to allocate backward pooling workspace");
4231       }
4232       DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4233       int64_t dest2_size = 0;
4234 
4235       // miopen requires the strides and dims to be ordered as BDYX.
4236       std::vector<int64_t> dims64 =
4237           output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4238       // miopen does not use strides and must have 4D tensor.
4239       // std::vector<int> dims(pooling_dimensions.ndims() + 2);
4240 
4241       dest2_size = (element_type == dnn::DataType::kFloat)
4242                        ? sizeof(float)
4243                        : sizeof(Eigen::half);
4244       for (auto& x : dims64) dest2_size *= x;
4245 
4246       if (dest2_size > 0) {
4247         assert(workspace_allocator);
4248         auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4249         if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4250           return port::InternalError(
4251               "Failed to allocate backward pooling workspace");
4252         }
4253       } else {
4254         LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4255                       "backward pooling";
4256       }
4257 
4258       status = wrap::miopenPoolingForward(
4259           miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4260           input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
4261           workspace.opaque(), workspace_size_in_bytes);
4262 
4263       if (status != miopenStatusSuccess) {
4264         return port::InternalError(absl::StrCat(
4265             "Failed to enqueue forward pooling (before backward) on stream: ",
4266             ToString(status)));
4267       }
4268       workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
4269     }
4270   }
4271 
4272   status = wrap::miopenPoolingBackward(
4273       miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4274       output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4275       src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4276       output_diff_data.opaque(), workspace_ptr);
4277 
4278   if (status != miopenStatusSuccess) {
4279     return port::InternalError(absl::StrCat(
4280         "Failed to enqueue backward pooling on stream: ", ToString(status)));
4281   }
4282   return port::Status::OK();
4283 }
4284 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4285 bool MIOpenSupport::DoNormalizeWithDimensions(
4286     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4287     const dnn::BatchDescriptor& dimensions,
4288     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4289   // Check for unsupported modes.
4290   if (normalize_descriptor.wrap_around()) {
4291     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4292     return false;
4293   }
4294   if (normalize_descriptor.segment_size()) {
4295     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4296     return false;
4297   }
4298 
4299   auto miopen = miopen_->GetHandle(parent_, stream);
4300 
4301   // Launch the normalization.
4302   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4303   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4304 
4305   // Alpha is the scaling factor for input.
4306   float alpha = 1.0f;
4307   // Beta is the scaling factor for output.
4308   float beta = 0.0f;
4309 
4310   auto status = wrap::miopenLRNForward(
4311       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4312       input_data.opaque(), &beta, dims.handle(), output_data->opaque(), false,
4313       nullptr);
4314   if (status != miopenStatusSuccess) {
4315     LOG(ERROR) << "failed to run miopenLRNForward";
4316     return false;
4317   }
4318   return true;
4319 }
4320 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4321 bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
4322     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4323     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4324     const DeviceMemory<float>& normalized_data,
4325     const DeviceMemory<float>& normalized_variable_gradient,
4326     DeviceMemory<float>* raw_variable_gradient,
4327     ScratchAllocator* workspace_allocator) {
4328   // Check for unsupported modes.
4329   if (normalize_descriptor.wrap_around()) {
4330     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4331     return false;
4332   }
4333   if (normalize_descriptor.segment_size()) {
4334     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4335     return false;
4336   }
4337 
4338   auto miopen = miopen_->GetHandle(parent_, stream);
4339 
4340   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4341   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4342 
4343   float alpha = 1.0f;
4344   float beta = 0.0f;
4345 
4346   DeviceMemory<uint8> workspace;
4347   size_t workspace_size_in_bytes = 0;
4348   auto status =
4349       wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes);
4350 
4351   if (status != miopenStatusSuccess) {
4352     LOG(ERROR) << "failed to obtain workspace size for miopenLRNBackward";
4353     return false;
4354   }
4355 
4356   // Allocate the workspace.
4357   if (workspace_size_in_bytes > 0) {
4358     assert(workspace_allocator);
4359     auto allocated =
4360         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4361     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4362       LOG(ERROR) << "Failed to allocate backward pooling workspace";
4363       return false;
4364     }
4365   }
4366 
4367   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4368   int dest2_size = 0;
4369 
4370   // miopen requires the strides and dims to be ordered as BDYX.
4371   std::vector<int64_t> dims64 =
4372       dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4373 
4374   // miopen does not use strides and must have 4D tensor.
4375   std::vector<int> dimsint(4);
4376 
4377   std::transform(dims64.cbegin(), dims64.cend(), dimsint.begin(),
4378                  &CheckedNarrowing<int64_t, int>);
4379 
4380   dest2_size =
4381       dimsint[0] * dimsint[1] * dimsint[2] * dimsint[3] * sizeof(float);
4382 
4383   if (dest2_size > 0) {
4384     assert(workspace_allocator);
4385     auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4386     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4387       LOG(ERROR)
4388           << "Failed to allocate tensor to chain forward and backward LRN";
4389       return false;
4390     }
4391   } else {
4392     LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4393                   "backward LRN";
4394   }
4395 
4396   status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
4397                                   dims.handle(), raw_data.opaque(), &beta,
4398                                   dims.handle(), dest2.opaque(), true,
4399                                   workspace.opaque());
4400 
4401   if (status != miopenStatusSuccess) {
4402     LOG(ERROR) << "failed to run miopenLRNForward";
4403     return false;
4404   }
4405 
4406   status = wrap::miopenLRNBackward(
4407       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4408       normalized_data.opaque(), dims.handle(),
4409       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4410       &beta, dims.handle(), raw_variable_gradient->opaque(),
4411       workspace.opaque());
4412 
4413   if (status != miopenStatusSuccess) {
4414     LOG(ERROR) << "failed to run miopenLRNBackward";
4415     return false;
4416   }
4417   return true;
4418 }
4419 
DoDepthConcatenate(Stream * stream,absl::Span<const dnn::BatchDescriptor> input_dimensions,absl::Span<const DeviceMemory<float> * const> input_data,DeviceMemory<float> * output_data)4420 bool MIOpenSupport::DoDepthConcatenate(
4421     Stream* stream, absl::Span<const dnn::BatchDescriptor> input_dimensions,
4422     absl::Span<const DeviceMemory<float>* const> input_data,
4423     DeviceMemory<float>* output_data) {
4424   CHECK_EQ(input_dimensions.size(), input_data.size());
4425 
4426   for (const auto& dimensions : input_dimensions) {
4427     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4428       LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only "
4429                     "supports the kBatchDepthYX layout.";
4430       return false;
4431     }
4432   }
4433 
4434   if (input_dimensions.empty()) {
4435     return true;  // Nothing to do.
4436   }
4437 
4438   dnn::BatchDescriptor output_dimensions =
4439       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4440 
4441   const int64_t area = output_dimensions.width() * output_dimensions.height();
4442   const auto index = [area](int64_t batch, int64_t depth, int64_t yx,
4443                             int64_t max_depth) {
4444     return (batch * max_depth + depth) * area + yx;
4445   };
4446 
4447   std::vector<float> output_host(output_dimensions.ElementCount());
4448   std::vector<float> tmp;
4449   int64_t depth_sum = 0;
4450   for (size_t i = 0; i < input_data.size(); ++i) {
4451     const auto& dimensions = input_dimensions[i];
4452     tmp.resize(dimensions.ElementCount());
4453     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4454     port::Status block_status = stream->BlockHostUntilDone();
4455     if (!block_status.ok()) {
4456       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4457       return false;
4458     }
4459 
4460     for (int64_t batch = 0; batch < output_dimensions.count(); ++batch) {
4461       for (int64_t yx = 0; yx < area; ++yx) {
4462         for (int64_t depth = 0; depth < dimensions.feature_map_count();
4463              ++depth) {
4464           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4465                     << yx << ' ' << depth;
4466           output_host[index(batch, depth + depth_sum, yx,
4467                             output_dimensions.feature_map_count())] =
4468               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4469         }
4470       }
4471     }
4472     depth_sum += dimensions.feature_map_count();
4473   }
4474   stream->ThenMemcpyH2D<float>(output_host, output_data);
4475   return true;
4476 }
4477 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,absl::Span<const dnn::BatchDescriptor> input_dimensions,absl::Span<const DeviceMemory<float> * const> input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4478 bool MIOpenSupport::DoElementwiseOperate(
4479     Stream* stream, dnn::ElementwiseOperation operation,
4480     absl::Span<const dnn::BatchDescriptor> input_dimensions,
4481     absl::Span<const DeviceMemory<float>* const> input_data,
4482     const dnn::BatchDescriptor& output_dimensions,
4483     DeviceMemory<float>* output_data) {
4484   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4485   return false;
4486 }
4487 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)4488 bool MIOpenSupport::DoXYPad(Stream* stream,
4489                             const dnn::BatchDescriptor& dimensions,
4490                             const DeviceMemory<float>& input_data,
4491                             int64_t left_pad, int64_t right_pad,
4492                             int64_t top_pad, int64_t bottom_pad,
4493                             DeviceMemory<float>* output_data) {
4494   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4495   return false;
4496 }
4497 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)4498 bool MIOpenSupport::DoXYSlice(Stream* stream,
4499                               const dnn::BatchDescriptor& dimensions,
4500                               const DeviceMemory<float>& input_data,
4501                               int64_t left_trim, int64_t right_trim,
4502                               int64_t top_trim, int64_t bottom_trim,
4503                               DeviceMemory<float>* output_data) {
4504   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4505   return false;
4506 }
4507 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64_t size)4508 bool MIOpenSupport::DoMemcpyD2HQuantized(
4509     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4510     dnn::QuantizedActivationMode mode, void* host_dst, int64_t size) {
4511   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4512   return false;
4513 }
4514 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4515 bool MIOpenSupport::DoMemcpyH2DQuantized(
4516     Stream* stream, const void* host_src, int64_t size,
4517     dnn::QuantizedActivationMode mode,
4518     DeviceMemory<float>* gpu_unquantized_dst) {
4519   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4520   return false;
4521 }
4522 
DeriveOutputBatchDescriptor(const BatchDescriptor & batch_descriptor,const FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4523 bool MIOpenSupport::DeriveOutputBatchDescriptor(
4524     const BatchDescriptor& batch_descriptor,
4525     const FilterDescriptor& filter_descriptor,
4526     const dnn::ConvolutionDescriptor& convolution_descriptor,
4527     dnn::BatchDescriptor* output_batch_descriptor) {
4528   ScopedTensorDescriptor input_nd{batch_descriptor, miopenFloat};
4529   ScopedFilterDescriptor filter{filter_descriptor, miopenFloat};
4530   ScopedConvolutionDescriptor conv{convolution_descriptor, miopenFloat};
4531 
4532   int dn = batch_descriptor.ndims() + 2;
4533   std::vector<int> dims(dn);  // in BDYX
4534   auto status = wrap::miopenGetConvolutionNdForwardOutputDim(
4535       conv.handle(), input_nd.handle(), filter.handle(), &dn, dims.data());
4536   if (status != miopenStatusSuccess) {
4537     LOG(ERROR) << "could not get output tensor for convolution: "
4538                << ToString(status);
4539     return false;
4540   }
4541 
4542   output_batch_descriptor->set_count(dims[0])
4543       .set_feature_map_count(dims[1])
4544       .set_layout(batch_descriptor.layout());
4545 
4546   for (int i = 0; i < batch_descriptor.ndims(); i++) {
4547     output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4548                                              dims.rbegin()[i]);
4549   }
4550 
4551   return true;
4552 }
4553 
4554 template <typename T>
DoFusedConvolutionBiasActivationImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<T> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<T> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<T> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<T> * output_data,dnn::ProfileResult * output_profile_result)4555 bool MIOpenSupport::DoFusedConvolutionBiasActivationImpl(
4556     Stream* stream,
4557     int miopen_type,  // Actually miopenDataType_t.
4558     const dnn::BatchDescriptor& conv_input_descriptor,
4559     const DeviceMemory<T>& conv_input_data,
4560     const dnn::FilterDescriptor& filter_descriptor,
4561     const DeviceMemory<T>& filter_data,
4562     const dnn::ConvolutionDescriptor& convolution_descriptor,
4563     const dnn::BatchDescriptor& bias_descriptor,
4564     const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
4565     const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
4566     dnn::ProfileResult* output_profile_result) {
4567   auto miopen = miopen_->GetHandle(parent_, stream);
4568 
4569   ScopedTensorDescriptor conv_input_nd{
4570       conv_input_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4571 
4572   ScopedTensorDescriptor bias_nd{bias_descriptor,
4573                                  static_cast<miopenDataType_t>(miopen_type)};
4574 
4575   ScopedTensorDescriptor output_nd{output_descriptor,
4576                                    static_cast<miopenDataType_t>(miopen_type)};
4577 
4578   ScopedConvolutionDescriptor conv{convolution_descriptor,
4579                                    static_cast<miopenDataType_t>(miopen_type)};
4580 
4581   ScopedFilterDescriptor filter{filter_descriptor,
4582                                 static_cast<miopenDataType_t>(miopen_type)};
4583 
4584   ScopedActivationDescriptor activation_desc{activation_mode};
4585 
4586   ScopedFusionPlanConvolutionBiasActivation fusion_plan{
4587       miopen.handle(), conv_input_nd.handle(), filter.handle(),
4588       conv.handle(),   bias_nd.handle(),       activation_desc};
4589 
4590   bool retval = false;
4591 
4592   if (fusion_plan.CompilationSucceeded()) {
4593     const bool is_profiling = output_profile_result != nullptr;
4594 
4595     std::unique_ptr<GpuTimer> timer;
4596     if (is_profiling) {
4597       timer.reset(new GpuTimer(parent_));
4598       timer->Init();
4599       timer->Start(AsGpuStream(stream));
4600     }
4601 
4602     miopenStatus_t status = miopenStatusSuccess;
4603 
4604     if (status == miopenStatusSuccess) {
4605       fusion_plan.SetConvolutionArgs(filter_data.opaque());
4606     }
4607 
4608     if (status == miopenStatusSuccess) {
4609       status = fusion_plan.SetBiasArgs(bias_data.opaque());
4610     }
4611 
4612     if (status == miopenStatusSuccess) {
4613       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4614     }
4615 
4616     if (status == miopenStatusSuccess) {
4617       status =
4618           fusion_plan.Execute(conv_input_nd.handle(), conv_input_data.opaque(),
4619                               output_nd.handle(), output_data->opaque());
4620     }
4621 
4622     if (is_profiling) {
4623       timer->Stop(AsGpuStream(stream));
4624       if (status == miopenStatusSuccess) {
4625         output_profile_result->set_elapsed_time_in_ms(
4626             timer->GetElapsedMilliseconds());
4627       }
4628       timer->Destroy();
4629     }
4630 
4631     if (status != miopenStatusSuccess) {
4632       // Silently return when we are profiling.
4633       if (!is_profiling) {
4634         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4635                    << ToString(status);
4636       }
4637     }
4638 
4639     retval = true;
4640   }
4641 
4642   return retval;
4643 }
4644 
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)4645 bool MIOpenSupport::DoFusedConvolutionBiasActivation(
4646     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
4647     const DeviceMemory<float>& conv_input_data,
4648     const dnn::FilterDescriptor& filter_descriptor,
4649     const DeviceMemory<float>& filter_data,
4650     const dnn::ConvolutionDescriptor& convolution_descriptor,
4651     const dnn::BatchDescriptor& bias_descriptor,
4652     const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
4653     const dnn::BatchDescriptor& output_descriptor,
4654     DeviceMemory<float>* output_data,
4655     dnn::ProfileResult* output_profile_result) {
4656   return DoFusedConvolutionBiasActivationImpl<float>(
4657       stream, miopenFloat, conv_input_descriptor, conv_input_data,
4658       filter_descriptor, filter_data, convolution_descriptor, bias_descriptor,
4659       bias_data, activation_mode, output_descriptor, output_data,
4660       output_profile_result);
4661 }
4662 
4663 template <typename T, typename U>
DoFusedBatchNormActivationInferenceImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & mean_data,const DeviceMemory<U> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,dnn::ProfileResult * output_profile_result)4664 bool MIOpenSupport::DoFusedBatchNormActivationInferenceImpl(
4665     Stream* stream,
4666     int miopen_type,  // Actually miopenDataType_t.
4667     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4668     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4669     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4670     const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
4671     double epsilon, dnn::ActivationMode activation_mode,
4672     DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result) {
4673   auto miopen = miopen_->GetHandle(parent_, stream);
4674 
4675   ScopedTensorDescriptor x_nd{x_descriptor,
4676                               static_cast<miopenDataType_t>(miopen_type)};
4677 
4678   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4679       scale_offset_mean_variance_descriptor,
4680       static_cast<miopenDataType_t>(miopen_type)};
4681 
4682   ScopedActivationDescriptor activation_desc{activation_mode};
4683 
4684   ScopedFusionPlanBatchNormActivationInference fusion_plan{
4685       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4686       activation_desc};
4687 
4688   bool retval = false;
4689 
4690   if (fusion_plan.CompilationSucceeded()) {
4691     const bool is_profiling = output_profile_result != nullptr;
4692 
4693     std::unique_ptr<GpuTimer> timer;
4694     if (is_profiling) {
4695       timer.reset(new GpuTimer(parent_));
4696       timer->Init();
4697       timer->Start(AsGpuStream(stream));
4698     }
4699 
4700     miopenStatus_t status = miopenStatusSuccess;
4701 
4702     if (status == miopenStatusSuccess) {
4703       fusion_plan.SetBatchNormInferenceArgs(
4704           scale_data.opaque(), offset_data.opaque(), mean_data.opaque(),
4705           variance_data.opaque(), epsilon);
4706     }
4707 
4708     if (status == miopenStatusSuccess) {
4709       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4710     }
4711 
4712     if (status == miopenStatusSuccess) {
4713       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4714                                    x_nd.handle(), y_data->opaque());
4715     }
4716 
4717     if (is_profiling) {
4718       timer->Stop(AsGpuStream(stream));
4719       if (status == miopenStatusSuccess) {
4720         output_profile_result->set_elapsed_time_in_ms(
4721             timer->GetElapsedMilliseconds());
4722       }
4723       timer->Destroy();
4724     }
4725 
4726     if (status != miopenStatusSuccess) {
4727       // Silently return when we are profiling.
4728       if (!is_profiling) {
4729         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4730                    << ToString(status);
4731       }
4732     }
4733 
4734     retval = true;
4735   }
4736 
4737   return retval;
4738 }
4739 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)4740 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4741     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4742     const DeviceMemory<float>& x_data,
4743     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4744     const DeviceMemory<float>& scale_data,
4745     const DeviceMemory<float>& offset_data,
4746     const DeviceMemory<float>& mean_data,
4747     const DeviceMemory<float>& variance_data, double epsilon,
4748     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4749     dnn::ProfileResult* output_profile_result) {
4750   return DoFusedBatchNormActivationInferenceImpl<float, float>(
4751       stream, miopenFloat, x_descriptor, x_data,
4752       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4753       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4754 }
4755 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)4756 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4757     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4758     const DeviceMemory<Eigen::half>& x_data,
4759     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4760     const DeviceMemory<float>& scale_data,
4761     const DeviceMemory<float>& offset_data,
4762     const DeviceMemory<float>& mean_data,
4763     const DeviceMemory<float>& variance_data, double epsilon,
4764     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4765     dnn::ProfileResult* output_profile_result) {
4766   return DoFusedBatchNormActivationInferenceImpl<Eigen::half, float>(
4767       stream, miopenHalf, x_descriptor, x_data,
4768       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4769       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4770 }
4771 
4772 template <typename T, typename U>
DoFusedBatchNormActivationForwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,DeviceMemory<U> * batch_mean_data,DeviceMemory<U> * batch_var_data,DeviceMemory<U> * saved_mean_data,DeviceMemory<U> * saved_var_data,dnn::ProfileResult * output_profile_result)4773 bool MIOpenSupport::DoFusedBatchNormActivationForwardImpl(
4774     Stream* stream,
4775     int miopen_type,  // Actually miopenDataType_t.
4776     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4777     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4778     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4779     double epsilon, dnn::ActivationMode activation_mode,
4780     DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
4781     DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
4782     DeviceMemory<U>* saved_var_data,
4783     dnn::ProfileResult* output_profile_result) {
4784   auto miopen = miopen_->GetHandle(parent_, stream);
4785 
4786   ScopedTensorDescriptor x_nd{x_descriptor,
4787                               static_cast<miopenDataType_t>(miopen_type)};
4788 
4789   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4790       scale_offset_mean_variance_descriptor,
4791       static_cast<miopenDataType_t>(miopen_type)};
4792 
4793   ScopedActivationDescriptor activation_desc{activation_mode};
4794 
4795   ScopedFusionPlanBatchNormActivationForward fusion_plan{
4796       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4797       activation_desc};
4798 
4799   bool retval = false;
4800 
4801   if (fusion_plan.CompilationSucceeded()) {
4802     const bool is_profiling = output_profile_result != nullptr;
4803 
4804     std::unique_ptr<GpuTimer> timer;
4805     if (is_profiling) {
4806       timer.reset(new GpuTimer(parent_));
4807       timer->Init();
4808       timer->Start(AsGpuStream(stream));
4809     }
4810 
4811     miopenStatus_t status = miopenStatusSuccess;
4812 
4813     if (status == miopenStatusSuccess) {
4814       fusion_plan.SetBatchNormForwardArgs(
4815           scale_data.opaque(), offset_data.opaque(), batch_mean_data->opaque(),
4816           batch_var_data->opaque(), saved_mean_data->opaque(),
4817           saved_var_data->opaque(), epsilon);
4818     }
4819 
4820     if (status == miopenStatusSuccess) {
4821       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4822     }
4823 
4824     if (status == miopenStatusSuccess) {
4825       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4826                                    x_nd.handle(), y_data->opaque());
4827     }
4828 
4829     if (is_profiling) {
4830       timer->Stop(AsGpuStream(stream));
4831       if (status == miopenStatusSuccess) {
4832         output_profile_result->set_elapsed_time_in_ms(
4833             timer->GetElapsedMilliseconds());
4834       }
4835       timer->Destroy();
4836     }
4837 
4838     if (status != miopenStatusSuccess) {
4839       // Silently return when we are profiling.
4840       if (!is_profiling) {
4841         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4842                    << ToString(status);
4843       }
4844     }
4845 
4846     retval = true;
4847   }
4848 
4849   return retval;
4850 }
4851 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4852 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4853     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4854     const DeviceMemory<float>& x_data,
4855     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4856     const DeviceMemory<float>& scale_data,
4857     const DeviceMemory<float>& offset_data, double epsilon,
4858     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4859     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4860     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4861     dnn::ProfileResult* output_profile_result) {
4862   return DoFusedBatchNormActivationForwardImpl<float, float>(
4863       stream, miopenFloat, x_descriptor, x_data,
4864       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4865       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4866       saved_var_data, output_profile_result);
4867 }
4868 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4869 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4870     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4871     const DeviceMemory<Eigen::half>& x_data,
4872     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4873     const DeviceMemory<float>& scale_data,
4874     const DeviceMemory<float>& offset_data, double epsilon,
4875     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4876     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4877     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4878     dnn::ProfileResult* output_profile_result) {
4879   return DoFusedBatchNormActivationForwardImpl<Eigen::half, float>(
4880       stream, miopenHalf, x_descriptor, x_data,
4881       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4882       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4883       saved_var_data, output_profile_result);
4884 }
4885 
4886 template <typename T, typename U>
DoFusedBatchNormActivationBackwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<T> & y_act_backprop_data,const DeviceMemory<T> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<T> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & saved_mean_data,const DeviceMemory<U> & saved_var_data,DeviceMemory<T> * x_bn_backprop_data,DeviceMemory<U> * scale_backprop_data,DeviceMemory<U> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4887 bool MIOpenSupport::DoFusedBatchNormActivationBackwardImpl(
4888     Stream* stream,
4889     int miopen_type,  // Actually miopenDataType_t.
4890     const dnn::BatchDescriptor& y_act_backprop_descriptor,
4891     const DeviceMemory<T>& y_act_backprop_data,
4892     const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
4893     const DeviceMemory<T>& x_bn_data,
4894     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4895     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4896     const DeviceMemory<U>& saved_mean_data,
4897     const DeviceMemory<U>& saved_var_data, DeviceMemory<T>* x_bn_backprop_data,
4898     DeviceMemory<U>* scale_backprop_data, DeviceMemory<U>* offset_backprop_data,
4899     dnn::ProfileResult* output_profile_result) {
4900   auto miopen = miopen_->GetHandle(parent_, stream);
4901 
4902   ScopedTensorDescriptor y_act_backprop_nd{
4903       y_act_backprop_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4904 
4905   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4906       scale_offset_mean_variance_descriptor,
4907       static_cast<miopenDataType_t>(miopen_type)};
4908 
4909   ScopedActivationDescriptor activation_desc{activation_mode};
4910 
4911   ScopedFusionPlanBatchNormActivationBackward fusion_plan{
4912       miopen.handle(), y_act_backprop_nd.handle(),
4913       scale_offset_mean_variance_nd.handle(), activation_desc};
4914 
4915   bool retval = false;
4916 
4917   if (fusion_plan.CompilationSucceeded()) {
4918     const bool is_profiling = output_profile_result != nullptr;
4919 
4920     std::unique_ptr<GpuTimer> timer;
4921     if (is_profiling) {
4922       timer.reset(new GpuTimer(parent_));
4923       timer->Init();
4924       timer->Start(AsGpuStream(stream));
4925     }
4926 
4927     miopenStatus_t status = miopenStatusSuccess;
4928 
4929     if (status == miopenStatusSuccess) {
4930       fusion_plan.SetBatchNormBackwardArgs(
4931           x_bn_data.opaque(), scale_data.opaque(), offset_data.opaque(),
4932           saved_mean_data.opaque(), saved_var_data.opaque(),
4933           scale_backprop_data->opaque(), offset_backprop_data->opaque());
4934     }
4935 
4936     if (status == miopenStatusSuccess) {
4937       status = fusion_plan.SetActivationBackwardArgs(activation_desc,
4938                                                      y_act_data.opaque());
4939     }
4940 
4941     if (status == miopenStatusSuccess) {
4942       status = fusion_plan.Execute(
4943           y_act_backprop_nd.handle(), y_act_backprop_data.opaque(),
4944           y_act_backprop_nd.handle(), x_bn_backprop_data->opaque());
4945     }
4946 
4947     if (is_profiling) {
4948       timer->Stop(AsGpuStream(stream));
4949       if (status == miopenStatusSuccess) {
4950         output_profile_result->set_elapsed_time_in_ms(
4951             timer->GetElapsedMilliseconds());
4952       }
4953       timer->Destroy();
4954     }
4955 
4956     if (status != miopenStatusSuccess) {
4957       // Silently return when we are profiling.
4958       if (!is_profiling) {
4959         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4960                    << ToString(status);
4961       }
4962     }
4963 
4964     retval = true;
4965   }
4966 
4967   return retval;
4968 }
4969 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4970 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4971     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4972     const DeviceMemory<float>& y_act_backprop_data,
4973     const DeviceMemory<float>& y_act_data, dnn::ActivationMode activation_mode,
4974     const DeviceMemory<float>& x_bn_data,
4975     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4976     const DeviceMemory<float>& scale_data,
4977     const DeviceMemory<float>& offset_data,
4978     const DeviceMemory<float>& saved_mean_data,
4979     const DeviceMemory<float>& saved_var_data,
4980     DeviceMemory<float>* x_bn_backprop_data,
4981     DeviceMemory<float>* scale_backprop_data,
4982     DeviceMemory<float>* offset_backprop_data,
4983     dnn::ProfileResult* output_profile_result) {
4984   return DoFusedBatchNormActivationBackwardImpl<float, float>(
4985       stream, miopenFloat, y_act_backprop_descriptor, y_act_backprop_data,
4986       y_act_data, activation_mode, x_bn_data,
4987       scale_offset_mean_variance_descriptor, scale_data, offset_data,
4988       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
4989       offset_backprop_data, output_profile_result);
4990 }
4991 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4992 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4993     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4994     const DeviceMemory<Eigen::half>& y_act_backprop_data,
4995     const DeviceMemory<Eigen::half>& y_act_data,
4996     dnn::ActivationMode activation_mode,
4997     const DeviceMemory<Eigen::half>& x_bn_data,
4998     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4999     const DeviceMemory<float>& scale_data,
5000     const DeviceMemory<float>& offset_data,
5001     const DeviceMemory<float>& saved_mean_data,
5002     const DeviceMemory<float>& saved_var_data,
5003     DeviceMemory<Eigen::half>* x_bn_backprop_data,
5004     DeviceMemory<float>* scale_backprop_data,
5005     DeviceMemory<float>* offset_backprop_data,
5006     dnn::ProfileResult* output_profile_result) {
5007   return DoFusedBatchNormActivationBackwardImpl<Eigen::half, float>(
5008       stream, miopenHalf, y_act_backprop_descriptor, y_act_backprop_data,
5009       y_act_data, activation_mode, x_bn_data,
5010       scale_offset_mean_variance_descriptor, scale_data, offset_data,
5011       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
5012       offset_backprop_data, output_profile_result);
5013 }
5014 
5015 }  // namespace gpu
5016 
initialize_miopen()5017 void initialize_miopen() {
5018   auto miopenAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
5019       rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
5020 
5021   if (!miopenAlreadyRegistered) {
5022     port::Status status =
5023         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
5024             rocm::kROCmPlatformId, gpu::kMIOpenPlugin, "MIOpen",
5025             [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
5026               gpu::GpuExecutor* rocm_executor =
5027                   dynamic_cast<gpu::GpuExecutor*>(parent);
5028               if (rocm_executor == nullptr) {
5029                 LOG(ERROR)
5030                     << "Attempting to initialize an instance of the MIOpen "
5031                     << "support library with a non-ROCM StreamExecutor";
5032                 return nullptr;
5033               }
5034 
5035               gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(rocm_executor);
5036               if (!dnn->Init().ok()) {
5037                 // Note: Init() will log a more specific error.
5038                 delete dnn;
5039                 return nullptr;
5040               }
5041               return dnn;
5042             });
5043 
5044     if (!status.ok()) {
5045       LOG(ERROR) << "Unable to register MIOpen factory: "
5046                  << status.error_message();
5047     }
5048 
5049     PluginRegistry::Instance()->SetDefaultFactory(
5050         rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
5051   }
5052 }
5053 
5054 }  // namespace stream_executor
5055 
5056 REGISTER_MODULE_INITIALIZER(register_miopen,
5057                             { stream_executor::initialize_miopen(); });
5058