1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 18 19 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h" 20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 21 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 22 23 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 24 #include "tensorflow/stream_executor/gpu/gpu_types.h" 25 #endif 26 27 namespace xla { 28 namespace gpu { 29 30 // Thunk to run a GPU custom call. 31 // 32 // This thunk's `ExecuteOnStream` implementation executes a host function 33 // `call_target` which is expected to enqueue operations onto the GPU. 34 // 35 // Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. 36 // XLA itself creates kCustomCall instructions when lowering kConvolution HLOs 37 // into calls to cudnn. These internally-created custom-calls are run using 38 // ConvolutionThunk, not CustomCallThunk. There's no ambiguity because they 39 // have special call target names (e.g. "__cudnn$convForward") that only the 40 // compiler is allowed to create. 41 class CustomCallThunk : public Thunk { 42 public: 43 using OptionalSlice = ::std::optional<BufferAllocation::Slice>; 44 45 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 46 using Stream = stream_executor::gpu::GpuStreamHandle; 47 #else // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 48 using Stream = void*; 49 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 50 51 using CustomCallTarget = std::function<void(Stream, void**, const char*, 52 size_t, XlaCustomCallStatus*)>; 53 CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target, 54 std::vector<OptionalSlice> operands, 55 std::vector<OptionalSlice> results, 56 const std::string& opaque); 57 58 Status ExecuteOnStream(const ExecuteParams& params) override; 59 60 private: 61 const CustomCallTarget call_target_; 62 const std::vector<OptionalSlice> operands_; 63 const std::vector<OptionalSlice> results_; 64 const std::string opaque_; 65 }; 66 67 } // namespace gpu 68 } // namespace xla 69 70 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 71