xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17 
18 #include <complex>
19 #include <cstdarg>
20 #include <cstddef>
21 #include <cstring>
22 #include <functional>
23 #include <limits>
24 #include <optional>
25 #include <string>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/base/dynamic_annotations.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/synchronization/mutex.h"
35 #include "tensorflow/compiler/xla/executable_run_options.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/primitive_util.h"
38 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
39 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
40 #include "tensorflow/compiler/xla/service/computation_placer.h"
41 #include "tensorflow/compiler/xla/service/hlo_parser.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/statusor.h"
44 #include "tensorflow/core/platform/logging.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47 #include "tensorflow/stream_executor/device_memory.h"
48 #include "tensorflow/stream_executor/stream_executor.h"
49 
50 namespace se = ::stream_executor;
51 
52 namespace {
53 template <class T>
54 struct is_complex : std::false_type {};
55 template <class T>
56 struct is_complex<std::complex<T>> : std::true_type {};
57 }  // namespace
58 
59 namespace xla {
60 namespace cpu {
61 namespace runtime {
62 
GetXfeedManager(int device_ordinal)63 XfeedManager* GetXfeedManager(int device_ordinal) {
64   static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
65   static absl::Mutex* mutex = new absl::Mutex();
66 
67   absl::MutexLock lock(mutex);
68   auto it = managers->find(device_ordinal);
69   if (it == managers->end()) {
70     it = managers->emplace(device_ordinal, new XfeedManager()).first;
71   }
72   return it->second;
73 }
74 
75 extern const char* const kEigenMatMulF16SymbolName =
76     "__xla_cpu_runtime_EigenMatMulF16";
77 extern const char* const kEigenMatMulF32SymbolName =
78     "__xla_cpu_runtime_EigenMatMulF32";
79 extern const char* const kEigenMatMulF64SymbolName =
80     "__xla_cpu_runtime_EigenMatMulF64";
81 extern const char* const kEigenMatMulC64SymbolName =
82     "__xla_cpu_runtime_EigenMatMulC64";
83 extern const char* const kEigenMatMulC128SymbolName =
84     "__xla_cpu_runtime_EigenMatMulC128";
85 extern const char* const kEigenMatMulS32SymbolName =
86     "__xla_cpu_runtime_EigenMatMulS32";
87 extern const char* const kEigenBatchMatMulF32SymbolName =
88     "__xla_cpu_runtime_EigenBatchMatMulF32";
89 extern const char* const kMKLConv2DF32SymbolName =
90     "__xla_cpu_runtime_MKLConv2DF32";
91 extern const char* const kACLConv2DF32SymbolName =
92     "__xla_cpu_runtime_ACLConv2DF32";
93 extern const char* const kMKLMatMulF32SymbolName =
94     "__xla_cpu_runtime_MKLMatMulF32";
95 extern const char* const kMKLMatMulF64SymbolName =
96     "__xla_cpu_runtime_MKLMatMulF64";
97 extern const char* const kACLMatMulF32SymbolName =
98     "__xla_cpu_runtime_ACLMatMulF32";
99 extern const char* const kACLBatchMatMulF32SymbolName =
100     "__xla_cpu_runtime_ACLBatchMatMulF32";
101 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
102     "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
103 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
104     "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
105 extern const char* const kEigenConv2DF16SymbolName =
106     "__xla_cpu_runtime_EigenConv2DF16";
107 extern const char* const kEigenConv2DF32SymbolName =
108     "__xla_cpu_runtime_EigenConv2DF32";
109 extern const char* const kEigenConv3DF16SymbolName =
110     "__xla_cpu_runtime_EigenConv3DF16";
111 extern const char* const kEigenConv3DF32SymbolName =
112     "__xla_cpu_runtime_EigenConv3DF32";
113 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
114 extern const char* const kEigenSingleThreadedFftSymbolName =
115     "__xla_cpu_runtime_EigenSingleThreadedFft";
116 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
117     "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
118 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
119     "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
120 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
121     "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
122 extern const char* const kEigenSingleThreadedMatMulC64SymbolName =
123     "__xla_cpu_runtime_EigenSingleThreadedMatMulC64";
124 extern const char* const kEigenSingleThreadedMatMulC128SymbolName =
125     "__xla_cpu_runtime_EigenSingleThreadedMatMulC128";
126 extern const char* const kEigenSingleThreadedMatMulS32SymbolName =
127     "__xla_cpu_runtime_EigenSingleThreadedMatMulS32";
128 extern const char* const kEigenSingleThreadedConv2DF16SymbolName =
129     "__xla_cpu_runtime_EigenSingleThreadedConv2DF16";
130 extern const char* const kEigenSingleThreadedConv2DF32SymbolName =
131     "__xla_cpu_runtime_EigenSingleThreadedConv2DF32";
132 extern const char* const kEigenSingleThreadedConv3DF16SymbolName =
133     "__xla_cpu_runtime_EigenSingleThreadedConv3DF16";
134 extern const char* const kEigenSingleThreadedConv3DF32SymbolName =
135     "__xla_cpu_runtime_EigenSingleThreadedConv3DF32";
136 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
137     "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
138 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
139     "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
140 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
141     "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
142 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
143     "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
144 extern const char* const kParallelForkJoinSymbolName =
145     "__xla_cpu_runtime_ParallelForkJoin";
146 extern const char* const kPrintfToStderrSymbolName =
147     "__xla_cpu_runtime_PrintfToStderr";
148 extern const char* const kStatusIsSuccessSymbolName =
149     "__xla_cpu_runtime_StatusIsSuccess";
150 extern const char* const kKeyValueSortSymbolName =
151     "__xla_cpu_runtime_KeyValueSort";
152 extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32";
153 extern const char* const kTracingStartSymbolName =
154     "__xla_cpu_runtime_TracingStart";
155 extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
156 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
157 extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
158 extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
159 extern const char* const kCollectivePermuteSymbolName =
160     "__xla_cpu_runtime_CollectivePermute";
161 extern const char* const kPartitionIdSymbolName =
162     "__xla_cpu_runtime_PartitionId";
163 extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
164 
165 }  // namespace runtime
166 }  // namespace cpu
167 }  // namespace xla
168 
169 namespace {
170 
171 struct CollectivePermuteParticipantData : xla::ParticipantData {
CollectivePermuteParticipantData__anon8f39799f0211::CollectivePermuteParticipantData172   CollectivePermuteParticipantData(const xla::RendezvousKey& rendezvous_key_p,
173                                    int64_t device_ordinal_p,
174                                    se::Stream* stream_p)
175       : ParticipantData(rendezvous_key_p),
176         device_ordinal(device_ordinal_p),
177         stream(stream_p) {}
178 
179   int64_t device_ordinal;
180   se::Stream* stream;
181   int replica_id;
182   se::DeviceMemoryBase source_data;
183   se::DeviceMemoryBase destination_data;
184   int64_t byte_size;
185   std::vector<int> replica_ids_to_copy_to;
186 
ToString__anon8f39799f0211::CollectivePermuteParticipantData187   std::string ToString() const override {
188     return absl::StrFormat(
189         "CollectivePermuteParticipantData{replica_id=%d, "
190         "source_data=%p, destination_data=%p, byte_size=%d, "
191         "replica_ids_to_copy_to=[%s], device_ordinal=%d, stream=%p}",
192         replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
193         absl::StrJoin(replica_ids_to_copy_to, ", "), device_ordinal, stream);
194   }
195 };
196 
197 struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData__anon8f39799f0211::AllToAllParticipantData198   AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
199                           int64_t device_ordinal_p, se::Stream* stream_p)
200       : ParticipantData(rendezvous_key_p),
201         device_ordinal(device_ordinal_p),
202         stream(stream_p) {}
203 
204   int64_t device_ordinal;
205   se::Stream* stream;
206   std::vector<se::DeviceMemoryBase> source_buffers;
207   std::vector<se::DeviceMemoryBase> destination_buffers;
208   xla::GlobalDeviceId device_id;
209 
210   // Replica ids participating in AllToAll, concatenation happens in the order
211   // of appearance.
212   std::vector<xla::GlobalDeviceId> devices_to_copy_to;
213 
ToString__anon8f39799f0211::AllToAllParticipantData214   std::string ToString() const override {
215     auto addr_formatter = [](std::string* out,
216                              const se::DeviceMemoryBase& mem) {
217       absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
218     };
219     auto device_formatter = [](std::string* out,
220                                const xla::GlobalDeviceId& device) {
221       absl::StrAppend(out, device.value());
222     };
223     return absl::StrFormat(
224         "AllToAllParticipantData{replica_id=%d, "
225         "replica_ids_to_copy_to=[%s], source_buffers=[%s], "
226         "destination_buffers=[%s], device_ordinal=%d, stream=%p}",
227         device_id.value(),
228         absl::StrJoin(devices_to_copy_to, ", ", device_formatter),
229         absl::StrJoin(source_buffers, ", ", addr_formatter),
230         absl::StrJoin(destination_buffers, ", ", addr_formatter),
231         device_ordinal, stream);
232   }
233 };
234 
235 // Inverses the encoding of a Shape protobuf into an LLVM global variable.
DecodeSelfDescribingShapeConstant(const void * shape_ptr,int32_t size_bytes)236 xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
237     const void* shape_ptr, int32_t size_bytes) {
238   xla::ShapeProto shape_proto;
239   if (!shape_proto.ParseFromArray(shape_ptr, size_bytes)) {
240     return tensorflow::errors::Internal("Failed parsing the shape proto");
241   }
242   xla::Shape shape(shape_proto);
243   auto status = xla::ShapeUtil::ValidateShape(shape);
244   if (!status.ok()) {
245     return status;
246   }
247   return std::move(shape);
248 }
249 
ShapeString(const void * shape_ptr,int32_t shape_length)250 std::string ShapeString(const void* shape_ptr, int32_t shape_length) {
251   xla::StatusOr<xla::Shape> shape =
252       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
253   if (shape.ok()) {
254     return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
255   }
256   return "<invalid shape>";
257 }
258 
259 // TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal
260 // directly since callers may not have a Stream*.
GetDeviceOrdinal(const xla::ExecutableRunOptions * run_options)261 int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
262   if (!run_options) {
263     return 0;
264   } else if (run_options->device_ordinal() != -1) {
265     return run_options->device_ordinal();
266   }
267   return run_options->stream()->parent()->device_ordinal();
268 }
269 
270 }  // namespace
271 
272 extern "C" {
273 
__xla_cpu_runtime_PrintfToStderr(const char * format,...)274 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY int __xla_cpu_runtime_PrintfToStderr(
275     const char* format, ...) {
276   VLOG(3) << "__xla_cpu_runtime_PrintfToStderr " << format;
277   va_list args;
278   va_start(args, format);
279   int result = vfprintf(stderr, format, args);
280   va_end(args);
281   return result;
282 }
283 
__xla_cpu_runtime_TracingStart(const void * run_options_ptr,const char * name)284 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY int64_t __xla_cpu_runtime_TracingStart(
285     const void* /* xla::ExecutableRunOptions* */ run_options_ptr,
286     const char* name) {
287   VLOG(3) << "TracingStart " << name;
288   return tensorflow::profiler::TraceMe::ActivityStart(name);
289 }
290 
__xla_cpu_runtime_TracingEnd(const void * run_options_ptr,int64_t id)291 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TracingEnd(
292     const void* /* xla::ExecutableRunOptions* */ run_options_ptr, int64_t id) {
293   VLOG(3) << "TracingEnd " << id;
294   tensorflow::profiler::TraceMe::ActivityEnd(id);
295 }
296 
297 }  // extern "C"
298 
299 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,const void * shape,int32_t shape_length)300 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
301     const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
302     const void* shape, int32_t shape_length) {
303   int device_ordinal = GetDeviceOrdinal(run_options);
304 
305   VLOG(2) << "AcquireInfeedBufferForDequeue: "
306           << ShapeString(shape, shape_length) << " on stream executor "
307           << device_ordinal;
308 
309   xla::cpu::runtime::XfeedManager* xfeed =
310       xla::cpu::runtime::GetXfeedManager(device_ordinal);
311   // Wait until there's a buffer to dequeue.
312   xla::cpu::runtime::XfeedBuffer* buffer =
313       xfeed->infeed()->BlockingDequeueBuffer();
314   CHECK_EQ(buffer->length(), buffer_length)
315       << "XLA program infeed request buffer size " << buffer_length
316       << " did not match the runtime's infed buffer length " << buffer->length()
317       << "; program reports desired shape: "
318       << ShapeString(shape, shape_length);
319   return buffer->data();
320 }
321 
322 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,void * buffer_ptr,const void * shape_ptr,int32_t shape_length)323 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
324     const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
325     void* buffer_ptr, const void* shape_ptr, int32_t shape_length) {
326   int device_ordinal = GetDeviceOrdinal(run_options);
327 
328   VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
329           << ShapeString(shape_ptr, shape_length) << " on stream executor "
330           << device_ordinal;
331 
332   xla::cpu::runtime::XfeedManager* xfeed =
333       xla::cpu::runtime::GetXfeedManager(device_ordinal);
334   xla::StatusOr<xla::Shape> shape =
335       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
336   xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
337                                         std::move(shape));
338 }
339 
340 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,const void * shape_ptr,int32_t shape_length)341 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
342     const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
343     const void* shape_ptr, int32_t shape_length) {
344   int device_ordinal = GetDeviceOrdinal(run_options);
345 
346   VLOG(2) << "AcquireOutfeedBufferForPopulation: "
347           << ShapeString(shape_ptr, shape_length) << " on stream executor "
348           << device_ordinal;
349 
350   xla::cpu::runtime::XfeedManager* xfeed =
351       xla::cpu::runtime::GetXfeedManager(device_ordinal);
352   // Wait until there's a buffer to dequeue.
353   xla::cpu::runtime::XfeedBuffer* buffer =
354       xfeed->outfeed()->BlockingDequeueBuffer();
355   CHECK_EQ(buffer->length(), buffer_length)
356       << "XLA program outfeed request buffer size " << buffer_length
357       << " did not match the runtime's outfeed buffer length "
358       << buffer->length() << "; program reports outfed shape: "
359       << ShapeString(shape_ptr, shape_length);
360   return buffer->data();
361 }
362 
363 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,int32_t buffer_length,void * buffer_ptr,const void * shape_ptr,int32_t shape_length)364 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
365     const xla::ExecutableRunOptions* run_options, int32_t buffer_length,
366     void* buffer_ptr, const void* shape_ptr, int32_t shape_length) {
367   int device_ordinal = GetDeviceOrdinal(run_options);
368 
369   VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
370           << ShapeString(shape_ptr, shape_length) << " on stream executor "
371           << device_ordinal;
372 
373   xla::cpu::runtime::XfeedManager* xfeed =
374       xla::cpu::runtime::GetXfeedManager(device_ordinal);
375   xla::StatusOr<xla::Shape> shape =
376       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
377   xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
378                                          std::move(shape));
379 }
380 
381 namespace {
382 
383 class CpuAllToAllRendezvous
384     : public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
385  public:
CpuAllToAllRendezvous(const xla::RendezvousKey & k)386   explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
387       : xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
388 
389  protected:
RunCollectiveOp(const AllToAllParticipantData &)390   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
391       const AllToAllParticipantData& /*participant*/) override {
392     bool is_primary = InitializationBarrier();
393 
394     if (is_primary) {
395       absl::MutexLock lock(&mu_);
396 
397       CHECK(!participants_.empty());
398       CHECK(!participants_[0].source_buffers.empty());
399       int expected_buffer_size = participants_[0].source_buffers[0].size();
400 
401       // Device id -> position in participants_.
402       absl::flat_hash_map<xla::GlobalDeviceId, int> device_map;
403 
404       for (int pos = 0; pos < participants_.size(); pos++) {
405         const AllToAllParticipantData& p = participants_[pos];
406         CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
407         CHECK_EQ(p.source_buffers.size(), participants_.size());
408         for (int i = 0; i < p.source_buffers.size(); i++) {
409           CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
410           CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
411         }
412         device_map[p.device_id] = pos;
413       }
414 
415       const std::vector<xla::GlobalDeviceId>& devices_to_copy_to =
416           participants_[0].devices_to_copy_to;
417 
418       // Device id -> rank
419       absl::flat_hash_map<xla::GlobalDeviceId, int> device_ranks;
420       for (int rank = 0; rank < devices_to_copy_to.size(); ++rank) {
421         auto device_id = devices_to_copy_to[rank];
422         device_ranks[device_id] = rank;
423       }
424 
425       for (const AllToAllParticipantData& sender : participants_) {
426         VLOG(3) << "Processing AllToAll participant: " << sender.ToString();
427 
428         int rank = xla::FindOrDie(device_ranks, sender.device_id);
429 
430         for (int i = 0; i < participants_.size(); ++i) {
431           auto device_id = devices_to_copy_to[i];
432           int participant_num = xla::FindOrDie(device_map, device_id);
433           AllToAllParticipantData& receiver = participants_[participant_num];
434 
435           std::memcpy(receiver.destination_buffers[rank].opaque(),
436                       sender.source_buffers[i].opaque(), expected_buffer_size);
437         }
438       }
439     }
440     return nullptr;
441   }
442 };
443 
444 class CpuCollectivePermuteRendezvous
445     : public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
446  public:
CpuCollectivePermuteRendezvous(const xla::RendezvousKey & k)447   explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
448       : xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
449 
450  protected:
RunCollectiveOp(const CollectivePermuteParticipantData &)451   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
452       const CollectivePermuteParticipantData& /*participant*/) override {
453     bool primary = InitializationBarrier();
454 
455     // Perform all copies from the primary thread.
456     if (primary) {
457       absl::MutexLock lock(&mu_);
458 
459       std::map<int, int> replica_idx_to_participant_idx;
460       for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
461         replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
462       }
463       for (auto& p : participants_) {
464         for (int dest_replica : p.replica_ids_to_copy_to) {
465           auto& dest_p = participants_[xla::FindOrDie(
466               replica_idx_to_participant_idx, dest_replica)];
467           std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
468                       p.byte_size);
469 
470           // Each replica may be copied into only once.
471           replica_idx_to_participant_idx.erase(dest_replica);
472         }
473       }
474 
475       // Zero out untouched participants.
476       for (auto& replica_p : replica_idx_to_participant_idx) {
477         auto& p = participants_[replica_p.second];
478         std::memset(p.destination_data.opaque(), 0, p.byte_size);
479       }
480     }
481     return nullptr;
482   }
483 };
484 
485 class CpuAllReduceRendezvous
486     : public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
487  public:
CpuAllReduceRendezvous(const xla::RendezvousKey & k)488   explicit CpuAllReduceRendezvous(const xla::RendezvousKey& k)
489       : xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
490 
491  protected:
RunCollectiveOp(const xla::AllReduceParticipantData & participant)492   xla::StatusOr<std::nullptr_t> RunCollectiveOp(
493       const xla::AllReduceParticipantData& participant) override {
494     xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
495     bool primary = InitializationBarrier();
496 
497     if (primary) {
498       switch (datatype) {
499         case xla::S8:
500           DoAllReduce<xla::S8>(participant);
501           break;
502         case xla::PRED:
503         case xla::U8:
504           DoAllReduce<xla::U8>(participant);
505           break;
506         case xla::S32:
507           DoAllReduce<xla::S32>(participant);
508           break;
509         case xla::U32:
510           DoAllReduce<xla::U32>(participant);
511           break;
512         case xla::S64:
513           DoAllReduce<xla::S64>(participant);
514           break;
515         case xla::U64:
516           DoAllReduce<xla::U64>(participant);
517           break;
518         case xla::F16:
519           DoAllReduce<xla::F16>(participant);
520           break;
521         case xla::F32:
522           DoAllReduce<xla::F32>(participant);
523           break;
524         case xla::F64:
525           DoAllReduce<xla::F64>(participant);
526           break;
527         case xla::C64:
528           DoAllReduce<xla::C64>(participant);
529           break;
530         case xla::C128:
531           DoAllReduce<xla::C128>(participant);
532           break;
533         default:
534           LOG(FATAL) << "Unexpected datatype;";
535       }
536     }
537     return nullptr;
538   }
539 
540  private:
541   template <xla::PrimitiveType PT>
DoAllReduce(xla::AllReduceParticipantData participant)542   void DoAllReduce(xla::AllReduceParticipantData participant) {
543     using T = typename xla::primitive_util::PrimitiveTypeToNative<PT>::type;
544     absl::MutexLock lock(&mu_);
545     CHECK(!participants_.empty());
546     xla::ReductionKind reduction_kind = participant.reduction_kind;
547     for (const auto& p : participants_) {
548       CHECK(p.reduction_kind == reduction_kind);
549     }
550     int num_participants = participants_.size();
551 
552     // participant_idx -> buffer_idx -> buffer.
553     std::vector<std::vector<absl::Span<T>>> input_buffers;
554     std::vector<std::vector<absl::Span<T>>> output_buffers;
555     input_buffers.reserve(num_participants);
556     output_buffers.reserve(num_participants);
557     const xla::AllReduceParticipantData& first_participant =
558         participants_.front();
559 
560     int buffers_per_participant = first_participant.buffers.size();
561     for (xla::AllReduceParticipantData& p : participants_) {
562       CHECK_EQ(p.buffers.size(), buffers_per_participant);
563 
564       input_buffers.emplace_back();
565       output_buffers.emplace_back();
566       std::vector<absl::Span<T>>& participant_input_buffers =
567           input_buffers.back();
568       std::vector<absl::Span<T>>& participant_output_buffers =
569           output_buffers.back();
570       participant_input_buffers.reserve(p.buffers.size());
571       participant_output_buffers.reserve(p.buffers.size());
572 
573       for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
574            buffer_idx++) {
575         auto& participant_buffer = p.buffers[buffer_idx];
576         participant_input_buffers.emplace_back(
577             static_cast<T*>(participant_buffer.source_data.opaque()),
578             participant_buffer.element_count);
579         participant_output_buffers.emplace_back(
580             static_cast<T*>(participant_buffer.destination_data.opaque()),
581             participant_buffer.element_count);
582         CHECK_EQ(participant_buffer.element_count,
583                  first_participant.buffers[buffer_idx].element_count);
584       }
585     }
586 
587     for (int buffer_idx = 0; buffer_idx < buffers_per_participant;
588          buffer_idx++) {
589       int element_count = first_participant.buffers[buffer_idx].element_count;
590       for (int idx = 0; idx < element_count; idx++) {
591         T out = GetInitialValue<T>(reduction_kind);
592         for (int participant_idx = 0; participant_idx < participants_.size();
593              participant_idx++) {
594           out = PerformReductionStep<T>(
595               reduction_kind, out,
596               input_buffers[participant_idx][buffer_idx][idx]);
597         }
598         for (int participant_idx = 0; participant_idx < participants_.size();
599              participant_idx++) {
600           output_buffers[participant_idx][buffer_idx][idx] = out;
601         }
602       }
603     }
604   }
605 
606   template <typename T>
GetInitialValue(xla::ReductionKind reduction_kind)607   T GetInitialValue(xla::ReductionKind reduction_kind) {
608     switch (reduction_kind) {
609       case xla::ReductionKind::SUM:
610         return static_cast<T>(0);
611       case xla::ReductionKind::PRODUCT:
612         return static_cast<T>(1);
613       case xla::ReductionKind::MIN:
614         return std::numeric_limits<T>::max();
615       case xla::ReductionKind::MAX:
616         return std::numeric_limits<T>::min();
617     }
618   }
619 
620   template <typename T, bool kIsSignedIntegralType>
621   struct SumProductTypeForReductionStep {
622     using type = T;
623   };
624 
625   template <typename T>
626   struct SumProductTypeForReductionStep<T, /*kIsSignedIntegralType=*/true> {
627     using type = typename std::make_unsigned_t<T>;
628   };
629 
630   template <typename T,
631             typename std::enable_if<!is_complex<T>::value>::type* = nullptr>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)632   T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
633     using SumProductType = typename SumProductTypeForReductionStep<
634         T, std::is_integral<T>::value && std::is_signed<T>::value>::type;
635     switch (reduction_kind) {
636       case xla::ReductionKind::SUM:
637         return absl::bit_cast<T>(
638             static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) +
639                                         absl::bit_cast<SumProductType>(b)));
640       case xla::ReductionKind::PRODUCT:
641         return absl::bit_cast<T>(
642             static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) *
643                                         absl::bit_cast<SumProductType>(b)));
644       case xla::ReductionKind::MIN:
645         return std::min(a, b);
646       case xla::ReductionKind::MAX:
647         return std::max(a, b);
648     }
649   }
650 
651   template <typename T,
652             typename std::enable_if<is_complex<T>::value>::type* = nullptr>
PerformReductionStep(xla::ReductionKind reduction_kind,T a,T b)653   T PerformReductionStep(xla::ReductionKind reduction_kind, T a, T b) {
654     using SumProductType = typename SumProductTypeForReductionStep<
655         T, std::is_integral<T>::value && std::is_signed<T>::value>::type;
656     switch (reduction_kind) {
657       case xla::ReductionKind::SUM:
658         return absl::bit_cast<T>(
659             static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) +
660                                         absl::bit_cast<SumProductType>(b)));
661       case xla::ReductionKind::PRODUCT:
662         return absl::bit_cast<T>(
663             static_cast<SumProductType>(absl::bit_cast<SumProductType>(a) *
664                                         absl::bit_cast<SumProductType>(b)));
665       case xla::ReductionKind::MIN:
666       case xla::ReductionKind::MAX:
667         LOG(FATAL) << "min/max not valid for complex types";
668     }
669   }
670 };
671 
672 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalAllReduceRendezvousMap()673 GlobalAllReduceRendezvousMap() {
674   static auto& m =
675       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
676   return m;
677 }
678 
679 xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap()680 GlobalCollectivePermuteRendezvousMap() {
681   static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
682                                                 CpuCollectivePermuteRendezvous>;
683   return m;
684 }
685 
686 xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap()687 GlobalAllToAllRendezvousMap() {
688   static auto& m =
689       *new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
690   return m;
691 }
692 
GetRendezvousKey(const xla::ExecutableRunOptions * run_options,std::vector<xla::ReplicaGroup> group,int32_t channel_id_present,std::optional<bool> use_global_device_ids,int64_t op_id)693 xla::RendezvousKey GetRendezvousKey(
694     const xla::ExecutableRunOptions* run_options,
695     std::vector<xla::ReplicaGroup> group, int32_t channel_id_present,
696     std::optional<bool> use_global_device_ids, int64_t op_id) {
697   const xla::DeviceAssignment& device_assignment =
698       *run_options->device_assignment();
699   int device_ordinal = GetDeviceOrdinal(run_options);
700   xla::RendezvousKey::CollectiveOpKind op_kind =
701       channel_id_present ? xla::RendezvousKey::kCrossModule
702                          : xla::RendezvousKey::kCrossReplica;
703   std::vector<xla::GlobalDeviceId> participating_devices =
704       xla::GetParticipatingDevices(
705           xla::GlobalDeviceId(device_ordinal), device_assignment, group,
706           xla::GetCollectiveOpGroupMode(channel_id_present != 0,
707                                         use_global_device_ids)
708               .ValueOrDie())
709           .ValueOrDie();
710   int num_local_participants = participating_devices.size();
711   return xla::RendezvousKey{run_options->run_id(),
712                             std::move(participating_devices),
713                             num_local_participants, op_kind, op_id};
714 }
715 
716 }  // namespace
717 
__xla_cpu_runtime_AllToAll(const xla::ExecutableRunOptions * run_options,int32_t channel_id_present,int64_t op_id,const void * replica_groups_str,int32_t replica_groups_str_size,int32_t num_buffers,int64_t buffer_size,void ** source_buffers,void ** destination_buffers)718 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
719     const xla::ExecutableRunOptions* run_options, int32_t channel_id_present,
720     int64_t op_id, const void* replica_groups_str,
721     int32_t replica_groups_str_size, int32_t num_buffers, int64_t buffer_size,
722     void** source_buffers, void** destination_buffers) {
723   int device_ordinal = GetDeviceOrdinal(run_options);
724   absl::string_view replica_groups_serialized(
725       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
726   std::vector<xla::ReplicaGroup> group =
727       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
728   xla::RendezvousKey rendezvous_key =
729       GetRendezvousKey(run_options, group, channel_id_present,
730                        /*use_global_device_ids=*/std::nullopt, op_id);
731 
732   AllToAllParticipantData participant(rendezvous_key, device_ordinal,
733                                       run_options->stream());
734   participant.device_id = xla::GlobalDeviceId(device_ordinal);
735   participant.devices_to_copy_to =
736       xla::GetParticipatingDevices(
737           xla::GlobalDeviceId(device_ordinal),
738           *run_options->device_assignment(), group,
739           xla::GetCollectiveOpGroupMode(channel_id_present != 0,
740                                         /*use_global_device_ids=*/std::nullopt)
741               .ValueOrDie())
742           .ValueOrDie();
743   for (int i = 0; i < num_buffers; i++) {
744     participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
745     participant.destination_buffers.emplace_back(destination_buffers[i],
746                                                  buffer_size);
747   }
748   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
749     return std::make_unique<CpuAllToAllRendezvous>(k);
750   };
751   TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
752                   [&] {
753                     return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
754                         rendezvous_key, make_cpu_rendezvous);
755                   },
756                   participant)
757                   .status());
758 }
759 
__xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions * run_options,const void * replica_groups_str,int32_t replica_groups_str_size,int32_t channel_id_present,int32_t use_global_device_ids,int64_t op_id,int32_t reduction_kind,const void * shape_ptr,int32_t shape_length,int32_t num_buffers,void ** input_buffers,void ** output_buffers)760 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
761     const xla::ExecutableRunOptions* run_options,
762     const void* replica_groups_str, int32_t replica_groups_str_size,
763     int32_t channel_id_present, int32_t use_global_device_ids, int64_t op_id,
764     int32_t reduction_kind, const void* shape_ptr, int32_t shape_length,
765     int32_t num_buffers, void** input_buffers, void** output_buffers) {
766   int device_ordinal = GetDeviceOrdinal(run_options);
767   absl::string_view replica_groups_serialized(
768       static_cast<const char*>(replica_groups_str), replica_groups_str_size);
769   std::vector<xla::ReplicaGroup> group =
770       xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
771   xla::RendezvousKey rendezvous_key = GetRendezvousKey(
772       run_options, group, channel_id_present, use_global_device_ids, op_id);
773   auto shape_str = ShapeString(shape_ptr, shape_length);
774   VLOG(2) << "All-reduce input/output shape : " << shape_str;
775 
776   xla::Shape shape =
777       DecodeSelfDescribingShapeConstant(shape_ptr, shape_length).ValueOrDie();
778 
779   CHECK((num_buffers > 1 && shape.IsTuple()) ||
780         (num_buffers == 1 && xla::LayoutUtil::IsDenseArray(shape)));
781 
782   xla::AllReduceParticipantData participant(rendezvous_key, device_ordinal,
783                                             run_options->stream());
784   participant.reduction_kind = static_cast<xla::ReductionKind>(reduction_kind);
785   for (int i = 0; i < num_buffers; i++) {
786     xla::Shape subshape = num_buffers == 1 ? shape : shape.tuple_shapes(i);
787     xla::AllReduceParticipantData::Buffer buffer;
788     buffer.element_count = xla::ShapeUtil::ElementsIn(subshape);
789     buffer.primitive_type = subshape.element_type();
790     buffer.source_data = se::DeviceMemoryBase(
791         input_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
792     buffer.destination_data = se::DeviceMemoryBase(
793         output_buffers[i], xla::ShapeUtil::ByteSizeOf(subshape));
794     participant.buffers.push_back(buffer);
795   }
796 
797   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
798     return std::make_unique<CpuAllReduceRendezvous>(k);
799   };
800 
801   TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
802                   [&] {
803                     return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
804                         rendezvous_key, make_cpu_rendezvous);
805                   },
806                   participant)
807                   .status());
808 }
809 
__xla_cpu_runtime_ReplicaId(const xla::ExecutableRunOptions * run_options,void * output_buffer)810 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
811     const xla::ExecutableRunOptions* run_options, void* output_buffer) {
812   int device_ordinal = GetDeviceOrdinal(run_options);
813   int32_t replica_id =
814       run_options->device_assignment()
815           ->ReplicaIdForDevice(xla::GlobalDeviceId(device_ordinal))
816           .ValueOrDie();
817   std::memcpy(output_buffer, &replica_id, 4);
818 }
819 
__xla_cpu_runtime_PartitionId(const xla::ExecutableRunOptions * run_options,void * output_buffer)820 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_PartitionId(
821     const xla::ExecutableRunOptions* run_options, void* output_buffer) {
822   int device_ordinal = GetDeviceOrdinal(run_options);
823   const xla::DeviceAssignment::LogicalID logical_id =
824       run_options->device_assignment()
825           ->LogicalIdForDevice(xla::GlobalDeviceId(device_ordinal))
826           .ValueOrDie();
827   std::memcpy(output_buffer, &logical_id.computation_id, 4);
828 }
829 
__xla_cpu_runtime_CollectivePermute(const xla::ExecutableRunOptions * run_options,int32_t channel_id_present,int64_t op_id,int32_t byte_size,void * input_buffer,void * output_buffer,const void * source_target_pairs,int32_t source_target_pairs_size)830 ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
831     const xla::ExecutableRunOptions* run_options, int32_t channel_id_present,
832     int64_t op_id, int32_t byte_size, void* input_buffer, void* output_buffer,
833     const void* source_target_pairs, int32_t source_target_pairs_size) {
834   int device_ordinal = GetDeviceOrdinal(run_options);
835   absl::string_view source_target_pairs_serialized(
836       static_cast<const char*>(source_target_pairs), source_target_pairs_size);
837   auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
838   const xla::DeviceAssignment::LogicalID logical_id =
839       run_options->device_assignment()
840           ->LogicalIdForDevice(xla::GlobalDeviceId(device_ordinal))
841           .ValueOrDie();
842   int32_t logical_device_id =
843       channel_id_present ? logical_id.computation_id : logical_id.replica_id;
844 
845   std::vector<int> copy_to;
846   for (auto& p : pairs) {
847     std::vector<std::string> mapping = absl::StrSplit(p, '=');
848     CHECK_EQ(mapping.size(), 2);
849     int from = std::stoi(mapping[0]);
850     int to = std::stoi(mapping[1]);
851     if (from == logical_device_id) {
852       copy_to.push_back(to);
853     }
854   }
855   xla::RendezvousKey rendezvous_key =
856       GetRendezvousKey(run_options, {}, channel_id_present,
857                        /*use_global_device_ids=*/std::nullopt, op_id);
858 
859   CollectivePermuteParticipantData participant(rendezvous_key, device_ordinal,
860                                                run_options->stream());
861   participant.replica_id = logical_device_id;
862   participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
863   participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
864   participant.replica_ids_to_copy_to = copy_to;
865   participant.byte_size = byte_size;
866 
867   auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
868     return std::make_unique<CpuCollectivePermuteRendezvous>(k);
869   };
870   TF_CHECK_OK(
871       CpuCollectivePermuteRendezvous::SubmitParticipant(
872           [&] {
873             return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
874                 rendezvous_key, make_cpu_rendezvous);
875           },
876           participant)
877           .status());
878 }
879