xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_support_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_support_utils.h"
17 
18 #include <functional>
19 
20 #include "tensorflow/compiler/xla/primitive_util.h"
21 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
22 #include "tensorflow/core/platform/status.h"
23 
24 namespace xla {
25 namespace gpu {
26 
CudnnSupportsOptimizedIntegerConvolution(const se::CudaComputeCapability & compute_capability,HloCustomCallInstruction & conv,int vector_size)27 StatusOr<bool> CudnnSupportsOptimizedIntegerConvolution(
28     const se::CudaComputeCapability& compute_capability,
29     HloCustomCallInstruction& conv, int vector_size) {
30   TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(&conv));
31   const Shape& input_shape = conv.operand(0)->shape();
32   const Shape& kernel_shape = conv.operand(1)->shape();
33   const Shape& result_shape = conv.shape().tuple_shapes(0);
34   const auto& dnums = conv.convolution_dimension_numbers();
35 
36   // Only vectorization/padding of 4 or 32 for integers is supported.
37   if (vector_size != 4 && vector_size != 32) {
38     VLOG(3) << "Unsupported vector size for integer convolution: "
39             << vector_size;
40     return false;
41   }
42 
43   // Require cc6.1+ for any vectorized integer convolutions
44   // Require cc7.5+ for any IMMA convolutions
45   if ((vector_size == 32 && !compute_capability.IsAtLeast(7, 5)) ||
46       !compute_capability.IsAtLeast(6, 1)) {
47     VLOG(3) << "Compute capability " << compute_capability.ToString()
48             << " is not sufficent for int8x" << vector_size
49             << " vectorization.";
50     return false;
51   }
52 
53   // kForward and kForwardActivation only
54   if (kind != CudnnConvKind::kForward &&
55       kind != CudnnConvKind::kForwardActivation) {
56     VLOG(3) << "Convolution kind is not forward or foward-activation: "
57             << conv.ToString();
58     return false;
59   }
60 
61   // Integer inputs/weights only
62   if (!primitive_util::IsIntegralType(input_shape.element_type()) ||
63       !primitive_util::IsIntegralType(kernel_shape.element_type())) {
64     VLOG(3) << "Convolution does not accept integer inputs/weights: "
65             << conv.ToString();
66     return false;
67   }
68 
69   // 2D convolutions only
70   if (dnums.input_spatial_dimensions().size() != 2 ||
71       dnums.kernel_spatial_dimensions().size() != 2 ||
72       dnums.output_spatial_dimensions().size() != 2) {
73     VLOG(3) << "Convolution is not 2D: " << conv.ToString();
74     return false;
75   }
76 
77   // Only allow for int8x32 when output is also integer
78   if (vector_size == 32 &&
79       !primitive_util::IsIntegralType(result_shape.element_type())) {
80     VLOG(3) << "int8x32 convolutions only support integer output: "
81             << conv.ToString();
82     return false;
83   }
84 
85   // For int8x32 convolution check to see if the input/filter size are
86   // consistent with the limitation for cuDNN algo1. Per cuDNN release notes:
87   // "In INT8x32 Tensor Core cases, the parameters supported by cuDNN v7.6 are
88   // limited to W >= (R-1) * dilationW && H >= (S-1) * dilationH, whereas, in
89   // cuDNN v8.0.x, W == (R-1) * dilationW || H == (S-1) * dilationH cases are no
90   // longer supported."
91   //
92   // This check is more strict than necessary for cuDNN v7 (allowed for
93   // equality) to avoid checking the version of cuDNN explicitly.
94   if (vector_size == 32) {
95     int64_t W = input_shape.dimensions(dnums.input_spatial_dimensions()[0]);
96     int64_t H = input_shape.dimensions(dnums.input_spatial_dimensions()[1]);
97     int64_t R = kernel_shape.dimensions(dnums.kernel_spatial_dimensions()[0]);
98     int64_t S = kernel_shape.dimensions(dnums.kernel_spatial_dimensions()[1]);
99     const int64_t dilationW = conv.window().dimensions()[0].base_dilation();
100     const int64_t dilationH = conv.window().dimensions()[1].base_dilation();
101     if ((W <= (R - 1) * dilationW) || (H <= (S - 1) * dilationH)) {
102       VLOG(3) << "Conv spatial filter/input dimensions are too small for "
103                  "vecotrized int8x32 convolution: "
104               << conv.ToString();
105       return false;
106     }
107   }
108 
109   // Dilation is not supported with integer convs.
110   if (window_util::HasDilation(conv.window())) {
111     VLOG(3) << "Vectorized integer convolutions do not support dilation: "
112             << conv.ToString();
113     return false;
114   }
115 
116   return true;
117 }
118 }  // namespace gpu
119 }  // namespace xla
120