1 #include <torch/csrc/cuda/THCP.h> 2 #include <torch/csrc/python_headers.h> 3 #include <cstdarg> 4 #include <string> 5 6 #ifdef USE_CUDA 7 // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use 8 // whatever the current stream of the device the input is associated with was. 9 std::vector<std::optional<at::cuda::CUDAStream>> THPUtils_PySequence_to_CUDAStreamList(PyObject * obj)10THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { 11 if (!PySequence_Check(obj)) { 12 throw std::runtime_error( 13 "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList"); 14 } 15 THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr)); 16 if (seq.get() == nullptr) { 17 throw std::runtime_error( 18 "expected PySequence, but got " + std::string(THPUtils_typename(obj))); 19 } 20 21 std::vector<std::optional<at::cuda::CUDAStream>> streams; 22 Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); 23 for (Py_ssize_t i = 0; i < length; i++) { 24 PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i); 25 26 if (PyObject_IsInstance(stream, THCPStreamClass)) { 27 // Spicy hot reinterpret cast!! 28 streams.emplace_back(at::cuda::CUDAStream::unpack3( 29 (reinterpret_cast<THCPStream*>(stream))->stream_id, 30 (reinterpret_cast<THCPStream*>(stream))->device_index, 31 static_cast<c10::DeviceType>( 32 (reinterpret_cast<THCPStream*>(stream))->device_type))); 33 } else if (stream == Py_None) { 34 streams.emplace_back(); 35 } else { 36 // NOLINTNEXTLINE(bugprone-throw-keyword-missing) 37 std::runtime_error( 38 "Unknown data type found in stream list. Need torch.cuda.Stream or None"); 39 } 40 } 41 return streams; 42 } 43 44 #endif 45