xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/SymmetricMemory.hpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/distributed/c10d/Store.hpp>
5 
6 namespace c10d {
7 namespace symmetric_memory {
8 
9 // SymmetricMemory represents symmetric allocations across a group of devices.
10 // The allocations represented by a SymmetricMemory object are accessible by
11 // all devices in the group. The class can be used for op-level custom
12 // communication patterns (via the get_buffer APIs and the synchronization
13 // primitives), as well as custom communication kernels (via the buffer and
14 // signal_pad device pointers).
15 //
16 // To acquire a SymmetricMemory object, each rank first allocates
17 // identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes
18 // SymmetricMemoryAllocator::rendezvous() on the memory to establish the
19 // association across peer buffers. The rendezvous is a one-time process, and
20 // the mapping between a local memory memory and the associated SymmetricMemory
21 // object is unique.
22 //
23 // NOTE [symmetric memory signal pad]
24 // Signal pads are P2P-accessible memory regions designated for
25 // synchronization. SymmetricMemory offers built-in synchronization primitives
26 // such as barriers, put_signal, and wait_signal, which are all based on signal
27 // pads. Users may utilize signal pads for their own synchronization logic,
28 // provided that the signal pads remain zero-filled following successful
29 // synchronization.
30 //
31 // NOTE [symmetric memory synchronization channel]
32 // Synchronization channels allow users to use a single SymmetricMemory object
33 // to perform isolated synchronizations on different streams. For example,
34 // consider the case in which two barriers are issued on two streams for
35 // different purposes. Without the concept of channels, we cannot guarantee the
36 // correctness of the barriers since signals issued from barrier on stream A
37 // can be received by the barrier on stream B. By specifying different channels
38 // for these two barriers, they can operate correctly in parallel.
39 class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
40  public:
~SymmetricMemory()41   virtual ~SymmetricMemory() {}
42 
43   virtual std::vector<void*> get_buffer_ptrs() = 0;
44   virtual std::vector<void*> get_signal_pad_ptrs() = 0;
45 
46   // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer
47   // to a device array of size world_size, containing buffer pointers and
48   // signal pad pointers, respectively.
49   virtual void** get_buffer_ptrs_dev() = 0;
50   virtual void** get_signal_pad_ptrs_dev() = 0;
51   virtual size_t get_buffer_size() = 0;
52   virtual size_t get_signal_pad_size() = 0;
53 
54   virtual bool has_multicast_support() = 0;
55   virtual void* get_multicast_ptr() = 0;
56 
57   virtual at::Tensor get_buffer(
58       int rank,
59       c10::IntArrayRef sizes,
60       c10::ScalarType dtype,
61       int64_t storage_offset) = 0;
62 
63   virtual void barrier(int channel) = 0;
64   virtual void put_signal(int dst_rank, int channel) = 0;
65   virtual void wait_signal(int src_rank, int channel) = 0;
66 
67   virtual int get_rank() = 0;
68   virtual int get_world_size() = 0;
69 };
70 
71 class SymmetricMemoryAllocator : public c10::intrusive_ptr_target {
72  public:
~SymmetricMemoryAllocator()73   virtual ~SymmetricMemoryAllocator(){};
74 
75   virtual void* alloc(
76       size_t size,
77       int device_idx,
78       const std::string& group_name) = 0;
79 
80   virtual void free(void* ptr) = 0;
81   virtual size_t get_alloc_size(void* ptr) = 0;
82   virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0;
83   virtual bool is_rendezvous_completed(void* ptr) = 0;
84   virtual bool has_multicast_support(int device_idx) = 0;
85 };
86 
87 C10_EXPORT bool is_finalizing();
88 
89 C10_EXPORT void register_allocator(
90     c10::DeviceType device_type,
91     c10::intrusive_ptr<SymmetricMemoryAllocator> allocator);
92 
93 C10_EXPORT c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator(
94     c10::DeviceType device_type);
95 
96 // Set a store for rendezvousing symmetric allocations on a group of devices
97 // identified by `group_name`. The concept of groups is logical; users can
98 // utilize predefined groups (e.g., a group of device identified by a
99 // ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator
100 // backends might employ a more efficient communication channel for the actual
101 // rendezvous process and only use the store for bootstrapping purposes.
102 TORCH_API void set_group_info(
103     const std::string& group_name,
104     int rank,
105     int world_size,
106     c10::intrusive_ptr<Store> store);
107 
108 struct GroupInfo {
109   int rank;
110   int world_size;
111   c10::intrusive_ptr<c10d::Store> store;
112 };
113 
114 C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name);
115 
116 // Identical to empty_strided, but allows symmetric memory access to be
117 // established for the allocated tensor via SymmetricMemory::rendezvous(). This
118 // function itself is not a collective operation. It invokes
119 // SymmetricMemoryAllocator::alloc() for the requested device under the hood.
120 //
121 // NOTE [symmetric memory persistent allocation]
122 // If an `alloc_id` is supplied, empty_strided_p2p will perform persistent
123 // allocation. This makes the function cache allocated memory and ensure that
124 // invocations with the same `alloc_id` receive tensors backed by the same
125 // memory address. For safety, if a previous persistent allocation is still
126 // active (i.e., the storage of the returned tensor is still alive), persistent
127 // allocations with the same `alloc_id` will fail. This determinism coupled
128 // with memory planning of communication buffers (e.g., by Inductor) allows
129 // communication algorithms to reliably reuse previously established remote
130 // memory access.
131 TORCH_API at::Tensor empty_strided_p2p(
132     c10::IntArrayRef size,
133     c10::IntArrayRef stride,
134     c10::ScalarType dtype,
135     c10::Device device,
136     const std::string& group_name,
137     std::optional<uint64_t> alloc_id);
138 
139 // Establishes symmetric memory access on tensors allocated via
140 // empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a
141 // one-time process, and the mapping between a local memory region and the
142 // associated SymmetricMemory object is unique. Subsequent calls to
143 // rendezvous() with the same tensor, or tensors allocated with
144 // empty_strided_p2p_persistent() using the same alloc_id, will receive the
145 // cached SymmetricMemory object.
146 //
147 // The function has a collective semantic and must be invoked simultaneously
148 // from all rendezvous participants.
149 TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous(
150     const at::Tensor& tensor);
151 
152 // Returns the SymmetricMemory object associated with the tensor. It can only
153 // be invoked after rendezvous() but does not need to be invoked collectively.
154 TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory(
155     const at::Tensor& tensor);
156 
157 TORCH_API bool has_multicast_support(
158     c10::DeviceType device_type,
159     int device_idx);
160 } // namespace symmetric_memory
161 } // namespace c10d
162