1*da0073e9SAndroid Build Coastguard Workerimport os 2*da0073e9SAndroid Build Coastguard Workerimport sys 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Storage 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker__all__: List[str] = [] 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef _dummy_fn(name: str) -> Callable: 13*da0073e9SAndroid Build Coastguard Worker def fn(*args, **kwargs): # type: ignore[no-untyped-def] 14*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"torch._C.{name} is not supported on this platform") 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker return fn 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_gds_register_buffer"): 20*da0073e9SAndroid Build Coastguard Worker assert not hasattr(torch._C, "_gds_deregister_buffer") 21*da0073e9SAndroid Build Coastguard Worker assert not hasattr(torch._C, "_gds_register_handle") 22*da0073e9SAndroid Build Coastguard Worker assert not hasattr(torch._C, "_gds_deregister_handle") 23*da0073e9SAndroid Build Coastguard Worker assert not hasattr(torch._C, "_gds_load_storage") 24*da0073e9SAndroid Build Coastguard Worker assert not hasattr(torch._C, "_gds_save_storage") 25*da0073e9SAndroid Build Coastguard Worker # Define functions 26*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer") 27*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer") 28*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle") 29*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle") 30*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage") 31*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage") 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerdef _gds_register_buffer(s: Storage) -> None: 35*da0073e9SAndroid Build Coastguard Worker """Registers a buffer. 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker Args: 38*da0073e9SAndroid Build Coastguard Worker s (Storage): Buffer to register. 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker torch._C._gds_register_buffer(s) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerdef _gds_deregister_buffer(s: Storage) -> None: 44*da0073e9SAndroid Build Coastguard Worker """Registers a buffer. 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker Args: 47*da0073e9SAndroid Build Coastguard Worker s (Storage): Buffer to register. 48*da0073e9SAndroid Build Coastguard Worker """ 49*da0073e9SAndroid Build Coastguard Worker torch._C._gds_deregister_buffer(s) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Workerclass _GdsFile: 53*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around cuFile. 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker cuFile is a file-like interface to the GPUDirect Storage (GDS) API. 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker Args: 58*da0073e9SAndroid Build Coastguard Worker filename (str): Name of the file to open. 59*da0073e9SAndroid Build Coastguard Worker flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will 60*da0073e9SAndroid Build Coastguard Worker be added automatically. 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker .. _CUDA GPUDirect Storage Documentation: 63*da0073e9SAndroid Build Coastguard Worker https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api 64*da0073e9SAndroid Build Coastguard Worker """ 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def __init__(self, filename: str, flags: int): 67*da0073e9SAndroid Build Coastguard Worker if sys.platform == "win32": 68*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("GdsFile is not supported on this platform.") 69*da0073e9SAndroid Build Coastguard Worker self.filename = filename 70*da0073e9SAndroid Build Coastguard Worker self.flags = flags 71*da0073e9SAndroid Build Coastguard Worker self.fd = os.open(filename, flags | os.O_DIRECT) 72*da0073e9SAndroid Build Coastguard Worker self.handle: Optional[int] = None 73*da0073e9SAndroid Build Coastguard Worker self.register_handle() 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def __del__(self) -> None: 76*da0073e9SAndroid Build Coastguard Worker if self.handle is not None: 77*da0073e9SAndroid Build Coastguard Worker self.deregister_handle() 78*da0073e9SAndroid Build Coastguard Worker os.close(self.fd) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def register_handle(self) -> None: 81*da0073e9SAndroid Build Coastguard Worker """Registers file descriptor to cuFile Driver. 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker This is a wrapper around ``cuFileHandleRegister``. 84*da0073e9SAndroid Build Coastguard Worker """ 85*da0073e9SAndroid Build Coastguard Worker assert ( 86*da0073e9SAndroid Build Coastguard Worker self.handle is None 87*da0073e9SAndroid Build Coastguard Worker ), "Cannot register a handle that is already registered." 88*da0073e9SAndroid Build Coastguard Worker self.handle = torch._C._gds_register_handle(self.fd) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def deregister_handle(self) -> None: 91*da0073e9SAndroid Build Coastguard Worker """Deregisters file descriptor from cuFile Driver. 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker This is a wrapper around ``cuFileHandleDeregister``. 94*da0073e9SAndroid Build Coastguard Worker """ 95*da0073e9SAndroid Build Coastguard Worker assert ( 96*da0073e9SAndroid Build Coastguard Worker self.handle is not None 97*da0073e9SAndroid Build Coastguard Worker ), "Cannot deregister a handle that is not registered." 98*da0073e9SAndroid Build Coastguard Worker torch._C._gds_deregister_handle(self.handle) 99*da0073e9SAndroid Build Coastguard Worker self.handle = None 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def load_storage(self, storage: Storage, offset: int = 0) -> None: 102*da0073e9SAndroid Build Coastguard Worker """Loads data from the file into the storage. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data 105*da0073e9SAndroid Build Coastguard Worker will be loaded from the file at ``offset`` into the storage. 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker Args: 108*da0073e9SAndroid Build Coastguard Worker storage (Storage): Storage to load data into. 109*da0073e9SAndroid Build Coastguard Worker offset (int, optional): Offset into the file to start loading from. (Default: 0) 110*da0073e9SAndroid Build Coastguard Worker """ 111*da0073e9SAndroid Build Coastguard Worker assert ( 112*da0073e9SAndroid Build Coastguard Worker self.handle is not None 113*da0073e9SAndroid Build Coastguard Worker ), "Cannot load data from a file that is not registered." 114*da0073e9SAndroid Build Coastguard Worker torch._C._gds_load_storage(self.handle, storage, offset) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def save_storage(self, storage: Storage, offset: int = 0) -> None: 117*da0073e9SAndroid Build Coastguard Worker """Saves data from the storage into the file. 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker This is a wrapper around ``cuFileWrite``. All bytes of the storage 120*da0073e9SAndroid Build Coastguard Worker will be written to the file at ``offset``. 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker Args: 123*da0073e9SAndroid Build Coastguard Worker storage (Storage): Storage to save data from. 124*da0073e9SAndroid Build Coastguard Worker offset (int, optional): Offset into the file to start saving to. (Default: 0) 125*da0073e9SAndroid Build Coastguard Worker """ 126*da0073e9SAndroid Build Coastguard Worker assert ( 127*da0073e9SAndroid Build Coastguard Worker self.handle is not None 128*da0073e9SAndroid Build Coastguard Worker ), "Cannot save data to a file that is not registered." 129*da0073e9SAndroid Build Coastguard Worker torch._C._gds_save_storage(self.handle, storage, offset) 130