1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17
18 #include <complex>
19 #include <cstdarg>
20 #include <cstddef>
21 #include <cstring>
22 #include <functional>
23 #include <limits>
24 #include <optional>
25 #include <string>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/base/dynamic_annotations.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/synchronization/mutex.h"
35 #include "tensorflow/compiler/xla/executable_run_options.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
39 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
40 #include "tensorflow/compiler/xla/service/computation_placer.h"
41 #include "tensorflow/compiler/xla/service/hlo_parser.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47 #include "tensorflow/stream_executor/device_memory.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49
50 namespace se = ::stream_executor;
51
52 namespace {
53 template <class T>
54 struct is_complex : std::false_type {};
55 template <class T>
56 struct is_complex<std::complex<T>> : std::true_type {};
57 } // namespace
58
59 namespace xla {
60 namespace cpu {
61 namespace runtime {
62
GetXfeedManager(int device_ordinal)63 XfeedManager* GetXfeedManager(int device_ordinal) {
64 static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
65 static absl::Mutex* mutex = new absl::Mutex();
66
67 absl::MutexLock lock(mutex);
68 auto it = managers->find(device_ordinal);
69 if (it == managers->end()) {
70 it = managers->emplace(device_ordinal, new XfeedManager()).first;
71 }
72 return it->second;
73 }
74
75 extern const char* const kEigenMatMulF16SymbolName =
76 "__xla_cpu_runtime_EigenMatMulF16";
77 extern const char* const kEigenMatMulF32SymbolName =
78 "__xla_cpu_runtime_EigenMatMulF32";
79 extern const char* const kEigenMatMulF64SymbolName =
80 "__xla_cpu_runtime_EigenMatMulF64";
81 extern const char* const kEigenMatMulC64SymbolName =
82 "__xla_cpu_runtime_EigenMatMulC64";
83 extern const char* const kEigenMatMulC128SymbolName =
84 "__xla_cpu_runtime_EigenMatMulC128";
85 extern const char* const kEigenMatMulS32SymbolName =
86 "__xla_cpu_runtime_EigenMatMulS32";
87 extern const char* const kEigenBatchMatMulF32SymbolName =
88 "__xla_cpu_runtime_EigenBatchMatMulF32";
89 extern const char* const kMKLConv2DF32SymbolName =
90 "__xla_cpu_runtime_MKLConv2DF32";
91 extern const char* const kACLConv2DF32SymbolName =
92 "__xla_cpu_runtime_ACLConv2DF32";
93 extern const char* const kMKLMatMulF32SymbolName =
94 "__xla_cpu_runtime_MKLMatMulF32";
95 extern const char* const kMKLMatMulF64SymbolName =
96 "__xla_cpu_runtime_MKLMatMulF64";
97 extern const char* const kACLMatMulF32SymbolName =
98 "__xla_cpu_runtime_ACLMatMulF32";
99 extern const char* const kACLBatchMatMulF32SymbolName =
100 "__xla_cpu_runtime_ACLBatchMatMulF32";
101 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
102 "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
103 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
104 "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
105 extern const char* const kEigenConv2DF16SymbolName =
106 "__xla_cpu_runtime_EigenConv2DF16";
107 extern const char* const kEigenConv2DF32SymbolName =
108 "__xla_cpu_runtime_EigenConv2DF32";
109 extern const char* const kEigenConv3DF16SymbolName =
110 "__xla_cpu_runtime_EigenConv3DF16";
111 extern const char* const kEigenConv3DF32SymbolName =
112 "__xla_cpu_runtime_EigenConv3DF32";
113 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
114 extern const char* const kEigenSingleThreadedFftSymbolName =
115 "__xla_cpu_runtime_EigenSingleThreadedFft";
116 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
117 "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
118 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
119 "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
120 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
121 "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
122 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
123 "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
124 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
125 "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
126 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
127 "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
128 extern const char* const kEigenSingleThreadedConv2DF16SymbolName =
129 "__xla_cpu_runtime_EigenSingleThreadedConv2DF16";
130 extern const char* const kEigenSingleThreadedConv2DF32SymbolName =
131 "__xla_cpu_runtime_EigenSingleThreadedConv2DF32";
132 extern const char* const kEigenSingleThreadedConv3DF16SymbolName =
133 "__xla_cpu_runtime_EigenSingleThreadedConv3DF16";
134 extern const char* const kEigenSingleThreadedConv3DF32SymbolName =
135 "__xla_cpu_runtime_EigenSingleThreadedConv3DF32";
136 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
137 "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
138 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
139 "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
140 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
141 "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
142 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
143 "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
144 extern const char* const kParallelForkJoinSymbolName =
145 "__xla_cpu_runtime_ParallelForkJoin";
146 extern const char* const kPrintfToStderrSymbolName =
147 "__xla_cpu_runtime_PrintfToStderr";
148 extern const char* const kStatusIsSuccessSymbolName =
149 "__xla_cpu_runtime_StatusIsSuccess";
150 extern const char* const kKeyValueSortSymbolName =
151 "__xla_cpu_runtime_KeyValueSort";
152 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
153 extern const char* const kTracingStartSymbolName =
154 "__xla_cpu_runtime_TracingStart";
155 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
156 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
157 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
158 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
159 extern const char* const kCollectivePermuteSymbolName =
160 "__xla_cpu_runtime_CollectivePermute";
161 extern const char* const kPartitionIdSymbolName =
162 "__xla_cpu_runtime_PartitionId";
163 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
164
165 } // namespace runtime
166 } // namespace cpu
167 } // namespace xla
168
169 namespace {
170
171 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anon8f39799f0211::CollectivePermuteParticipantData172 CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
173 int64_t device_ordinal_p,
174 se::Stream* stream_p)
175 : ParticipantData(rendezvous_key_p),
176 device_ordinal(device_ordinal_p),
177 stream(stream_p) {}
178
179 int64_t device_ordinal;
180 se::Stream* stream;
181 int replica_id;
182 se::DeviceMemoryBase source_data;
183 se::DeviceMemoryBase destination_data;
184 int64_t byte_size;
185 std::vector<int> replica_ids_to_copy_to;
186
ToString__anon8f39799f0211::CollectivePermuteParticipantData187 std::string ToString() const override {
188 return absl::StrFormat(
189 "CollectivePermuteParticipantData{replica_id=%d, "
190 "source_data=%p, destination_data=%p, byte_size=%d, "
191 "replica_ids_to_copy_to=[%s], device_ordinal=%d, stream=%p}",
192 replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
193 absl::StrJoin(replica_ids_to_copy_to, ", "), device_ordinal, stream);
194 }
195 };
196
197 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anon8f39799f0211::AllToAllParticipantData198 AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
199 int64_t device_ordinal_p, se::Stream* stream_p)
200 : ParticipantData(rendezvous_key_p),
201 device_ordinal(device_ordinal_p),
202 stream(stream_p) {}
203
204 int64_t device_ordinal;
205 se::Stream* stream;
206 std::vector<se::DeviceMemoryBase> source_buffers;
207 std::vector<se::DeviceMemoryBase> destination_buffers;
208 xla::GlobalDeviceId device_id;
209
210 // Replica ids participating in AllToAll, concatenation happens in the order
211 // of appearance.
212 std::vector<xla::GlobalDeviceId> devices_to_copy_to;
213
ToString__anon8f39799f0211::AllToAllParticipantData214 std::string ToString() const override {
215 auto addr_formatter = [](std::string* out,
216 const se::DeviceMemoryBase& mem) {
217 absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
218 };
219 auto device_formatter = [](std::string* out,
220 const xla::GlobalDeviceId& device) {
221 absl::StrAppend(out, device.value());
222 };
223 return absl::StrFormat(
224 "AllToAllParticipantData{replica_id=%d, "
225 "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
226 "destination_buffers=[%s], device_ordinal=%d, stream=%p}",
227 device_id.value(),
228 absl::StrJoin(devices_to_copy_to, ", ", device_formatter),
229 absl::StrJoin(source_buffers, ", ", addr_formatter),
230 absl::StrJoin(destination_buffers, ", ", addr_formatter),
231 device_ordinal, stream);
232 }
233 };
234
235 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,int32_t size_bytes)236 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
237 const void* shape_ptr, int32_t size_bytes) {
238 xla::ShapeProto shape_proto;
239 if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
240 return tensorflow::errors::Internal("Failed parsing the shape proto");
241 }
242 xla::Shape shape(shape_proto);
243 auto status = xla::ShapeUtil::ValidateShape(shape);
244 if (!status.ok()) {
245 return status;
246 }
247 return std::move(shape);
248 }
249
ShapeString(const void * shape_ptr,int32_t shape_length)250 std::string ShapeString(const void* shape_ptr, int32_t shape_length) {
251 xla::StatusOr<xla::Shape> shape =
252 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
253 if (shape.ok()) {
254 return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
255 }
256 return "<invalid shape>";
257 }
258
259 // TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal
260 // directly since callers may not have a Stream*.
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)261 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
262 if (!run_options) {
263 return 0;
264 } else if (run_options->device_ordinal() != -1) {
265 return run_options->device_ordinal();
266 }
267 return run_options->stream()->parent()->device_ordinal();
268 }
269
270 } // namespace
271
272 extern "C" {
273
__xla_cpu_runtime_PrintfToStderr(const char * format,...)274 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY int __xla_cpu_runtime_PrintfToStderr(
275 const char* format, ...) {
276 VLOG(3) << "__xla_cpu_runtime_PrintfToStderr " << format;
277 va_list args;
278 va_start(args, format);
279 int result = vfprintf(stderr, format, args);
280 va_end(args);
281 return result;
282 }
283
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)284 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY int64_t __xla_cpu_runtime_TracingStart(
285 const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
286 const char* name) {
287 VLOG(3) << "TracingStart " << name;
288 return tensorflow::profiler::TraceMe::ActivityStart(name);
289 }
290
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,int64_t id)291 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
292 const void* /* xla::ExecutableRunOptions* */ run_options_ptr, int64_t id) {
293 VLOG(3) << "TracingEnd " << id;
294 tensorflow::profiler::TraceMe::ActivityEnd(id);
295 }
296
297 } // extern "C"
298
299 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,const void * shape,int32_t shape_length)300 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
301 const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
302 const void* shape, int32_t shape_length) {
303 int device_ordinal = GetDeviceOrdinal(run_options);
304
305 VLOG(2) << "AcquireInfeedBufferForDequeue: "
306 << ShapeString(shape, shape_length) << " on stream executor "
307 << device_ordinal;
308
309 xla::cpu::runtime::XfeedManager* xfeed =
310 xla::cpu::runtime::GetXfeedManager(device_ordinal);
311 // Wait until there's a buffer to dequeue.
312 xla::cpu::runtime::XfeedBuffer* buffer =
313 xfeed->infeed()->BlockingDequeueBuffer();
314 CHECK_EQ(buffer->length(), buffer_length)
315 << "XLA program infeed request buffer size " << buffer_length
316 << " did not match the runtime's infed buffer length " << buffer->length()
317 << "; program reports desired shape: "
318 << ShapeString(shape, shape_length);
319 return buffer->data();
320 }
321
322 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,void * buffer_ptr,const void * shape_ptr,int32_t shape_length)323 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
324 const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
325 void* buffer_ptr, const void* shape_ptr, int32_t shape_length) {
326 int device_ordinal = GetDeviceOrdinal(run_options);
327
328 VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
329 << ShapeString(shape_ptr, shape_length) << " on stream executor "
330 << device_ordinal;
331
332 xla::cpu::runtime::XfeedManager* xfeed =
333 xla::cpu::runtime::GetXfeedManager(device_ordinal);
334 xla::StatusOr<xla::Shape> shape =
335 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
336 xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
337 std::move(shape));
338 }
339
340 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,const void * shape_ptr,int32_t shape_length)341 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
342 const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
343 const void* shape_ptr, int32_t shape_length) {
344 int device_ordinal = GetDeviceOrdinal(run_options);
345
346 VLOG(2) << "AcquireOutfeedBufferForPopulation: "
347 << ShapeString(shape_ptr, shape_length) << " on stream executor "
348 << device_ordinal;
349
350 xla::cpu::runtime::XfeedManager* xfeed =
351 xla::cpu::runtime::GetXfeedManager(device_ordinal);
352 // Wait until there's a buffer to dequeue.
353 xla::cpu::runtime::XfeedBuffer* buffer =
354 xfeed->outfeed()->BlockingDequeueBuffer();
355 CHECK_EQ(buffer->length(), buffer_length)
356 << "XLA program outfeed request buffer size " << buffer_length
357 << " did not match the runtime's outfeed buffer length "
358 << buffer->length() << "; program reports outfed shape: "
359 << ShapeString(shape_ptr, shape_length);
360 return buffer->data();
361 }
362
363 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,void * buffer_ptr,const void * shape_ptr,int32_t shape_length)364 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
365 const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
366 void* buffer_ptr, const void* shape_ptr, int32_t shape_length) {
367 int device_ordinal = GetDeviceOrdinal(run_options);
368
369 VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
370 << ShapeString(shape_ptr, shape_length) << " on stream executor "
371 << device_ordinal;
372
373 xla::cpu::runtime::XfeedManager* xfeed =
374 xla::cpu::runtime::GetXfeedManager(device_ordinal);
375 xla::StatusOr<xla::Shape> shape =
376 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
377 xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
378 std::move(shape));
379 }
380
381 namespace {
382
383 class CpuAllToAllRendezvous
384 : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
385 public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)386 explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
387 : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
388
389 protected:
RunCollectiveOp(const AllToAllParticipantData &)390 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
391 const AllToAllParticipantData& /*participant*/) override {
392 bool is_primary = InitializationBarrier();
393
394 if (is_primary) {
395 absl::MutexLock lock(&mu_);
396
397 CHECK(!participants_.empty());
398 CHECK(!participants_[0].source_buffers.empty());
399 int expected_buffer_size = participants_[0].source_buffers[0].size();
400
401 // Device id -> position in participants_.
402 absl::flat_hash_map<xla::GlobalDeviceId, int> device_map;
403
404 for (int pos = 0; pos < participants_.size(); pos++) {
405 const AllToAllParticipantData& p = participants_[pos];
406 CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
407 CHECK_EQ(p.source_buffers.size(), participants_.size());
408 for (int i = 0; i < p.source_buffers.size(); i++) {
409 CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
410 CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
411 }
412 device_map[p.device_id] = pos;
413 }
414
415 const std::vector<xla::GlobalDeviceId>& devices_to_copy_to =
416 participants_[0].devices_to_copy_to;
417
418 // Device id -> rank
419 absl::flat_hash_map<xla::GlobalDeviceId, int> device_ranks;
420 for (int rank = 0; rank < devices_to_copy_to.size(); ++rank) {
421 auto device_id = devices_to_copy_to[rank];
422 device_ranks[device_id] = rank;
423 }
424
425 for (const AllToAllParticipantData& sender : participants_) {
426 VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
427
428 int rank = xla::FindOrDie(device_ranks, sender.device_id);
429
430 for (int i = 0; i < participants_.size(); ++i) {
431 auto device_id = devices_to_copy_to[i];
432 int participant_num = xla::FindOrDie(device_map, device_id);
433 AllToAllParticipantData& receiver = participants_[participant_num];
434
435 std::memcpy(receiver.destination_buffers[rank].opaque(),
436 sender.source_buffers[i].opaque(), expected_buffer_size);
437 }
438 }
439 }
440 return nullptr;
441 }
442 };
443
444 class CpuCollectivePermuteRendezvous
445 : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
446 public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)447 explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
448 : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
449
450 protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)451 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
452 const CollectivePermuteParticipantData& /*participant*/) override {
453 bool primary = InitializationBarrier();
454
455 // Perform all copies from the primary thread.
456 if (primary) {
457 absl::MutexLock lock(&mu_);
458
459 std::map<int, int> replica_idx_to_participant_idx;
460 for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
461 replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
462 }
463 for (auto& p : participants_) {
464 for (int dest_replica : p.replica_ids_to_copy_to) {
465 auto& dest_p = participants_[xla::FindOrDie(
466 replica_idx_to_participant_idx, dest_replica)];
467 std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
468 p.byte_size);
469
470 // Each replica may be copied into only once.
471 replica_idx_to_participant_idx.erase(dest_replica);
472 }
473 }
474
475 // Zero out untouched participants.
476 for (auto& replica_p : replica_idx_to_participant_idx) {
477 auto& p = participants_[replica_p.second];
478 std::memset(p.destination_data.opaque(), 0, p.byte_size);
479 }
480 }
481 return nullptr;
482 }
483 };
484
485 class CpuAllReduceRendezvous
486 : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
487 public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)488 explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
489 : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
490
491 protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)492 xla::StatusOr<std::nullptr_t> RunCollectiveOp(
493 const xla::AllReduceParticipantData& participant) override {
494 xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
495 bool primary = InitializationBarrier();
496
497 if (primary) {
498 switch (datatype) {
499 case xla::S8:
500 DoAllReduce<xla::S8>(participant);
501 break;
502 case xla::PRED:
503 case xla::U8:
504 DoAllReduce<xla::U8>(participant);
505 break;
506 case xla::S32:
507 DoAllReduce<xla::S32>(participant);
508 break;
509 case xla::U32:
510 DoAllReduce<xla::U32>(participant);
511 break;
512 case xla::S64:
513 DoAllReduce<xla::S64>(participant);
514 break;
515 case xla::U64:
516 DoAllReduce<xla::U64>(participant);
517 break;
518 case xla::F16:
519 DoAllReduce<xla::F16>(participant);
520 break;
521 case xla::F32:
522 DoAllReduce<xla::F32>(participant);
523 break;
524 case xla::F64:
525 DoAllReduce<xla::F64>(participant);
526 break;
527 case xla::C64:
528 DoAllReduce<xla::C64>(participant);
529 break;
530 case xla::C128:
531 DoAllReduce<xla::C128>(participant);
532 break;
533 default:
534 LOG(FATAL) << "Unexpected datatype;";
535 }
536 }
537 return nullptr;
538 }
539
540 private:
541 template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)542 void DoAllReduce(xla::AllReduceParticipantData participant) {
543 using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
544 absl::MutexLock lock(&mu_);
545 CHECK(!participants_.empty());
546 xla::ReductionKind reduction_kind = participant.reduction_kind;
547 for (const auto& p : participants_) {
548 CHECK(p.reduction_kind == reduction_kind);
549 }
550 int num_participants = participants_.size();
551
552 // participant_idx -> buffer_idx -> buffer.
553 std::vector<std::vector<absl::Span<T>>> input_buffers;
554 std::vector<std::vector<absl::Span<T>>> output_buffers;
555 input_buffers.reserve(num_participants);
556 output_buffers.reserve(num_participants);
557 const xla::AllReduceParticipantData& first_participant =
558 participants_.front();
559
560 int buffers_per_participant = first_participant.buffers.size();
561 for (xla::AllReduceParticipantData& p : participants_) {
562 CHECK_EQ(p.buffers.size(), buffers_per_participant);
563
564 input_buffers.emplace_back();
565 output_buffers.emplace_back();
566 std::vector<absl::Span<T>>& participant_input_buffers =
567 input_buffers.back();
568 std::vector<absl::Span<T>>& participant_output_buffers =
569 output_buffers.back();
570 participant_input_buffers.reserve(p.buffers.size());
571 participant_output_buffers.reserve(p.buffers.size());
572
573 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
574 buffer_idx++) {
575 auto& participant_buffer = p.buffers[buffer_idx];
576 participant_input_buffers.emplace_back(
577 static_cast<T*>(participant_buffer.source_data.opaque()),
578 participant_buffer.element_count);
579 participant_output_buffers.emplace_back(
580 static_cast<T*>(participant_buffer.destination_data.opaque()),
581 participant_buffer.element_count);
582 CHECK_EQ(participant_buffer.element_count,
583 first_participant.buffers[buffer_idx].element_count);
584 }
585 }
586
587 for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
588 buffer_idx++) {
589 int element_count = first_participant.buffers[buffer_idx].element_count;
590 for (int idx = 0; idx < element_count; idx++) {
591 T out = GetInitialValue<T>(reduction_kind);
592 for (int participant_idx = 0; participant_idx < participants_.size();
593 participant_idx++) {
594 out = PerformReductionStep<T>(
595 reduction_kind, out,
596 input_buffers[participant_idx][buffer_idx][idx]);
597 }
598 for (int participant_idx = 0; participant_idx < participants_.size();
599 participant_idx++) {
600 output_buffers[participant_idx][buffer_idx][idx] = out;
601 }
602 }
603 }
604 }
605
606 template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)607 T GetInitialValue(xla::ReductionKind reduction_kind) {
608 switch (reduction_kind) {
609 case xla::ReductionKind::SUM:
610 return static_cast<T>(0);
611 case xla::ReductionKind::PRODUCT:
612 return static_cast<T>(1);
613 case xla::ReductionKind::MIN:
614 return std::numeric_limits<T>::max();
615 case xla::ReductionKind::MAX:
616 return std::numeric_limits<T>::min();
617 }
618 }
619
620 template <typename T, bool kIsSignedIntegralType>
621 struct SumProductTypeForReductionStep {
622 using type = T;
623 };
624
625 template <typename T>
626 struct SumProductTypeForReductionStep<T, /*kIsSignedIntegralType=*/true> {
627 using type = typename std::make_unsigned_t<T>;
628 };
629
630 template <typename T,
631 typename std::enable_if<!is_complex<T>::value>::type* = nullptr>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)632 T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
633 using SumProductType = typename SumProductTypeForReductionStep<
634 T, std::is_integral<T>::value && std::is_signed<T>::value>::type;
635 switch (reduction_kind) {
636 case xla::ReductionKind::SUM:
637 return absl::bit_cast<T>(
638 static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) +
639 absl::bit_cast<SumProductType>(b)));
640 case xla::ReductionKind::PRODUCT:
641 return absl::bit_cast<T>(
642 static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) *
643 absl::bit_cast<SumProductType>(b)));
644 case xla::ReductionKind::MIN:
645 return std::min(a, b);
646 case xla::ReductionKind::MAX:
647 return std::max(a, b);
648 }
649 }
650
651 template <typename T,
652 typename std::enable_if<is_complex<T>::value>::type* = nullptr>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)653 T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
654 using SumProductType = typename SumProductTypeForReductionStep<
655 T, std::is_integral<T>::value && std::is_signed<T>::value>::type;
656 switch (reduction_kind) {
657 case xla::ReductionKind::SUM:
658 return absl::bit_cast<T>(
659 static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) +
660 absl::bit_cast<SumProductType>(b)));
661 case xla::ReductionKind::PRODUCT:
662 return absl::bit_cast<T>(
663 static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) *
664 absl::bit_cast<SumProductType>(b)));
665 case xla::ReductionKind::MIN:
666 case xla::ReductionKind::MAX:
667 LOG(FATAL) << "min/max not valid for complex types";
668 }
669 }
670 };
671
672 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()673 GlobalAllReduceRendezvousMap() {
674 static auto& m =
675 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
676 return m;
677 }
678
679 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()680 GlobalCollectivePermuteRendezvousMap() {
681 static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
682 CpuCollectivePermuteRendezvous>;
683 return m;
684 }
685
686 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()687 GlobalAllToAllRendezvousMap() {
688 static auto& m =
689 *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
690 return m;
691 }
692
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,int32_t channel_id_present,std::optional<bool> use_global_device_ids,int64_t op_id)693 xla::RendezvousKey GetRendezvousKey(
694 const xla::ExecutableRunOptions* run_options,
695 std::vector<xla::ReplicaGroup> group, int32_t channel_id_present,
696 std::optional<bool> use_global_device_ids, int64_t op_id) {
697 const xla::DeviceAssignment& device_assignment =
698 *run_options->device_assignment();
699 int device_ordinal = GetDeviceOrdinal(run_options);
700 xla::RendezvousKey::CollectiveOpKind op_kind =
701 channel_id_present ? xla::RendezvousKey::kCrossModule
702 : xla::RendezvousKey::kCrossReplica;
703 std::vector<xla::GlobalDeviceId> participating_devices =
704 xla::GetParticipatingDevices(
705 xla::GlobalDeviceId(device_ordinal), device_assignment, group,
706 xla::GetCollectiveOpGroupMode(channel_id_present != 0,
707 use_global_device_ids)
708 .ValueOrDie())
709 .ValueOrDie();
710 int num_local_participants = participating_devices.size();
711 return xla::RendezvousKey{run_options->run_id(),
712 std::move(participating_devices),
713 num_local_participants, op_kind, op_id};
714 }
715
716 } // namespace
717
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,int32_t channel_id_present,int64_t op_id,const void * replica_groups_str,int32_t replica_groups_str_size,int32_t num_buffers,int64_t buffer_size,void ** source_buffers,void ** destination_buffers)718 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
719 const xla::ExecutableRunOptions* run_options, int32_t channel_id_present,
720 int64_t op_id, const void* replica_groups_str,
721 int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size,
722 void** source_buffers, void** destination_buffers) {
723 int device_ordinal = GetDeviceOrdinal(run_options);
724 absl::string_view replica_groups_serialized(
725 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
726 std::vector<xla::ReplicaGroup> group =
727 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
728 xla::RendezvousKey rendezvous_key =
729 GetRendezvousKey(run_options, group, channel_id_present,
730 /*use_global_device_ids=*/std::nullopt, op_id);
731
732 AllToAllParticipantData participant(rendezvous_key, device_ordinal,
733 run_options->stream());
734 participant.device_id = xla::GlobalDeviceId(device_ordinal);
735 participant.devices_to_copy_to =
736 xla::GetParticipatingDevices(
737 xla::GlobalDeviceId(device_ordinal),
738 *run_options->device_assignment(), group,
739 xla::GetCollectiveOpGroupMode(channel_id_present != 0,
740 /*use_global_device_ids=*/std::nullopt)
741 .ValueOrDie())
742 .ValueOrDie();
743 for (int i = 0; i < num_buffers; i++) {
744 participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
745 participant.destination_buffers.emplace_back(destination_buffers[i],
746 buffer_size);
747 }
748 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
749 return std::make_unique<CpuAllToAllRendezvous>(k);
750 };
751 TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
752 [&] {
753 return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
754 rendezvous_key, make_cpu_rendezvous);
755 },
756 participant)
757 .status());
758 }
759
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,int32_t replica_groups_str_size,int32_t channel_id_present,int32_t use_global_device_ids,int64_t op_id,int32_t reduction_kind,const void * shape_ptr,int32_t shape_length,int32_t num_buffers,void ** input_buffers,void ** output_buffers)760 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
761 const xla::ExecutableRunOptions* run_options,
762 const void* replica_groups_str, int32_t replica_groups_str_size,
763 int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id,
764 int32_t reduction_kind, const void* shape_ptr, int32_t shape_length,
765 int32_t num_buffers, void** input_buffers, void** output_buffers) {
766 int device_ordinal = GetDeviceOrdinal(run_options);
767 absl::string_view replica_groups_serialized(
768 static_cast<const char*>(replica_groups_str), replica_groups_str_size);
769 std::vector<xla::ReplicaGroup> group =
770 xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
771 xla::RendezvousKey rendezvous_key = GetRendezvousKey(
772 run_options, group, channel_id_present, use_global_device_ids, op_id);
773 auto shape_str = ShapeString(shape_ptr, shape_length);
774 VLOG(2) << "All-reduce input/output shape : " << shape_str;
775
776 xla::Shape shape =
777 DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
778
779 CHECK((num_buffers > 1 && shape.IsTuple()) ||
780 (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
781
782 xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
783 run_options->stream());
784 participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
785 for (int i = 0; i < num_buffers; i++) {
786 xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
787 xla::AllReduceParticipantData::Buffer buffer;
788 buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
789 buffer.primitive_type = subshape.element_type();
790 buffer.source_data = se::DeviceMemoryBase(
791 input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
792 buffer.destination_data = se::DeviceMemoryBase(
793 output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
794 participant.buffers.push_back(buffer);
795 }
796
797 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
798 return std::make_unique<CpuAllReduceRendezvous>(k);
799 };
800
801 TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
802 [&] {
803 return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
804 rendezvous_key, make_cpu_rendezvous);
805 },
806 participant)
807 .status());
808 }
809
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)810 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
811 const xla::ExecutableRunOptions* run_options, void* output_buffer) {
812 int device_ordinal = GetDeviceOrdinal(run_options);
813 int32_t replica_id =
814 run_options->device_assignment()
815 ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
816 .ValueOrDie();
817 std::memcpy(output_buffer, &replica_id, 4);
818 }
819
__xla_cpu_runtime_PartitionId(const xla::ExecutableRunOptions * run_options,void * output_buffer)820 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_PartitionId(
821 const xla::ExecutableRunOptions* run_options, void* output_buffer) {
822 int device_ordinal = GetDeviceOrdinal(run_options);
823 const xla::DeviceAssignment::LogicalID logical_id =
824 run_options->device_assignment()
825 ->LogicalIdForDevice(xla::GlobalDeviceId(device_ordinal))
826 .ValueOrDie();
827 std::memcpy(output_buffer, &logical_id.computation_id, 4);
828 }
829
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,int32_t channel_id_present,int64_t op_id,int32_t byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,int32_t source_target_pairs_size)830 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
831 const xla::ExecutableRunOptions* run_options, int32_t channel_id_present,
832 int64_t op_id, int32_t byte_size, void* input_buffer, void* output_buffer,
833 const void* source_target_pairs, int32_t source_target_pairs_size) {
834 int device_ordinal = GetDeviceOrdinal(run_options);
835 absl::string_view source_target_pairs_serialized(
836 static_cast<const char*>(source_target_pairs), source_target_pairs_size);
837 auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
838 const xla::DeviceAssignment::LogicalID logical_id =
839 run_options->device_assignment()
840 ->LogicalIdForDevice(xla::GlobalDeviceId(device_ordinal))
841 .ValueOrDie();
842 int32_t logical_device_id =
843 channel_id_present ? logical_id.computation_id : logical_id.replica_id;
844
845 std::vector<int> copy_to;
846 for (auto& p : pairs) {
847 std::vector<std::string> mapping = absl::StrSplit(p, '=');
848 CHECK_EQ(mapping.size(), 2);
849 int from = std::stoi(mapping[0]);
850 int to = std::stoi(mapping[1]);
851 if (from == logical_device_id) {
852 copy_to.push_back(to);
853 }
854 }
855 xla::RendezvousKey rendezvous_key =
856 GetRendezvousKey(run_options, {}, channel_id_present,
857 /*use_global_device_ids=*/std::nullopt, op_id);
858
859 CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
860 run_options->stream());
861 participant.replica_id = logical_device_id;
862 participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
863 participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
864 participant.replica_ids_to_copy_to = copy_to;
865 participant.byte_size = byte_size;
866
867 auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
868 return std::make_unique<CpuCollectivePermuteRendezvous>(k);
869 };
870 TF_CHECK_OK(
871 CpuCollectivePermuteRendezvous::SubmitParticipant(
872 [&] {
873 return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
874 rendezvous_key, make_cpu_rendezvous);
875 },
876 participant)
877 .status());
878 }
879