xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/TensorBody.h>
4 #include <c10/util/Exception.h>
5 
6 namespace at {
7 class TORCH_API OptionalTensorRef {
8  public:
9   OptionalTensorRef() = default;
10 
~OptionalTensorRef()11   ~OptionalTensorRef() {
12     ref_.unsafeReleaseTensorImpl();
13   }
14 
OptionalTensorRef(const TensorBase & src)15   OptionalTensorRef(const TensorBase& src)
16       : ref_(Tensor::unsafe_borrow_t{}, src) {
17     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
18   }
19 
OptionalTensorRef(const OptionalTensorRef & rhs)20   OptionalTensorRef(const OptionalTensorRef& rhs)
21       : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
22 
23   OptionalTensorRef& operator=(OptionalTensorRef rhs) {
24     std::swap(ref_, rhs.ref_);
25     return *this;
26   }
27 
has_value()28   bool has_value() const {
29     return ref_.defined();
30   }
31 
getTensorRef()32   const Tensor& getTensorRef() const & {
33     return ref_;
34   }
35 
36   const Tensor& operator*() const & {
37     return ref_;
38   }
39 
40   const Tensor* operator->() const & {
41     return &ref_;
42   }
43 
44   operator bool() const {
45     return ref_.defined();
46   }
47 
48  private:
49   Tensor ref_;
50 };
51 
52 // Use to convert a TensorBase (that may be undefined) to an at::Tensor
53 // without bumping refcount.
54 class TORCH_API TensorRef {
55  public:
~TensorRef()56   ~TensorRef() {
57     ref_.unsafeReleaseTensorImpl();
58   }
59 
TensorRef(const TensorBase & src)60   TensorRef(const TensorBase& src)
61       : ref_(Tensor::unsafe_borrow_t{}, src) {}
62 
63   const Tensor& operator*() const & {
64     return ref_;
65   }
66  private:
67   Tensor ref_;
68 };
69 
70 template <typename T>
71 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
72 auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
73   // Return the grad argument in case of a hook with void return type to have an
74   // std::function with Tensor return type
75   static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
76                 "Expected hook to return void");
77   return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
78     TensorRef grad(grad_base);
79     fn(*grad);
80     return Tensor();
81   });
82 }
83 
84 template <typename T>
85 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
86 auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
87   return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
88     TensorRef grad(grad_base);
89     Tensor ret = fn(*grad);
90     return TensorBase(std::move(ret));
91   });
92 }
93 
94 } // namespace at
95