xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/local_device_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/stream_executor/stream.h"
26 
27 namespace xla {
28 
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,int max_inflight_computations,bool allow_event_reuse,bool use_callback_stream)29 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
30                                    LocalClient* client,
31                                    AllocationModel allocation_model,
32                                    int max_inflight_computations,
33                                    bool allow_event_reuse,
34                                    bool use_callback_stream)
35     : allocation_model_(allocation_model),
36       event_pool_(allow_event_reuse),
37       compute_semaphore_(
38           /*capacity=*/max_inflight_computations),
39       executor_(executor),
40       client_(client),
41       prng_seed_generator_(prng_seed_device_()),
42       prng_seed_distribution_(std::numeric_limits<int>::min(),
43                               std::numeric_limits<int>::max()) {
44   compute_stream_ = std::make_unique<se::Stream>(executor);
45   host_to_device_stream_ = std::make_unique<se::Stream>(executor);
46   compute_stream_->Init();
47   host_to_device_stream_->Init();
48   if (use_callback_stream) {
49     callback_stream_map_ =
50         absl::flat_hash_map<se::Stream*, std::unique_ptr<se::Stream>>();
51   }
52   device_to_host_streams_.reserve(kNumDeviceToHostStreams);
53   for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
54     auto stream = std::make_unique<se::Stream>(executor);
55     stream->Init();
56     device_to_host_streams_.push_back(std::move(stream));
57   }
58   device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
59   for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
60     auto stream = std::make_unique<se::Stream>(executor);
61     stream->Init();
62     device_to_device_streams_.push_back(std::move(stream));
63   }
64   execute_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
65                                                    "py_xla_execute");
66   callback_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
67                                                     "py_xla_callback");
68 }
69 
~LocalDeviceState()70 LocalDeviceState::~LocalDeviceState() {
71   Status status = SynchronizeAllActivity();
72   if (!status.ok()) {
73     LOG(ERROR) << "Error when closing device: " << status;
74   }
75 }
76 
SynchronizeAllActivity()77 Status LocalDeviceState::SynchronizeAllActivity() {
78   Status status;
79   // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
80   // suffice. However on the Host platform SynchronizeAllActivity is a dummy
81   // implementation that doesn't actually block. To make sure activity has
82   // stopped, also block on the compute stream. If SynchronizeAllActivity is
83   // fixed, we could remove the BlockHostUntilDone call.
84   status.Update(compute_stream_->BlockHostUntilDone());
85   if (callback_stream_map_.has_value()) {
86     for (auto& callback_stream : callback_stream_map_.value()) {
87       status.Update(callback_stream.second->BlockHostUntilDone());
88     }
89   }
90   for (auto& stream : device_to_host_streams_) {
91     status.Update(stream->BlockHostUntilDone());
92   }
93   bool ok = compute_stream_->parent()->SynchronizeAllActivity();
94   if (!ok) {
95     status.Update(Unknown("SynchronizeAllActivity failed."));
96   }
97   return status;
98 }
99 
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)100 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
101     se::Stream* transfer_stream, se::Stream* dst_stream,
102     se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
103   // The default implementation simply calls ThenMemcpyD2D, and assumes that
104   // the buffer addresses identify the devices. This does not work
105   // on all platforms; this method is virtual so it can be overridden.
106   transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
107   return OkStatus();
108 }
109 
ThenExecuteCallback(se::Stream * stream,std::function<void ()> callback)110 void LocalDeviceState::ThenExecuteCallback(se::Stream* stream,
111                                            std::function<void()> callback) {
112   tensorflow::profiler::TraceMe traceme("ThenExecuteCallback");
113   if (callback_stream_map_.has_value()) {
114     // Prevent concurrent updates to the callback stream map.
115     absl::MutexLock lock(&mu_);
116     auto callback_stream = callback_stream_map_->find(stream);
117     if (callback_stream == callback_stream_map_->end()) {
118       auto new_stream = std::make_unique<se::Stream>(executor_);
119       new_stream->Init();
120       callback_stream =
121           callback_stream_map_->insert({stream, std::move(new_stream)}).first;
122     }
123     callback_stream->second->ThenWaitFor(stream);
124     stream = callback_stream->second.get();
125   }
126   stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable {
127     callback_thread_->Schedule(std::move(callback));
128   });
129 }
130 
GetDeviceToHostStream()131 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
132   absl::MutexLock lock(&mu_);
133   int i = next_device_to_host_stream_;
134   next_device_to_host_stream_ =
135       (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
136   return device_to_host_streams_.at(i).get();
137 }
138 
GetDeviceToDeviceStream()139 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
140   absl::MutexLock lock(&mu_);
141   int i = next_device_to_device_stream_;
142   next_device_to_device_stream_ =
143       (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
144   return device_to_device_streams_.at(i).get();
145 }
146 
BorrowStreamFromPool()147 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
148   absl::MutexLock lock(&mu_);
149   if (usage_stream_pool_.empty()) {
150     auto stream = std::make_unique<se::Stream>(compute_stream_->parent());
151     stream->Init();
152     return stream;
153   } else {
154     std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
155     usage_stream_pool_.pop();
156     auto status = stream->RefreshStatus();  // Can return error::Unimplemented
157     // Stream may fail with "ABORTED: Bad connection".
158     if (status.code() != tensorflow::error::ABORTED) {
159       CHECK(stream->ok()) << status;
160     }
161     return stream;
162   }
163 }
164 
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)165 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
166   auto status = stream->RefreshStatus();  // Can return error::Unimplemented
167   // Stream may fail with "ABORTED: Bad connection".
168   if (status.code() != tensorflow::error::ABORTED) {
169     CHECK(stream->ok()) << status;
170   }
171   absl::MutexLock lock(&mu_);
172   usage_stream_pool_.push(std::move(stream));
173 }
174 
GetNewPrngSeed()175 int LocalDeviceState::GetNewPrngSeed() {
176   absl::MutexLock lock(&mu_);
177   int x = 0;
178   do {
179     x = prng_seed_distribution_(prng_seed_generator_);
180   } while (x == 0);
181   return x;
182 }
183 
184 }  // namespace xla
185