1 #pragma once 2 3 #include <ATen/core/TensorBody.h> 4 #include <c10/util/Exception.h> 5 6 namespace at { 7 class TORCH_API OptionalTensorRef { 8 public: 9 OptionalTensorRef() = default; 10 ~OptionalTensorRef()11 ~OptionalTensorRef() { 12 ref_.unsafeReleaseTensorImpl(); 13 } 14 OptionalTensorRef(const TensorBase & src)15 OptionalTensorRef(const TensorBase& src) 16 : ref_(Tensor::unsafe_borrow_t{}, src) { 17 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); 18 } 19 OptionalTensorRef(const OptionalTensorRef & rhs)20 OptionalTensorRef(const OptionalTensorRef& rhs) 21 : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {} 22 23 OptionalTensorRef& operator=(OptionalTensorRef rhs) { 24 std::swap(ref_, rhs.ref_); 25 return *this; 26 } 27 has_value()28 bool has_value() const { 29 return ref_.defined(); 30 } 31 getTensorRef()32 const Tensor& getTensorRef() const & { 33 return ref_; 34 } 35 36 const Tensor& operator*() const & { 37 return ref_; 38 } 39 40 const Tensor* operator->() const & { 41 return &ref_; 42 } 43 44 operator bool() const { 45 return ref_.defined(); 46 } 47 48 private: 49 Tensor ref_; 50 }; 51 52 // Use to convert a TensorBase (that may be undefined) to an at::Tensor 53 // without bumping refcount. 54 class TORCH_API TensorRef { 55 public: ~TensorRef()56 ~TensorRef() { 57 ref_.unsafeReleaseTensorImpl(); 58 } 59 TensorRef(const TensorBase & src)60 TensorRef(const TensorBase& src) 61 : ref_(Tensor::unsafe_borrow_t{}, src) {} 62 63 const Tensor& operator*() const & { 64 return ref_; 65 } 66 private: 67 Tensor ref_; 68 }; 69 70 template <typename T> 71 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) 72 auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> { 73 // Return the grad argument in case of a hook with void return type to have an 74 // std::function with Tensor return type 75 static_assert(std::is_same<decltype(hook(Tensor())), void>::value, 76 "Expected hook to return void"); 77 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) { 78 TensorRef grad(grad_base); 79 fn(*grad); 80 return Tensor(); 81 }); 82 } 83 84 template <typename T> 85 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) 86 auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> { 87 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) { 88 TensorRef grad(grad_base); 89 Tensor ret = fn(*grad); 90 return TensorBase(std::move(ret)); 91 }); 92 } 93 94 } // namespace at 95