1 #pragma once 2 3 #include <ATen/cuda/Exceptions.h> 4 5 #include <cuda.h> 6 #include <cuda_runtime.h> 7 8 namespace at::cuda { 9 getDeviceFromPtr(void * ptr)10inline Device getDeviceFromPtr(void* ptr) { 11 cudaPointerAttributes attr{}; 12 13 AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr)); 14 15 #if !defined(USE_ROCM) 16 TORCH_CHECK(attr.type != cudaMemoryTypeUnregistered, 17 "The specified pointer resides on host memory and is not registered with any CUDA device."); 18 #endif 19 20 return {c10::DeviceType::CUDA, static_cast<DeviceIndex>(attr.device)}; 21 } 22 23 } // namespace at::cuda 24