1 #include <torch/csrc/utils/tensor_memoryformats.h> 2 3 #include <c10/core/MemoryFormat.h> 4 #include <torch/csrc/DynamicTypes.h> 5 #include <torch/csrc/Exceptions.h> 6 #include <torch/csrc/MemoryFormat.h> 7 8 #include <torch/csrc/python_headers.h> 9 #include <torch/csrc/utils/object_ptr.h> 10 11 namespace torch::utils { 12 13 namespace { 14 // Intentionally leaked 15 std::array<PyObject*, static_cast<int>(at::MemoryFormat::NumOptions)> 16 memory_format_registry = {}; 17 } // anonymous namespace 18 getTHPMemoryFormat(at::MemoryFormat memory_format)19PyObject* getTHPMemoryFormat(at::MemoryFormat memory_format) { 20 auto py_memory_format = 21 memory_format_registry[static_cast<int>(memory_format)]; 22 if (!py_memory_format) { 23 throw std::invalid_argument("unsupported memory_format"); 24 } 25 return py_memory_format; 26 } 27 initializeMemoryFormats()28void initializeMemoryFormats() { 29 auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); 30 if (!torch_module) { 31 throw python_error(); 32 } 33 34 auto add_memory_format = [&](at::MemoryFormat format, const char* name) { 35 std::string module_name = "torch."; 36 PyObject* memory_format = THPMemoryFormat_New(format, module_name + name); 37 Py_INCREF(memory_format); 38 if (PyModule_AddObject(torch_module, name, memory_format) != 0) { 39 Py_DECREF(memory_format); 40 throw python_error(); 41 } 42 Py_INCREF(memory_format); 43 memory_format_registry[static_cast<size_t>(format)] = memory_format; 44 }; 45 46 add_memory_format(at::MemoryFormat::Preserve, "preserve_format"); 47 add_memory_format(at::MemoryFormat::Contiguous, "contiguous_format"); 48 add_memory_format(at::MemoryFormat::ChannelsLast, "channels_last"); 49 add_memory_format(at::MemoryFormat::ChannelsLast3d, "channels_last_3d"); 50 } 51 52 } // namespace torch::utils 53