1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19
20 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/compiler/xla/stream_executor/blas.h"
32 #include "tensorflow/compiler/xla/stream_executor/fft.h"
33 #include "tensorflow/compiler/xla/stream_executor/lib/env.h"
34 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
35 #include "tensorflow/compiler/xla/stream_executor/lib/stacktrace.h"
36 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
37 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
38 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
39 #include "tensorflow/compiler/xla/stream_executor/rng.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream.h"
41 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
42 #include "tensorflow/core/util/env_var.h"
43
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 } // namespace
47
48 namespace stream_executor {
49 namespace {
50
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52 if (VLOG_IS_ON(10)) {
53 return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54 } else {
55 return "";
56 }
57 }
58
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool* executor) {
62 absl::Notification n;
63 executor->Schedule([&n]() { n.Notify(); });
64 n.WaitForNotification();
65 }
66
67 std::atomic_int_fast64_t correlation_id_generator(0);
68
69 } // namespace
70
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72 typename... BeginArgsT>
73 class ScopedTracer {
74 public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75 ScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call,
76 CompleteCallT complete_call, const ReturnT* result,
77 BeginArgsT... begin_args)
78 : stream_exec_(stream_exec),
79 complete_call_(complete_call),
80 result_(result) {
81 if (stream_exec_->tracing_enabled_) {
82 correlation_id_ =
83 correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84 Trace(begin_call, begin_args...);
85 }
86 }
87
~ScopedTracer()88 ~ScopedTracer() {
89 if (stream_exec_->tracing_enabled_) {
90 Trace(complete_call_, result_);
91 }
92 }
93
94 private:
95 template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96 void Trace(CallbackT callback, TraceArgsT... args) {
97 {
98 // Instance tracers held in a block to limit the lock lifetime.
99 absl::ReaderMutexLock lock{&stream_exec_->mu_};
100 for (TraceListener* listener : stream_exec_->listeners_) {
101 (listener->*callback)(correlation_id_,
102 std::forward<TraceArgsT>(args)...);
103 }
104 }
105 }
106
107 StreamExecutor* stream_exec_;
108 CompleteCallT complete_call_;
109 const ReturnT* result_;
110 int64_t correlation_id_;
111 };
112
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114 typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call,
117 CompleteCallT complete_call, ReturnT* result,
118 BeginArgsT... begin_args) {
119 return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120 stream_exec, begin_call, complete_call, result,
121 std::forward<BeginArgsT>(begin_args)...);
122 }
123
124 #define SCOPED_TRACE(LOC, ...) \
125 auto tracer = \
126 MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64_t GetMemoryLimitBytes() {
133 int64_t value;
134 SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135 0, &value));
136 return value * (1ll << 20);
137 }
138
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140 const Platform* platform,
141 std::unique_ptr<internal::StreamExecutorInterface> implementation,
142 int device_ordinal)
143 : platform_(platform),
144 implementation_(std::move(implementation)),
145 device_ordinal_(device_ordinal),
146 background_threads_(new port::ThreadPool(
147 port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148 live_stream_count_(0),
149 tracing_enabled_(false),
150 mem_alloc_bytes_(0),
151 memory_limit_bytes_(GetMemoryLimitBytes()),
152 allocator_(this) {
153 std::string name = absl::AsciiStrToLower(platform_->Name());
154 if (name == "cuda") {
155 platform_kind_ = PlatformKind::kCuda;
156 } else if (name == "rocm") {
157 platform_kind_ = PlatformKind::kROCm;
158 } else if (name == "opencl") {
159 platform_kind_ = PlatformKind::kOpenCL;
160 } else if (name == "host") {
161 platform_kind_ = PlatformKind::kHost;
162 } else {
163 platform_kind_ = PlatformKind::kInvalid;
164 }
165 }
166
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168 BlockOnThreadExecutor(background_threads_.get());
169
170 if (live_stream_count_.load() != 0) {
171 LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172 << "time. This may lead to unexpected/bad behavior - "
173 << "especially if any stream is still active!";
174 }
175
176 if (FLAGS_check_device_leaks) {
177 for (const auto& it : mem_allocs_) {
178 LOG(INFO) << "Memory alloced at executor exit: addr: "
179 << absl::StrFormat("%p", it.first)
180 << ", bytes: " << it.second.bytes << ", trace: \n"
181 << it.second.stack_trace;
182 }
183 }
184 }
185
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187 return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
193 KernelBase* kernel) {
194 return implementation_->GetKernel(spec, kernel);
195 }
196
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase* kernel) {
198 implementation_->UnloadKernel(kernel);
199 }
200
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
202 ModuleHandle* module_handle) {
203 return implementation_->LoadModule(spec, module_handle);
204 }
205
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207 return implementation_->UnloadModule(module_handle);
208 }
209
210 port::StatusOr<std::shared_ptr<DeviceMemoryBase>>
CreateOrShareConstant(Stream * stream,const std::vector<uint8_t> & content)211 StreamExecutor::CreateOrShareConstant(Stream* stream,
212 const std::vector<uint8_t>& content) {
213 return implementation_->CreateOrShareConstant(stream, std::move(content));
214 }
215
Deallocate(DeviceMemoryBase * mem)216 void StreamExecutor::Deallocate(DeviceMemoryBase* mem) {
217 VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
218 << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
219
220 if (mem->opaque() != nullptr) {
221 EraseAllocRecord(mem->opaque());
222 }
223 implementation_->Deallocate(mem);
224 mem->Reset(nullptr, 0);
225 }
226
GetMemAllocs(std::map<void *,AllocRecord> * records_out)227 void StreamExecutor::GetMemAllocs(std::map<void*, AllocRecord>* records_out) {
228 absl::ReaderMutexLock lock(&mu_);
229 *records_out = mem_allocs_;
230 }
231
CanEnablePeerAccessTo(StreamExecutor * other)232 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor* other) {
233 return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
234 }
235
EnablePeerAccessTo(StreamExecutor * other)236 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor* other) {
237 return implementation_->EnablePeerAccessTo(other->implementation_.get());
238 }
239
GetDeviceDescription() const240 const DeviceDescription& StreamExecutor::GetDeviceDescription() const {
241 absl::MutexLock lock(&mu_);
242 if (device_description_ != nullptr) {
243 return *device_description_;
244 }
245
246 device_description_ = CreateDeviceDescription();
247 return *device_description_;
248 }
249
GetDeviceLoad() const250 int64_t StreamExecutor::GetDeviceLoad() const {
251 return implementation_->GetDeviceLoad();
252 }
253
PlatformDeviceCount() const254 int StreamExecutor::PlatformDeviceCount() const {
255 return implementation_->PlatformDeviceCount();
256 }
257
SupportsBlas() const258 bool StreamExecutor::SupportsBlas() const {
259 return implementation_->SupportsBlas();
260 }
261
SupportsRng() const262 bool StreamExecutor::SupportsRng() const {
263 return implementation_->SupportsRng();
264 }
265
SupportsDnn() const266 bool StreamExecutor::SupportsDnn() const {
267 return implementation_->SupportsDnn();
268 }
269
GetConvolveAlgorithms(dnn::ConvolutionKind kind,std::vector<dnn::AlgorithmDesc> * out_algorithms)270 bool StreamExecutor::GetConvolveAlgorithms(
271 dnn::ConvolutionKind kind,
272 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
273 dnn::DnnSupport* dnn_support = AsDnn();
274 if (!dnn_support) {
275 return false;
276 }
277 switch (kind) {
278 default:
279 return false;
280 case dnn::ConvolutionKind::FORWARD:
281 case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION:
282 return dnn_support->GetConvolveAlgorithms(
283 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
284 case dnn::ConvolutionKind::BACKWARD_DATA:
285 return dnn_support->GetConvolveBackwardDataAlgorithms(
286 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
287 case dnn::ConvolutionKind::BACKWARD_FILTER:
288 return dnn_support->GetConvolveBackwardFilterAlgorithms(
289 GetDeviceDescription().cuda_compute_capability(), out_algorithms);
290 }
291 }
292
GetConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,ScratchAllocator * scratch_allocator,std::vector<std::unique_ptr<const dnn::ConvRunner>> * out_exec_plans)293 port::Status StreamExecutor::GetConvolveRunners(
294 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
295 dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
296 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
297 const dnn::FilterDescriptor& filter_descriptor,
298 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
299 DeviceMemoryBase output_data,
300 const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
301 ScratchAllocator* scratch_allocator,
302 std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans) {
303 dnn::DnnSupport* dnn_support = AsDnn();
304 if (!dnn_support) {
305 return port::UnimplementedError("DNN library is not found.");
306 }
307 return dnn_support->GetConvolveRunners(
308 use_cudnn_frontend, kind, input_type, output_type, stream,
309 input_descriptor, input_data, filter_descriptor, filter_data,
310 output_descriptor, output_data, convolution_descriptor, use_fallback,
311 scratch_allocator, out_exec_plans);
312 }
313
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)314 port::Status StreamExecutor::GetFusedConvolveRunners(
315 bool use_cudnn_frontend, dnn::ConvolutionKind kind,
316 dnn::DataType input_type, dnn::DataType bias_type,
317 dnn::DataType output_type, double conv_input_scale, double side_input_scale,
318 double leakyrelu_alpha, Stream* stream,
319 const dnn::BatchDescriptor& input_descriptor,
320 const dnn::FilterDescriptor& filter_descriptor,
321 const dnn::BatchDescriptor& bias_descriptor,
322 const dnn::BatchDescriptor& output_descriptor,
323 const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
324 dnn::ActivationMode activation_mode,
325 std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
326 dnn::DnnSupport* dnn_support = AsDnn();
327 if (!dnn_support) {
328 return port::UnimplementedError("DNN library is not found.");
329 }
330 return dnn_support->GetFusedConvolveRunners(
331 use_cudnn_frontend, kind, input_type, bias_type, output_type,
332 conv_input_scale, side_input_scale, leakyrelu_alpha, stream,
333 input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
334 convolution_descriptor, use_fallback, activation_mode, out_exec_plans);
335 }
336
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)337 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
338 dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
339 const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
340 const dnn::FilterDescriptor& filter_descriptor,
341 DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
342 DeviceMemoryBase output_data,
343 const dnn::ConvolutionDescriptor& convolution_descriptor,
344 ScratchAllocator* scratch_allocator,
345 std::vector<dnn::ProfileResult>* out_algorithms) {
346 dnn::DnnSupport* dnn_support = AsDnn();
347 if (!dnn_support) {
348 return false;
349 }
350 return dnn_support->GetMIOpenConvolveAlgorithms(
351 kind, element_type, stream, input_descriptor, input_data,
352 filter_descriptor, filter_data, output_descriptor, output_data,
353 convolution_descriptor, scratch_allocator, out_algorithms);
354 }
355
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)356 bool StreamExecutor::GetRnnAlgorithms(
357 std::vector<dnn::AlgorithmDesc>* out_algorithms) {
358 dnn::DnnSupport* dnn_support = AsDnn();
359 if (!dnn_support) {
360 return false;
361 }
362 return dnn_support->GetRnnAlgorithms(out_algorithms);
363 }
364
GetBlasGemmAlgorithms(Stream * stream,std::vector<blas::AlgorithmType> * out_algorithms)365 bool StreamExecutor::GetBlasGemmAlgorithms(
366 Stream* stream, std::vector<blas::AlgorithmType>* out_algorithms) {
367 blas::BlasSupport* blas_support = AsBlas();
368 if (!blas_support) {
369 return false;
370 }
371 return blas_support->GetBlasGemmAlgorithms(stream, out_algorithms);
372 }
373
374 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)375 StreamExecutor::createRnnDescriptor(
376 int num_layers, int hidden_size, int input_size, int cell_size,
377 int batch_size, dnn::RnnInputMode input_mode,
378 dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
379 dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
380 float dropout, uint64_t seed, ScratchAllocator* state_allocator,
381 bool use_padded_io) {
382 dnn::DnnSupport* dnn_support = AsDnn();
383 if (!dnn_support) {
384 return port::Status(port::error::UNKNOWN,
385 "Fail to find the dnn implementation.");
386 }
387 return dnn_support->createRnnDescriptor(
388 num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
389 direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
390 state_allocator, use_padded_io);
391 }
392
393 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)394 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
395 int batch_size, int data_size,
396 dnn::DataType data_type) {
397 dnn::DnnSupport* dnn_support = AsDnn();
398 if (!dnn_support) {
399 return port::Status(port::error::UNKNOWN,
400 "Fail to find the dnn implementation.");
401 }
402 return dnn_support->createRnnSequenceTensorDescriptor(
403 max_seq_length, batch_size, data_size, data_type);
404 }
405
406 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)407 StreamExecutor::createRnnSequenceTensorDescriptor(
408 int max_seq_length, int batch_size, int data_size,
409 const absl::Span<const int>& seq_lengths, bool time_major,
410 dnn::DataType data_type) {
411 dnn::DnnSupport* dnn_support = AsDnn();
412 if (!dnn_support) {
413 return port::Status(port::error::UNKNOWN,
414 "Fail to find the dnn implementation.");
415 }
416 return dnn_support->createRnnSequenceTensorDescriptor(
417 max_seq_length, batch_size, data_size, seq_lengths, time_major,
418 data_type);
419 }
420
421 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)422 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
423 int data_size,
424 dnn::DataType data_type) {
425 dnn::DnnSupport* dnn_support = AsDnn();
426 if (!dnn_support) {
427 return port::Status(port::error::UNKNOWN,
428 "Fail to find the dnn implementation.");
429 }
430 return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
431 data_size, data_type);
432 }
433
AsDnn()434 dnn::DnnSupport* StreamExecutor::AsDnn() {
435 absl::MutexLock lock(&mu_);
436 if (dnn_ != nullptr) {
437 return dnn_.get();
438 }
439
440 dnn_.reset(implementation_->CreateDnn());
441 return dnn_.get();
442 }
443
AsBlas()444 blas::BlasSupport* StreamExecutor::AsBlas() {
445 absl::MutexLock lock(&mu_);
446 if (blas_ != nullptr) {
447 return blas_.get();
448 }
449
450 blas_.reset(implementation_->CreateBlas());
451 return blas_.get();
452 }
453
AsFft()454 fft::FftSupport* StreamExecutor::AsFft() {
455 absl::MutexLock lock(&mu_);
456 if (fft_ != nullptr) {
457 return fft_.get();
458 }
459
460 fft_.reset(implementation_->CreateFft());
461 return fft_.get();
462 }
463
AsRng()464 rng::RngSupport* StreamExecutor::AsRng() {
465 absl::MutexLock lock(&mu_);
466 if (rng_ != nullptr) {
467 return rng_.get();
468 }
469
470 rng_.reset(implementation_->CreateRng());
471 return rng_.get();
472 }
473
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)474 port::Status StreamExecutor::Launch(Stream* stream,
475 const ThreadDim& thread_dims,
476 const BlockDim& block_dims,
477 const KernelBase& kernel,
478 const KernelArgsArrayBase& args) {
479 SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
480 kernel, args);
481
482 return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
483 }
484
BlockHostUntilDone(Stream * stream)485 port::Status StreamExecutor::BlockHostUntilDone(Stream* stream) {
486 port::Status result;
487 SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
488
489 result = implementation_->BlockHostUntilDone(stream);
490 return result;
491 }
492
GetStatus(Stream * stream)493 port::Status StreamExecutor::GetStatus(Stream* stream) {
494 return implementation_->GetStatus(stream);
495 }
496
Allocate(uint64_t size,int64_t memory_space)497 DeviceMemoryBase StreamExecutor::Allocate(uint64_t size, int64_t memory_space) {
498 if (memory_limit_bytes_ > 0 &&
499 static_cast<int64_t>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
500 LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
501 << device_ordinal_
502 << " within provided limit. [used=" << mem_alloc_bytes_
503 << ", limit=" << memory_limit_bytes_ << "]";
504 return DeviceMemoryBase();
505 }
506 DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
507 VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
508 << ", memory_space=" << memory_space << ") returns " << buf.opaque()
509 << StackTraceIfVLOG10();
510 CreateAllocRecord(buf.opaque(), size);
511
512 return buf;
513 }
514
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)515 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
516 const std::string& symbol_name, ModuleHandle module_handle) {
517 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
518 // be nullptr/0 for consistency with DeviceMemory semantics.
519 void* opaque = nullptr;
520 size_t bytes = 0;
521 if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
522 return DeviceMemoryBase(opaque, bytes);
523 }
524
525 return port::Status(
526 port::error::NOT_FOUND,
527 absl::StrCat("Check if module containing symbol ", symbol_name,
528 " is loaded (module_handle = ",
529 reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
530 }
531
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)532 bool StreamExecutor::GetSymbol(const std::string& symbol_name,
533 ModuleHandle module_handle, void** mem,
534 size_t* bytes) {
535 return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
536 }
537
UnifiedMemoryAllocate(uint64_t bytes)538 void* StreamExecutor::UnifiedMemoryAllocate(uint64_t bytes) {
539 void* buffer = implementation_->UnifiedMemoryAllocate(bytes);
540 VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
541 << ") returns " << buffer << StackTraceIfVLOG10();
542 return buffer;
543 }
544
UnifiedMemoryDeallocate(void * location)545 void StreamExecutor::UnifiedMemoryDeallocate(void* location) {
546 VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
547 << location << ")" << StackTraceIfVLOG10();
548
549 return implementation_->UnifiedMemoryDeallocate(location);
550 }
551
HostMemoryAllocate(uint64_t size)552 void* StreamExecutor::HostMemoryAllocate(uint64_t size) {
553 void* buffer = implementation_->HostMemoryAllocate(size);
554 VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
555 << ") returns " << buffer << StackTraceIfVLOG10();
556 return buffer;
557 }
558
HostMemoryDeallocate(void * location)559 void StreamExecutor::HostMemoryDeallocate(void* location) {
560 VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
561 << ")" << StackTraceIfVLOG10();
562
563 return implementation_->HostMemoryDeallocate(location);
564 }
565
HostMemoryRegister(void * location,uint64_t size)566 bool StreamExecutor::HostMemoryRegister(void* location, uint64_t size) {
567 VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
568 << ", size=" << size << ")" << StackTraceIfVLOG10();
569 if (location == nullptr || size == 0) {
570 LOG(WARNING) << "attempting to register null or zero-sized memory: "
571 << location << "; size " << size;
572 }
573 return implementation_->HostMemoryRegister(location, size);
574 }
575
HostMemoryUnregister(void * location)576 bool StreamExecutor::HostMemoryUnregister(void* location) {
577 VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
578 << ")" << StackTraceIfVLOG10();
579 return implementation_->HostMemoryUnregister(location);
580 }
581
SynchronizeAllActivity()582 bool StreamExecutor::SynchronizeAllActivity() {
583 VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
584 << StackTraceIfVLOG10();
585 bool ok = implementation_->SynchronizeAllActivity();
586
587 // This should all be quick and infallible work, so we can perform the
588 // synchronization even in the case of failure.
589 BlockOnThreadExecutor(background_threads_.get());
590
591 return ok;
592 }
593
SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)594 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase* location,
595 uint64_t size) {
596 VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
597 << ", size=" << size << ")" << StackTraceIfVLOG10();
598
599 return implementation_->SynchronousMemZero(location, size);
600 }
601
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)602 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase* location,
603 int value, uint64_t size) {
604 VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
605 << ", value=" << value << ", size=" << size << ")"
606 << StackTraceIfVLOG10();
607
608 return implementation_->SynchronousMemSet(location, value, size);
609 }
610
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)611 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst,
612 const void* host_src, uint64_t size) {
613 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
614 << device_dst->opaque() << ", host_src=" << host_src
615 << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
616
617 // Tracing overloaded methods is very difficult due to issues with type
618 // inference on template args. Since use of these overloaded methods is
619 // discouraged anyway, this isn't a huge deal.
620 port::Status status =
621 implementation_->SynchronousMemcpy(device_dst, host_src, size);
622 if (!status.ok()) {
623 LOG(ERROR) << "synchronous memcpy: " << status;
624 }
625 return status.ok();
626 }
627
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64_t size)628 bool StreamExecutor::SynchronousMemcpy(void* host_dst,
629 const DeviceMemoryBase& device_src,
630 uint64_t size) {
631 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
632 << ", device_src=" << device_src.opaque() << ", size=" << size
633 << ") D2H" << StackTraceIfVLOG10();
634
635 port::Status status =
636 implementation_->SynchronousMemcpy(host_dst, device_src, size);
637 if (!status.ok()) {
638 LOG(ERROR) << "synchronous memcpy: " << status;
639 }
640 return status.ok();
641 }
642
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64_t size)643 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst,
644 const DeviceMemoryBase& device_src,
645 uint64_t size) {
646 VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
647 << device_dst->opaque() << ", device_src=" << device_src.opaque()
648 << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
649
650 port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
651 device_dst, device_src, size);
652 if (!status.ok()) {
653 LOG(ERROR) << "synchronous memcpy: " << status;
654 }
655 return status.ok();
656 }
657
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64_t size,void * host_dst)658 port::Status StreamExecutor::SynchronousMemcpyD2H(
659 const DeviceMemoryBase& device_src, int64_t size, void* host_dst) {
660 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
661 << device_src.opaque() << ", size=" << size
662 << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
663
664 port::Status result;
665 SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
666 host_dst);
667
668 result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
669 if (!result.ok()) {
670 result = port::Status(
671 port::error::INTERNAL,
672 absl::StrFormat("failed to synchronously memcpy device-to-host: device "
673 "%p to host %p size %d: %s",
674 device_src.opaque(), host_dst, size,
675 result.ToString()));
676 }
677
678 return result;
679 }
680
SynchronousMemcpyH2D(const void * host_src,int64_t size,DeviceMemoryBase * device_dst)681 port::Status StreamExecutor::SynchronousMemcpyH2D(
682 const void* host_src, int64_t size, DeviceMemoryBase* device_dst) {
683 VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
684 << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
685 << StackTraceIfVLOG10();
686
687 port::Status result;
688 SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
689 device_dst);
690
691 result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
692 if (!result.ok()) {
693 result = port::Status(
694 port::error::INTERNAL,
695 absl::StrFormat("failed to synchronously memcpy host-to-device: host "
696 "%p to device %p size %d: %s",
697 host_src, device_dst->opaque(), size,
698 result.ToString()));
699 }
700
701 return result;
702 }
703
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64_t size)704 bool StreamExecutor::Memcpy(Stream* stream, void* host_dst,
705 const DeviceMemoryBase& device_src, uint64_t size) {
706 return implementation_->Memcpy(stream, host_dst, device_src, size);
707 }
708
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)709 bool StreamExecutor::Memcpy(Stream* stream, DeviceMemoryBase* device_dst,
710 const void* host_src, uint64_t size) {
711 return implementation_->Memcpy(stream, device_dst, host_src, size);
712 }
713
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64_t size)714 bool StreamExecutor::MemcpyDeviceToDevice(Stream* stream,
715 DeviceMemoryBase* device_dst,
716 const DeviceMemoryBase& device_src,
717 uint64_t size) {
718 return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
719 size);
720 }
721
MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)722 port::Status StreamExecutor::MemZero(Stream* stream, DeviceMemoryBase* location,
723 uint64_t size) {
724 return implementation_->MemZero(stream, location, size);
725 }
726
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64_t size)727 port::Status StreamExecutor::Memset32(Stream* stream,
728 DeviceMemoryBase* location,
729 uint32 pattern, uint64_t size) {
730 CHECK_EQ(0, size % 4)
731 << "need 32-bit multiple size to fill with 32-bit pattern";
732 return implementation_->Memset32(stream, location, pattern, size);
733 }
734
HostCallback(Stream * stream,std::function<void ()> callback)735 bool StreamExecutor::HostCallback(Stream* stream,
736 std::function<void()> callback) {
737 return implementation_->HostCallback(stream, std::move(callback));
738 }
739
HostCallback(Stream * stream,std::function<port::Status ()> callback)740 bool StreamExecutor::HostCallback(Stream* stream,
741 std::function<port::Status()> callback) {
742 return implementation_->HostCallback(stream, std::move(callback));
743 }
744
AllocateEvent(Event * event)745 port::Status StreamExecutor::AllocateEvent(Event* event) {
746 return implementation_->AllocateEvent(event);
747 }
748
DeallocateEvent(Event * event)749 port::Status StreamExecutor::DeallocateEvent(Event* event) {
750 return implementation_->DeallocateEvent(event);
751 }
752
RecordEvent(Stream * stream,Event * event)753 port::Status StreamExecutor::RecordEvent(Stream* stream, Event* event) {
754 return implementation_->RecordEvent(stream, event);
755 }
756
WaitForEvent(Stream * stream,Event * event)757 port::Status StreamExecutor::WaitForEvent(Stream* stream, Event* event) {
758 return implementation_->WaitForEvent(stream, event);
759 }
760
PollForEventStatus(Event * event)761 Event::Status StreamExecutor::PollForEventStatus(Event* event) {
762 return implementation_->PollForEventStatus(event);
763 }
764
AllocateStream(Stream * stream)765 bool StreamExecutor::AllocateStream(Stream* stream) {
766 live_stream_count_.fetch_add(1, std::memory_order_relaxed);
767 if (!implementation_->AllocateStream(stream)) {
768 auto count = live_stream_count_.fetch_sub(1);
769 CHECK_GE(count, 0) << "live stream count should not dip below zero";
770 LOG(INFO) << "failed to allocate stream; live stream count: " << count;
771 return false;
772 }
773
774 return true;
775 }
776
DeallocateStream(Stream * stream)777 void StreamExecutor::DeallocateStream(Stream* stream) {
778 dnn::DnnSupport* dnn;
779 {
780 absl::MutexLock lock(&mu_);
781 dnn = dnn_.get();
782 }
783 if (dnn) {
784 dnn->NotifyStreamDestroyed(stream);
785 }
786 implementation_->DeallocateStream(stream);
787 CHECK_GE(live_stream_count_.fetch_sub(1), 0)
788 << "live stream count should not dip below zero";
789 }
790
CreateStreamDependency(Stream * dependent,Stream * other)791 bool StreamExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
792 return implementation_->CreateStreamDependency(dependent, other);
793 }
794
AllocateTimer(Timer * timer)795 bool StreamExecutor::AllocateTimer(Timer* timer) {
796 return implementation_->AllocateTimer(timer);
797 }
798
DeallocateTimer(Timer * timer)799 void StreamExecutor::DeallocateTimer(Timer* timer) {
800 return implementation_->DeallocateTimer(timer);
801 }
802
StartTimer(Stream * stream,Timer * timer)803 bool StreamExecutor::StartTimer(Stream* stream, Timer* timer) {
804 return implementation_->StartTimer(stream, timer);
805 }
806
StopTimer(Stream * stream,Timer * timer)807 bool StreamExecutor::StopTimer(Stream* stream, Timer* timer) {
808 return implementation_->StopTimer(stream, timer);
809 }
810
CreateDeviceDescription() const811 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
812 const {
813 return implementation_->CreateDeviceDescription().value();
814 }
815
DeviceMemoryUsage(int64_t * free,int64_t * total) const816 bool StreamExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
817 return implementation_->DeviceMemoryUsage(free, total);
818 }
819
EnqueueOnBackgroundThread(std::function<void ()> task)820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
821 background_threads_->Schedule(std::move(task));
822 }
823
CreateAllocRecord(void * opaque,uint64_t bytes)824 void StreamExecutor::CreateAllocRecord(void* opaque, uint64_t bytes) {
825 if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
826 absl::MutexLock lock(&mu_);
827 mem_allocs_[opaque] = AllocRecord{bytes, ""};
828 mem_alloc_bytes_ += bytes;
829 }
830 }
831
EraseAllocRecord(void * opaque)832 void StreamExecutor::EraseAllocRecord(void* opaque) {
833 if (FLAGS_check_device_leaks && opaque != nullptr) {
834 absl::MutexLock lock(&mu_);
835 if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
836 LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
837 } else {
838 mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
839 mem_allocs_.erase(opaque);
840 }
841 }
842 }
843
EnableTracing(bool enabled)844 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
845
RegisterTraceListener(TraceListener * listener)846 void StreamExecutor::RegisterTraceListener(TraceListener* listener) {
847 {
848 absl::MutexLock lock(&mu_);
849 if (listeners_.find(listener) != listeners_.end()) {
850 LOG(INFO) << "Attempt to register already-registered listener, "
851 << listener;
852 } else {
853 listeners_.insert(listener);
854 }
855 }
856
857 implementation_->RegisterTraceListener(listener);
858 }
859
UnregisterTraceListener(TraceListener * listener)860 bool StreamExecutor::UnregisterTraceListener(TraceListener* listener) {
861 {
862 absl::MutexLock lock(&mu_);
863 if (listeners_.find(listener) == listeners_.end()) {
864 LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
865 return false;
866 }
867 listeners_.erase(listener);
868 }
869
870 implementation_->UnregisterTraceListener(listener);
871 return true;
872 }
873
GetAllocatorStats()874 std::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
875 return implementation_->GetAllocatorStats();
876 }
877
ClearAllocatorStats()878 bool StreamExecutor::ClearAllocatorStats() {
879 return implementation_->ClearAllocatorStats();
880 }
881
882 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)883 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT&&... args) {
884 if (tracing_enabled_) {
885 {
886 // instance tracers held in a block to limit the lock lifetime.
887 absl::ReaderMutexLock lock(&mu_);
888 for (TraceListener* listener : listeners_) {
889 (listener->*trace_call)(std::forward<ArgsT>(args)...);
890 }
891 }
892 }
893 }
894
implementation()895 internal::StreamExecutorInterface* StreamExecutor::implementation() {
896 return implementation_->GetUnderlyingExecutor();
897 }
898
StreamExecutorMemoryAllocator(StreamExecutor * executor)899 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
900 StreamExecutor* executor)
901 : DeviceMemoryAllocator(executor->platform()) {
902 stream_executors_ = {executor};
903 }
904
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)905 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
906 const Platform* platform,
907 absl::Span<StreamExecutor* const> stream_executors)
908 : DeviceMemoryAllocator(platform),
909 stream_executors_(stream_executors.begin(), stream_executors.end()) {}
910
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure,int64_t memory_space)911 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
912 int device_ordinal, uint64_t size, bool retry_on_failure,
913 int64_t memory_space) {
914 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
915 GetStreamExecutor(device_ordinal));
916 DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
917 if (size > 0 && result == nullptr) {
918 return tensorflow::errors::ResourceExhausted(absl::StrFormat(
919 "Failed to allocate request for %s (%uB) on device ordinal %d",
920 tensorflow::strings::HumanReadableNumBytes(size), size,
921 device_ordinal));
922 }
923 VLOG(3) << absl::StreamFormat(
924 "Allocated %s (%uB) on device ordinal %d: %p",
925 tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
926 result.opaque());
927 return OwningDeviceMemory(result, device_ordinal, this);
928 }
929
Deallocate(int device_ordinal,DeviceMemoryBase mem)930 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
931 DeviceMemoryBase mem) {
932 if (!mem.is_null()) {
933 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
934 GetStreamExecutor(device_ordinal));
935 VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
936 mem.opaque(), device_ordinal);
937 executor->Deallocate(&mem);
938 }
939 return ::tensorflow::OkStatus();
940 }
941
942 port::StatusOr<StreamExecutor*>
GetStreamExecutor(int device_ordinal) const943 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
944 if (device_ordinal < 0) {
945 return tensorflow::errors::InvalidArgument(absl::StrFormat(
946 "device ordinal value (%d) must be non-negative", device_ordinal));
947 }
948 for (StreamExecutor* se : stream_executors_) {
949 if (se->device_ordinal() == device_ordinal) {
950 return se;
951 }
952 }
953 return tensorflow::errors::NotFound(
954 absl::StrFormat("Device %s:%d present but not supported",
955 platform()->Name(), device_ordinal));
956 }
957
AllowsAsynchronousDeallocation() const958 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
959 return false;
960 }
961
GetStream(int device_ordinal)962 port::StatusOr<Stream*> StreamExecutorMemoryAllocator::GetStream(
963 int device_ordinal) {
964 CHECK(!AllowsAsynchronousDeallocation())
965 << "The logic below only works for synchronous allocators";
966 TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
967 GetStreamExecutor(device_ordinal));
968 Stream* out = [&] {
969 absl::MutexLock lock(&mutex_);
970 if (!streams_.count(device_ordinal)) {
971 auto p = streams_.emplace(std::piecewise_construct,
972 std::forward_as_tuple(device_ordinal),
973 std::forward_as_tuple(executor));
974 p.first->second.Init();
975 return &p.first->second;
976 }
977 return &streams_.at(device_ordinal);
978 }();
979 return out;
980 }
981
982 } // namespace stream_executor
983