xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_memoryformats.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/tensor_memoryformats.h>
2 
3 #include <c10/core/MemoryFormat.h>
4 #include <torch/csrc/DynamicTypes.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/MemoryFormat.h>
7 
8 #include <torch/csrc/python_headers.h>
9 #include <torch/csrc/utils/object_ptr.h>
10 
11 namespace torch::utils {
12 
13 namespace {
14 // Intentionally leaked
15 std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)>
16     memory_format_registry = {};
17 } // anonymous namespace
18 
getTHPMemoryFormat(at::MemoryFormat memory_format)19 PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) {
20   auto py_memory_format =
21       memory_format_registry[static_cast<int>(memory_format)];
22   if (!py_memory_format) {
23     throw std::invalid_argument("unsupported memory_format");
24   }
25   return py_memory_format;
26 }
27 
initializeMemoryFormats()28 void initializeMemoryFormats() {
29   auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
30   if (!torch_module) {
31     throw python_error();
32   }
33 
34   auto add_memory_format = [&](at::MemoryFormat format, const char* name) {
35     std::string module_name = "torch.";
36     PyObject* memory_format = THPMemoryFormat_New(format, module_name + name);
37     Py_INCREF(memory_format);
38     if (PyModule_AddObject(torch_module, name, memory_format) != 0) {
39       Py_DECREF(memory_format);
40       throw python_error();
41     }
42     Py_INCREF(memory_format);
43     memory_format_registry[static_cast<size_t>(format)] = memory_format;
44   };
45 
46   add_memory_format(at::MemoryFormat::Preserve, "preserve_format");
47   add_memory_format(at::MemoryFormat::Contiguous, "contiguous_format");
48   add_memory_format(at::MemoryFormat::ChannelsLast, "channels_last");
49   add_memory_format(at::MemoryFormat::ChannelsLast3d, "channels_last_3d");
50 }
51 
52 } // namespace torch::utils
53