1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/profiler/lib/traceme.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/stream_executor/stream.h"
26
27 namespace xla {
28
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,int max_inflight_computations,bool allow_event_reuse,bool use_callback_stream)29 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
30 LocalClient* client,
31 AllocationModel allocation_model,
32 int max_inflight_computations,
33 bool allow_event_reuse,
34 bool use_callback_stream)
35 : allocation_model_(allocation_model),
36 event_pool_(allow_event_reuse),
37 compute_semaphore_(
38 /*capacity=*/max_inflight_computations),
39 executor_(executor),
40 client_(client),
41 prng_seed_generator_(prng_seed_device_()),
42 prng_seed_distribution_(std::numeric_limits<int>::min(),
43 std::numeric_limits<int>::max()) {
44 compute_stream_ = std::make_unique<se::Stream>(executor);
45 host_to_device_stream_ = std::make_unique<se::Stream>(executor);
46 compute_stream_->Init();
47 host_to_device_stream_->Init();
48 if (use_callback_stream) {
49 callback_stream_map_ =
50 absl::flat_hash_map<se::Stream*, std::unique_ptr<se::Stream>>();
51 }
52 device_to_host_streams_.reserve(kNumDeviceToHostStreams);
53 for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
54 auto stream = std::make_unique<se::Stream>(executor);
55 stream->Init();
56 device_to_host_streams_.push_back(std::move(stream));
57 }
58 device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
59 for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
60 auto stream = std::make_unique<se::Stream>(executor);
61 stream->Init();
62 device_to_device_streams_.push_back(std::move(stream));
63 }
64 execute_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
65 "py_xla_execute");
66 callback_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
67 "py_xla_callback");
68 }
69
~LocalDeviceState()70 LocalDeviceState::~LocalDeviceState() {
71 Status status = SynchronizeAllActivity();
72 if (!status.ok()) {
73 LOG(ERROR) << "Error when closing device: " << status;
74 }
75 }
76
SynchronizeAllActivity()77 Status LocalDeviceState::SynchronizeAllActivity() {
78 Status status;
79 // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
80 // suffice. However on the Host platform SynchronizeAllActivity is a dummy
81 // implementation that doesn't actually block. To make sure activity has
82 // stopped, also block on the compute stream. If SynchronizeAllActivity is
83 // fixed, we could remove the BlockHostUntilDone call.
84 status.Update(compute_stream_->BlockHostUntilDone());
85 if (callback_stream_map_.has_value()) {
86 for (auto& callback_stream : callback_stream_map_.value()) {
87 status.Update(callback_stream.second->BlockHostUntilDone());
88 }
89 }
90 for (auto& stream : device_to_host_streams_) {
91 status.Update(stream->BlockHostUntilDone());
92 }
93 bool ok = compute_stream_->parent()->SynchronizeAllActivity();
94 if (!ok) {
95 status.Update(Unknown("SynchronizeAllActivity failed."));
96 }
97 return status;
98 }
99
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)100 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
101 se::Stream* transfer_stream, se::Stream* dst_stream,
102 se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
103 // The default implementation simply calls ThenMemcpyD2D, and assumes that
104 // the buffer addresses identify the devices. This does not work
105 // on all platforms; this method is virtual so it can be overridden.
106 transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
107 return OkStatus();
108 }
109
ThenExecuteCallback(se::Stream * stream,std::function<void ()> callback)110 void LocalDeviceState::ThenExecuteCallback(se::Stream* stream,
111 std::function<void()> callback) {
112 tensorflow::profiler::TraceMe traceme("ThenExecuteCallback");
113 if (callback_stream_map_.has_value()) {
114 // Prevent concurrent updates to the callback stream map.
115 absl::MutexLock lock(&mu_);
116 auto callback_stream = callback_stream_map_->find(stream);
117 if (callback_stream == callback_stream_map_->end()) {
118 auto new_stream = std::make_unique<se::Stream>(executor_);
119 new_stream->Init();
120 callback_stream =
121 callback_stream_map_->insert({stream, std::move(new_stream)}).first;
122 }
123 callback_stream->second->ThenWaitFor(stream);
124 stream = callback_stream->second.get();
125 }
126 stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable {
127 callback_thread_->Schedule(std::move(callback));
128 });
129 }
130
GetDeviceToHostStream()131 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
132 absl::MutexLock lock(&mu_);
133 int i = next_device_to_host_stream_;
134 next_device_to_host_stream_ =
135 (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
136 return device_to_host_streams_.at(i).get();
137 }
138
GetDeviceToDeviceStream()139 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
140 absl::MutexLock lock(&mu_);
141 int i = next_device_to_device_stream_;
142 next_device_to_device_stream_ =
143 (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
144 return device_to_device_streams_.at(i).get();
145 }
146
BorrowStreamFromPool()147 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
148 absl::MutexLock lock(&mu_);
149 if (usage_stream_pool_.empty()) {
150 auto stream = std::make_unique<se::Stream>(compute_stream_->parent());
151 stream->Init();
152 return stream;
153 } else {
154 std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
155 usage_stream_pool_.pop();
156 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
157 // Stream may fail with "ABORTED: Bad connection".
158 if (status.code() != tensorflow::error::ABORTED) {
159 CHECK(stream->ok()) << status;
160 }
161 return stream;
162 }
163 }
164
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)165 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
166 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
167 // Stream may fail with "ABORTED: Bad connection".
168 if (status.code() != tensorflow::error::ABORTED) {
169 CHECK(stream->ok()) << status;
170 }
171 absl::MutexLock lock(&mu_);
172 usage_stream_pool_.push(std::move(stream));
173 }
174
GetNewPrngSeed()175 int LocalDeviceState::GetNewPrngSeed() {
176 absl::MutexLock lock(&mu_);
177 int x = 0;
178 do {
179 x = prng_seed_distribution_(prng_seed_generator_);
180 } while (x == 0);
181 return x;
182 }
183
184 } // namespace xla
185