1 #pragma once 2 3 #include <c10/core/TensorImpl.h> 4 #include <c10/core/UndefinedTensorImpl.h> 5 6 #include <utility> 7 8 namespace c10 { 9 // Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and 10 // at::TensorBase. 11 template <typename TensorType> 12 struct ExclusivelyOwnedTensorTraits { 13 using repr_type = TensorType; 14 using pointer_type = TensorType*; 15 using const_pointer_type = const TensorType*; 16 nullReprExclusivelyOwnedTensorTraits17 static repr_type nullRepr() { 18 return TensorType(); 19 } 20 21 template <class... Args> createInPlaceExclusivelyOwnedTensorTraits22 static repr_type createInPlace(Args&&... args) { 23 return TensorType(std::forward<Args>(args)...); 24 } 25 moveToReprExclusivelyOwnedTensorTraits26 static repr_type moveToRepr(TensorType&& x) { 27 return std::move(x); 28 } 29 destroyOwnedExclusivelyOwnedTensorTraits30 static void destroyOwned(TensorType& x) { 31 TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl(); 32 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 33 toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); 34 // May be 0 because UndefinedTensorImpl doesn't get its refcount 35 // incremented. 36 const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton(); 37 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 38 toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined), 39 "ExclusivelyOwned<Tensor> destroyed with isUndefined ", 40 isUndefined, 41 " and refcount ", 42 toDestroy->refcount_, 43 ", expected 1 or, if isUndefined, 0!"); 44 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 45 toDestroy->weakcount_ == 1 || 46 (toDestroy->weakcount_ == 0 && 47 toDestroy == UndefinedTensorImpl::singleton()), 48 "ExclusivelyOwned<Tensor> destroyed with isUndefined ", 49 isUndefined, 50 " and weakcount ", 51 toDestroy->weakcount_, 52 ", expected 1 or, if isUndefined, 0!"); 53 if (!isUndefined) { 54 #ifndef NDEBUG 55 // Needed to pass the debug assertions in ~intrusive_ptr_target. 56 toDestroy->refcount_ = 0; 57 toDestroy->weakcount_ = 0; 58 #endif 59 delete toDestroy; 60 } 61 } 62 takeExclusivelyOwnedTensorTraits63 static TensorType take(TensorType& x) { 64 return std::move(x); 65 } 66 getImplExclusivelyOwnedTensorTraits67 static pointer_type getImpl(repr_type& x) { 68 return &x; 69 } 70 getImplExclusivelyOwnedTensorTraits71 static const_pointer_type getImpl(const repr_type& x) { 72 return &x; 73 } 74 }; 75 } // namespace c10 76