1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport multiprocessing 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport threading 5*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.reduction import ForkingPickler 6*da0073e9SAndroid Build Coastguard Workerfrom multiprocessing.util import register_after_fork 7*da0073e9SAndroid Build Coastguard Workerfrom typing import Union 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch._namedtensor_internals import check_serializing_named_tensor 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workertry: 14*da0073e9SAndroid Build Coastguard Worker # Early load resource_sharer to prevent a partially initialized instance 15*da0073e9SAndroid Build Coastguard Worker # from being inherited in a forked child process. The reduce_storage method 16*da0073e9SAndroid Build Coastguard Worker # requires this module indirectly through DupFd(). The built-in mp.Queue 17*da0073e9SAndroid Build Coastguard Worker # class pickles arguments in a background thread which may overlap with the 18*da0073e9SAndroid Build Coastguard Worker # fork. 19*da0073e9SAndroid Build Coastguard Worker import multiprocessing.resource_sharer 20*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 21*da0073e9SAndroid Build Coastguard Worker pass 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass StorageWeakRef: 25*da0073e9SAndroid Build Coastguard Worker r"""A weak reference to a Storage. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker The cdata member is a Python number containing the integer representation of 28*da0073e9SAndroid Build Coastguard Worker the Storage pointer. 29*da0073e9SAndroid Build Coastguard Worker """ 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker __slots__ = ["cdata", "_free_weak_ref"] 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def __init__(self, storage): 34*da0073e9SAndroid Build Coastguard Worker self.cdata = storage._weak_ref() 35*da0073e9SAndroid Build Coastguard Worker # Save a direct reference to _free_weak_ref because the `torch` module 36*da0073e9SAndroid Build Coastguard Worker # might be cleared during Python shutdown before this module is cleared. 37*da0073e9SAndroid Build Coastguard Worker self._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker @classmethod 40*da0073e9SAndroid Build Coastguard Worker def from_weakref(cls, cdata): 41*da0073e9SAndroid Build Coastguard Worker instance = cls.__new__(cls) 42*da0073e9SAndroid Build Coastguard Worker instance.cdata = cdata 43*da0073e9SAndroid Build Coastguard Worker instance._free_weak_ref = torch.Storage._free_weak_ref # type: ignore[attr-defined] 44*da0073e9SAndroid Build Coastguard Worker return instance 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def expired(self): 47*da0073e9SAndroid Build Coastguard Worker return torch.Storage._expired(self.cdata) # type: ignore[attr-defined] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def __del__(self): 50*da0073e9SAndroid Build Coastguard Worker self._free_weak_ref(self.cdata) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 53*da0073e9SAndroid Build Coastguard Worker return self.cdata 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 56*da0073e9SAndroid Build Coastguard Worker if id(self) == id(other): 57*da0073e9SAndroid Build Coastguard Worker return True 58*da0073e9SAndroid Build Coastguard Worker return self.cdata == other.cdata 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerclass SharedCache(dict): 62*da0073e9SAndroid Build Coastguard Worker """Dictionary from multiprocessing handles to StorageWeakRef.""" 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 65*da0073e9SAndroid Build Coastguard Worker # free_dead_references() is called if the len exceeds the current 66*da0073e9SAndroid Build Coastguard Worker # limit. The limit scales with the number of remaining live objects. 67*da0073e9SAndroid Build Coastguard Worker self.limit = 128 68*da0073e9SAndroid Build Coastguard Worker # `fork` inherits lock state, so in case we fork when the lock is held, 69*da0073e9SAndroid Build Coastguard Worker # we register a function to reset the lock to a new object to avoid 70*da0073e9SAndroid Build Coastguard Worker # possible deadlocks, following python multiprocessing library design. 71*da0073e9SAndroid Build Coastguard Worker self._after_fork() 72*da0073e9SAndroid Build Coastguard Worker register_after_fork(self, SharedCache._after_fork) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def _after_fork(self): 75*da0073e9SAndroid Build Coastguard Worker self.lock = threading.Lock() 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def get(self, key): 78*da0073e9SAndroid Build Coastguard Worker with self.lock: 79*da0073e9SAndroid Build Coastguard Worker return dict.get(self, key) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, storage_ref): 82*da0073e9SAndroid Build Coastguard Worker with self.lock: 83*da0073e9SAndroid Build Coastguard Worker dict.__setitem__(self, key, storage_ref) 84*da0073e9SAndroid Build Coastguard Worker if len(self) > self.limit: 85*da0073e9SAndroid Build Coastguard Worker self.free_dead_references() 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def free_dead_references(self): 88*da0073e9SAndroid Build Coastguard Worker live = 0 89*da0073e9SAndroid Build Coastguard Worker for key, storage_ref in list(self.items()): 90*da0073e9SAndroid Build Coastguard Worker if storage_ref.expired(): 91*da0073e9SAndroid Build Coastguard Worker del self[key] 92*da0073e9SAndroid Build Coastguard Worker else: 93*da0073e9SAndroid Build Coastguard Worker live += 1 94*da0073e9SAndroid Build Coastguard Worker self.limit = max(128, live * 2) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker# mapping from handles to StorageWeakRef objects 98*da0073e9SAndroid Build Coastguard Workershared_cache = SharedCache() 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerdef rebuild_event(device, handle): 102*da0073e9SAndroid Build Coastguard Worker return torch.cuda.Event.from_ipc_handle(device, handle) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Workerdef reduce_event(event): 106*da0073e9SAndroid Build Coastguard Worker handle = event.ipc_handle() 107*da0073e9SAndroid Build Coastguard Worker return (rebuild_event, (event.device, handle)) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerdef rebuild_tensor(cls, storage, metadata): 111*da0073e9SAndroid Build Coastguard Worker storage_offset, size, stride, requires_grad = metadata 112*da0073e9SAndroid Build Coastguard Worker t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 113*da0073e9SAndroid Build Coastguard Worker if cls == torch.nn.parameter.Parameter: 114*da0073e9SAndroid Build Coastguard Worker # we have to pass requires_grad into constructor, rather than set it as an 115*da0073e9SAndroid Build Coastguard Worker # attribute later, because it's an important check for Integer Tensors to 116*da0073e9SAndroid Build Coastguard Worker # have requires_grad=False (or else they raise an error) 117*da0073e9SAndroid Build Coastguard Worker t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) 118*da0073e9SAndroid Build Coastguard Worker else: 119*da0073e9SAndroid Build Coastguard Worker t.requires_grad = requires_grad 120*da0073e9SAndroid Build Coastguard Worker return t 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Workerdef rebuild_meta_tensor( 124*da0073e9SAndroid Build Coastguard Worker tensor_cls, 125*da0073e9SAndroid Build Coastguard Worker tensor_size, 126*da0073e9SAndroid Build Coastguard Worker tensor_stride, 127*da0073e9SAndroid Build Coastguard Worker tensor_offset, 128*da0073e9SAndroid Build Coastguard Worker dtype, 129*da0073e9SAndroid Build Coastguard Worker storage_size_bytes, 130*da0073e9SAndroid Build Coastguard Worker requires_grad, 131*da0073e9SAndroid Build Coastguard Worker): 132*da0073e9SAndroid Build Coastguard Worker untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta") 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker typed_storage = torch.TypedStorage( 135*da0073e9SAndroid Build Coastguard Worker wrap_storage=untyped_storage, dtype=dtype, _internal=True 136*da0073e9SAndroid Build Coastguard Worker ) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker t = torch._utils._rebuild_tensor( 139*da0073e9SAndroid Build Coastguard Worker typed_storage, 140*da0073e9SAndroid Build Coastguard Worker tensor_offset, 141*da0073e9SAndroid Build Coastguard Worker tensor_size, 142*da0073e9SAndroid Build Coastguard Worker tensor_stride, 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker if tensor_cls == torch.nn.parameter.Parameter: 146*da0073e9SAndroid Build Coastguard Worker # It is crucial for integer tensors to receive 147*da0073e9SAndroid Build Coastguard Worker # the requires_grad=False as an argument in the constructor 148*da0073e9SAndroid Build Coastguard Worker t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) 149*da0073e9SAndroid Build Coastguard Worker else: 150*da0073e9SAndroid Build Coastguard Worker t.requires_grad = requires_grad 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker return t 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef rebuild_cuda_tensor( 156*da0073e9SAndroid Build Coastguard Worker tensor_cls, 157*da0073e9SAndroid Build Coastguard Worker tensor_size, 158*da0073e9SAndroid Build Coastguard Worker tensor_stride, 159*da0073e9SAndroid Build Coastguard Worker tensor_offset, 160*da0073e9SAndroid Build Coastguard Worker storage_cls, 161*da0073e9SAndroid Build Coastguard Worker dtype, 162*da0073e9SAndroid Build Coastguard Worker storage_device, 163*da0073e9SAndroid Build Coastguard Worker storage_handle, 164*da0073e9SAndroid Build Coastguard Worker storage_size_bytes, 165*da0073e9SAndroid Build Coastguard Worker storage_offset_bytes, 166*da0073e9SAndroid Build Coastguard Worker requires_grad, 167*da0073e9SAndroid Build Coastguard Worker ref_counter_handle, 168*da0073e9SAndroid Build Coastguard Worker ref_counter_offset, 169*da0073e9SAndroid Build Coastguard Worker event_handle, 170*da0073e9SAndroid Build Coastguard Worker event_sync_required, 171*da0073e9SAndroid Build Coastguard Worker): 172*da0073e9SAndroid Build Coastguard Worker # If storage_handle is None, storage points to nullptr. 173*da0073e9SAndroid Build Coastguard Worker if storage_handle is None or storage_size_bytes == 0: 174*da0073e9SAndroid Build Coastguard Worker storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) 175*da0073e9SAndroid Build Coastguard Worker else: 176*da0073e9SAndroid Build Coastguard Worker storage = storage_from_cache( 177*da0073e9SAndroid Build Coastguard Worker storage_cls, (storage_handle, storage_offset_bytes) 178*da0073e9SAndroid Build Coastguard Worker ) 179*da0073e9SAndroid Build Coastguard Worker if storage is None: 180*da0073e9SAndroid Build Coastguard Worker torch.cuda._lazy_init() 181*da0073e9SAndroid Build Coastguard Worker storage = storage_cls._new_shared_cuda( 182*da0073e9SAndroid Build Coastguard Worker storage_device, 183*da0073e9SAndroid Build Coastguard Worker storage_handle, 184*da0073e9SAndroid Build Coastguard Worker storage_size_bytes, 185*da0073e9SAndroid Build Coastguard Worker storage_offset_bytes, 186*da0073e9SAndroid Build Coastguard Worker ref_counter_handle, 187*da0073e9SAndroid Build Coastguard Worker ref_counter_offset, 188*da0073e9SAndroid Build Coastguard Worker event_handle, 189*da0073e9SAndroid Build Coastguard Worker event_sync_required, 190*da0073e9SAndroid Build Coastguard Worker ) 191*da0073e9SAndroid Build Coastguard Worker shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef( 192*da0073e9SAndroid Build Coastguard Worker storage 193*da0073e9SAndroid Build Coastguard Worker ) 194*da0073e9SAndroid Build Coastguard Worker else: 195*da0073e9SAndroid Build Coastguard Worker # We already ref counting this Storage, but producer needs new ref-counters to be released. 196*da0073e9SAndroid Build Coastguard Worker storage_cls._release_ipc_counter( 197*da0073e9SAndroid Build Coastguard Worker ref_counter_handle, ref_counter_offset, device=storage_device 198*da0073e9SAndroid Build Coastguard Worker ) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker _storage = ( 201*da0073e9SAndroid Build Coastguard Worker storage 202*da0073e9SAndroid Build Coastguard Worker if isinstance(storage, torch.UntypedStorage) 203*da0073e9SAndroid Build Coastguard Worker else storage._untyped_storage 204*da0073e9SAndroid Build Coastguard Worker ) 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker t = torch._utils._rebuild_tensor( 207*da0073e9SAndroid Build Coastguard Worker torch.storage.TypedStorage(wrap_storage=_storage, dtype=dtype, _internal=True), 208*da0073e9SAndroid Build Coastguard Worker tensor_offset, 209*da0073e9SAndroid Build Coastguard Worker tensor_size, 210*da0073e9SAndroid Build Coastguard Worker tensor_stride, 211*da0073e9SAndroid Build Coastguard Worker ) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker if tensor_cls == torch.nn.parameter.Parameter: 214*da0073e9SAndroid Build Coastguard Worker # It is crucial for integer tensors to receive 215*da0073e9SAndroid Build Coastguard Worker # the requires_grad=False as an argument in the constructor 216*da0073e9SAndroid Build Coastguard Worker t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) 217*da0073e9SAndroid Build Coastguard Worker else: 218*da0073e9SAndroid Build Coastguard Worker t.requires_grad = requires_grad 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker return t 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Workerdef reduce_tensor(tensor): 224*da0073e9SAndroid Build Coastguard Worker if tensor.requires_grad and not tensor.is_leaf: 225*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 226*da0073e9SAndroid Build Coastguard Worker "Cowardly refusing to serialize non-leaf tensor which requires_grad, " 227*da0073e9SAndroid Build Coastguard Worker "since autograd does not support crossing process boundaries. " 228*da0073e9SAndroid Build Coastguard Worker "If you just want to transfer the data, call detach() on the tensor " 229*da0073e9SAndroid Build Coastguard Worker "before serializing (e.g., putting it on the queue)." 230*da0073e9SAndroid Build Coastguard Worker ) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker check_serializing_named_tensor(tensor) 233*da0073e9SAndroid Build Coastguard Worker torch.utils.hooks.warn_if_has_hooks(tensor) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker # Note [CUDA IPC and the caching allocator] 236*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 237*da0073e9SAndroid Build Coastguard Worker # When you send a CUDA tensor over IPC, you might expect that you will 238*da0073e9SAndroid Build Coastguard Worker # get out the same storage from the other end. However, the CUDA caching 239*da0073e9SAndroid Build Coastguard Worker # allocator makes it difficult to preserve this invariant. Consider 240*da0073e9SAndroid Build Coastguard Worker # the following situation: a tensor of size 0x100 points to offset 0x20 of 241*da0073e9SAndroid Build Coastguard Worker # a storage at 0xA100 of size 0x100. (For simplicity, all of these 242*da0073e9SAndroid Build Coastguard Worker # sizes are given in bytes). HOWEVER, with the caching allocator, this storage 243*da0073e9SAndroid Build Coastguard Worker # might be part of a larger cudaMalloc allocation 0xA000 of size 0x4000. 244*da0073e9SAndroid Build Coastguard Worker # 245*da0073e9SAndroid Build Coastguard Worker # When we want to send this CUDA tensor over IPC, we must send the 246*da0073e9SAndroid Build Coastguard Worker # *entire* cudaMalloc allocation, i.e., the 0xA000 region, not just 247*da0073e9SAndroid Build Coastguard Worker # the storage 0xA100 (because that is what CUDA supports). So, on the 248*da0073e9SAndroid Build Coastguard Worker # other end, there simply isn't any way to say, "Wait, you gave me 249*da0073e9SAndroid Build Coastguard Worker # a bigger region (0xA000) than the one I wanted (0xA100)". 250*da0073e9SAndroid Build Coastguard Worker # 251*da0073e9SAndroid Build Coastguard Worker # OK, so if you sent the cudaMalloc allocation, can you just wrap that up as 252*da0073e9SAndroid Build Coastguard Worker # one storage itself? No, because this cudaMalloc allocation might contain 253*da0073e9SAndroid Build Coastguard Worker # storages of mixed types: float, bytes, double... If you make the entire 254*da0073e9SAndroid Build Coastguard Worker # allocation a single storage of a type A, we'll hit an error when constructing 255*da0073e9SAndroid Build Coastguard Worker # a tensor of type B on the storage. 256*da0073e9SAndroid Build Coastguard Worker # 257*da0073e9SAndroid Build Coastguard Worker # cudaIpcMemHandle is an identifier to access the sender cudaMalloc allocation on the 258*da0073e9SAndroid Build Coastguard Worker # receiver side. However, cudaIpcMemHandles from each device in a given process may 259*da0073e9SAndroid Build Coastguard Worker # only be opened by one context per device per other process. 260*da0073e9SAndroid Build Coastguard Worker # If we open and close a memory handle multiples times in a process, CUDA is allowed 261*da0073e9SAndroid Build Coastguard Worker # to give it a different address; similarly, once we close the memory, we're not 262*da0073e9SAndroid Build Coastguard Worker # allowed to access it(and the storage/tensor built on top of it), even if it is 263*da0073e9SAndroid Build Coastguard Worker # still live in the original process. As we cannot make a cudaMalloc allocation 264*da0073e9SAndroid Build Coastguard Worker # to a single storage in one go, this requires us to cache the device pointer for 265*da0073e9SAndroid Build Coastguard Worker # each cudaIpcMemHandle on C++ side to reconstruct types of storages, while keep 266*da0073e9SAndroid Build Coastguard Worker # the old ones alives. 267*da0073e9SAndroid Build Coastguard Worker # See [https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html] 268*da0073e9SAndroid Build Coastguard Worker # 269*da0073e9SAndroid Build Coastguard Worker # This is fine, because all we need to do is to save our position in the allocation, 270*da0073e9SAndroid Build Coastguard Worker # and reconstruct storage and tensor from it. 271*da0073e9SAndroid Build Coastguard Worker # 0xA000 -> -------CUDA Allocation------ 272*da0073e9SAndroid Build Coastguard Worker # | | 273*da0073e9SAndroid Build Coastguard Worker # | | 274*da0073e9SAndroid Build Coastguard Worker # | | 275*da0073e9SAndroid Build Coastguard Worker # | | 276*da0073e9SAndroid Build Coastguard Worker # 0xA100 -> --------storage1 begin------ 277*da0073e9SAndroid Build Coastguard Worker # | | 278*da0073e9SAndroid Build Coastguard Worker # 0xA120 -> --------tensor1 begin ------ 279*da0073e9SAndroid Build Coastguard Worker # | | 280*da0073e9SAndroid Build Coastguard Worker # | | 281*da0073e9SAndroid Build Coastguard Worker # | | 282*da0073e9SAndroid Build Coastguard Worker # | | 283*da0073e9SAndroid Build Coastguard Worker # | | 284*da0073e9SAndroid Build Coastguard Worker # 0xA160 -> --------tensor1 end--------- 285*da0073e9SAndroid Build Coastguard Worker # | | 286*da0073e9SAndroid Build Coastguard Worker # | | 287*da0073e9SAndroid Build Coastguard Worker # | | 288*da0073e9SAndroid Build Coastguard Worker # 0xA200 -> --------storage1 end-------- 289*da0073e9SAndroid Build Coastguard Worker # | | 290*da0073e9SAndroid Build Coastguard Worker # 0xE000 -> --------CUDA allocation----- 291*da0073e9SAndroid Build Coastguard Worker # 292*da0073e9SAndroid Build Coastguard Worker # To send tensor1, the following info are required from sender to receiver for 293*da0073e9SAndroid Build Coastguard Worker # storage recontruction. 294*da0073e9SAndroid Build Coastguard Worker # 1. cudaIpcMemHandle of 0xA000(which can be mapped to a basePtr in receiver process). 295*da0073e9SAndroid Build Coastguard Worker # basePtr may not be exactly 0xA000 since it's a different process. 296*da0073e9SAndroid Build Coastguard Worker # 2. offset(0xA100) of storage1 in the CUDA allocation. 297*da0073e9SAndroid Build Coastguard Worker # 3. size of storage1(0x100). 298*da0073e9SAndroid Build Coastguard Worker # 299*da0073e9SAndroid Build Coastguard Worker # On receiver side: 300*da0073e9SAndroid Build Coastguard Worker # 1. Get the devPtr of the MemHandle to access the memory, reconstruct a storage 301*da0073e9SAndroid Build Coastguard Worker # of the same type using (basePtr, offset, size). 302*da0073e9SAndroid Build Coastguard Worker # 2. we can reconstruct the tensor on top of the reconstructed storage 303*da0073e9SAndroid Build Coastguard Worker # Tensor(size=0x040, offset=0x020, storage=Storage(data=basePtr+0xA100, size=0x0100)) 304*da0073e9SAndroid Build Coastguard Worker # 305*da0073e9SAndroid Build Coastguard Worker # This strategy has a few implications: 306*da0073e9SAndroid Build Coastguard Worker # 307*da0073e9SAndroid Build Coastguard Worker # 1. When we serialize a CUDA tensor for IPC, we cannot do it all in one 308*da0073e9SAndroid Build Coastguard Worker # go (non-compositionally), and this requires to have a global map 309*da0073e9SAndroid Build Coastguard Worker # memHandle -> devPtr for each process. 310*da0073e9SAndroid Build Coastguard Worker # 311*da0073e9SAndroid Build Coastguard Worker # 2. We MUST NOT let the new IPC tensor be resizable. Originally, a resize 312*da0073e9SAndroid Build Coastguard Worker # of the storage beyond 0x100 would merely have caused us to do a 313*da0073e9SAndroid Build Coastguard Worker # reallocation. You don't really want to do this, but if you did, 314*da0073e9SAndroid Build Coastguard Worker # all that would happen is that you would lose IPC sharing. But if 315*da0073e9SAndroid Build Coastguard Worker # you do this in the new world, we will happily let you write out of 316*da0073e9SAndroid Build Coastguard Worker # bounds of your "allocation", clobbering unrelated data in the cached 317*da0073e9SAndroid Build Coastguard Worker # allocator block. BAD! 318*da0073e9SAndroid Build Coastguard Worker # 319*da0073e9SAndroid Build Coastguard Worker # By the way, in old versions of PyTorch, we supported this situation 320*da0073e9SAndroid Build Coastguard Worker # natively using a "storage view", which permitted multiple storages to be 321*da0073e9SAndroid Build Coastguard Worker # views on each other. But this was the *only* use of storage views, so we 322*da0073e9SAndroid Build Coastguard Worker # eliminated it so that we could just use tensor views to implement the same 323*da0073e9SAndroid Build Coastguard Worker # thing. 324*da0073e9SAndroid Build Coastguard Worker # 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker # TODO: Handle distinguishing between subclass and non-subclass versions of NT better 327*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/110543 328*da0073e9SAndroid Build Coastguard Worker from torch.nested._internal.nested_tensor import NestedTensor 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker if tensor.is_nested and not isinstance(tensor, NestedTensor): 331*da0073e9SAndroid Build Coastguard Worker return reduce_nested_tensor(tensor) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker if tensor.layout in { 334*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo, 335*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 336*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 337*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 338*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 339*da0073e9SAndroid Build Coastguard Worker }: 340*da0073e9SAndroid Build Coastguard Worker return reduce_sparse_tensor(tensor) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker storage = tensor._typed_storage() 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker if storage._untyped_storage.device.type == "cuda": 345*da0073e9SAndroid Build Coastguard Worker ( 346*da0073e9SAndroid Build Coastguard Worker device, 347*da0073e9SAndroid Build Coastguard Worker handle, 348*da0073e9SAndroid Build Coastguard Worker storage_size_bytes, 349*da0073e9SAndroid Build Coastguard Worker storage_offset_bytes, 350*da0073e9SAndroid Build Coastguard Worker ref_counter_handle, 351*da0073e9SAndroid Build Coastguard Worker ref_counter_offset, 352*da0073e9SAndroid Build Coastguard Worker event_handle, 353*da0073e9SAndroid Build Coastguard Worker event_sync_required, 354*da0073e9SAndroid Build Coastguard Worker ) = storage._share_cuda_() 355*da0073e9SAndroid Build Coastguard Worker tensor_offset = tensor.storage_offset() 356*da0073e9SAndroid Build Coastguard Worker shared_cache[handle] = StorageWeakRef(storage) 357*da0073e9SAndroid Build Coastguard Worker # _backward_hooks purposely omitted here, see 358*da0073e9SAndroid Build Coastguard Worker # Note [Don't serialize hooks] 359*da0073e9SAndroid Build Coastguard Worker return ( 360*da0073e9SAndroid Build Coastguard Worker rebuild_cuda_tensor, 361*da0073e9SAndroid Build Coastguard Worker ( 362*da0073e9SAndroid Build Coastguard Worker type(tensor), 363*da0073e9SAndroid Build Coastguard Worker tensor.size(), 364*da0073e9SAndroid Build Coastguard Worker tensor.stride(), 365*da0073e9SAndroid Build Coastguard Worker tensor_offset, # tensor offset in its storage 366*da0073e9SAndroid Build Coastguard Worker type(storage), 367*da0073e9SAndroid Build Coastguard Worker tensor.dtype, 368*da0073e9SAndroid Build Coastguard Worker device, 369*da0073e9SAndroid Build Coastguard Worker handle, # identifier which CUDA allocation is the storage in. 370*da0073e9SAndroid Build Coastguard Worker storage_size_bytes, # size(in bytes) of the storage 371*da0073e9SAndroid Build Coastguard Worker storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation 372*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad, 373*da0073e9SAndroid Build Coastguard Worker ref_counter_handle, 374*da0073e9SAndroid Build Coastguard Worker ref_counter_offset, 375*da0073e9SAndroid Build Coastguard Worker event_handle, 376*da0073e9SAndroid Build Coastguard Worker event_sync_required, 377*da0073e9SAndroid Build Coastguard Worker ), 378*da0073e9SAndroid Build Coastguard Worker ) 379*da0073e9SAndroid Build Coastguard Worker elif storage._untyped_storage.device.type == "meta": 380*da0073e9SAndroid Build Coastguard Worker return ( 381*da0073e9SAndroid Build Coastguard Worker rebuild_meta_tensor, 382*da0073e9SAndroid Build Coastguard Worker ( 383*da0073e9SAndroid Build Coastguard Worker type(tensor), 384*da0073e9SAndroid Build Coastguard Worker tensor.size(), 385*da0073e9SAndroid Build Coastguard Worker tensor.stride(), 386*da0073e9SAndroid Build Coastguard Worker tensor.storage_offset(), 387*da0073e9SAndroid Build Coastguard Worker tensor.dtype, 388*da0073e9SAndroid Build Coastguard Worker tensor.untyped_storage().size(), 389*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad, 390*da0073e9SAndroid Build Coastguard Worker ), 391*da0073e9SAndroid Build Coastguard Worker ) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker # _backward_hooks purposely omitted here, see Note [Don't serialize hooks] 394*da0073e9SAndroid Build Coastguard Worker metadata = ( 395*da0073e9SAndroid Build Coastguard Worker tensor.storage_offset(), 396*da0073e9SAndroid Build Coastguard Worker tensor.size(), 397*da0073e9SAndroid Build Coastguard Worker tensor.stride(), 398*da0073e9SAndroid Build Coastguard Worker tensor.requires_grad, 399*da0073e9SAndroid Build Coastguard Worker ) 400*da0073e9SAndroid Build Coastguard Worker return (rebuild_tensor, (type(tensor), storage, metadata)) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Workerdef rebuild_nested_tensor( 404*da0073e9SAndroid Build Coastguard Worker rebuild_buffer_func, 405*da0073e9SAndroid Build Coastguard Worker rebuild_buffer_args, 406*da0073e9SAndroid Build Coastguard Worker rebuild_sizes_func, 407*da0073e9SAndroid Build Coastguard Worker rebuild_sizes_args, 408*da0073e9SAndroid Build Coastguard Worker rebuild_strides_func, 409*da0073e9SAndroid Build Coastguard Worker rebuild_strides_args, 410*da0073e9SAndroid Build Coastguard Worker rebuild_offsets_func, 411*da0073e9SAndroid Build Coastguard Worker rebuild_offsets_args, 412*da0073e9SAndroid Build Coastguard Worker): 413*da0073e9SAndroid Build Coastguard Worker buffer = rebuild_buffer_func(*rebuild_buffer_args) 414*da0073e9SAndroid Build Coastguard Worker sizes = rebuild_sizes_func(*rebuild_sizes_args) 415*da0073e9SAndroid Build Coastguard Worker strides = rebuild_strides_func(*rebuild_strides_args) 416*da0073e9SAndroid Build Coastguard Worker offsets = rebuild_offsets_func(*rebuild_offsets_args) 417*da0073e9SAndroid Build Coastguard Worker return torch._nested_view_from_buffer_copy(buffer, sizes, strides, offsets) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Workerdef reduce_nested_tensor(nt): 421*da0073e9SAndroid Build Coastguard Worker rebuild_buffer_func, rebuild_buffer_args = reduce_tensor(nt.values()) 422*da0073e9SAndroid Build Coastguard Worker rebuild_sizes_func, rebuild_sizes_args = reduce_tensor(nt._nested_tensor_size()) 423*da0073e9SAndroid Build Coastguard Worker rebuild_strides_func, rebuild_strides_args = reduce_tensor( 424*da0073e9SAndroid Build Coastguard Worker nt._nested_tensor_strides() 425*da0073e9SAndroid Build Coastguard Worker ) 426*da0073e9SAndroid Build Coastguard Worker rebuild_offsets_func, rebuild_offsets_args = reduce_tensor( 427*da0073e9SAndroid Build Coastguard Worker nt._nested_tensor_storage_offsets() 428*da0073e9SAndroid Build Coastguard Worker ) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker return ( 431*da0073e9SAndroid Build Coastguard Worker rebuild_nested_tensor, 432*da0073e9SAndroid Build Coastguard Worker ( 433*da0073e9SAndroid Build Coastguard Worker rebuild_buffer_func, 434*da0073e9SAndroid Build Coastguard Worker rebuild_buffer_args, 435*da0073e9SAndroid Build Coastguard Worker rebuild_sizes_func, 436*da0073e9SAndroid Build Coastguard Worker rebuild_sizes_args, 437*da0073e9SAndroid Build Coastguard Worker rebuild_strides_func, 438*da0073e9SAndroid Build Coastguard Worker rebuild_strides_args, 439*da0073e9SAndroid Build Coastguard Worker rebuild_offsets_func, 440*da0073e9SAndroid Build Coastguard Worker rebuild_offsets_args, 441*da0073e9SAndroid Build Coastguard Worker ), 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Workerdef rebuild_sparse_coo_tensor( 446*da0073e9SAndroid Build Coastguard Worker rebuild_indices_func, 447*da0073e9SAndroid Build Coastguard Worker rebuild_indices_args, 448*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, 449*da0073e9SAndroid Build Coastguard Worker rebuild_values_args, 450*da0073e9SAndroid Build Coastguard Worker shape, 451*da0073e9SAndroid Build Coastguard Worker is_coalesced, 452*da0073e9SAndroid Build Coastguard Worker): 453*da0073e9SAndroid Build Coastguard Worker indices = rebuild_indices_func(*rebuild_indices_args) 454*da0073e9SAndroid Build Coastguard Worker values = rebuild_values_func(*rebuild_values_args) 455*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(indices, values, shape, is_coalesced=is_coalesced) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Workerdef rebuild_sparse_compressed_tensor( 459*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_func, 460*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_args, 461*da0073e9SAndroid Build Coastguard Worker rebuild_plain_indices_func, 462*da0073e9SAndroid Build Coastguard Worker rebuild_plain_indices_args, 463*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, 464*da0073e9SAndroid Build Coastguard Worker rebuild_values_args, 465*da0073e9SAndroid Build Coastguard Worker shape, 466*da0073e9SAndroid Build Coastguard Worker layout, 467*da0073e9SAndroid Build Coastguard Worker): 468*da0073e9SAndroid Build Coastguard Worker compressed_indices = rebuild_compressed_indices_func( 469*da0073e9SAndroid Build Coastguard Worker *rebuild_compressed_indices_args 470*da0073e9SAndroid Build Coastguard Worker ) 471*da0073e9SAndroid Build Coastguard Worker plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args) 472*da0073e9SAndroid Build Coastguard Worker values = rebuild_values_func(*rebuild_values_args) 473*da0073e9SAndroid Build Coastguard Worker return torch.sparse_compressed_tensor( 474*da0073e9SAndroid Build Coastguard Worker compressed_indices, plain_indices, values, shape, layout=layout 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Workerdef reduce_sparse_tensor(sparse): 479*da0073e9SAndroid Build Coastguard Worker if sparse.layout is torch.sparse_coo: 480*da0073e9SAndroid Build Coastguard Worker rebuild_indices_func, rebuild_indices_args = reduce_tensor(sparse._indices()) 481*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, rebuild_values_args = reduce_tensor(sparse._values()) 482*da0073e9SAndroid Build Coastguard Worker return ( 483*da0073e9SAndroid Build Coastguard Worker rebuild_sparse_coo_tensor, 484*da0073e9SAndroid Build Coastguard Worker ( 485*da0073e9SAndroid Build Coastguard Worker rebuild_indices_func, 486*da0073e9SAndroid Build Coastguard Worker rebuild_indices_args, 487*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, 488*da0073e9SAndroid Build Coastguard Worker rebuild_values_args, 489*da0073e9SAndroid Build Coastguard Worker sparse.shape, 490*da0073e9SAndroid Build Coastguard Worker sparse.is_coalesced(), 491*da0073e9SAndroid Build Coastguard Worker ), 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker else: 494*da0073e9SAndroid Build Coastguard Worker if sparse.layout in {torch.sparse_csr, torch.sparse_bsr}: 495*da0073e9SAndroid Build Coastguard Worker compressed_indices = sparse.crow_indices() 496*da0073e9SAndroid Build Coastguard Worker plain_indices = sparse.col_indices() 497*da0073e9SAndroid Build Coastguard Worker elif sparse.layout in {torch.sparse_csc, torch.sparse_bsc}: 498*da0073e9SAndroid Build Coastguard Worker compressed_indices = sparse.ccol_indices() 499*da0073e9SAndroid Build Coastguard Worker plain_indices = sparse.row_indices() 500*da0073e9SAndroid Build Coastguard Worker else: 501*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(sparse.layout) 502*da0073e9SAndroid Build Coastguard Worker ( 503*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_func, 504*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_args, 505*da0073e9SAndroid Build Coastguard Worker ) = reduce_tensor(compressed_indices) 506*da0073e9SAndroid Build Coastguard Worker rebuild_plain_indices_func, rebuild_plain_indices_args = reduce_tensor( 507*da0073e9SAndroid Build Coastguard Worker plain_indices 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, rebuild_values_args = reduce_tensor(sparse.values()) 510*da0073e9SAndroid Build Coastguard Worker return ( 511*da0073e9SAndroid Build Coastguard Worker rebuild_sparse_compressed_tensor, 512*da0073e9SAndroid Build Coastguard Worker ( 513*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_func, 514*da0073e9SAndroid Build Coastguard Worker rebuild_compressed_indices_args, 515*da0073e9SAndroid Build Coastguard Worker rebuild_plain_indices_func, 516*da0073e9SAndroid Build Coastguard Worker rebuild_plain_indices_args, 517*da0073e9SAndroid Build Coastguard Worker rebuild_values_func, 518*da0073e9SAndroid Build Coastguard Worker rebuild_values_args, 519*da0073e9SAndroid Build Coastguard Worker sparse.shape, 520*da0073e9SAndroid Build Coastguard Worker sparse.layout, 521*da0073e9SAndroid Build Coastguard Worker ), 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Workerdef fd_id(fd): 526*da0073e9SAndroid Build Coastguard Worker # Returns a tuple which uniquely identifies a file descriptor. In Mac OS, 527*da0073e9SAndroid Build Coastguard Worker # this doesn't work with shared memory handles, which is why we don't 528*da0073e9SAndroid Build Coastguard Worker # support the "file_descriptor" sharing method on that platform. 529*da0073e9SAndroid Build Coastguard Worker stat = os.fstat(fd) 530*da0073e9SAndroid Build Coastguard Worker return (stat.st_ino, stat.st_dev) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Workerdef storage_from_cache(cls, key): 534*da0073e9SAndroid Build Coastguard Worker storage_ref = shared_cache.get(key) 535*da0073e9SAndroid Build Coastguard Worker if storage_ref is None: 536*da0073e9SAndroid Build Coastguard Worker return None 537*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage._new_with_weak_ptr(storage_ref.cdata) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_fd(cls, df, size): 541*da0073e9SAndroid Build Coastguard Worker fd = df.detach() 542*da0073e9SAndroid Build Coastguard Worker try: 543*da0073e9SAndroid Build Coastguard Worker storage = storage_from_cache(cls, fd_id(fd)) 544*da0073e9SAndroid Build Coastguard Worker if storage is not None: 545*da0073e9SAndroid Build Coastguard Worker return storage 546*da0073e9SAndroid Build Coastguard Worker storage = cls._new_shared_fd_cpu(fd, size) 547*da0073e9SAndroid Build Coastguard Worker shared_cache[fd_id(fd)] = StorageWeakRef(storage) 548*da0073e9SAndroid Build Coastguard Worker return storage 549*da0073e9SAndroid Build Coastguard Worker finally: 550*da0073e9SAndroid Build Coastguard Worker os.close(fd) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_filename(cls, manager, handle, size, dtype=None): 554*da0073e9SAndroid Build Coastguard Worker storage: Union[torch.TypedStorage, torch.UntypedStorage] = storage_from_cache( 555*da0073e9SAndroid Build Coastguard Worker cls, handle 556*da0073e9SAndroid Build Coastguard Worker ) 557*da0073e9SAndroid Build Coastguard Worker if storage is not None: 558*da0073e9SAndroid Build Coastguard Worker return storage._shared_decref() 559*da0073e9SAndroid Build Coastguard Worker if dtype is None: 560*da0073e9SAndroid Build Coastguard Worker storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size) 561*da0073e9SAndroid Build Coastguard Worker else: 562*da0073e9SAndroid Build Coastguard Worker byte_size = size * torch._utils._element_size(dtype) 563*da0073e9SAndroid Build Coastguard Worker untyped_storage: torch.UntypedStorage = ( 564*da0073e9SAndroid Build Coastguard Worker torch.UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size) 565*da0073e9SAndroid Build Coastguard Worker ) 566*da0073e9SAndroid Build Coastguard Worker storage = torch.TypedStorage( 567*da0073e9SAndroid Build Coastguard Worker wrap_storage=untyped_storage, dtype=dtype, _internal=True 568*da0073e9SAndroid Build Coastguard Worker ) 569*da0073e9SAndroid Build Coastguard Worker shared_cache[handle] = StorageWeakRef(storage) 570*da0073e9SAndroid Build Coastguard Worker return storage._shared_decref() 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Workerdef rebuild_storage_empty(cls): 574*da0073e9SAndroid Build Coastguard Worker return cls() 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Workerdef rebuild_typed_storage(storage, dtype): 578*da0073e9SAndroid Build Coastguard Worker return torch.storage.TypedStorage(wrap_storage=storage, dtype=dtype, _internal=True) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker# Use for torch.storage.TypedStorage 582*da0073e9SAndroid Build Coastguard Workerdef reduce_typed_storage(storage): 583*da0073e9SAndroid Build Coastguard Worker return (rebuild_typed_storage, (storage._untyped_storage, storage.dtype)) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Workerdef rebuild_typed_storage_child(storage, storage_type): 587*da0073e9SAndroid Build Coastguard Worker return storage_type(wrap_storage=storage, _internal=True) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker# Use for child classes of torch.storage.TypedStorage, like torch.FloatStorage 591*da0073e9SAndroid Build Coastguard Workerdef reduce_typed_storage_child(storage): 592*da0073e9SAndroid Build Coastguard Worker return (rebuild_typed_storage_child, (storage._untyped_storage, type(storage))) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Workerdef reduce_storage(storage): 596*da0073e9SAndroid Build Coastguard Worker from . import get_sharing_strategy 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker if storage.is_cuda: 599*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 600*da0073e9SAndroid Build Coastguard Worker "Cannot pickle CUDA storage; try pickling a CUDA tensor instead" 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker elif storage.device.type == "meta": 603*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 604*da0073e9SAndroid Build Coastguard Worker "Cannot pickle meta storage; try pickling a meta tensor instead" 605*da0073e9SAndroid Build Coastguard Worker ) 606*da0073e9SAndroid Build Coastguard Worker elif get_sharing_strategy() == "file_system": 607*da0073e9SAndroid Build Coastguard Worker metadata = storage._share_filename_cpu_() 608*da0073e9SAndroid Build Coastguard Worker cache_key = metadata[1] 609*da0073e9SAndroid Build Coastguard Worker rebuild = rebuild_storage_filename 610*da0073e9SAndroid Build Coastguard Worker if isinstance(storage, torch.TypedStorage): 611*da0073e9SAndroid Build Coastguard Worker metadata += (storage.dtype,) 612*da0073e9SAndroid Build Coastguard Worker storage._shared_incref() 613*da0073e9SAndroid Build Coastguard Worker elif storage.size() == 0: 614*da0073e9SAndroid Build Coastguard Worker # This is special cased because Empty tensors 615*da0073e9SAndroid Build Coastguard Worker # (with size 0) cannot be mmapped. 616*da0073e9SAndroid Build Coastguard Worker return (rebuild_storage_empty, (type(storage),)) 617*da0073e9SAndroid Build Coastguard Worker else: 618*da0073e9SAndroid Build Coastguard Worker fd, size = storage._share_fd_cpu_() 619*da0073e9SAndroid Build Coastguard Worker df = multiprocessing.reduction.DupFd(fd) 620*da0073e9SAndroid Build Coastguard Worker cache_key = fd_id(fd) 621*da0073e9SAndroid Build Coastguard Worker metadata = (df, size) 622*da0073e9SAndroid Build Coastguard Worker rebuild = rebuild_storage_fd # type: ignore[assignment] 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker shared_cache[cache_key] = StorageWeakRef(storage) 625*da0073e9SAndroid Build Coastguard Worker return (rebuild, (type(storage),) + metadata) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Workerdef init_reductions(): 629*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(torch.cuda.Event, reduce_event) 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker for t in torch._storage_classes: 632*da0073e9SAndroid Build Coastguard Worker if t.__name__ == "UntypedStorage": 633*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(t, reduce_storage) 634*da0073e9SAndroid Build Coastguard Worker else: 635*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(t, reduce_typed_storage_child) 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(torch.storage.TypedStorage, reduce_typed_storage) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker for t in torch._tensor_classes: 640*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(t, reduce_tensor) 641*da0073e9SAndroid Build Coastguard Worker 642*da0073e9SAndroid Build Coastguard Worker # TODO: Maybe this should be in tensor_classes? :) 643*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(torch.Tensor, reduce_tensor) 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker from torch.nn.parameter import Parameter 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker ForkingPickler.register(Parameter, reduce_tensor) 648