xref: /aosp_15_r20/external/pytorch/torch/csrc/Storage.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 #ifdef _MSC_VER
3 #include <c10/util/win32-headers.h>
4 #endif
5 #include <structmember.h>
6 
7 #include <ATen/mps/MPSDevice.h>
8 #include <c10/core/CPUAllocator.h>
9 #include <c10/core/RefcountedDeleter.h>
10 #include <libshm.h>
11 #include <torch/csrc/CudaIPCTypes.h>
12 #include <torch/csrc/Device.h>
13 #include <torch/csrc/DynamicTypes.h>
14 #include <torch/csrc/StorageMethods.h>
15 #include <torch/csrc/StorageSharing.h>
16 #include <torch/csrc/THP.h>
17 #include <torch/csrc/autograd/utils/wrap_outputs.h>
18 #include <torch/csrc/copy_utils.h>
19 #include <torch/csrc/utils/device_lazy_init.h>
20 #include <torch/csrc/utils/pyobject_preservation.h>
21 #include <torch/csrc/utils/python_arg_parser.h>
22 
23 #include <c10/util/intrusive_ptr.h>
24 #include <fmt/format.h>
25 
26 template <>
free()27 void THPPointer<c10::StorageImpl>::free() {
28   if (ptr) {
29     c10::raw::intrusive_ptr::decref(ptr);
30   }
31 }
32 
33 PyTypeObject* THPStorageClass = nullptr;
34 
THPStorage_NewWithStorage(PyTypeObject * type,c10::Storage _storage,c10::impl::PyInterpreterStatus status,bool allow_preexisting_pyobj)35 PyObject* THPStorage_NewWithStorage(
36     PyTypeObject* type,
37     c10::Storage _storage,
38     c10::impl::PyInterpreterStatus status,
39     bool allow_preexisting_pyobj) {
40   TORCH_CHECK(
41       PyType_IsSubtype(type, &THPStorageType),
42       "Creating a Storage subclass from a class that does not inherit from ",
43       "Storage is not possible. Make sure your class inherits from Storage.");
44 
45   auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
46       getPyInterpreter(), /*ignore_hermetic_tls=*/false);
47   if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
48     TORCH_CHECK(
49         allow_preexisting_pyobj,
50         "Creating a new Storage subclass ",
51         type->tp_name,
52         " but the raw Storage object is already associated to a python object ",
53         "of type ",
54         maybe_pyobj.value()->ob_type->tp_name);
55     PyObject* obj = *maybe_pyobj;
56     PyTypeObject* obj_type = Py_TYPE(obj);
57     TORCH_CHECK(
58         obj_type == type || PyType_IsSubtype(obj_type, type),
59         "Creating a new Storage subclass ",
60         type->tp_name,
61         " but the raw Storage object is already associated to a python object ",
62         "of type ",
63         maybe_pyobj.value()->ob_type->tp_name,
64         " which is not a subclass of the "
65         "requested type");
66     return THPStorage_Wrap(std::move(_storage));
67   }
68 
69   PyObject* obj = type->tp_alloc(type, 0);
70   TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
71 
72   auto s = (THPStorage*)obj;
73 
74   new (&s->cdata) c10::MaybeOwned<c10::Storage>();
75 
76   s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
77 
78   if (!c10::impl::HermeticPyObjectTLS::get_state()) {
79     s->is_hermetic = false;
80     const auto& storage = THPStorage_Unpack(s);
81     storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(
82         getPyInterpreter(), obj, status);
83   } else {
84     s->is_hermetic = true;
85   }
86 
87   return obj;
88 }
89 
90 // Wraps the c10::Storage with a storage PyObject
THPStorage_Wrap(c10::Storage storage)91 PyObject* THPStorage_Wrap(c10::Storage storage) {
92   c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
93   if (c10::impl::HermeticPyObjectTLS::get_state()) {
94     return THPStorage_NewWithStorage(
95         THPStorageClass,
96         std::move(storage),
97         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
98   }
99   c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
100 
101   // If the StorageImpl has a PyObject that is managed by a different
102   // interpreter than the current one, create a new StorageImpl that points to
103   // the same data and then create the Python storage from that.
104   // NOTE: This is only supposed to happen in MultiPy
105   if (pyobj_slot->has_pyobj_nonhermetic() &&
106       !pyobj_slot->check_interpreter(getPyInterpreter())) {
107     return THPStorage_NewWithStorage(
108         THPStorageClass,
109         c10::newStorageImplFromRefcountedDataPtr(storage),
110         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
111   }
112   std::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
113       getPyInterpreter(), /*ignore_hermetic_tls=*/false);
114   c10::impl::PyInterpreterStatus status =
115       c10::impl::PyInterpreterStatus::TAGGED_BY_US;
116   if (maybe_pyobj.has_value()) {
117     auto obj = *maybe_pyobj;
118     if (obj) {
119       TORCH_CHECK(
120           THPStorage_Check(obj),
121           "Expected a storage type, but got ",
122           Py_TYPE(obj)->tp_name);
123 
124       if (pyobj_slot->owns_pyobj()) {
125         pyobj_slot->set_owns_pyobj(false);
126         reinterpret_cast<THPStorage*>(obj)->cdata =
127             c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
128         return obj;
129       } else {
130         Py_INCREF(obj);
131         return obj;
132       }
133     }
134     status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
135   } else {
136     if (storage.use_count() <= 1) {
137       status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED;
138     } else {
139       status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
140     }
141   }
142   return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status);
143 }
144 
THPStorage_isPreservable(THPStorage * self)145 static bool THPStorage_isPreservable(THPStorage* self) {
146   if (self->cdata.unsafeIsBorrowed()) {
147     return false;
148   }
149   auto const& storage = THPStorage_Unpack(self);
150 
151   if (self->is_hermetic) {
152     return false;
153   }
154 
155   if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
156           getPyInterpreter(), /*ignore_hermetic_tls=*/true) !=
157       std::make_optional((PyObject*)self)) {
158     return false;
159   }
160   if (storage.use_count() <= 1) {
161     return false;
162   }
163   return true;
164 }
165 
THPStorage_tryPreserve(THPStorage * self)166 static bool THPStorage_tryPreserve(THPStorage* self) {
167   if (!THPStorage_isPreservable(self)) {
168     return false;
169   }
170 
171   const auto& storage = THPStorage_Unpack(self);
172   c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
173 
174   auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
175       getPyInterpreter(),
176       /*ignore_hermetic_tls=*/true);
177   // NOTE: It is possible to just set the PyObjectSlot here, but the point is
178   // that we should have already set PyObjectSlot when the storage PyObject was
179   // created.
180   TORCH_INTERNAL_ASSERT(
181       maybe_pyobj.has_value(),
182       "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
183 
184   PyObject* pyobj = *maybe_pyobj;
185 
186   TORCH_CHECK(
187       THPStorage_Check(pyobj),
188       "Expected a storage type, but got ",
189       Py_TYPE(pyobj)->tp_name);
190 
191   TORCH_INTERNAL_ASSERT(
192       (void*)pyobj == (void*)self,
193       "Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
194 
195   TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
196 
197   storage_impl->pyobj_slot()->set_owns_pyobj(true);
198   Py_INCREF(self);
199 
200   self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
201   return true;
202 }
203 
THPStorage_subclass_dealloc(PyObject * self)204 static void THPStorage_subclass_dealloc(PyObject* self) {
205   THPStorage* _self = (THPStorage*)self;
206 
207   if (THPStorage_tryPreserve(_self)) {
208     return;
209   }
210 
211   // Some subclass of StorageBase could be GC-tracked objects even
212   // though the base class is not
213   auto* type = Py_TYPE(self);
214   if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
215     PyObject_GC_UnTrack(self);
216   }
217 
218   bool has_finalizer = type->tp_finalize || type->tp_del;
219 
220   if (type->tp_finalize) {
221     PyObject_GC_Track(self);
222     if (PyObject_CallFinalizerFromDealloc(self) < 0) {
223       // The finalizer has resurrected the PyObject and there is a new Python
224       // reference to it, so we can just stop deallocating. Read about
225       // resurrection from `__del__` here:
226       // https://docs.python.org/3/reference/datamodel.html#object.__del__
227       return;
228     }
229     PyObject_GC_UnTrack(self);
230   }
231 
232   // base test is unnecessary as THPStorae does not set this
233   if (type->tp_weaklistoffset) {
234     PyObject_ClearWeakRefs(self);
235   }
236 
237   if (type->tp_del) {
238     PyObject_GC_Track(self);
239     type->tp_del(self);
240     if (Py_REFCNT(self) > 0) {
241       // Resurrected (see above comment about resurrection from `__del__`)
242       return;
243     }
244     PyObject_GC_UnTrack(self);
245   }
246 
247   if (has_finalizer) {
248     /* New weakrefs could be created during the finalizer call.
249        If this occurs, clear them out without calling their
250        finalizers since they might rely on part of the object
251        being finalized that has already been destroyed. */
252     if (type->tp_weaklistoffset) {
253       /* Modeled after GET_WEAKREFS_LISTPTR() */
254       PyWeakReference** list =
255           (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
256       while (*list)
257         _PyWeakref_ClearRef(*list);
258     }
259   }
260 
261   // Clear slots
262   {
263     PyTypeObject* base = type;
264     while (base != &THPStorageType) {
265       if (Py_SIZE(base)) {
266         clear_slots(base, self);
267       }
268       base = base->tp_base;
269       TORCH_INTERNAL_ASSERT(base);
270     }
271   }
272 
273   // Clear __dict__
274   if (C10_LIKELY(type->tp_dictoffset)) {
275     PyObject** dictptr = _PyObject_GetDictPtr(self);
276     if (dictptr != nullptr) {
277       PyObject* dict = *dictptr;
278       if (dict != nullptr) {
279         Py_DECREF(dict);
280         *dictptr = nullptr;
281       }
282     }
283   }
284 
285   TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
286 
287   _self->cdata.~MaybeOwned<c10::Storage>();
288   Py_TYPE(_self)->tp_free(self);
289 
290   TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
291   Py_DECREF(type);
292 }
293 
THPStorage_pynew(PyTypeObject * type,PyObject * args,PyObject * kwargs)294 static PyObject* THPStorage_pynew(
295     PyTypeObject* type,
296     PyObject* args,
297     PyObject* kwargs) {
298   HANDLE_TH_ERRORS
299   TORCH_CHECK(
300       type != &THPStorageType,
301       "Cannot directly construct StorageBase; subclass it and then construct that");
302   static torch::PythonArgParser parser({
303       THPStorageStr "(*, int64_t allocator=None, Device device=None)",
304       THPStorageStr
305       "(int64_t size, *, int64_t allocator=None, Device device=None)",
306       THPStorageStr
307       "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
308   });
309   torch::ParsedArgs<3> parsed_args;
310   auto r = parser.parse(args, kwargs, parsed_args);
311 
312   int allocator_arg_idx = 0;
313   int device_arg_idx = 1;
314 
315   if (r.idx > 0) {
316     allocator_arg_idx = 1;
317     device_arg_idx = 2;
318   }
319 
320   std::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
321   std::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
322 
323   TORCH_CHECK(
324       !allocator_opt.has_value() || !device_opt.has_value(),
325       THPStorageStr,
326       "(): only one or neither of 'allocator' or 'device' can ",
327       "be given, but not both");
328 
329   PyObject* self = nullptr;
330   c10::Allocator* allocator = nullptr;
331   at::OptionalDeviceGuard device_guard;
332 
333   if (allocator_opt.has_value()) {
334     // NOLINTNEXTLINE(performance-no-int-to-ptr)
335     allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
336   } else if (device_opt.has_value()) {
337     at::Device device = device_opt.value();
338     torch::utils::maybe_initialize_device(device);
339 
340     switch (device.type()) {
341       case at::kCPU:
342         allocator = c10::GetDefaultCPUAllocator();
343         break;
344 #ifdef USE_CUDA
345       case at::kCUDA:
346         allocator = c10::cuda::CUDACachingAllocator::get();
347         break;
348 #endif
349 #ifdef USE_MPS
350       case at::kMPS:
351         allocator = at::mps::GetMPSAllocator();
352         break;
353 #endif
354       case at::DeviceType::XPU:
355       case at::DeviceType::HPU:
356       case at::DeviceType::Meta:
357       case at::DeviceType::PrivateUse1:
358       case at::DeviceType::MAIA:
359         allocator = c10::GetAllocator(device.type());
360         break;
361       default:
362         // NOLINTEND(bugprone-branch-clone)
363         TORCH_CHECK(
364             false,
365             THPStorageStr,
366             "(): Storage device not recognized: ",
367             device.type());
368     }
369 
370     device_guard.reset_device(device);
371   } else {
372     allocator = c10::GetDefaultCPUAllocator();
373   }
374 
375   // torch.Storage(*, ...)
376   if (r.idx == 0) {
377     self = THPStorage_NewWithStorage(
378         type,
379         make_storage_impl(
380             c10::StorageImpl::use_byte_size_t(),
381             0,
382             at::DataPtr(),
383             allocator,
384             /*resizable=*/true,
385             device_opt),
386         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
387 
388     // torch.Storage(size, *, ...)
389   } else if (r.idx == 1) {
390     int64_t size = r.toInt64(0);
391     self = THPStorage_NewWithStorage(
392         type,
393         make_storage_impl(
394             c10::StorageImpl::use_byte_size_t(),
395             size,
396             at::DataPtr(),
397             allocator,
398             /*resizable=*/true,
399             device_opt),
400         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
401 
402     // torch.Storage(sequence, *, ...)
403   } else if (r.idx == 2) {
404     PyObject* sequence = r.pyobject(0);
405     Py_ssize_t length = PySequence_Length(sequence);
406     TORCH_CHECK(
407         PySequence_Check(sequence),
408         THPStorageStr,
409         "(): Expected a sequence type, but got ",
410         THPUtils_typename(sequence));
411     TORCH_CHECK(
412         length >= 0,
413         THPStorageStr,
414         "(): Could not obtain the length of sequence of type ",
415         THPUtils_typename(sequence));
416     self = THPStorage_NewWithStorage(
417         type,
418         make_storage_impl(
419             c10::StorageImpl::use_byte_size_t(),
420             length,
421             at::DataPtr(),
422             allocator,
423             /*resizable=*/true,
424             device_opt),
425         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
426     THPObjectPtr item;
427     try {
428       const auto& storage = THPStorage_Unpack(self);
429       for (Py_ssize_t i = 0; i < length; i++) {
430         item = PySequence_GetItem(sequence, i);
431         uint8_t value = THPByteUtils_unpackReal(item.get());
432         if (allocator == c10::GetDefaultCPUAllocator()) {
433           static_cast<uint8_t*>(storage.mutable_data())[i] = value;
434         } else {
435           // TODO: this might be slow - consider batched updates?
436           storage_set(storage, i, value);
437         }
438       }
439     } catch (const std::exception& e) {
440       TORCH_CHECK(
441           THPStorageStr "(): tried to construct a storage from a sequence (",
442           THPUtils_typename(sequence),
443           "), ",
444           "but one of the items was of type ",
445           THPUtils_typename(item.get()),
446           " instead of int");
447       return nullptr;
448     }
449   }
450   return self;
451   Py_RETURN_NONE;
452   END_HANDLE_TH_ERRORS
453 }
454 
THPStorage_length(THPStorage * self)455 static Py_ssize_t THPStorage_length(THPStorage* self) {
456   HANDLE_TH_ERRORS
457   THPStorage_assertNotNull(self);
458   return static_cast<Py_ssize_t>(THPStorage_Unpack(self).nbytes());
459   END_HANDLE_TH_ERRORS_RET(-1)
460 }
461 
THPStorage_get(THPStorage * self,PyObject * index)462 static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
463   HANDLE_TH_ERRORS
464   THPStorage_assertNotNull(self);
465   const auto& storage = THPStorage_Unpack(self);
466   int64_t len = static_cast<int64_t>(storage.nbytes());
467   /* Integer index */
468   if (THPUtils_checkLong(index)) {
469     int64_t nindex = THPUtils_unpackLong(index);
470     if (nindex < 0)
471       nindex += len;
472     if (nindex < 0 || nindex >= len) {
473       PyErr_SetString(
474           PyExc_IndexError,
475           fmt::format(
476               "index {} out of range for storage of size {}", nindex, len));
477       return nullptr;
478     }
479     uint8_t value = storage_get(storage, nindex);
480     return THPByteUtils_newReal(value);
481     /* Slice index */
482   } else if (PySlice_Check(index)) {
483     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
484     Py_ssize_t start, stop, slicelength, step;
485     if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
486       return nullptr;
487     }
488     slicelength = PySlice_AdjustIndices(len, &start, &stop, step);
489     if (step != 1) {
490       TORCH_CHECK(
491           "Trying to slice with a step of ",
492           step,
493           ", but only a step of "
494           "1 is supported");
495       return nullptr;
496     }
497 
498     const auto& storage = THPStorage_Unpack(self);
499     auto data = static_cast<uint8_t*>(storage.mutable_data());
500 
501     at::StorageImpl* old_storage_impl = storage.unsafeGetStorageImpl();
502     c10::raw::intrusive_ptr::incref(old_storage_impl);
503     std::optional<at::Device> device_opt = old_storage_impl->device();
504     auto new_storage_impl = make_storage_impl(
505         c10::StorageImpl::use_byte_size_t(),
506 #ifdef THQUANTIZED
507         slicelength * sizeof(quantized_t),
508 #else
509         slicelength,
510 #endif
511         at::DataPtr(
512             static_cast<void*>(data + start),
513             old_storage_impl,
514             [](void* s) {
515               c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
516             },
517             old_storage_impl->device()),
518         old_storage_impl->allocator(),
519         /* resizable */ false,
520         device_opt);
521 
522     PyObject* _ret = THPStorage_NewWithStorage(
523         Py_TYPE(self),
524         std::move(new_storage_impl),
525         c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
526 
527     return _ret;
528   }
529   PyErr_Format(
530       PyExc_TypeError,
531       "can't index a " THPStorageStr " with %s",
532       THPUtils_typename(index));
533   return nullptr;
534   END_HANDLE_TH_ERRORS
535 }
536 
THPStorage_set(THPStorage * self,PyObject * index,PyObject * value)537 static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
538   HANDLE_TH_ERRORS
539   THPStorage_assertNotNull(self);
540   if (!THPByteUtils_checkReal(value)) {
541     TORCH_CHECK(
542         "can only set storage content with a int types, but got ",
543         THPUtils_typename(value),
544         " instead");
545     return -1;
546   }
547 
548   uint8_t rvalue = THPByteUtils_unpackReal(value);
549   const auto& storage = THPStorage_Unpack(self);
550   if (THPUtils_checkLong(index)) {
551     int64_t nindex = THPUtils_unpackLong(index);
552     storage_set(storage, nindex, rvalue);
553     return 0;
554   } else if (PySlice_Check(index)) {
555     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
556     Py_ssize_t start, stop, step;
557     Py_ssize_t len = static_cast<Py_ssize_t>(storage.nbytes());
558     if (PySlice_Unpack(index, &start, &stop, &step) < 0) {
559       return -1;
560     }
561     PySlice_AdjustIndices(len, &start, &stop, step);
562     if (step != 1) {
563       TORCH_CHECK(
564           "Trying to slice with a step of ",
565           step,
566           ", but only a step of "
567           "1 is supported");
568       return 0;
569     }
570     // TODO: check the bounds only once
571     // TODO: fill?
572     for (; start < stop; start++)
573       storage_set(storage, start, rvalue);
574     return 0;
575   }
576   TORCH_CHECK(
577       "can't index a " THPStorageStr " with ", THPUtils_typename(index));
578   return -1;
579   END_HANDLE_TH_ERRORS_RET(-1)
580 }
581 
582 static PyMappingMethods THPStorage_mappingmethods = {
583     (lenfunc)THPStorage_length,
584     (binaryfunc)THPStorage_get,
585     (objobjargproc)THPStorage_set};
586 
587 struct THPStorageMeta {
588   PyHeapTypeObject base;
589 };
590 
591 int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
592 
593 PyTypeObject THPStorageMetaType = {
594     PyVarObject_HEAD_INIT(
595         DEFERRED_ADDRESS(&PyType_Type),
596         0) "torch._C._StorageMeta", /* tp_name */
597     sizeof(THPStorageMeta), /* tp_basicsize */
598     0, /* tp_itemsize */
599     nullptr, /* tp_dealloc */
600     0, /* tp_vectorcall_offset */
601     nullptr, /* tp_getattr */
602     nullptr, /* tp_setattr */
603     nullptr, /* tp_reserved */
604     nullptr, /* tp_repr */
605     nullptr, /* tp_as_number */
606     nullptr, /* tp_as_sequence */
607     nullptr, /* tp_as_mapping */
608     nullptr, /* tp_hash  */
609     nullptr, /* tp_call */
610     nullptr, /* tp_str */
611     nullptr, /* tp_getattro */
612     nullptr, /* tp_setattro */
613     nullptr, /* tp_as_buffer */
614     // NOLINTNEXTLINE(misc-redundant-expression)
615     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
616     nullptr, /* tp_doc */
617     nullptr, /* tp_traverse */
618     nullptr, /* tp_clear */
619     nullptr, /* tp_richcompare */
620     0, /* tp_weaklistoffset */
621     nullptr, /* tp_iter */
622     nullptr, /* tp_iternext */
623     nullptr, /* tp_methods */
624     nullptr, /* tp_members */
625     nullptr, /* tp_getset */
626     DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
627     nullptr, /* tp_dict */
628     nullptr, /* tp_descr_get */
629     nullptr, /* tp_descr_set */
630     0, /* tp_dictoffset */
631     THPStorageMetaType_init, /* tp_init */
632     nullptr, /* tp_alloc */
633     nullptr, /* tp_new */
634 };
635 
636 // TODO: implement equality
637 PyTypeObject THPStorageType = {
638     PyVarObject_HEAD_INIT(
639         &THPStorageMetaType,
640         0) "torch._C.StorageBase", /* tp_name */
641     sizeof(THPStorage), /* tp_basicsize */
642     0, /* tp_itemsize */
643     nullptr, /* tp_dealloc */
644     0, /* tp_vectorcall_offset */
645     nullptr, /* tp_getattr */
646     nullptr, /* tp_setattr */
647     nullptr, /* tp_reserved */
648     nullptr, /* tp_repr */
649     nullptr, /* tp_as_number */
650     nullptr, /* tp_as_sequence */
651     &THPStorage_mappingmethods, /* tp_as_mapping */
652     nullptr, /* tp_hash  */
653     nullptr, /* tp_call */
654     nullptr, /* tp_str */
655     nullptr, /* tp_getattro */
656     nullptr, /* tp_setattro */
657     nullptr, /* tp_as_buffer */
658     // NOLINTNEXTLINE(misc-redundant-expression)
659     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
660     nullptr, /* tp_doc */
661     nullptr, /* tp_traverse */
662     nullptr, /* tp_clear */
663     nullptr, /* tp_richcompare */
664     0, /* tp_weaklistoffset */
665     nullptr, /* tp_iter */
666     nullptr, /* tp_iternext */
667     nullptr,
668     /* will be assigned in init */ /* tp_methods */
669     nullptr,
670     /* will be assigned in init */ /* tp_members */
671     nullptr, /* tp_getset */
672     nullptr, /* tp_base */
673     nullptr, /* tp_dict */
674     nullptr, /* tp_descr_get */
675     nullptr, /* tp_descr_set */
676     0, /* tp_dictoffset */
677     nullptr, /* tp_init */
678     nullptr, /* tp_alloc */
679     THPStorage_pynew, /* tp_new */
680 };
681 
THPStorageMetaType_init(PyObject * cls,PyObject * args,PyObject * kwargs)682 int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
683   if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
684     return -1;
685   }
686   ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
687   return 0;
688 }
689 
THPStorage_device(THPStorage * self,void * unused)690 static PyObject* THPStorage_device(THPStorage* self, void* unused) {
691   HANDLE_TH_ERRORS
692   THPStorage_assertNotNull(self);
693   return THPDevice_New(THPStorage_Unpack(self).device());
694   END_HANDLE_TH_ERRORS
695 }
696 
THPStorage_get_cdata(THPStorage * self,void * unused)697 PyObject* THPStorage_get_cdata(THPStorage* self, void* unused) {
698   HANDLE_TH_ERRORS
699   return PyLong_FromVoidPtr(THPStorage_Unpack(self).unsafeGetStorageImpl());
700   END_HANDLE_TH_ERRORS
701 }
702 
703 typedef PyObject* (*getter)(PyObject*, void*);
704 
705 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
706 static struct PyGetSetDef THPStorage_properties[] = {
707     {"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
708     {"_cdata", (getter)THPStorage_get_cdata, nullptr, nullptr, nullptr},
709     {nullptr}};
710 
THPStorage_init(PyObject * module)711 bool THPStorage_init(PyObject* module) {
712   static std::vector<PyMethodDef> methods;
713   THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
714   THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
715 
716   THPStorageMetaType.tp_base = &PyType_Type;
717   if (PyType_Ready(&THPStorageMetaType) < 0)
718     return false;
719   Py_INCREF(&THPStorageMetaType);
720   PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
721 
722   THPStorageType.tp_methods = methods.data();
723   THPStorageType.tp_getset = THPStorage_properties;
724   if (PyType_Ready(&THPStorageType) < 0)
725     return false;
726   Py_INCREF(&THPStorageType);
727   PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
728   return true;
729 }
730 
THPStorage_postInit(PyObject * module)731 void THPStorage_postInit(PyObject* module) {
732   THPStorageClass =
733       (PyTypeObject*)PyObject_GetAttrString(module, "UntypedStorage");
734   if (!THPStorageClass)
735     throw python_error();
736 }
737 
THPStorage_assertNotNull(THPStorage * storage)738 void THPStorage_assertNotNull(THPStorage* storage) {
739   TORCH_CHECK(
740       THPStorage_Unpack(storage).unsafeGetStorageImpl(), "Got a null Storage");
741 }
742 
THPStorage_assertNotNull(PyObject * obj)743 void THPStorage_assertNotNull(PyObject* obj) {
744   THPStorage_assertNotNull((THPStorage*)obj);
745 }
746