1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h" 17 18 namespace xla { 19 namespace gpu { 20 IsCublasGemm(const HloInstruction & hlo)21bool IsCublasGemm(const HloInstruction& hlo) { 22 return hlo.opcode() == HloOpcode::kCustomCall && 23 hlo.custom_call_target() == kGemmCallTarget; 24 } 25 IsCublasLtMatmul(const HloInstruction & hlo)26bool IsCublasLtMatmul(const HloInstruction& hlo) { 27 return hlo.opcode() == HloOpcode::kCustomCall && 28 hlo.custom_call_target() == kCublasLtMatmulCallTarget; 29 } 30 31 const char* const kGemmCallTarget = "__cublas$gemm"; 32 const char* const kCublasLtMatmulCallTarget = "__cublas$lt$matmul"; 33 const char* const kTriangularSolveCallTarget = "__cublas$triangularSolve"; 34 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; 35 const char* const kCudnnConvBackwardInputCallTarget = 36 "__cudnn$convBackwardInput"; 37 const char* const kCudnnConvBackwardFilterCallTarget = 38 "__cudnn$convBackwardFilter"; 39 const char* const kCudnnConvBiasActivationForwardCallTarget = 40 "__cudnn$convBiasActivationForward"; 41 IsCustomCallToDnnConvolution(const HloInstruction & hlo)42bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { 43 if (hlo.opcode() != HloOpcode::kCustomCall) { 44 return false; 45 } 46 const auto& target = hlo.custom_call_target(); 47 return target == kCudnnConvForwardCallTarget || 48 target == kCudnnConvBackwardInputCallTarget || 49 target == kCudnnConvBackwardFilterCallTarget || 50 target == kCudnnConvBiasActivationForwardCallTarget; 51 } 52 GetCudnnConvKind(const HloCustomCallInstruction * instr)53StatusOr<CudnnConvKind> GetCudnnConvKind( 54 const HloCustomCallInstruction* instr) { 55 absl::string_view target = instr->custom_call_target(); 56 if (target == kCudnnConvForwardCallTarget) { 57 return CudnnConvKind::kForward; 58 } 59 if (target == kCudnnConvBackwardInputCallTarget) { 60 return CudnnConvKind::kBackwardInput; 61 } 62 if (target == kCudnnConvBackwardFilterCallTarget) { 63 return CudnnConvKind::kBackwardFilter; 64 } 65 if (target == kCudnnConvBiasActivationForwardCallTarget) { 66 return CudnnConvKind::kForwardActivation; 67 } 68 return InternalError("Unexpected call target: %s", target); 69 } 70 CudnnConvKindToString(CudnnConvKind kind)71std::string CudnnConvKindToString(CudnnConvKind kind) { 72 switch (kind) { 73 case CudnnConvKind::kForward: 74 return "forward"; 75 case CudnnConvKind::kBackwardFilter: 76 return "backward_filter"; 77 case CudnnConvKind::kBackwardInput: 78 return "backward_input"; 79 case CudnnConvKind::kForwardActivation: 80 return "forward with activation"; 81 } 82 } 83 84 } // namespace gpu 85 } // namespace xla 86