xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDADevice.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/Exceptions.h>
4 
5 #include <cuda.h>
6 #include <cuda_runtime.h>
7 
8 namespace at::cuda {
9 
getDeviceFromPtr(void * ptr)10 inline Device getDeviceFromPtr(void* ptr) {
11   cudaPointerAttributes attr{};
12 
13   AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
14 
15 #if !defined(USE_ROCM)
16   TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered,
17     "The specified pointer resides on host memory and is not registered with any CUDA device.");
18 #endif
19 
20   return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)};
21 }
22 
23 } // namespace at::cuda
24