xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/TensorBase.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/Layout.h>
5 #include <c10/core/MemoryFormat.h>
6 #include <c10/core/ScalarType.h>
7 #include <c10/core/ScalarTypeToTypeMeta.h>
8 #include <c10/core/Storage.h>
9 #include <c10/core/SymIntArrayRef.h>
10 #include <c10/core/TensorImpl.h>
11 #include <c10/core/TensorOptions.h>
12 #include <c10/core/UndefinedTensorImpl.h>
13 #include <c10/core/WrapDimMinimal.h>
14 #include <c10/util/C++17.h>
15 #include <c10/util/Exception.h>
16 #include <c10/util/ExclusivelyOwned.h>
17 #include <c10/util/ExclusivelyOwnedTensorTraits.h>
18 #include <c10/util/MaybeOwned.h>
19 #include <optional>
20 #include <c10/util/intrusive_ptr.h>
21 
22 #include <ATen/core/NamedTensor.h>
23 #include <ATen/core/QuantizerBase.h>
24 #include <ATen/core/TensorAccessor.h>
25 #include <ATen/StorageUtils.h>
26 
27 namespace c10 {
28 class Scalar;
29 }
30 
31 namespace torch::autograd {
32 
33 struct Node;
34 
35 } // namespace torch::autograd
36 
37 namespace at {
38 
39 class Tensor;
40 class TensorBase;
41 
42 // Convert Tensor to TensorBase without any need to include Tensor.h
43 TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
44 
45 namespace impl {
variable_excluded_from_dispatch()46 inline bool variable_excluded_from_dispatch() {
47 #ifdef C10_MOBILE
48   // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
49   return true;
50 #else
51   return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
52 #endif
53 }
54 
55 }
56 
57 // NOTE: [Tensor vs. TensorBase]
58 //
59 // Tensor, being the central data structure in PyTorch, gets used and
60 // it's header included almost everywhere. Unfortunately this means
61 // every time an operator signature is updated or changed in
62 // native_functions.yaml, you (and every other PyTorch developer) need
63 // to recompile all of ATen and it's dependencies.
64 //
65 // TensorBase aims to break up these header dependencies, and improve
66 // incremental build times for all PyTorch developers. TensorBase
67 // represents a reference counted handle to TensorImpl, exactly the
68 // same as Tensor. However, TensorBase doesn't have code generated
69 // methods in it's API and thus no dependence on native_functions.yaml.
70 //
71 // Usage tips
72 // ----------
73 // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
74 //   or .cu file to ensure it has no header dependencies on
75 //   native_functions.yaml (direct or indirect).
76 // - Tensor inherits from TensorBase, so functions taking
77 //   `const TensorBase &` are callable with Tensor as well.
78 // - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
79 //   but this requires a reference-count bump. OptionalTensorRef on
80 //   the other hand can materialize a `const Tensor &` without
81 //   touching the reference-count.
82 class TORCH_API TensorBase {
83  public:
84   struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
85 
86  protected:
87   // Create a Tensor with a +0 reference count. Special care must be
88   // taken to avoid decrementing this reference count at destruction
89   // time. Intended to support MaybeOwnedTraits<Tensor>.
TensorBase(unsafe_borrow_t,const TensorBase & rhs)90   explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
91       : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
92   friend MaybeOwnedTraits<TensorBase>;
93 
94  public:
95   TensorBase() = default;
96   // This constructor should not be used by end users and is an implementation
97   // detail invoked by autogenerated code.
TensorBase(c10::intrusive_ptr<TensorImpl,UndefinedTensorImpl> tensor_impl)98   explicit TensorBase(
99       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
100       : impl_(std::move(tensor_impl)) {
101     if (impl_.get() == nullptr) {
102       throw std::runtime_error("TensorImpl with nullptr is not supported");
103     }
104   }
105   TensorBase(const TensorBase&) = default;
106   TensorBase(TensorBase&&) noexcept = default;
107 
108  public:
109   // Creates a new wrapper from TensorImpl. Intentionally a free method because
110   // it should be used with care. Checks necessary invariants
wrap_tensor_impl(c10::intrusive_ptr<TensorImpl,UndefinedTensorImpl> tensor_impl)111   static TensorBase wrap_tensor_impl(
112       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
113     TensorBase r(std::move(tensor_impl));
114     r.enforce_invariants();
115     return r;
116   }
117 
dim()118   int64_t dim() const {
119     return impl_->dim();
120   }
storage_offset()121   int64_t storage_offset() const {
122     return impl_->storage_offset();
123   }
124 
125   TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
126     if (is_contiguous(memory_format)) {
127       return *this;
128     } else {
129       return __dispatch_contiguous(memory_format);
130     }
131   }
132 
133   /// Should be used if *this can reasonably be expected to be contiguous and
134   /// performance is important.
135   /// Compared to contiguous, it saves a reference count
136   /// increment/decrement if *this is already contiguous, at the cost
137   /// in all cases of an extra pointer of stack usage, an extra branch
138   /// to access, and an extra branch at destruction time.
139   c10::MaybeOwned<TensorBase> expect_contiguous(
140       MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
141 
142   // Use .contiguous() instead. Trying to borrow from a prvalue
143   // will only lead to trouble and dangling references.
144   c10::MaybeOwned<TensorBase> expect_contiguous(
145       MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
146 
147   const TensorBase& fill_(const c10::Scalar& scalar) const;
148   const TensorBase& zero_() const;
149 
150   TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
151 
is_complex()152   bool is_complex() const {
153     return at::isComplexType(this->scalar_type());
154   }
155 
is_floating_point()156   bool is_floating_point() const {
157     return at::isFloatingType(this->scalar_type());
158   }
159 
is_signed()160   bool is_signed() const {
161     return at::isSignedType(this->scalar_type());
162   }
163 
sym_size(int64_t dim)164   c10::SymInt sym_size(int64_t dim) const {
165     return impl_->sym_size(dim);
166   }
167 
sym_stride(int64_t dim)168   c10::SymInt sym_stride(int64_t dim) const {
169     const auto sizes = this->sym_strides();
170     const auto ndim = static_cast<int64_t>(sizes.size());
171     // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
172     return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
173 
174   }
175 
size(int64_t dim)176   int64_t size(int64_t dim) const {
177     return impl_->size(dim);
178   }
179 
stride(int64_t dim)180   int64_t stride(int64_t dim) const {
181     const auto strides = this->strides();
182     const auto ndim = static_cast<int64_t>(strides.size());
183     // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
184     return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
185   }
186 
unsafeGetTensorImpl()187   TensorImpl * unsafeGetTensorImpl() const {
188     return impl_.get();
189   }
unsafeReleaseTensorImpl()190   TensorImpl * unsafeReleaseTensorImpl() {
191     return impl_.release();
192   }
getIntrusivePtr()193   const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
194     return impl_;
195   }
196 
unsafeReleaseIntrusivePtr()197   c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
198     return std::move(impl_);
199   }
200 
defined()201   bool defined() const {
202     return impl_;
203   }
204 
reset()205   void reset() {
206     impl_.reset();
207   }
208 
209 #if defined (_MSC_VER)
210   TensorBase& operator=(const TensorBase& x) & {
211     impl_ = x.impl_;
212     return *this;
213   };
214   TensorBase& operator=(TensorBase&& x) & noexcept {
215     impl_ = std::move(x.impl_);
216     return *this;
217   }
218 #else
219   TensorBase& operator=(const TensorBase& x) & = default;
220   TensorBase& operator=(TensorBase&& x) & noexcept = default;
221 #endif
222 
223   // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
224   TensorBase& operator=(const TensorBase&) && = delete;
225   TensorBase& operator=(TensorBase&&) && noexcept = delete;
226 
is_same(const TensorBase & other)227   bool is_same(const TensorBase& other) const noexcept {
228     return impl_ == other.impl_;
229   }
use_count()230   size_t use_count() const noexcept {
231     return impl_.use_count();
232   }
weak_use_count()233   size_t weak_use_count() const noexcept {
234     return impl_.weak_use_count();
235   }
236 
237   std::string toString() const;
238 
sizes()239   IntArrayRef sizes() const {
240     return impl_->sizes();
241   }
sym_sizes()242   c10::SymIntArrayRef sym_sizes() const {
243     return impl_->sym_sizes();
244   }
sym_strides()245   c10::SymIntArrayRef sym_strides() const {
246     return impl_->sym_strides();
247   }
strides()248   IntArrayRef strides() const {
249     return impl_->strides();
250   }
251   // See impl::get_opt_names in ATen/NamedTensor.h for docs.
opt_names()252   std::optional<DimnameList> opt_names() const {
253     return impl::get_opt_names(unsafeGetTensorImpl());
254   }
255   // See impl::get_names in ATen/NamedTensor.h for docs.
names()256   DimnameList names() const {
257     return impl::get_names(unsafeGetTensorImpl());
258   }
ndimension()259   int64_t ndimension() const {
260     return dim();
261   }
262 
263   bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
264     return impl_->is_contiguous(memory_format);
265   }
266 
is_non_overlapping_and_dense()267   bool is_non_overlapping_and_dense() const {
268     return impl_->is_non_overlapping_and_dense();
269   }
270 
271   at::MemoryFormat suggest_memory_format(
272       bool channels_last_strides_exact_match = false) const {
273     // Setting channels_last_strides_exact_match to true forces function to
274     // check 0,1 - sized dimension strides.
275     if (layout() == at::kStrided) {
276       if (impl_->is_strides_like_channels_last()) {
277         if (!channels_last_strides_exact_match ||
278             get_channels_last_strides_2d(sizes()) == strides()) {
279           return at::MemoryFormat::ChannelsLast;
280         }
281       }
282       else if (impl_->is_strides_like_channels_last_3d()) {
283         if (!channels_last_strides_exact_match ||
284             get_channels_last_strides_3d(sizes()) == strides()) {
285           return at::MemoryFormat::ChannelsLast3d;
286         }
287       }
288     }
289     return at::MemoryFormat::Contiguous;
290   }
291 
292   // Total bytes consumed by the "view" of elements of the array.  Does not
293   // include size of metadata.  The number reported here does not necessarily
294   // correspond to the true physical memory consumed by a tensor; instead,
295   // it reports the memory the tensor would take *if* it were contiguous.
296   // Defined to be numel() * itemsize()
nbytes()297   size_t nbytes() const {
298     TORCH_CHECK(layout () != at::kSparse,
299                 "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
300                 "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
301                 "equivalent dense tensor, multiply numel() by element_size()");
302     return impl_->numel() * impl_->itemsize();
303   }
304 
sym_nbytes()305   c10::SymInt sym_nbytes() const {
306     TORCH_CHECK(layout () != at::kSparse,
307                 "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
308                 "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
309                 "equivalent dense tensor, multiply numel() by element_size()");
310     return impl_->sym_numel() * impl_->itemsize();
311   }
312 
numel()313   int64_t numel() const {
314     return impl_->numel();
315   }
316 
sym_numel()317   c10::SymInt sym_numel() const {
318     return impl_->sym_numel();
319   }
320 
sym_storage_offset()321   c10::SymInt sym_storage_offset() const {
322     return impl_->sym_storage_offset();
323   }
324 
325   // Length of one array element in bytes.  This is the traditional
326   // Numpy naming.
itemsize()327   size_t itemsize() const {
328     return impl_->itemsize();
329   }
330 
331   // Same as itemsize().  This is the PyTorch naming.
element_size()332   int64_t element_size() const {
333     return static_cast<int64_t>(impl_->itemsize());
334   }
335 
key_set()336   DispatchKeySet key_set() const {
337     return impl_->key_set();
338   }
scalar_type()339   ScalarType scalar_type() const {
340     return typeMetaToScalarType(impl_->dtype());
341   }
has_storage()342   bool has_storage() const {
343     return defined() && impl_->has_storage();
344   }
storage()345   const Storage& storage() const {
346     return impl_->storage();
347   }
is_alias_of(const at::TensorBase & other)348   bool is_alias_of(const at::TensorBase& other) const{
349     return impl_->storage().is_alias_of(other.storage());
350   }
351 
352   // Move the storage backend to shm based
353   // to enable memory sharing across processes.
354   //
355   // NB1: the ideal behavior of this API still requires further discussion
356   // but for now we are inclined to keep it consistent with existing THP behavior
357   // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
358   // so we don't assert on anything here and rely on caller knowing
359   // what it's doing.
360   //
361   // NB2: this currently provides Linux fd based shm support only
362   // to simplify the storage lifetime management logic in ATen
363   // and similarly for now we are not adding support for file system based
364   // shm support like in THP due to additional GC manager support needed
365   // to prevent leaks.
366   // As such, calling this from non supported systems (e.g. Windows) would fail.
share_memory_()367   void share_memory_() {
368     at::share_memory_(*this);
369   }
370 
_is_zerotensor()371   inline bool _is_zerotensor() const {
372     return impl_->_is_zerotensor();
373   }
374 
_set_zero(bool zero)375   inline void _set_zero(bool zero) const {
376     impl_->_set_zero(zero);
377   }
378 
is_conj()379   inline bool is_conj() const {
380     return impl_->is_conj();
381   }
382 
383   // sets the conjugate bit of a tensor.
384   // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
385   // that's what you want. Changing this might lead to incorrect behavior since conjugation is
386   // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
_set_conj(bool conjugate)387   inline void _set_conj(bool conjugate) const {
388     impl_->_set_conj(conjugate);
389   }
390 
is_neg()391   inline bool is_neg() const {
392     return impl_->is_neg();
393   }
394 
395   // sets the negative bit of a tensor.
396   // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
397   // that's what you want. Changing this might lead to incorrect behavior since we rely on this
398   // bit to determine if a negation needs to be materialized.
_set_neg(bool negative)399   inline void _set_neg(bool negative) const {
400     impl_->_set_neg(negative);
401   }
402 
403   /// Returns a `Tensor`'s layout.
layout()404   Layout layout() const {
405     return impl_->layout();
406   }
407 
408   /// Returns a `Tensor`'s dtype (`TypeMeta`).
dtype()409   caffe2::TypeMeta dtype() const {
410     return impl_->dtype();
411   }
412 
413   /// Returns a `Tensor`'s device.
device()414   inline Device device() const {
415     return impl_->device();
416   }
417 
418   /// Returns a `Tensor`'s device index.
get_device()419   DeviceIndex get_device() const {
420     // NB: this is not a native function to avoid dispatching overhead.
421     return impl_->get_device();
422   }
423 
424   /// Returns if a `Tensor` has CPU backend.
is_cpu()425   bool is_cpu() const {
426     // NB: this is not a native function to avoid dispatching overhead.
427     return impl_->is_cpu();
428   }
429 
430   /// Returns if a `Tensor` has CUDA backend.
is_cuda()431   bool is_cuda() const {
432     // NB: this is not a native function to avoid dispatching overhead.
433     return impl_->is_cuda();
434   }
435 
436   /// Returns if a `Tensor` has IPU backend.
is_ipu()437   bool is_ipu() const {
438     // NB: this is not a native function to avoid dispatching overhead.
439     return impl_->is_ipu();
440   }
441 
442   /// Returns if a `Tensor` has XPU backend.
is_xpu()443   bool is_xpu() const {
444     // NB: this is not a native function to avoid dispatching overhead.
445     return impl_->is_xpu();
446   }
447 
448   /// Returns if a `Tensor` has XLA backend.
is_xla()449   bool is_xla() const {
450     return impl_->is_xla();
451   }
452 
453   /// Returns if a `Tensor` has MTIA backend.
is_mtia()454   bool is_mtia() const {
455     return impl_->is_mtia();
456   }
457 
458   /// Returns if a `Tensor` has HPU backend.
is_hpu()459   bool is_hpu() const {
460     return impl_->is_hpu();
461   }
462 
463   /// Returns if a `Tensor` has Lazy backend.
is_lazy()464   bool is_lazy() const {
465     return impl_->is_lazy();
466   }
467 
468   /// Returns if a `Tensor` has HIP backend.
is_hip()469   bool is_hip() const {
470     // NB: this is not a native function to avoid dispatching overhead.
471     return impl_->is_hip();
472   }
473 
474   /// Returns if a `Tensor` has VE backend.
is_ve()475   bool is_ve() const {
476     // NB: this is not a native function to avoid dispatching overhead.
477     return impl_->is_ve();
478   }
479 
480   /// Returns if a `Tensor` has PrivateUse1 backend.
is_privateuseone()481   bool is_privateuseone() const {
482     // NB: this is not a native function to avoid dispatching overhead.
483     return impl_->is_privateuseone();
484   }
485 
486   /// Returns if a `Tensor` has sparse backend.
is_sparse()487   bool is_sparse() const {
488     // NB: this is not a native function to avoid dispatching overhead.
489     return impl_->is_sparse();
490   }
491 
492   /// Returns is a `Tensor` has a sparse CSR backend.
is_sparse_csr()493   bool is_sparse_csr() const {
494     // NB: this is not a native function to avoid dispatching overhead.
495     return impl_->is_sparse_csr();
496   }
497 
498   /// Returns if a `Tensor` is mkldnn tensor.
is_mkldnn()499   bool is_mkldnn() const {
500     // NB: this is not a native function to avoid dispatching overhead.
501     return impl_->is_mkldnn();
502   }
503 
504   /// Returns if a `Tensor` is mps tensor.
is_mps()505   bool is_mps() const {
506     // NB: this is not a native function to avoid dispatching overhead.
507     return impl_->is_mps();
508   }
509 
510   /// Returns if a `Tensor` is maia tensor.
is_maia()511   bool is_maia() const {
512     // NB: this is not a native function to avoid dispatching overhead.
513     return impl_->is_maia();
514   }
515 
516   /// Returns if a `Tensor` is vulkan tensor.
is_vulkan()517   bool is_vulkan() const {
518     // NB: this is not a native function to avoid dispatching overhead.
519     return impl_->is_vulkan();
520   }
521 
522   /// Returns if a `Tensor` is metal tensor.
is_metal()523   bool is_metal() const {
524     // NB: this is not a native function to avoid dispatching overhead.
525     return impl_->is_metal();
526   }
527 
528   /// Returns if a `Tensor` has quantized backend.
is_quantized()529   bool is_quantized() const {
530     // NB: this is not a native function to avoid dispatching overhead.
531     return impl_->is_quantized();
532   }
533 
534   /// Returns if a `Tensor` is a meta tensor.  Meta tensors can
535   /// also have other designations.
is_meta()536   bool is_meta() const {
537     return impl_->is_meta();
538   }
539 
540   /// Returns if a `Tensor` is an inference tensor.
is_inference()541   bool is_inference() const {
542     return impl_->is_inference();
543   }
544 
545   // Returns if a `Tensor` is a NestedTensor.
is_nested()546   bool is_nested() const {
547     return impl_->is_nested();
548   }
549 
550   /// If a tensor is a quantized tensor, returns its quantizer
551   /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
552   QuantizerPtr quantizer() const;
553 
554   /// Returns if a `Tensor` has any dimension names
has_names()555   bool has_names() const {
556     // If a user is using unnamed tensors, then we can short-circuit right here.
557     // Otherwise, impl::has_names attempts to retrieve names.
558     if (!impl_->has_named_tensor_meta()) {
559       return false;
560     }
561     return impl::has_names(unsafeGetTensorImpl());
562   }
563 
564   /// Returns a `Tensor`'s dimension names data structure
get_named_tensor_meta()565   const NamedTensorMeta* get_named_tensor_meta() const {
566     return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
567   }
568 
get_named_tensor_meta()569   NamedTensorMeta* get_named_tensor_meta() {
570     return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
571   }
572 
573   /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
574   /// TensorOptions.h.
options()575   TensorOptions options() const {
576     return TensorOptions().dtype(dtype())
577                           .device(device())
578                           .layout(layout());
579   }
580 
const_data_ptr()581   const void* const_data_ptr() const {
582     return this->unsafeGetTensorImpl()->data();
583   }
584 
mutable_data_ptr()585   void* mutable_data_ptr() const {
586     return this->unsafeGetTensorImpl()->mutable_data();
587   }
588 
589   // TODO(#97856) Make this return a const pointer. This currently
590   //              returns a non-const pointer because of the large
591   //              number of clients that we still want to audit before
592   //              migrating to mutable_data_ptr().
data_ptr()593   void* data_ptr() const {
594     return mutable_data_ptr();
595   }
596 
597   template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
598   const T* const_data_ptr() const;
599 
600   template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
601   const std::remove_const_t<T>* const_data_ptr() const;
602 
603   template <typename T>
604   T* mutable_data_ptr() const;
605 
606   // Legacy interface during the migration to indicate that a callsite
607   // has not been audited for mutability.
608   //
609   // Do not add new uses of this, use const_data_ptr() if possible,
610   // mutable_data_ptr() otherwise.
611   //
612   // TODO(#97856) Make this return a const pointer. This is currently
613   //              const because of the vast number of clients that
614   //              rely on this.
615   template <typename T>
616   T* data_ptr() const;
617 
618   // Purposely not defined here to avoid inlining
619   void print() const;
620 
621   // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
622   // dimension.
623   template<typename T, size_t N>
accessor()624   TensorAccessor<T,N> accessor() const& {
625     static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
626     TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
627     T* ptr = nullptr;
628     if constexpr (std::is_const<T>::value) {
629       ptr = const_data_ptr<T>();
630     } else {
631       ptr = mutable_data_ptr<T>();
632     }
633     return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
634   }
635   template<typename T, size_t N>
636   TensorAccessor<T,N> accessor() && = delete;
637 
638   // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
639   // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
640   // cast the data pointer to a __restrict__ pointer.
641   // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
642   // as an argument.
643   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
generic_packed_accessor()644   GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
645     static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
646     TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
647     T* ptr = nullptr;
648     if constexpr (std::is_const<T>::value) {
649       ptr = const_data_ptr<T>();
650     } else {
651       ptr = mutable_data_ptr<T>();
652     }
653     return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
654   }
655   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
656   GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
657 
658   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
packed_accessor32()659   PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
660     TORCH_CHECK(
661         impl_->numel() <=
662             static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
663         "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
664     return generic_packed_accessor<T,N,PtrTraits,int32_t>();
665   }
666   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
667   PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
668 
669   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
packed_accessor64()670   PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
671     return generic_packed_accessor<T,N,PtrTraits,int64_t>();
672   }
673   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
674   PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
675 
676   // ~~~~~ Autograd API ~~~~~
677 
678   /// \fn bool is_leaf() const;
679   ///
680   /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
681   ///
682   /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
683   /// created by the user. This means that they are not the result of an operation and so
684   /// `grad_fn()` is `nullptr`.
685   ///
686   /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
687   /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
688   ///
689   /// Example:
690   /// @code
691   /// auto a = torch::rand(10, torch::requires_grad());
692   /// std::cout << a.is_leaf() << std::endl; // prints `true`
693   ///
694   /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
695   /// std::cout << b.is_leaf() << std::endl; // prints `false`
696   /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
697   ///
698   /// auto c = torch::rand(10, torch::requires_grad()) + 2;
699   /// std::cout << c.is_leaf() << std::endl; // prints `false`
700   /// // c was created by the addition operation
701   ///
702   /// auto d = torch::rand(10).cuda();
703   /// std::cout << d.is_leaf() << std::endl; // prints `true`
704   /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
705   ///
706   /// auto e = torch::rand(10).cuda().requires_grad_();
707   /// std::cout << e.is_leaf() << std::endl; // prints `true`
708   /// // e requires gradients and has no operations creating it
709   ///
710   /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
711   /// std::cout << f.is_leaf() << std::endl; // prints `true`
712   /// // f requires grad, has no operation creating it
713   /// @endcode
714 
715   /// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
716   ///
717   /// Computes the gradient of current tensor with respect to graph leaves.
718   ///
719   /// The graph is differentiated using the chain rule. If the tensor is
720   /// non-scalar (i.e. its data has more than one element) and requires
721   /// gradient, the function additionally requires specifying ``gradient``.
722   /// It should be a tensor of matching type and location, that contains
723   /// the gradient of the differentiated function w.r.t. this Tensor.
724   ///
725   /// This function accumulates gradients in the leaves - you might need to
726   /// zero them before calling it.
727   ///
728   /// \param gradient Gradient w.r.t. the
729   ///     tensor. If it is a tensor, it will be automatically converted
730   ///     to a Tensor that does not require grad unless ``create_graph`` is True.
731   ///     None values can be specified for scalar Tensors or ones that
732   ///     don't require grad. If a None value would be acceptable then
733   ///     this argument is optional.
734   /// \param retain_graph If ``false``, the graph used to compute
735   ///     the grads will be freed. Note that in nearly all cases setting
736   ///     this option to True is not needed and often can be worked around
737   ///     in a much more efficient way. Defaults to the value of
738   ///     ``create_graph``.
739   /// \param create_graph If ``true``, graph of the derivative will
740   ///     be constructed, allowing to compute higher order derivative
741   ///     products. Defaults to ``false``.
742   /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
743   ///     ``at::Tensor::grad``. All other Tensors will be ignored. If not
744   ///     provided, the gradient is accumulated into all the leaf Tensors
745   ///     that were used to compute the current tensor.
746   ///     When inputs are provided and a given input is not a leaf,
747   ///     the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
748   ///     It is an implementation detail on which the user should not rely.
749   ///     See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
750 
751   /// \fn Tensor detach() const;
752   ///
753   /// Returns a new Tensor, detached from the current graph.
754   /// The result will never require gradient.
755 
756   /// \fn Tensor & detach_() const;
757   ///
758   /// Detaches the Tensor from the graph that created it, making it a leaf.
759   /// Views cannot be detached in-place.
760 
761   /// \fn void retain_grad() const;
762   ///
763   /// Enables this Tensor to have their :attr:`grad` populated during
764   /// :func:`backward`. This is a no-op for leaf tensors.
765 
766   /// \fn bool retains_grad() const;
767   ///
768   /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
769   /// populated during :func:`backward`, ``false`` otherwise.
770 
set_requires_grad(bool requires_grad)771   const TensorBase& set_requires_grad(bool requires_grad) const {
772     impl_->set_requires_grad(requires_grad);
773     return *this;
774   }
requires_grad()775   bool requires_grad() const {
776     return impl_->requires_grad();
777   }
778 
779   // The Forward AD API functions below are low level and are not to be used by end
780   // users who should use the API provided in torch/csrc/autograd.h
781 
782   /// This function returns the forward gradient for this Tensor at the given level.
_fw_grad(uint64_t level)783   const Tensor& _fw_grad(uint64_t level) const {
784     return impl_->_fw_grad(level, *this);
785   }
786 
787   /// This function can be used to set the value of the forward grad.
788   /// Note that the given new_grad might not be used directly if it has different
789   /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
790   /// new_grad content will be copied into a new Tensor
_set_fw_grad(const TensorBase & new_grad,uint64_t level,bool is_inplace_op)791   void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
792     impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
793   }
794 
795   /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
796   /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
797   /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
798   ///
799   /// One notable difference with the legacy `.data()` function is that changes to the
800   /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
801   /// will not update the original `Variable`, due to the fact that this function
802   /// shallow-copies the `Variable`'s underlying TensorImpl.
803   at::TensorBase tensor_data() const;
804 
805   /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
806   /// in Python, which create a new `Variable` that shares the same storage and
807   /// tensor metadata with the original `Variable`, but with a completely new
808   /// autograd history.
809   ///
810   /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
811   /// storage / storage_offset) of a variable created from `var.variable_data()`, those
812   /// changes will not update the original variable `var`. In `.variable_data()`, we set
813   /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
814   /// in order to prevent users from changing metadata of `var.variable_data()`
815   /// and expecting the original variable `var` to also be updated.
816   at::TensorBase variable_data() const;
817 
818   // Gradient Node and Edges
819   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
820 
821   /// Gets the gradient function of the `Variable`. If this is a leaf variable,
822   /// the pointer returned will be null.
823   ///
824   /// For View Variables:
825   /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
826   /// re-create the grad_fn to express the up-to-date view relationship between
827   /// this and the base Variable.
828   const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
829 
830   // Hooks
831   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
832 
833   template <typename T>
834   using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
835   template <typename T>
836   using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
837 
838   /// Registers a backward hook.
839   ///
840   /// The hook will be called every time a gradient with respect to the Tensor is computed.
841   /// The hook should have one of the following signature:
842   /// ```
843   /// hook(TensorBase grad) -> TensorBase
844   /// ```
845   /// ```
846   /// hook(TensorBase grad) -> void
847   /// ```
848   /// The hook should not modify its argument, but it can optionally return a new gradient
849   /// which will be used in place of `grad`.
850   ///
851   /// This function returns the index of the hook in the list which can be used to remove hook.
852   ///
853   /// Example:
854   /// @code
855   /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
856   /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
857   /// v.backward(torch::tensor({1., 2., 3.}));
858   /// // This prints:
859   /// // ```
860   /// //  2
861   /// //  4
862   /// //  6
863   /// // [ CPUFloatType{3} ]
864   /// // ```
865   /// std::cout << v.grad() << std::endl;
866   /// v.remove_hook(h);  // removes the hook
867   /// @endcode
868   template <typename T>
869   hook_return_void_t<T> register_hook(T&& hook) const;
870   template <typename T>
871   hook_return_var_t<T> register_hook(T&& hook) const;
872 
873 protected:
874   unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
875 
876 public:
877 
878   /// Remove hook at given position
879   void remove_hook(unsigned pos) const;
880 
881   // Variable methods
882   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
883 
884   bool is_leaf() const;
885 
886   int64_t output_nr() const;
887 
888   void set_data(const TensorBase & new_data) const;
889 
890   TensorBase data() const;
891 
892   int64_t _version() const;
893 
894   void retain_grad() const;
895 
896   bool retains_grad() const;
897 
898   const TensorBase& requires_grad_(bool _requires_grad=true) const;
899 
900   // View Variables
901   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
902 
903   /// Returns true if this `Variable` is a view of another `Variable`.
904   bool is_view() const;
905 
906   /// Returns the `Variable` that this `Variable` is a view of. If this
907   /// `Variable` is not a view, throw a `std::runtime_error`.
908   const TensorBase& _base() const;
909 
910   // Miscellaneous
911   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
912 
913   const std::string& name() const;
914 
915 protected:
916   void enforce_invariants();
917   c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
918 
919 private:
920   TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
921 };
922 
get_device(const TensorBase & self)923 inline DeviceIndex get_device(const TensorBase& self) {
924   return self.get_device();
925 }
926 
927 template <typename T>
928 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
929 auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
930   // Return the grad argument in case of a hook with void return type to have an
931   // std::function with Tensor return type
932   static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
933                 "Expected hook to return void");
934   return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
935     fn(grad);
936     return TensorBase();
937   });
938 }
939 
940 template <typename T>
941 auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
942   return _register_hook(std::forward<T>(hook));
943 }
944 
945 namespace detail {
946 // Helper creator for Tensor class which doesn't requires the users to pass
947 // in an intrusive_ptr instead it just converts the argument passed to
948 // requested intrusive_ptr type.
949 template <typename T, typename... Args>
make_tensor_base(Args &&...args)950 TensorBase make_tensor_base(Args&&... args) {
951   return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
952 }
953 
954 } // namespace detail
955 
legacyExtractDispatchKey(const TensorBase & t)956 inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
957   return legacyExtractDispatchKey(t.key_set());
958 }
959 
960 } // namespace at
961 
962 namespace c10 {
963 template <>
964 struct MaybeOwnedTraits<at::TensorBase> {
965   using owned_type = at::TensorBase;
966   using borrow_type = at::TensorBase;
967 
968   static borrow_type createBorrow(const owned_type& from) {
969     // NOTE: this can be implemented without the special
970     // unsafe_borrow_t Tensor constructor as
971     //
972     // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
973     //
974     // but that hurts inlining due to the nullptr check in the
975     // Tensor(c10::intrusive_ptr<...>) constructor. We already know
976     // that from.impl_ isn't null because from is a valid Tensor, so
977     // we needn't do the check again. (using __builtin_assume can
978     // avoid this, but wouldn't be portable to MSVC.)
979     return borrow_type(borrow_type::unsafe_borrow_t{}, from);
980   }
981 
982   static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
983     lhs.unsafeReleaseTensorImpl();
984     // See above note: this can be implemented with public API
985     // similarly to createBorrow(), but that would hurt inlining.
986     lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
987   }
988 
989   static void destroyBorrow(borrow_type& toDestroy) {
990     toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
991   }
992 
993   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
994     return borrow;
995   }
996 
997   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
998     return &borrow;
999   }
1000 
1001   static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
1002     return true;
1003   }
1004 };
1005 
1006 template <>
1007 struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
1008 } // namespace c10
1009 
1010 namespace at {
1011 
1012 inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
1013     const std::optional<TensorBase>& opt) {
1014   return opt.has_value()
1015     ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
1016     : c10::MaybeOwned<TensorBase>::owned(std::in_place);
1017 }
1018 
1019 inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
1020   if (is_contiguous(memory_format)) {
1021     return c10::MaybeOwned<TensorBase>::borrowed(*this);
1022   } else {
1023     return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
1024   }
1025 }
1026 
1027 namespace symint {
1028 
1029 template <typename T>
1030 using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
1031 template <typename T>
1032 using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
1033 
1034 template <typename T, typename = enable_if_symint<T>>
1035 c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
1036 template <typename T, typename = enable_if_int<T>>
1037 IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
1038 
1039 template <typename T, typename = enable_if_symint<T>>
1040 c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
1041 template <typename T, typename = enable_if_int<T>>
1042 int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
1043 
1044 template <typename T, typename = enable_if_symint<T>>
1045 c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
1046 template <typename T, typename = enable_if_int<T>>
1047 IntArrayRef strides(const TensorBase& t) { return t.strides(); }
1048 
1049 template <typename T, typename = enable_if_symint<T>>
1050 c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
1051 template <typename T, typename = enable_if_int<T>>
1052 int64_t numel(const TensorBase& t) { return t.numel(); }
1053 
1054 } // namespace symint
1055 
1056 } // namespace at
1057