xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cublas_cudnn.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_CUDNN_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_CUDNN_H_
18 
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 #include "tensorflow/core/platform/statusor.h"
21 
22 namespace xla {
23 namespace gpu {
24 
25 // Different types of convolutions supported by cudnn.
26 //
27 // A way to think about these is that a convolution is defined by three arrays
28 // -- the "input", the "filter", and the "output" -- and given any two of these,
29 // we can compute the third.  For example, a backward-input convolution takes as
30 // input a filter and an "output" and produces an "input" such that if one were
31 // to do a forward convolution of "input" using filter, the result would be
32 // something with the same shape as "output".
33 //
34 // This way of thinking is not correct if you look at the values produced. For
35 // example, a backward-input convolution is not actually the mathematical
36 // inverse of a forward convolution.  But it's right as far as the shapes and
37 // "connectivity" (i.e. which elements of the input affect which elements of
38 // the output) are concerned.
39 enum class CudnnConvKind {
40   kForward,            // input  + filter => output
41   kBackwardInput,      // filter + output => input
42   kBackwardFilter,     // input  + output => filter
43   kForwardActivation,  // activation(conv(input, filter) + broadcast(bias) +
44                        // (optionally) side_input) => output
45 };
46 
47 StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
48 
49 // Converts a CudnnConvKind value to a string.
50 std::string CudnnConvKindToString(CudnnConvKind kind);
51 
52 // Matrix multiplication rewritten into a GEMM custom call.
53 // All matrix multiplications should be rewritten as such custom calls
54 // after a GemmRewriter lowering pass.
55 bool IsCublasGemm(const HloInstruction& hlo);
56 
57 // Matrix multiplication that calls into cublasLt.
58 bool IsCublasLtMatmul(const HloInstruction& hlo);
59 
60 // A call to cuBLAS general matrix multiplication API.
61 extern const char* const kGemmCallTarget;
62 
63 // A call to cuBLAS Lt API matrix multiplication.
64 extern const char* const kCublasLtMatmulCallTarget;
65 
66 // A call to cuBLAS for a triangular solve.
67 //
68 // Like cudnn convolutions, this op returns a tuple (result, scratch_memory).
69 extern const char* const kTriangularSolveCallTarget;
70 
71 // A call to cuDNN for convolution (forward, backward filter, or backward input)
72 // is represented as a CustomCall HLO with a call target equal to one of these
73 // strings.
74 //
75 // These CustomCalls have window() and convolution_dimension_numbers() set like
76 // regular convolution ops.  They have the same LHS and RHS operands, plus two
77 // additional constant operands: an int64_t operand for the cudnn algorithm and
78 // a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
79 // algorithm means that the implementation is free to choose the best algorithm
80 // it can.
81 //
82 // These calls output a tuple (conv_result, scratch_memory), where conv_result
83 // is the actual result of the convolution, and scratch_memory is temporary
84 // memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
85 // is not well-defined.
86 //
87 // GpuConvRewriter lowers kConvolution HLOs to these custom calls.
88 // When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
89 // on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit
90 // algorithm for each conv and sets the amount of scratch space needed.
91 //
92 // (Representing the scratch memory as an output may seem strange at first, but
93 // it's quite sensible, from a certain point of view.  The scratch buffer is a
94 // location in memory that the conv can write into, but which it can't legally
95 // read from, at least until it's written something first.  But that's exactly
96 // the definition of an output buffer.)
97 extern const char* const kCudnnConvForwardCallTarget;
98 extern const char* const kCudnnConvBackwardInputCallTarget;
99 extern const char* const kCudnnConvBackwardFilterCallTarget;
100 extern const char* const kCudnnConvBiasActivationForwardCallTarget;
101 
102 // Returns true if `hlo` will be implemented as a call to a cuDNN convolution
103 // routine.
104 //
105 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
106 // one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
107 // kConvolution opcode.
108 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
109 
110 }  // namespace gpu
111 }  // namespace xla
112 
113 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_CUDNN_H_
114