xref: /aosp_15_r20/external/pytorch/c10/util/ExclusivelyOwnedTensorTraits.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/TensorImpl.h>
4 #include <c10/core/UndefinedTensorImpl.h>
5 
6 #include <utility>
7 
8 namespace c10 {
9 // Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and
10 // at::TensorBase.
11 template <typename TensorType>
12 struct ExclusivelyOwnedTensorTraits {
13   using repr_type = TensorType;
14   using pointer_type = TensorType*;
15   using const_pointer_type = const TensorType*;
16 
nullReprExclusivelyOwnedTensorTraits17   static repr_type nullRepr() {
18     return TensorType();
19   }
20 
21   template <class... Args>
createInPlaceExclusivelyOwnedTensorTraits22   static repr_type createInPlace(Args&&... args) {
23     return TensorType(std::forward<Args>(args)...);
24   }
25 
moveToReprExclusivelyOwnedTensorTraits26   static repr_type moveToRepr(TensorType&& x) {
27     return std::move(x);
28   }
29 
destroyOwnedExclusivelyOwnedTensorTraits30   static void destroyOwned(TensorType& x) {
31     TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl();
32     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
33         toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
34     // May be 0 because UndefinedTensorImpl doesn't get its refcount
35     // incremented.
36     const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton();
37     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38         toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined),
39         "ExclusivelyOwned<Tensor> destroyed with isUndefined ",
40         isUndefined,
41         " and refcount ",
42         toDestroy->refcount_,
43         ", expected 1 or, if isUndefined, 0!");
44     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
45         toDestroy->weakcount_ == 1 ||
46             (toDestroy->weakcount_ == 0 &&
47              toDestroy == UndefinedTensorImpl::singleton()),
48         "ExclusivelyOwned<Tensor> destroyed with isUndefined ",
49         isUndefined,
50         " and weakcount ",
51         toDestroy->weakcount_,
52         ", expected 1 or, if isUndefined, 0!");
53     if (!isUndefined) {
54 #ifndef NDEBUG
55       // Needed to pass the debug assertions in ~intrusive_ptr_target.
56       toDestroy->refcount_ = 0;
57       toDestroy->weakcount_ = 0;
58 #endif
59       delete toDestroy;
60     }
61   }
62 
takeExclusivelyOwnedTensorTraits63   static TensorType take(TensorType& x) {
64     return std::move(x);
65   }
66 
getImplExclusivelyOwnedTensorTraits67   static pointer_type getImpl(repr_type& x) {
68     return &x;
69   }
70 
getImplExclusivelyOwnedTensorTraits71   static const_pointer_type getImpl(const repr_type& x) {
72     return &x;
73   }
74 };
75 } // namespace c10
76