xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the StreamExecutor interface by passing through to its
17 // implementation_ value (in pointer-to-implementation style), which
18 // implements StreamExecutorInterface.
19 
20 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <utility>
25 
26 #include "absl/base/const_init.h"
27 #include "absl/strings/ascii.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/synchronization/notification.h"
31 #include "tensorflow/compiler/xla/stream_executor/blas.h"
32 #include "tensorflow/compiler/xla/stream_executor/fft.h"
33 #include "tensorflow/compiler/xla/stream_executor/lib/env.h"
34 #include "tensorflow/compiler/xla/stream_executor/lib/error.h"
35 #include "tensorflow/compiler/xla/stream_executor/lib/stacktrace.h"
36 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
37 #include "tensorflow/compiler/xla/stream_executor/lib/threadpool.h"
38 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
39 #include "tensorflow/compiler/xla/stream_executor/rng.h"
40 #include "tensorflow/compiler/xla/stream_executor/stream.h"
41 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
42 #include "tensorflow/core/util/env_var.h"
43 
44 namespace {
45 bool FLAGS_check_device_leaks = false;
46 }  // namespace
47 
48 namespace stream_executor {
49 namespace {
50 
StackTraceIfVLOG10()51 std::string StackTraceIfVLOG10() {
52   if (VLOG_IS_ON(10)) {
53     return absl::StrCat(" ", port::CurrentStackTrace(), "\n");
54   } else {
55     return "";
56   }
57 }
58 
59 // Make sure the executor is done with its work; we know (because this isn't
60 // publicly visible) that all enqueued work is quick.
BlockOnThreadExecutor(port::ThreadPool * executor)61 void BlockOnThreadExecutor(port::ThreadPool* executor) {
62   absl::Notification n;
63   executor->Schedule([&n]() { n.Notify(); });
64   n.WaitForNotification();
65 }
66 
67 std::atomic_int_fast64_t correlation_id_generator(0);
68 
69 }  // namespace
70 
71 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
72           typename... BeginArgsT>
73 class ScopedTracer {
74  public:
ScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,const ReturnT * result,BeginArgsT...begin_args)75   ScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call,
76                CompleteCallT complete_call, const ReturnT* result,
77                BeginArgsT... begin_args)
78       : stream_exec_(stream_exec),
79         complete_call_(complete_call),
80         result_(result) {
81     if (stream_exec_->tracing_enabled_) {
82       correlation_id_ =
83           correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1;
84       Trace(begin_call, begin_args...);
85     }
86   }
87 
~ScopedTracer()88   ~ScopedTracer() {
89     if (stream_exec_->tracing_enabled_) {
90       Trace(complete_call_, result_);
91     }
92   }
93 
94  private:
95   template <typename CallbackT, typename... TraceArgsT>
Trace(CallbackT callback,TraceArgsT...args)96   void Trace(CallbackT callback, TraceArgsT... args) {
97     {
98       // Instance tracers held in a block to limit the lock lifetime.
99       absl::ReaderMutexLock lock{&stream_exec_->mu_};
100       for (TraceListener* listener : stream_exec_->listeners_) {
101         (listener->*callback)(correlation_id_,
102                               std::forward<TraceArgsT>(args)...);
103       }
104     }
105   }
106 
107   StreamExecutor* stream_exec_;
108   CompleteCallT complete_call_;
109   const ReturnT* result_;
110   int64_t correlation_id_;
111 };
112 
113 template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
114           typename... BeginArgsT>
115 ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>
MakeScopedTracer(StreamExecutor * stream_exec,BeginCallT begin_call,CompleteCallT complete_call,ReturnT * result,BeginArgsT...begin_args)116 MakeScopedTracer(StreamExecutor* stream_exec, BeginCallT begin_call,
117                  CompleteCallT complete_call, ReturnT* result,
118                  BeginArgsT... begin_args) {
119   return ScopedTracer<BeginCallT, CompleteCallT, ReturnT, BeginArgsT...>(
120       stream_exec, begin_call, complete_call, result,
121       std::forward<BeginArgsT>(begin_args)...);
122 }
123 
124 #define SCOPED_TRACE(LOC, ...) \
125   auto tracer =                \
126       MakeScopedTracer(this, &LOC##Begin, &LOC##Complete, ##__VA_ARGS__);
127 
128 /* static */ absl::Mutex StreamExecutor::static_mu_{absl::kConstInit};
129 
130 // Get per-device memory limit in bytes. Returns 0 if
131 // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
GetMemoryLimitBytes()132 static int64_t GetMemoryLimitBytes() {
133   int64_t value;
134   SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
135                                               0, &value));
136   return value * (1ll << 20);
137 }
138 
StreamExecutor(const Platform * platform,std::unique_ptr<internal::StreamExecutorInterface> implementation,int device_ordinal)139 StreamExecutor::StreamExecutor(
140     const Platform* platform,
141     std::unique_ptr<internal::StreamExecutorInterface> implementation,
142     int device_ordinal)
143     : platform_(platform),
144       implementation_(std::move(implementation)),
145       device_ordinal_(device_ordinal),
146       background_threads_(new port::ThreadPool(
147           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
148       live_stream_count_(0),
149       tracing_enabled_(false),
150       mem_alloc_bytes_(0),
151       memory_limit_bytes_(GetMemoryLimitBytes()),
152       allocator_(this) {
153   std::string name = absl::AsciiStrToLower(platform_->Name());
154   if (name == "cuda") {
155     platform_kind_ = PlatformKind::kCuda;
156   } else if (name == "rocm") {
157     platform_kind_ = PlatformKind::kROCm;
158   } else if (name == "opencl") {
159     platform_kind_ = PlatformKind::kOpenCL;
160   } else if (name == "host") {
161     platform_kind_ = PlatformKind::kHost;
162   } else {
163     platform_kind_ = PlatformKind::kInvalid;
164   }
165 }
166 
~StreamExecutor()167 StreamExecutor::~StreamExecutor() {
168   BlockOnThreadExecutor(background_threads_.get());
169 
170   if (live_stream_count_.load() != 0) {
171     LOG(WARNING) << "Not all streams were deallocated at executor destruction "
172                  << "time. This may lead to unexpected/bad behavior - "
173                  << "especially if any stream is still active!";
174   }
175 
176   if (FLAGS_check_device_leaks) {
177     for (const auto& it : mem_allocs_) {
178       LOG(INFO) << "Memory alloced at executor exit: addr: "
179                 << absl::StrFormat("%p", it.first)
180                 << ", bytes: " << it.second.bytes << ", trace: \n"
181                 << it.second.stack_trace;
182     }
183   }
184 }
185 
Init(DeviceOptions device_options)186 port::Status StreamExecutor::Init(DeviceOptions device_options) {
187   return implementation_->Init(device_ordinal_, std::move(device_options));
188 }
189 
Init()190 port::Status StreamExecutor::Init() { return Init(DeviceOptions::Default()); }
191 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)192 port::Status StreamExecutor::GetKernel(const MultiKernelLoaderSpec& spec,
193                                        KernelBase* kernel) {
194   return implementation_->GetKernel(spec, kernel);
195 }
196 
UnloadKernel(const KernelBase * kernel)197 void StreamExecutor::UnloadKernel(const KernelBase* kernel) {
198   implementation_->UnloadKernel(kernel);
199 }
200 
LoadModule(const MultiModuleLoaderSpec & spec,ModuleHandle * module_handle)201 port::Status StreamExecutor::LoadModule(const MultiModuleLoaderSpec& spec,
202                                         ModuleHandle* module_handle) {
203   return implementation_->LoadModule(spec, module_handle);
204 }
205 
UnloadModule(ModuleHandle module_handle)206 bool StreamExecutor::UnloadModule(ModuleHandle module_handle) {
207   return implementation_->UnloadModule(module_handle);
208 }
209 
210 port::StatusOr<std::shared_ptr<DeviceMemoryBase>>
CreateOrShareConstant(Stream * stream,const std::vector<uint8_t> & content)211 StreamExecutor::CreateOrShareConstant(Stream* stream,
212                                       const std::vector<uint8_t>& content) {
213   return implementation_->CreateOrShareConstant(stream, std::move(content));
214 }
215 
Deallocate(DeviceMemoryBase * mem)216 void StreamExecutor::Deallocate(DeviceMemoryBase* mem) {
217   VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
218           << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
219 
220   if (mem->opaque() != nullptr) {
221     EraseAllocRecord(mem->opaque());
222   }
223   implementation_->Deallocate(mem);
224   mem->Reset(nullptr, 0);
225 }
226 
GetMemAllocs(std::map<void *,AllocRecord> * records_out)227 void StreamExecutor::GetMemAllocs(std::map<void*, AllocRecord>* records_out) {
228   absl::ReaderMutexLock lock(&mu_);
229   *records_out = mem_allocs_;
230 }
231 
CanEnablePeerAccessTo(StreamExecutor * other)232 bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor* other) {
233   return implementation_->CanEnablePeerAccessTo(other->implementation_.get());
234 }
235 
EnablePeerAccessTo(StreamExecutor * other)236 port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor* other) {
237   return implementation_->EnablePeerAccessTo(other->implementation_.get());
238 }
239 
GetDeviceDescription() const240 const DeviceDescription& StreamExecutor::GetDeviceDescription() const {
241   absl::MutexLock lock(&mu_);
242   if (device_description_ != nullptr) {
243     return *device_description_;
244   }
245 
246   device_description_ = CreateDeviceDescription();
247   return *device_description_;
248 }
249 
GetDeviceLoad() const250 int64_t StreamExecutor::GetDeviceLoad() const {
251   return implementation_->GetDeviceLoad();
252 }
253 
PlatformDeviceCount() const254 int StreamExecutor::PlatformDeviceCount() const {
255   return implementation_->PlatformDeviceCount();
256 }
257 
SupportsBlas() const258 bool StreamExecutor::SupportsBlas() const {
259   return implementation_->SupportsBlas();
260 }
261 
SupportsRng() const262 bool StreamExecutor::SupportsRng() const {
263   return implementation_->SupportsRng();
264 }
265 
SupportsDnn() const266 bool StreamExecutor::SupportsDnn() const {
267   return implementation_->SupportsDnn();
268 }
269 
GetConvolveAlgorithms(dnn::ConvolutionKind kind,std::vector<dnn::AlgorithmDesc> * out_algorithms)270 bool StreamExecutor::GetConvolveAlgorithms(
271     dnn::ConvolutionKind kind,
272     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
273   dnn::DnnSupport* dnn_support = AsDnn();
274   if (!dnn_support) {
275     return false;
276   }
277   switch (kind) {
278     default:
279       return false;
280     case dnn::ConvolutionKind::FORWARD:
281     case dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION:
282       return dnn_support->GetConvolveAlgorithms(
283           GetDeviceDescription().cuda_compute_capability(), out_algorithms);
284     case dnn::ConvolutionKind::BACKWARD_DATA:
285       return dnn_support->GetConvolveBackwardDataAlgorithms(
286           GetDeviceDescription().cuda_compute_capability(), out_algorithms);
287     case dnn::ConvolutionKind::BACKWARD_FILTER:
288       return dnn_support->GetConvolveBackwardFilterAlgorithms(
289           GetDeviceDescription().cuda_compute_capability(), out_algorithms);
290   }
291 }
292 
GetConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,ScratchAllocator * scratch_allocator,std::vector<std::unique_ptr<const dnn::ConvRunner>> * out_exec_plans)293 port::Status StreamExecutor::GetConvolveRunners(
294     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
295     dnn::DataType input_type, dnn::DataType output_type, Stream* stream,
296     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
297     const dnn::FilterDescriptor& filter_descriptor,
298     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
299     DeviceMemoryBase output_data,
300     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
301     ScratchAllocator* scratch_allocator,
302     std::vector<std::unique_ptr<const dnn::ConvRunner>>* out_exec_plans) {
303   dnn::DnnSupport* dnn_support = AsDnn();
304   if (!dnn_support) {
305     return port::UnimplementedError("DNN library is not found.");
306   }
307   return dnn_support->GetConvolveRunners(
308       use_cudnn_frontend, kind, input_type, output_type, stream,
309       input_descriptor, input_data, filter_descriptor, filter_data,
310       output_descriptor, output_data, convolution_descriptor, use_fallback,
311       scratch_allocator, out_exec_plans);
312 }
313 
GetFusedConvolveRunners(bool use_cudnn_frontend,dnn::ConvolutionKind kind,dnn::DataType input_type,dnn::DataType bias_type,dnn::DataType output_type,double conv_input_scale,double side_input_scale,double leakyrelu_alpha,Stream * stream,const dnn::BatchDescriptor & input_descriptor,const dnn::FilterDescriptor & filter_descriptor,const dnn::BatchDescriptor & bias_descriptor,const dnn::BatchDescriptor & output_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,bool use_fallback,dnn::ActivationMode activation_mode,std::vector<std::unique_ptr<const dnn::FusedConvRunner>> * out_exec_plans)314 port::Status StreamExecutor::GetFusedConvolveRunners(
315     bool use_cudnn_frontend, dnn::ConvolutionKind kind,
316     dnn::DataType input_type, dnn::DataType bias_type,
317     dnn::DataType output_type, double conv_input_scale, double side_input_scale,
318     double leakyrelu_alpha, Stream* stream,
319     const dnn::BatchDescriptor& input_descriptor,
320     const dnn::FilterDescriptor& filter_descriptor,
321     const dnn::BatchDescriptor& bias_descriptor,
322     const dnn::BatchDescriptor& output_descriptor,
323     const dnn::ConvolutionDescriptor& convolution_descriptor, bool use_fallback,
324     dnn::ActivationMode activation_mode,
325     std::vector<std::unique_ptr<const dnn::FusedConvRunner>>* out_exec_plans) {
326   dnn::DnnSupport* dnn_support = AsDnn();
327   if (!dnn_support) {
328     return port::UnimplementedError("DNN library is not found.");
329   }
330   return dnn_support->GetFusedConvolveRunners(
331       use_cudnn_frontend, kind, input_type, bias_type, output_type,
332       conv_input_scale, side_input_scale, leakyrelu_alpha, stream,
333       input_descriptor, filter_descriptor, bias_descriptor, output_descriptor,
334       convolution_descriptor, use_fallback, activation_mode, out_exec_plans);
335 }
336 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)337 bool StreamExecutor::GetMIOpenConvolveAlgorithms(
338     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
339     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
340     const dnn::FilterDescriptor& filter_descriptor,
341     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
342     DeviceMemoryBase output_data,
343     const dnn::ConvolutionDescriptor& convolution_descriptor,
344     ScratchAllocator* scratch_allocator,
345     std::vector<dnn::ProfileResult>* out_algorithms) {
346   dnn::DnnSupport* dnn_support = AsDnn();
347   if (!dnn_support) {
348     return false;
349   }
350   return dnn_support->GetMIOpenConvolveAlgorithms(
351       kind, element_type, stream, input_descriptor, input_data,
352       filter_descriptor, filter_data, output_descriptor, output_data,
353       convolution_descriptor, scratch_allocator, out_algorithms);
354 }
355 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)356 bool StreamExecutor::GetRnnAlgorithms(
357     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
358   dnn::DnnSupport* dnn_support = AsDnn();
359   if (!dnn_support) {
360     return false;
361   }
362   return dnn_support->GetRnnAlgorithms(out_algorithms);
363 }
364 
GetBlasGemmAlgorithms(Stream * stream,std::vector<blas::AlgorithmType> * out_algorithms)365 bool StreamExecutor::GetBlasGemmAlgorithms(
366     Stream* stream, std::vector<blas::AlgorithmType>* out_algorithms) {
367   blas::BlasSupport* blas_support = AsBlas();
368   if (!blas_support) {
369     return false;
370   }
371   return blas_support->GetBlasGemmAlgorithms(stream, out_algorithms);
372 }
373 
374 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64_t seed,ScratchAllocator * state_allocator,bool use_padded_io)375 StreamExecutor::createRnnDescriptor(
376     int num_layers, int hidden_size, int input_size, int cell_size,
377     int batch_size, dnn::RnnInputMode input_mode,
378     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
379     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
380     float dropout, uint64_t seed, ScratchAllocator* state_allocator,
381     bool use_padded_io) {
382   dnn::DnnSupport* dnn_support = AsDnn();
383   if (!dnn_support) {
384     return port::Status(port::error::UNKNOWN,
385                         "Fail to find the dnn implementation.");
386   }
387   return dnn_support->createRnnDescriptor(
388       num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
389       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
390       state_allocator, use_padded_io);
391 }
392 
393 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,dnn::DataType data_type)394 StreamExecutor::createRnnSequenceTensorDescriptor(int max_seq_length,
395                                                   int batch_size, int data_size,
396                                                   dnn::DataType data_type) {
397   dnn::DnnSupport* dnn_support = AsDnn();
398   if (!dnn_support) {
399     return port::Status(port::error::UNKNOWN,
400                         "Fail to find the dnn implementation.");
401   }
402   return dnn_support->createRnnSequenceTensorDescriptor(
403       max_seq_length, batch_size, data_size, data_type);
404 }
405 
406 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int max_seq_length,int batch_size,int data_size,const absl::Span<const int> & seq_lengths,bool time_major,dnn::DataType data_type)407 StreamExecutor::createRnnSequenceTensorDescriptor(
408     int max_seq_length, int batch_size, int data_size,
409     const absl::Span<const int>& seq_lengths, bool time_major,
410     dnn::DataType data_type) {
411   dnn::DnnSupport* dnn_support = AsDnn();
412   if (!dnn_support) {
413     return port::Status(port::error::UNKNOWN,
414                         "Fail to find the dnn implementation.");
415   }
416   return dnn_support->createRnnSequenceTensorDescriptor(
417       max_seq_length, batch_size, data_size, seq_lengths, time_major,
418       data_type);
419 }
420 
421 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)422 StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size,
423                                                int data_size,
424                                                dnn::DataType data_type) {
425   dnn::DnnSupport* dnn_support = AsDnn();
426   if (!dnn_support) {
427     return port::Status(port::error::UNKNOWN,
428                         "Fail to find the dnn implementation.");
429   }
430   return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size,
431                                                      data_size, data_type);
432 }
433 
AsDnn()434 dnn::DnnSupport* StreamExecutor::AsDnn() {
435   absl::MutexLock lock(&mu_);
436   if (dnn_ != nullptr) {
437     return dnn_.get();
438   }
439 
440   dnn_.reset(implementation_->CreateDnn());
441   return dnn_.get();
442 }
443 
AsBlas()444 blas::BlasSupport* StreamExecutor::AsBlas() {
445   absl::MutexLock lock(&mu_);
446   if (blas_ != nullptr) {
447     return blas_.get();
448   }
449 
450   blas_.reset(implementation_->CreateBlas());
451   return blas_.get();
452 }
453 
AsFft()454 fft::FftSupport* StreamExecutor::AsFft() {
455   absl::MutexLock lock(&mu_);
456   if (fft_ != nullptr) {
457     return fft_.get();
458   }
459 
460   fft_.reset(implementation_->CreateFft());
461   return fft_.get();
462 }
463 
AsRng()464 rng::RngSupport* StreamExecutor::AsRng() {
465   absl::MutexLock lock(&mu_);
466   if (rng_ != nullptr) {
467     return rng_.get();
468   }
469 
470   rng_.reset(implementation_->CreateRng());
471   return rng_.get();
472 }
473 
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)474 port::Status StreamExecutor::Launch(Stream* stream,
475                                     const ThreadDim& thread_dims,
476                                     const BlockDim& block_dims,
477                                     const KernelBase& kernel,
478                                     const KernelArgsArrayBase& args) {
479   SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims,
480               kernel, args);
481 
482   return implementation_->Launch(stream, thread_dims, block_dims, kernel, args);
483 }
484 
BlockHostUntilDone(Stream * stream)485 port::Status StreamExecutor::BlockHostUntilDone(Stream* stream) {
486   port::Status result;
487   SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream);
488 
489   result = implementation_->BlockHostUntilDone(stream);
490   return result;
491 }
492 
GetStatus(Stream * stream)493 port::Status StreamExecutor::GetStatus(Stream* stream) {
494   return implementation_->GetStatus(stream);
495 }
496 
Allocate(uint64_t size,int64_t memory_space)497 DeviceMemoryBase StreamExecutor::Allocate(uint64_t size, int64_t memory_space) {
498   if (memory_limit_bytes_ > 0 &&
499       static_cast<int64_t>(mem_alloc_bytes_ + size) > memory_limit_bytes_) {
500     LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
501                  << device_ordinal_
502                  << " within provided limit. [used=" << mem_alloc_bytes_
503                  << ", limit=" << memory_limit_bytes_ << "]";
504     return DeviceMemoryBase();
505   }
506   DeviceMemoryBase buf = implementation_->Allocate(size, memory_space);
507   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
508           << ", memory_space=" << memory_space << ") returns " << buf.opaque()
509           << StackTraceIfVLOG10();
510   CreateAllocRecord(buf.opaque(), size);
511 
512   return buf;
513 }
514 
GetUntypedSymbol(const std::string & symbol_name,ModuleHandle module_handle)515 port::StatusOr<DeviceMemoryBase> StreamExecutor::GetUntypedSymbol(
516     const std::string& symbol_name, ModuleHandle module_handle) {
517   // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to
518   // be nullptr/0 for consistency with DeviceMemory semantics.
519   void* opaque = nullptr;
520   size_t bytes = 0;
521   if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) {
522     return DeviceMemoryBase(opaque, bytes);
523   }
524 
525   return port::Status(
526       port::error::NOT_FOUND,
527       absl::StrCat("Check if module containing symbol ", symbol_name,
528                    " is loaded (module_handle = ",
529                    reinterpret_cast<uintptr_t>(module_handle.id()), ")"));
530 }
531 
GetSymbol(const std::string & symbol_name,ModuleHandle module_handle,void ** mem,size_t * bytes)532 bool StreamExecutor::GetSymbol(const std::string& symbol_name,
533                                ModuleHandle module_handle, void** mem,
534                                size_t* bytes) {
535   return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes);
536 }
537 
UnifiedMemoryAllocate(uint64_t bytes)538 void* StreamExecutor::UnifiedMemoryAllocate(uint64_t bytes) {
539   void* buffer = implementation_->UnifiedMemoryAllocate(bytes);
540   VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes
541           << ") returns " << buffer << StackTraceIfVLOG10();
542   return buffer;
543 }
544 
UnifiedMemoryDeallocate(void * location)545 void StreamExecutor::UnifiedMemoryDeallocate(void* location) {
546   VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location="
547           << location << ")" << StackTraceIfVLOG10();
548 
549   return implementation_->UnifiedMemoryDeallocate(location);
550 }
551 
HostMemoryAllocate(uint64_t size)552 void* StreamExecutor::HostMemoryAllocate(uint64_t size) {
553   void* buffer = implementation_->HostMemoryAllocate(size);
554   VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
555           << ") returns " << buffer << StackTraceIfVLOG10();
556   return buffer;
557 }
558 
HostMemoryDeallocate(void * location)559 void StreamExecutor::HostMemoryDeallocate(void* location) {
560   VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
561           << ")" << StackTraceIfVLOG10();
562 
563   return implementation_->HostMemoryDeallocate(location);
564 }
565 
HostMemoryRegister(void * location,uint64_t size)566 bool StreamExecutor::HostMemoryRegister(void* location, uint64_t size) {
567   VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
568           << ", size=" << size << ")" << StackTraceIfVLOG10();
569   if (location == nullptr || size == 0) {
570     LOG(WARNING) << "attempting to register null or zero-sized memory: "
571                  << location << "; size " << size;
572   }
573   return implementation_->HostMemoryRegister(location, size);
574 }
575 
HostMemoryUnregister(void * location)576 bool StreamExecutor::HostMemoryUnregister(void* location) {
577   VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
578           << ")" << StackTraceIfVLOG10();
579   return implementation_->HostMemoryUnregister(location);
580 }
581 
SynchronizeAllActivity()582 bool StreamExecutor::SynchronizeAllActivity() {
583   VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
584           << StackTraceIfVLOG10();
585   bool ok = implementation_->SynchronizeAllActivity();
586 
587   // This should all be quick and infallible work, so we can perform the
588   // synchronization even in the case of failure.
589   BlockOnThreadExecutor(background_threads_.get());
590 
591   return ok;
592 }
593 
SynchronousMemZero(DeviceMemoryBase * location,uint64_t size)594 port::Status StreamExecutor::SynchronousMemZero(DeviceMemoryBase* location,
595                                                 uint64_t size) {
596   VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
597           << ", size=" << size << ")" << StackTraceIfVLOG10();
598 
599   return implementation_->SynchronousMemZero(location, size);
600 }
601 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64_t size)602 port::Status StreamExecutor::SynchronousMemSet(DeviceMemoryBase* location,
603                                                int value, uint64_t size) {
604   VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
605           << ", value=" << value << ", size=" << size << ")"
606           << StackTraceIfVLOG10();
607 
608   return implementation_->SynchronousMemSet(location, value, size);
609 }
610 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)611 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst,
612                                        const void* host_src, uint64_t size) {
613   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
614           << device_dst->opaque() << ", host_src=" << host_src
615           << ", size=" << size << ") H2D" << StackTraceIfVLOG10();
616 
617   // Tracing overloaded methods is very difficult due to issues with type
618   // inference on template args. Since use of these overloaded methods is
619   // discouraged anyway, this isn't a huge deal.
620   port::Status status =
621       implementation_->SynchronousMemcpy(device_dst, host_src, size);
622   if (!status.ok()) {
623     LOG(ERROR) << "synchronous memcpy: " << status;
624   }
625   return status.ok();
626 }
627 
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & device_src,uint64_t size)628 bool StreamExecutor::SynchronousMemcpy(void* host_dst,
629                                        const DeviceMemoryBase& device_src,
630                                        uint64_t size) {
631   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
632           << ", device_src=" << device_src.opaque() << ", size=" << size
633           << ") D2H" << StackTraceIfVLOG10();
634 
635   port::Status status =
636       implementation_->SynchronousMemcpy(host_dst, device_src, size);
637   if (!status.ok()) {
638     LOG(ERROR) << "synchronous memcpy: " << status;
639   }
640   return status.ok();
641 }
642 
SynchronousMemcpy(DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64_t size)643 bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase* device_dst,
644                                        const DeviceMemoryBase& device_src,
645                                        uint64_t size) {
646   VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst="
647           << device_dst->opaque() << ", device_src=" << device_src.opaque()
648           << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
649 
650   port::Status status = implementation_->SynchronousMemcpyDeviceToDevice(
651       device_dst, device_src, size);
652   if (!status.ok()) {
653     LOG(ERROR) << "synchronous memcpy: " << status;
654   }
655   return status.ok();
656 }
657 
SynchronousMemcpyD2H(const DeviceMemoryBase & device_src,int64_t size,void * host_dst)658 port::Status StreamExecutor::SynchronousMemcpyD2H(
659     const DeviceMemoryBase& device_src, int64_t size, void* host_dst) {
660   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src="
661           << device_src.opaque() << ", size=" << size
662           << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10();
663 
664   port::Status result;
665   SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size,
666                host_dst);
667 
668   result = implementation_->SynchronousMemcpy(host_dst, device_src, size);
669   if (!result.ok()) {
670     result = port::Status(
671         port::error::INTERNAL,
672         absl::StrFormat("failed to synchronously memcpy device-to-host: device "
673                         "%p to host %p size %d: %s",
674                         device_src.opaque(), host_dst, size,
675                         result.ToString()));
676   }
677 
678   return result;
679 }
680 
SynchronousMemcpyH2D(const void * host_src,int64_t size,DeviceMemoryBase * device_dst)681 port::Status StreamExecutor::SynchronousMemcpyH2D(
682     const void* host_src, int64_t size, DeviceMemoryBase* device_dst) {
683   VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
684           << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
685           << StackTraceIfVLOG10();
686 
687   port::Status result;
688   SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size,
689                device_dst);
690 
691   result = implementation_->SynchronousMemcpy(device_dst, host_src, size);
692   if (!result.ok()) {
693     result = port::Status(
694         port::error::INTERNAL,
695         absl::StrFormat("failed to synchronously memcpy host-to-device: host "
696                         "%p to device %p size %d: %s",
697                         host_src, device_dst->opaque(), size,
698                         result.ToString()));
699   }
700 
701   return result;
702 }
703 
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & device_src,uint64_t size)704 bool StreamExecutor::Memcpy(Stream* stream, void* host_dst,
705                             const DeviceMemoryBase& device_src, uint64_t size) {
706   return implementation_->Memcpy(stream, host_dst, device_src, size);
707 }
708 
Memcpy(Stream * stream,DeviceMemoryBase * device_dst,const void * host_src,uint64_t size)709 bool StreamExecutor::Memcpy(Stream* stream, DeviceMemoryBase* device_dst,
710                             const void* host_src, uint64_t size) {
711   return implementation_->Memcpy(stream, device_dst, host_src, size);
712 }
713 
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * device_dst,const DeviceMemoryBase & device_src,uint64_t size)714 bool StreamExecutor::MemcpyDeviceToDevice(Stream* stream,
715                                           DeviceMemoryBase* device_dst,
716                                           const DeviceMemoryBase& device_src,
717                                           uint64_t size) {
718   return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src,
719                                                size);
720 }
721 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64_t size)722 port::Status StreamExecutor::MemZero(Stream* stream, DeviceMemoryBase* location,
723                                      uint64_t size) {
724   return implementation_->MemZero(stream, location, size);
725 }
726 
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64_t size)727 port::Status StreamExecutor::Memset32(Stream* stream,
728                                       DeviceMemoryBase* location,
729                                       uint32 pattern, uint64_t size) {
730   CHECK_EQ(0, size % 4)
731       << "need 32-bit multiple size to fill with 32-bit pattern";
732   return implementation_->Memset32(stream, location, pattern, size);
733 }
734 
HostCallback(Stream * stream,std::function<void ()> callback)735 bool StreamExecutor::HostCallback(Stream* stream,
736                                   std::function<void()> callback) {
737   return implementation_->HostCallback(stream, std::move(callback));
738 }
739 
HostCallback(Stream * stream,std::function<port::Status ()> callback)740 bool StreamExecutor::HostCallback(Stream* stream,
741                                   std::function<port::Status()> callback) {
742   return implementation_->HostCallback(stream, std::move(callback));
743 }
744 
AllocateEvent(Event * event)745 port::Status StreamExecutor::AllocateEvent(Event* event) {
746   return implementation_->AllocateEvent(event);
747 }
748 
DeallocateEvent(Event * event)749 port::Status StreamExecutor::DeallocateEvent(Event* event) {
750   return implementation_->DeallocateEvent(event);
751 }
752 
RecordEvent(Stream * stream,Event * event)753 port::Status StreamExecutor::RecordEvent(Stream* stream, Event* event) {
754   return implementation_->RecordEvent(stream, event);
755 }
756 
WaitForEvent(Stream * stream,Event * event)757 port::Status StreamExecutor::WaitForEvent(Stream* stream, Event* event) {
758   return implementation_->WaitForEvent(stream, event);
759 }
760 
PollForEventStatus(Event * event)761 Event::Status StreamExecutor::PollForEventStatus(Event* event) {
762   return implementation_->PollForEventStatus(event);
763 }
764 
AllocateStream(Stream * stream)765 bool StreamExecutor::AllocateStream(Stream* stream) {
766   live_stream_count_.fetch_add(1, std::memory_order_relaxed);
767   if (!implementation_->AllocateStream(stream)) {
768     auto count = live_stream_count_.fetch_sub(1);
769     CHECK_GE(count, 0) << "live stream count should not dip below zero";
770     LOG(INFO) << "failed to allocate stream; live stream count: " << count;
771     return false;
772   }
773 
774   return true;
775 }
776 
DeallocateStream(Stream * stream)777 void StreamExecutor::DeallocateStream(Stream* stream) {
778   dnn::DnnSupport* dnn;
779   {
780     absl::MutexLock lock(&mu_);
781     dnn = dnn_.get();
782   }
783   if (dnn) {
784     dnn->NotifyStreamDestroyed(stream);
785   }
786   implementation_->DeallocateStream(stream);
787   CHECK_GE(live_stream_count_.fetch_sub(1), 0)
788       << "live stream count should not dip below zero";
789 }
790 
CreateStreamDependency(Stream * dependent,Stream * other)791 bool StreamExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
792   return implementation_->CreateStreamDependency(dependent, other);
793 }
794 
AllocateTimer(Timer * timer)795 bool StreamExecutor::AllocateTimer(Timer* timer) {
796   return implementation_->AllocateTimer(timer);
797 }
798 
DeallocateTimer(Timer * timer)799 void StreamExecutor::DeallocateTimer(Timer* timer) {
800   return implementation_->DeallocateTimer(timer);
801 }
802 
StartTimer(Stream * stream,Timer * timer)803 bool StreamExecutor::StartTimer(Stream* stream, Timer* timer) {
804   return implementation_->StartTimer(stream, timer);
805 }
806 
StopTimer(Stream * stream,Timer * timer)807 bool StreamExecutor::StopTimer(Stream* stream, Timer* timer) {
808   return implementation_->StopTimer(stream, timer);
809 }
810 
CreateDeviceDescription() const811 std::unique_ptr<DeviceDescription> StreamExecutor::CreateDeviceDescription()
812     const {
813   return implementation_->CreateDeviceDescription().value();
814 }
815 
DeviceMemoryUsage(int64_t * free,int64_t * total) const816 bool StreamExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const {
817   return implementation_->DeviceMemoryUsage(free, total);
818 }
819 
EnqueueOnBackgroundThread(std::function<void ()> task)820 void StreamExecutor::EnqueueOnBackgroundThread(std::function<void()> task) {
821   background_threads_->Schedule(std::move(task));
822 }
823 
CreateAllocRecord(void * opaque,uint64_t bytes)824 void StreamExecutor::CreateAllocRecord(void* opaque, uint64_t bytes) {
825   if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) {
826     absl::MutexLock lock(&mu_);
827     mem_allocs_[opaque] = AllocRecord{bytes, ""};
828     mem_alloc_bytes_ += bytes;
829   }
830 }
831 
EraseAllocRecord(void * opaque)832 void StreamExecutor::EraseAllocRecord(void* opaque) {
833   if (FLAGS_check_device_leaks && opaque != nullptr) {
834     absl::MutexLock lock(&mu_);
835     if (mem_allocs_.find(opaque) == mem_allocs_.end()) {
836       LOG(ERROR) << "Deallocating unknown pointer: " << opaque;
837     } else {
838       mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
839       mem_allocs_.erase(opaque);
840     }
841   }
842 }
843 
EnableTracing(bool enabled)844 void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; }
845 
RegisterTraceListener(TraceListener * listener)846 void StreamExecutor::RegisterTraceListener(TraceListener* listener) {
847   {
848     absl::MutexLock lock(&mu_);
849     if (listeners_.find(listener) != listeners_.end()) {
850       LOG(INFO) << "Attempt to register already-registered listener, "
851                 << listener;
852     } else {
853       listeners_.insert(listener);
854     }
855   }
856 
857   implementation_->RegisterTraceListener(listener);
858 }
859 
UnregisterTraceListener(TraceListener * listener)860 bool StreamExecutor::UnregisterTraceListener(TraceListener* listener) {
861   {
862     absl::MutexLock lock(&mu_);
863     if (listeners_.find(listener) == listeners_.end()) {
864       LOG(INFO) << "Attempt to unregister unknown listener, " << listener;
865       return false;
866     }
867     listeners_.erase(listener);
868   }
869 
870   implementation_->UnregisterTraceListener(listener);
871   return true;
872 }
873 
GetAllocatorStats()874 std::optional<AllocatorStats> StreamExecutor::GetAllocatorStats() {
875   return implementation_->GetAllocatorStats();
876 }
877 
ClearAllocatorStats()878 bool StreamExecutor::ClearAllocatorStats() {
879   return implementation_->ClearAllocatorStats();
880 }
881 
882 template <typename TraceCallT, typename... ArgsT>
SubmitTrace(TraceCallT trace_call,ArgsT &&...args)883 void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT&&... args) {
884   if (tracing_enabled_) {
885     {
886       // instance tracers held in a block to limit the lock lifetime.
887       absl::ReaderMutexLock lock(&mu_);
888       for (TraceListener* listener : listeners_) {
889         (listener->*trace_call)(std::forward<ArgsT>(args)...);
890       }
891     }
892   }
893 }
894 
implementation()895 internal::StreamExecutorInterface* StreamExecutor::implementation() {
896   return implementation_->GetUnderlyingExecutor();
897 }
898 
StreamExecutorMemoryAllocator(StreamExecutor * executor)899 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
900     StreamExecutor* executor)
901     : DeviceMemoryAllocator(executor->platform()) {
902   stream_executors_ = {executor};
903 }
904 
StreamExecutorMemoryAllocator(const Platform * platform,absl::Span<StreamExecutor * const> stream_executors)905 StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
906     const Platform* platform,
907     absl::Span<StreamExecutor* const> stream_executors)
908     : DeviceMemoryAllocator(platform),
909       stream_executors_(stream_executors.begin(), stream_executors.end()) {}
910 
Allocate(int device_ordinal,uint64_t size,bool retry_on_failure,int64_t memory_space)911 port::StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate(
912     int device_ordinal, uint64_t size, bool retry_on_failure,
913     int64_t memory_space) {
914   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
915                       GetStreamExecutor(device_ordinal));
916   DeviceMemoryBase result = executor->AllocateArray<uint8>(size, memory_space);
917   if (size > 0 && result == nullptr) {
918     return tensorflow::errors::ResourceExhausted(absl::StrFormat(
919         "Failed to allocate request for %s (%uB) on device ordinal %d",
920         tensorflow::strings::HumanReadableNumBytes(size), size,
921         device_ordinal));
922   }
923   VLOG(3) << absl::StreamFormat(
924       "Allocated %s (%uB) on device ordinal %d: %p",
925       tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal,
926       result.opaque());
927   return OwningDeviceMemory(result, device_ordinal, this);
928 }
929 
Deallocate(int device_ordinal,DeviceMemoryBase mem)930 port::Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal,
931                                                        DeviceMemoryBase mem) {
932   if (!mem.is_null()) {
933     TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
934                         GetStreamExecutor(device_ordinal));
935     VLOG(3) << absl::StreamFormat("Freeing %p on device ordinal %d",
936                                   mem.opaque(), device_ordinal);
937     executor->Deallocate(&mem);
938   }
939   return ::tensorflow::OkStatus();
940 }
941 
942 port::StatusOr<StreamExecutor*>
GetStreamExecutor(int device_ordinal) const943 StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) const {
944   if (device_ordinal < 0) {
945     return tensorflow::errors::InvalidArgument(absl::StrFormat(
946         "device ordinal value (%d) must be non-negative", device_ordinal));
947   }
948   for (StreamExecutor* se : stream_executors_) {
949     if (se->device_ordinal() == device_ordinal) {
950       return se;
951     }
952   }
953   return tensorflow::errors::NotFound(
954       absl::StrFormat("Device %s:%d present but not supported",
955                       platform()->Name(), device_ordinal));
956 }
957 
AllowsAsynchronousDeallocation() const958 bool StreamExecutorMemoryAllocator::AllowsAsynchronousDeallocation() const {
959   return false;
960 }
961 
GetStream(int device_ordinal)962 port::StatusOr<Stream*> StreamExecutorMemoryAllocator::GetStream(
963     int device_ordinal) {
964   CHECK(!AllowsAsynchronousDeallocation())
965       << "The logic below only works for synchronous allocators";
966   TF_ASSIGN_OR_RETURN(StreamExecutor * executor,
967                       GetStreamExecutor(device_ordinal));
968   Stream* out = [&] {
969     absl::MutexLock lock(&mutex_);
970     if (!streams_.count(device_ordinal)) {
971       auto p = streams_.emplace(std::piecewise_construct,
972                                 std::forward_as_tuple(device_ordinal),
973                                 std::forward_as_tuple(executor));
974       p.first->second.Init();
975       return &p.first->second;
976     }
977     return &streams_.at(device_ordinal);
978   }();
979   return out;
980 }
981 
982 }  // namespace stream_executor
983