xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/stream_executor_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/layout.h"
22 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
23 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/statusor.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
29 #include "tensorflow/core/protobuf/autotuning.pb.h"
30 #include "tensorflow/stream_executor/kernel_spec.h"
31 
32 // Helper functions for interacting with StreamExecutor.
33 
34 namespace xla {
35 namespace gpu {
36 
37 // Returns (input, filter, output) XLA Layout protos given the StreamExecutor
38 // layouts.
39 StatusOr<std::tuple<Layout, Layout, Layout>>
40 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
41                                       se::dnn::DataLayout input,
42                                       se::dnn::FilterLayout filter,
43                                       se::dnn::DataLayout output);
44 
45 // Returns (input, filter, output) StreamExecutor layouts given the XLA layouts.
46 StatusOr<
47     std::tuple<se::dnn::DataLayout, se::dnn::FilterLayout, se::dnn::DataLayout>>
48 XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
49                                      const Shape& input, const Shape& filter,
50                                      const Shape& output);
51 
52 // Finds the VECT_C dimension in input/filter/output, if present.
53 //
54 // A cudnn convolution may have layout NCHW_VECT_C, which means instead of
55 // [N,C,H,W], the layout is [N,C/k,H,W,k] for some k (usually 4 or 32).
56 //
57 // ConvolutionDimensionNumbers doesn't explicitly store which is the `k`
58 // dimension, because only cudnn convolutions have this feature; it's not
59 // applicable elsewhere.  We find it by finding a dimension in the
60 // input/filter/output shape that is *not* in dnums.
61 std::tuple<std::optional<int64_t>, std::optional<int64_t>,
62            std::optional<int64_t>>
63 FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums,
64                           const Shape& input, const Shape& filter,
65                           const Shape& output);
66 
67 // Generates and returns a unique lock per the provided executor.
68 // Guarantees that blocks of code running for the same provided
69 // executor will not be running concurrently if they lock the returned mutex.
70 //
71 // This is used to prevent other XLA instances from trying to autotune on a
72 // device while another thread is using it.
73 absl::Mutex& GetGpuMutex(const se::StreamExecutor* stream_exec);
74 
75 // Creates a kernel with a provided name, based from provided PTX in ptx.
76 // The kernel should be executed using the provided executor.
77 // The argument cubin_data represents compiled PTX and may be left empty.
78 //
79 // The canonical storage for both ptx and cubin_data should outlive
80 // the lifetime of the kernel.
81 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
82     absl::string_view kernel_name, uint64_t num_args, absl::string_view ptx,
83     absl::Span<const uint8_t> cubin_data, se::StreamExecutor* stream_exec);
84 
85 // Runs loaded kernel on the stream with the provided arguments.
86 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
87                              absl::Span<const se::DeviceMemoryBase> args,
88                              const LaunchDimensions& dims, se::Stream* stream);
89 
90 // Initializes `buffer` with random data on `stream`.
91 // `rng_state` is an inout parameter for the pseudorandom generator state.
92 // `buffer_type` determines what buffer would be filled out with.
93 //
94 // Precondition: `buffer_type` is a floating point type, `rng_state` needs to be
95 // initialized to zero on the first use.
96 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
97                       int64_t* rng_state, se::DeviceMemoryBase buffer);
98 
99 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
100     CudnnConvKind kind);
101 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(PrimitiveType type);
102 
103 // Returns result with the smallest time which has not failed.
104 // If deterministic output is requested, returns first (not failing) result.
105 StatusOr<tensorflow::AutotuneResult> PickBestResult(
106     absl::Span<tensorflow::AutotuneResult const> profile_results,
107     const HloInstruction& instr);
108 
109 // Returns whether determinism is required.
110 bool RequireDeterminism(const HloModuleConfig& config);
111 
112 }  // namespace gpu
113 }  // namespace xla
114 
115 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
116