1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <torch/csrc/distributed/c10d/Store.hpp> 5 6 namespace c10d { 7 namespace symmetric_memory { 8 9 // SymmetricMemory represents symmetric allocations across a group of devices. 10 // The allocations represented by a SymmetricMemory object are accessible by 11 // all devices in the group. The class can be used for op-level custom 12 // communication patterns (via the get_buffer APIs and the synchronization 13 // primitives), as well as custom communication kernels (via the buffer and 14 // signal_pad device pointers). 15 // 16 // To acquire a SymmetricMemory object, each rank first allocates 17 // identical-sized memory via SymmetricMemoryAllocator::alloc(), then invokes 18 // SymmetricMemoryAllocator::rendezvous() on the memory to establish the 19 // association across peer buffers. The rendezvous is a one-time process, and 20 // the mapping between a local memory memory and the associated SymmetricMemory 21 // object is unique. 22 // 23 // NOTE [symmetric memory signal pad] 24 // Signal pads are P2P-accessible memory regions designated for 25 // synchronization. SymmetricMemory offers built-in synchronization primitives 26 // such as barriers, put_signal, and wait_signal, which are all based on signal 27 // pads. Users may utilize signal pads for their own synchronization logic, 28 // provided that the signal pads remain zero-filled following successful 29 // synchronization. 30 // 31 // NOTE [symmetric memory synchronization channel] 32 // Synchronization channels allow users to use a single SymmetricMemory object 33 // to perform isolated synchronizations on different streams. For example, 34 // consider the case in which two barriers are issued on two streams for 35 // different purposes. Without the concept of channels, we cannot guarantee the 36 // correctness of the barriers since signals issued from barrier on stream A 37 // can be received by the barrier on stream B. By specifying different channels 38 // for these two barriers, they can operate correctly in parallel. 39 class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { 40 public: ~SymmetricMemory()41 virtual ~SymmetricMemory() {} 42 43 virtual std::vector<void*> get_buffer_ptrs() = 0; 44 virtual std::vector<void*> get_signal_pad_ptrs() = 0; 45 46 // get_buffer_ptrs_dev() and get_signal_pad_ptrs_dev() each return a pointer 47 // to a device array of size world_size, containing buffer pointers and 48 // signal pad pointers, respectively. 49 virtual void** get_buffer_ptrs_dev() = 0; 50 virtual void** get_signal_pad_ptrs_dev() = 0; 51 virtual size_t get_buffer_size() = 0; 52 virtual size_t get_signal_pad_size() = 0; 53 54 virtual bool has_multicast_support() = 0; 55 virtual void* get_multicast_ptr() = 0; 56 57 virtual at::Tensor get_buffer( 58 int rank, 59 c10::IntArrayRef sizes, 60 c10::ScalarType dtype, 61 int64_t storage_offset) = 0; 62 63 virtual void barrier(int channel) = 0; 64 virtual void put_signal(int dst_rank, int channel) = 0; 65 virtual void wait_signal(int src_rank, int channel) = 0; 66 67 virtual int get_rank() = 0; 68 virtual int get_world_size() = 0; 69 }; 70 71 class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { 72 public: ~SymmetricMemoryAllocator()73 virtual ~SymmetricMemoryAllocator(){}; 74 75 virtual void* alloc( 76 size_t size, 77 int device_idx, 78 const std::string& group_name) = 0; 79 80 virtual void free(void* ptr) = 0; 81 virtual size_t get_alloc_size(void* ptr) = 0; 82 virtual c10::intrusive_ptr<SymmetricMemory> rendezvous(void* ptr) = 0; 83 virtual bool is_rendezvous_completed(void* ptr) = 0; 84 virtual bool has_multicast_support(int device_idx) = 0; 85 }; 86 87 C10_EXPORT bool is_finalizing(); 88 89 C10_EXPORT void register_allocator( 90 c10::DeviceType device_type, 91 c10::intrusive_ptr<SymmetricMemoryAllocator> allocator); 92 93 C10_EXPORT c10::intrusive_ptr<SymmetricMemoryAllocator> get_allocator( 94 c10::DeviceType device_type); 95 96 // Set a store for rendezvousing symmetric allocations on a group of devices 97 // identified by `group_name`. The concept of groups is logical; users can 98 // utilize predefined groups (e.g., a group of device identified by a 99 // ProcessGroup) or create custom ones. Note that a SymmetricMemoryAllocator 100 // backends might employ a more efficient communication channel for the actual 101 // rendezvous process and only use the store for bootstrapping purposes. 102 TORCH_API void set_group_info( 103 const std::string& group_name, 104 int rank, 105 int world_size, 106 c10::intrusive_ptr<Store> store); 107 108 struct GroupInfo { 109 int rank; 110 int world_size; 111 c10::intrusive_ptr<c10d::Store> store; 112 }; 113 114 C10_EXPORT const GroupInfo& get_group_info(const std::string& group_name); 115 116 // Identical to empty_strided, but allows symmetric memory access to be 117 // established for the allocated tensor via SymmetricMemory::rendezvous(). This 118 // function itself is not a collective operation. It invokes 119 // SymmetricMemoryAllocator::alloc() for the requested device under the hood. 120 // 121 // NOTE [symmetric memory persistent allocation] 122 // If an `alloc_id` is supplied, empty_strided_p2p will perform persistent 123 // allocation. This makes the function cache allocated memory and ensure that 124 // invocations with the same `alloc_id` receive tensors backed by the same 125 // memory address. For safety, if a previous persistent allocation is still 126 // active (i.e., the storage of the returned tensor is still alive), persistent 127 // allocations with the same `alloc_id` will fail. This determinism coupled 128 // with memory planning of communication buffers (e.g., by Inductor) allows 129 // communication algorithms to reliably reuse previously established remote 130 // memory access. 131 TORCH_API at::Tensor empty_strided_p2p( 132 c10::IntArrayRef size, 133 c10::IntArrayRef stride, 134 c10::ScalarType dtype, 135 c10::Device device, 136 const std::string& group_name, 137 std::optional<uint64_t> alloc_id); 138 139 // Establishes symmetric memory access on tensors allocated via 140 // empty_strided_p2p() and empty_strided_p2p_persistent(). rendezvous() is a 141 // one-time process, and the mapping between a local memory region and the 142 // associated SymmetricMemory object is unique. Subsequent calls to 143 // rendezvous() with the same tensor, or tensors allocated with 144 // empty_strided_p2p_persistent() using the same alloc_id, will receive the 145 // cached SymmetricMemory object. 146 // 147 // The function has a collective semantic and must be invoked simultaneously 148 // from all rendezvous participants. 149 TORCH_API c10::intrusive_ptr<SymmetricMemory> rendezvous( 150 const at::Tensor& tensor); 151 152 // Returns the SymmetricMemory object associated with the tensor. It can only 153 // be invoked after rendezvous() but does not need to be invoked collectively. 154 TORCH_API c10::intrusive_ptr<SymmetricMemory> get_symmetric_memory( 155 const at::Tensor& tensor); 156 157 TORCH_API bool has_multicast_support( 158 c10::DeviceType device_type, 159 int device_idx); 160 } // namespace symmetric_memory 161 } // namespace c10d 162