xref: /aosp_15_r20/external/pytorch/torch/csrc/cuda/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/cuda/THCP.h>
2 #include <torch/csrc/python_headers.h>
3 #include <cstdarg>
4 #include <string>
5 
6 #ifdef USE_CUDA
7 // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
8 // whatever the current stream of the device the input is associated with was.
9 std::vector<std::optional<at::cuda::CUDAStream>>
THPUtils_PySequence_to_CUDAStreamList(PyObject * obj)10 THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
11   if (!PySequence_Check(obj)) {
12     throw std::runtime_error(
13         "Expected a sequence in THPUtils_PySequence_to_CUDAStreamList");
14   }
15   THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, nullptr));
16   if (seq.get() == nullptr) {
17     throw std::runtime_error(
18         "expected PySequence, but got " + std::string(THPUtils_typename(obj)));
19   }
20 
21   std::vector<std::optional<at::cuda::CUDAStream>> streams;
22   Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
23   for (Py_ssize_t i = 0; i < length; i++) {
24     PyObject* stream = PySequence_Fast_GET_ITEM(seq.get(), i);
25 
26     if (PyObject_IsInstance(stream, THCPStreamClass)) {
27       // Spicy hot reinterpret cast!!
28       streams.emplace_back(at::cuda::CUDAStream::unpack3(
29           (reinterpret_cast<THCPStream*>(stream))->stream_id,
30           (reinterpret_cast<THCPStream*>(stream))->device_index,
31           static_cast<c10::DeviceType>(
32               (reinterpret_cast<THCPStream*>(stream))->device_type)));
33     } else if (stream == Py_None) {
34       streams.emplace_back();
35     } else {
36       // NOLINTNEXTLINE(bugprone-throw-keyword-missing)
37       std::runtime_error(
38           "Unknown data type found in stream list. Need torch.cuda.Stream or None");
39     }
40   }
41   return streams;
42 }
43 
44 #endif
45