1 #pragma once
2
3 #include <c10/core/Device.h>
4 #include <c10/core/Layout.h>
5 #include <c10/core/MemoryFormat.h>
6 #include <c10/core/ScalarType.h>
7 #include <c10/core/ScalarTypeToTypeMeta.h>
8 #include <c10/core/Storage.h>
9 #include <c10/core/SymIntArrayRef.h>
10 #include <c10/core/TensorImpl.h>
11 #include <c10/core/TensorOptions.h>
12 #include <c10/core/UndefinedTensorImpl.h>
13 #include <c10/core/WrapDimMinimal.h>
14 #include <c10/util/C++17.h>
15 #include <c10/util/Exception.h>
16 #include <c10/util/ExclusivelyOwned.h>
17 #include <c10/util/ExclusivelyOwnedTensorTraits.h>
18 #include <c10/util/MaybeOwned.h>
19 #include <optional>
20 #include <c10/util/intrusive_ptr.h>
21
22 #include <ATen/core/NamedTensor.h>
23 #include <ATen/core/QuantizerBase.h>
24 #include <ATen/core/TensorAccessor.h>
25 #include <ATen/StorageUtils.h>
26
27 namespace c10 {
28 class Scalar;
29 }
30
31 namespace torch::autograd {
32
33 struct Node;
34
35 } // namespace torch::autograd
36
37 namespace at {
38
39 class Tensor;
40 class TensorBase;
41
42 // Convert Tensor to TensorBase without any need to include Tensor.h
43 TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
44
45 namespace impl {
variable_excluded_from_dispatch()46 inline bool variable_excluded_from_dispatch() {
47 #ifdef C10_MOBILE
48 // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
49 return true;
50 #else
51 return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
52 #endif
53 }
54
55 }
56
57 // NOTE: [Tensor vs. TensorBase]
58 //
59 // Tensor, being the central data structure in PyTorch, gets used and
60 // it's header included almost everywhere. Unfortunately this means
61 // every time an operator signature is updated or changed in
62 // native_functions.yaml, you (and every other PyTorch developer) need
63 // to recompile all of ATen and it's dependencies.
64 //
65 // TensorBase aims to break up these header dependencies, and improve
66 // incremental build times for all PyTorch developers. TensorBase
67 // represents a reference counted handle to TensorImpl, exactly the
68 // same as Tensor. However, TensorBase doesn't have code generated
69 // methods in it's API and thus no dependence on native_functions.yaml.
70 //
71 // Usage tips
72 // ----------
73 // - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
74 // or .cu file to ensure it has no header dependencies on
75 // native_functions.yaml (direct or indirect).
76 // - Tensor inherits from TensorBase, so functions taking
77 // `const TensorBase &` are callable with Tensor as well.
78 // - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
79 // but this requires a reference-count bump. OptionalTensorRef on
80 // the other hand can materialize a `const Tensor &` without
81 // touching the reference-count.
82 class TORCH_API TensorBase {
83 public:
84 struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
85
86 protected:
87 // Create a Tensor with a +0 reference count. Special care must be
88 // taken to avoid decrementing this reference count at destruction
89 // time. Intended to support MaybeOwnedTraits<Tensor>.
TensorBase(unsafe_borrow_t,const TensorBase & rhs)90 explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
91 : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
92 friend MaybeOwnedTraits<TensorBase>;
93
94 public:
95 TensorBase() = default;
96 // This constructor should not be used by end users and is an implementation
97 // detail invoked by autogenerated code.
TensorBase(c10::intrusive_ptr<TensorImpl,UndefinedTensorImpl> tensor_impl)98 explicit TensorBase(
99 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
100 : impl_(std::move(tensor_impl)) {
101 if (impl_.get() == nullptr) {
102 throw std::runtime_error("TensorImpl with nullptr is not supported");
103 }
104 }
105 TensorBase(const TensorBase&) = default;
106 TensorBase(TensorBase&&) noexcept = default;
107
108 public:
109 // Creates a new wrapper from TensorImpl. Intentionally a free method because
110 // it should be used with care. Checks necessary invariants
wrap_tensor_impl(c10::intrusive_ptr<TensorImpl,UndefinedTensorImpl> tensor_impl)111 static TensorBase wrap_tensor_impl(
112 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
113 TensorBase r(std::move(tensor_impl));
114 r.enforce_invariants();
115 return r;
116 }
117
dim()118 int64_t dim() const {
119 return impl_->dim();
120 }
storage_offset()121 int64_t storage_offset() const {
122 return impl_->storage_offset();
123 }
124
125 TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
126 if (is_contiguous(memory_format)) {
127 return *this;
128 } else {
129 return __dispatch_contiguous(memory_format);
130 }
131 }
132
133 /// Should be used if *this can reasonably be expected to be contiguous and
134 /// performance is important.
135 /// Compared to contiguous, it saves a reference count
136 /// increment/decrement if *this is already contiguous, at the cost
137 /// in all cases of an extra pointer of stack usage, an extra branch
138 /// to access, and an extra branch at destruction time.
139 c10::MaybeOwned<TensorBase> expect_contiguous(
140 MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
141
142 // Use .contiguous() instead. Trying to borrow from a prvalue
143 // will only lead to trouble and dangling references.
144 c10::MaybeOwned<TensorBase> expect_contiguous(
145 MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
146
147 const TensorBase& fill_(const c10::Scalar& scalar) const;
148 const TensorBase& zero_() const;
149
150 TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, std::optional<at::MemoryFormat> memory_format=std::nullopt) const;
151
is_complex()152 bool is_complex() const {
153 return at::isComplexType(this->scalar_type());
154 }
155
is_floating_point()156 bool is_floating_point() const {
157 return at::isFloatingType(this->scalar_type());
158 }
159
is_signed()160 bool is_signed() const {
161 return at::isSignedType(this->scalar_type());
162 }
163
sym_size(int64_t dim)164 c10::SymInt sym_size(int64_t dim) const {
165 return impl_->sym_size(dim);
166 }
167
sym_stride(int64_t dim)168 c10::SymInt sym_stride(int64_t dim) const {
169 const auto sizes = this->sym_strides();
170 const auto ndim = static_cast<int64_t>(sizes.size());
171 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
172 return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
173
174 }
175
size(int64_t dim)176 int64_t size(int64_t dim) const {
177 return impl_->size(dim);
178 }
179
stride(int64_t dim)180 int64_t stride(int64_t dim) const {
181 const auto strides = this->strides();
182 const auto ndim = static_cast<int64_t>(strides.size());
183 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
184 return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
185 }
186
unsafeGetTensorImpl()187 TensorImpl * unsafeGetTensorImpl() const {
188 return impl_.get();
189 }
unsafeReleaseTensorImpl()190 TensorImpl * unsafeReleaseTensorImpl() {
191 return impl_.release();
192 }
getIntrusivePtr()193 const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
194 return impl_;
195 }
196
unsafeReleaseIntrusivePtr()197 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
198 return std::move(impl_);
199 }
200
defined()201 bool defined() const {
202 return impl_;
203 }
204
reset()205 void reset() {
206 impl_.reset();
207 }
208
209 #if defined (_MSC_VER)
210 TensorBase& operator=(const TensorBase& x) & {
211 impl_ = x.impl_;
212 return *this;
213 };
214 TensorBase& operator=(TensorBase&& x) & noexcept {
215 impl_ = std::move(x.impl_);
216 return *this;
217 }
218 #else
219 TensorBase& operator=(const TensorBase& x) & = default;
220 TensorBase& operator=(TensorBase&& x) & noexcept = default;
221 #endif
222
223 // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
224 TensorBase& operator=(const TensorBase&) && = delete;
225 TensorBase& operator=(TensorBase&&) && noexcept = delete;
226
is_same(const TensorBase & other)227 bool is_same(const TensorBase& other) const noexcept {
228 return impl_ == other.impl_;
229 }
use_count()230 size_t use_count() const noexcept {
231 return impl_.use_count();
232 }
weak_use_count()233 size_t weak_use_count() const noexcept {
234 return impl_.weak_use_count();
235 }
236
237 std::string toString() const;
238
sizes()239 IntArrayRef sizes() const {
240 return impl_->sizes();
241 }
sym_sizes()242 c10::SymIntArrayRef sym_sizes() const {
243 return impl_->sym_sizes();
244 }
sym_strides()245 c10::SymIntArrayRef sym_strides() const {
246 return impl_->sym_strides();
247 }
strides()248 IntArrayRef strides() const {
249 return impl_->strides();
250 }
251 // See impl::get_opt_names in ATen/NamedTensor.h for docs.
opt_names()252 std::optional<DimnameList> opt_names() const {
253 return impl::get_opt_names(unsafeGetTensorImpl());
254 }
255 // See impl::get_names in ATen/NamedTensor.h for docs.
names()256 DimnameList names() const {
257 return impl::get_names(unsafeGetTensorImpl());
258 }
ndimension()259 int64_t ndimension() const {
260 return dim();
261 }
262
263 bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
264 return impl_->is_contiguous(memory_format);
265 }
266
is_non_overlapping_and_dense()267 bool is_non_overlapping_and_dense() const {
268 return impl_->is_non_overlapping_and_dense();
269 }
270
271 at::MemoryFormat suggest_memory_format(
272 bool channels_last_strides_exact_match = false) const {
273 // Setting channels_last_strides_exact_match to true forces function to
274 // check 0,1 - sized dimension strides.
275 if (layout() == at::kStrided) {
276 if (impl_->is_strides_like_channels_last()) {
277 if (!channels_last_strides_exact_match ||
278 get_channels_last_strides_2d(sizes()) == strides()) {
279 return at::MemoryFormat::ChannelsLast;
280 }
281 }
282 else if (impl_->is_strides_like_channels_last_3d()) {
283 if (!channels_last_strides_exact_match ||
284 get_channels_last_strides_3d(sizes()) == strides()) {
285 return at::MemoryFormat::ChannelsLast3d;
286 }
287 }
288 }
289 return at::MemoryFormat::Contiguous;
290 }
291
292 // Total bytes consumed by the "view" of elements of the array. Does not
293 // include size of metadata. The number reported here does not necessarily
294 // correspond to the true physical memory consumed by a tensor; instead,
295 // it reports the memory the tensor would take *if* it were contiguous.
296 // Defined to be numel() * itemsize()
nbytes()297 size_t nbytes() const {
298 TORCH_CHECK(layout () != at::kSparse,
299 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
300 "tensors, add the nbytes of the indices and values. If you want the size of the " \
301 "equivalent dense tensor, multiply numel() by element_size()");
302 return impl_->numel() * impl_->itemsize();
303 }
304
sym_nbytes()305 c10::SymInt sym_nbytes() const {
306 TORCH_CHECK(layout () != at::kSparse,
307 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
308 "tensors, add the nbytes of the indices and values. If you want the size of the " \
309 "equivalent dense tensor, multiply numel() by element_size()");
310 return impl_->sym_numel() * impl_->itemsize();
311 }
312
numel()313 int64_t numel() const {
314 return impl_->numel();
315 }
316
sym_numel()317 c10::SymInt sym_numel() const {
318 return impl_->sym_numel();
319 }
320
sym_storage_offset()321 c10::SymInt sym_storage_offset() const {
322 return impl_->sym_storage_offset();
323 }
324
325 // Length of one array element in bytes. This is the traditional
326 // Numpy naming.
itemsize()327 size_t itemsize() const {
328 return impl_->itemsize();
329 }
330
331 // Same as itemsize(). This is the PyTorch naming.
element_size()332 int64_t element_size() const {
333 return static_cast<int64_t>(impl_->itemsize());
334 }
335
key_set()336 DispatchKeySet key_set() const {
337 return impl_->key_set();
338 }
scalar_type()339 ScalarType scalar_type() const {
340 return typeMetaToScalarType(impl_->dtype());
341 }
has_storage()342 bool has_storage() const {
343 return defined() && impl_->has_storage();
344 }
storage()345 const Storage& storage() const {
346 return impl_->storage();
347 }
is_alias_of(const at::TensorBase & other)348 bool is_alias_of(const at::TensorBase& other) const{
349 return impl_->storage().is_alias_of(other.storage());
350 }
351
352 // Move the storage backend to shm based
353 // to enable memory sharing across processes.
354 //
355 // NB1: the ideal behavior of this API still requires further discussion
356 // but for now we are inclined to keep it consistent with existing THP behavior
357 // https://github.com/pytorch/pytorch/blob/4dca9bde0552afc67b5b74f4a0696fe6055709c4/torch/storage.py#L196-L212
358 // so we don't assert on anything here and rely on caller knowing
359 // what it's doing.
360 //
361 // NB2: this currently provides Linux fd based shm support only
362 // to simplify the storage lifetime management logic in ATen
363 // and similarly for now we are not adding support for file system based
364 // shm support like in THP due to additional GC manager support needed
365 // to prevent leaks.
366 // As such, calling this from non supported systems (e.g. Windows) would fail.
share_memory_()367 void share_memory_() {
368 at::share_memory_(*this);
369 }
370
_is_zerotensor()371 inline bool _is_zerotensor() const {
372 return impl_->_is_zerotensor();
373 }
374
_set_zero(bool zero)375 inline void _set_zero(bool zero) const {
376 impl_->_set_zero(zero);
377 }
378
is_conj()379 inline bool is_conj() const {
380 return impl_->is_conj();
381 }
382
383 // sets the conjugate bit of a tensor.
384 // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
385 // that's what you want. Changing this might lead to incorrect behavior since conjugation is
386 // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
_set_conj(bool conjugate)387 inline void _set_conj(bool conjugate) const {
388 impl_->_set_conj(conjugate);
389 }
390
is_neg()391 inline bool is_neg() const {
392 return impl_->is_neg();
393 }
394
395 // sets the negative bit of a tensor.
396 // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
397 // that's what you want. Changing this might lead to incorrect behavior since we rely on this
398 // bit to determine if a negation needs to be materialized.
_set_neg(bool negative)399 inline void _set_neg(bool negative) const {
400 impl_->_set_neg(negative);
401 }
402
403 /// Returns a `Tensor`'s layout.
layout()404 Layout layout() const {
405 return impl_->layout();
406 }
407
408 /// Returns a `Tensor`'s dtype (`TypeMeta`).
dtype()409 caffe2::TypeMeta dtype() const {
410 return impl_->dtype();
411 }
412
413 /// Returns a `Tensor`'s device.
device()414 inline Device device() const {
415 return impl_->device();
416 }
417
418 /// Returns a `Tensor`'s device index.
get_device()419 DeviceIndex get_device() const {
420 // NB: this is not a native function to avoid dispatching overhead.
421 return impl_->get_device();
422 }
423
424 /// Returns if a `Tensor` has CPU backend.
is_cpu()425 bool is_cpu() const {
426 // NB: this is not a native function to avoid dispatching overhead.
427 return impl_->is_cpu();
428 }
429
430 /// Returns if a `Tensor` has CUDA backend.
is_cuda()431 bool is_cuda() const {
432 // NB: this is not a native function to avoid dispatching overhead.
433 return impl_->is_cuda();
434 }
435
436 /// Returns if a `Tensor` has IPU backend.
is_ipu()437 bool is_ipu() const {
438 // NB: this is not a native function to avoid dispatching overhead.
439 return impl_->is_ipu();
440 }
441
442 /// Returns if a `Tensor` has XPU backend.
is_xpu()443 bool is_xpu() const {
444 // NB: this is not a native function to avoid dispatching overhead.
445 return impl_->is_xpu();
446 }
447
448 /// Returns if a `Tensor` has XLA backend.
is_xla()449 bool is_xla() const {
450 return impl_->is_xla();
451 }
452
453 /// Returns if a `Tensor` has MTIA backend.
is_mtia()454 bool is_mtia() const {
455 return impl_->is_mtia();
456 }
457
458 /// Returns if a `Tensor` has HPU backend.
is_hpu()459 bool is_hpu() const {
460 return impl_->is_hpu();
461 }
462
463 /// Returns if a `Tensor` has Lazy backend.
is_lazy()464 bool is_lazy() const {
465 return impl_->is_lazy();
466 }
467
468 /// Returns if a `Tensor` has HIP backend.
is_hip()469 bool is_hip() const {
470 // NB: this is not a native function to avoid dispatching overhead.
471 return impl_->is_hip();
472 }
473
474 /// Returns if a `Tensor` has VE backend.
is_ve()475 bool is_ve() const {
476 // NB: this is not a native function to avoid dispatching overhead.
477 return impl_->is_ve();
478 }
479
480 /// Returns if a `Tensor` has PrivateUse1 backend.
is_privateuseone()481 bool is_privateuseone() const {
482 // NB: this is not a native function to avoid dispatching overhead.
483 return impl_->is_privateuseone();
484 }
485
486 /// Returns if a `Tensor` has sparse backend.
is_sparse()487 bool is_sparse() const {
488 // NB: this is not a native function to avoid dispatching overhead.
489 return impl_->is_sparse();
490 }
491
492 /// Returns is a `Tensor` has a sparse CSR backend.
is_sparse_csr()493 bool is_sparse_csr() const {
494 // NB: this is not a native function to avoid dispatching overhead.
495 return impl_->is_sparse_csr();
496 }
497
498 /// Returns if a `Tensor` is mkldnn tensor.
is_mkldnn()499 bool is_mkldnn() const {
500 // NB: this is not a native function to avoid dispatching overhead.
501 return impl_->is_mkldnn();
502 }
503
504 /// Returns if a `Tensor` is mps tensor.
is_mps()505 bool is_mps() const {
506 // NB: this is not a native function to avoid dispatching overhead.
507 return impl_->is_mps();
508 }
509
510 /// Returns if a `Tensor` is maia tensor.
is_maia()511 bool is_maia() const {
512 // NB: this is not a native function to avoid dispatching overhead.
513 return impl_->is_maia();
514 }
515
516 /// Returns if a `Tensor` is vulkan tensor.
is_vulkan()517 bool is_vulkan() const {
518 // NB: this is not a native function to avoid dispatching overhead.
519 return impl_->is_vulkan();
520 }
521
522 /// Returns if a `Tensor` is metal tensor.
is_metal()523 bool is_metal() const {
524 // NB: this is not a native function to avoid dispatching overhead.
525 return impl_->is_metal();
526 }
527
528 /// Returns if a `Tensor` has quantized backend.
is_quantized()529 bool is_quantized() const {
530 // NB: this is not a native function to avoid dispatching overhead.
531 return impl_->is_quantized();
532 }
533
534 /// Returns if a `Tensor` is a meta tensor. Meta tensors can
535 /// also have other designations.
is_meta()536 bool is_meta() const {
537 return impl_->is_meta();
538 }
539
540 /// Returns if a `Tensor` is an inference tensor.
is_inference()541 bool is_inference() const {
542 return impl_->is_inference();
543 }
544
545 // Returns if a `Tensor` is a NestedTensor.
is_nested()546 bool is_nested() const {
547 return impl_->is_nested();
548 }
549
550 /// If a tensor is a quantized tensor, returns its quantizer
551 /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
552 QuantizerPtr quantizer() const;
553
554 /// Returns if a `Tensor` has any dimension names
has_names()555 bool has_names() const {
556 // If a user is using unnamed tensors, then we can short-circuit right here.
557 // Otherwise, impl::has_names attempts to retrieve names.
558 if (!impl_->has_named_tensor_meta()) {
559 return false;
560 }
561 return impl::has_names(unsafeGetTensorImpl());
562 }
563
564 /// Returns a `Tensor`'s dimension names data structure
get_named_tensor_meta()565 const NamedTensorMeta* get_named_tensor_meta() const {
566 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
567 }
568
get_named_tensor_meta()569 NamedTensorMeta* get_named_tensor_meta() {
570 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
571 }
572
573 /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
574 /// TensorOptions.h.
options()575 TensorOptions options() const {
576 return TensorOptions().dtype(dtype())
577 .device(device())
578 .layout(layout());
579 }
580
const_data_ptr()581 const void* const_data_ptr() const {
582 return this->unsafeGetTensorImpl()->data();
583 }
584
mutable_data_ptr()585 void* mutable_data_ptr() const {
586 return this->unsafeGetTensorImpl()->mutable_data();
587 }
588
589 // TODO(#97856) Make this return a const pointer. This currently
590 // returns a non-const pointer because of the large
591 // number of clients that we still want to audit before
592 // migrating to mutable_data_ptr().
data_ptr()593 void* data_ptr() const {
594 return mutable_data_ptr();
595 }
596
597 template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
598 const T* const_data_ptr() const;
599
600 template <typename T, std::enable_if_t<std::is_const_v<T>, int> = 0>
601 const std::remove_const_t<T>* const_data_ptr() const;
602
603 template <typename T>
604 T* mutable_data_ptr() const;
605
606 // Legacy interface during the migration to indicate that a callsite
607 // has not been audited for mutability.
608 //
609 // Do not add new uses of this, use const_data_ptr() if possible,
610 // mutable_data_ptr() otherwise.
611 //
612 // TODO(#97856) Make this return a const pointer. This is currently
613 // const because of the vast number of clients that
614 // rely on this.
615 template <typename T>
616 T* data_ptr() const;
617
618 // Purposely not defined here to avoid inlining
619 void print() const;
620
621 // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
622 // dimension.
623 template<typename T, size_t N>
accessor()624 TensorAccessor<T,N> accessor() const& {
625 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
626 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
627 T* ptr = nullptr;
628 if constexpr (std::is_const<T>::value) {
629 ptr = const_data_ptr<T>();
630 } else {
631 ptr = mutable_data_ptr<T>();
632 }
633 return TensorAccessor<T,N>(ptr,sizes().data(),strides().data());
634 }
635 template<typename T, size_t N>
636 TensorAccessor<T,N> accessor() && = delete;
637
638 // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
639 // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
640 // cast the data pointer to a __restrict__ pointer.
641 // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
642 // as an argument.
643 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
generic_packed_accessor()644 GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
645 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
646 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
647 T* ptr = nullptr;
648 if constexpr (std::is_const<T>::value) {
649 ptr = const_data_ptr<T>();
650 } else {
651 ptr = mutable_data_ptr<T>();
652 }
653 return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(ptr),sizes().data(),strides().data());
654 }
655 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
656 GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
657
658 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
packed_accessor32()659 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
660 TORCH_CHECK(
661 impl_->numel() <=
662 static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
663 "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
664 return generic_packed_accessor<T,N,PtrTraits,int32_t>();
665 }
666 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
667 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
668
669 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
packed_accessor64()670 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
671 return generic_packed_accessor<T,N,PtrTraits,int64_t>();
672 }
673 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
674 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
675
676 // ~~~~~ Autograd API ~~~~~
677
678 /// \fn bool is_leaf() const;
679 ///
680 /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
681 ///
682 /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
683 /// created by the user. This means that they are not the result of an operation and so
684 /// `grad_fn()` is `nullptr`.
685 ///
686 /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
687 /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
688 ///
689 /// Example:
690 /// @code
691 /// auto a = torch::rand(10, torch::requires_grad());
692 /// std::cout << a.is_leaf() << std::endl; // prints `true`
693 ///
694 /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
695 /// std::cout << b.is_leaf() << std::endl; // prints `false`
696 /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
697 ///
698 /// auto c = torch::rand(10, torch::requires_grad()) + 2;
699 /// std::cout << c.is_leaf() << std::endl; // prints `false`
700 /// // c was created by the addition operation
701 ///
702 /// auto d = torch::rand(10).cuda();
703 /// std::cout << d.is_leaf() << std::endl; // prints `true`
704 /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
705 ///
706 /// auto e = torch::rand(10).cuda().requires_grad_();
707 /// std::cout << e.is_leaf() << std::endl; // prints `true`
708 /// // e requires gradients and has no operations creating it
709 ///
710 /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
711 /// std::cout << f.is_leaf() << std::endl; // prints `true`
712 /// // f requires grad, has no operation creating it
713 /// @endcode
714
715 /// \fn void backward(const Tensor & gradient={}, std::optional<bool> retain_graph=std::nullopt, bool create_graph=false, std::optional<TensorList> inputs=std::nullopt) const;
716 ///
717 /// Computes the gradient of current tensor with respect to graph leaves.
718 ///
719 /// The graph is differentiated using the chain rule. If the tensor is
720 /// non-scalar (i.e. its data has more than one element) and requires
721 /// gradient, the function additionally requires specifying ``gradient``.
722 /// It should be a tensor of matching type and location, that contains
723 /// the gradient of the differentiated function w.r.t. this Tensor.
724 ///
725 /// This function accumulates gradients in the leaves - you might need to
726 /// zero them before calling it.
727 ///
728 /// \param gradient Gradient w.r.t. the
729 /// tensor. If it is a tensor, it will be automatically converted
730 /// to a Tensor that does not require grad unless ``create_graph`` is True.
731 /// None values can be specified for scalar Tensors or ones that
732 /// don't require grad. If a None value would be acceptable then
733 /// this argument is optional.
734 /// \param retain_graph If ``false``, the graph used to compute
735 /// the grads will be freed. Note that in nearly all cases setting
736 /// this option to True is not needed and often can be worked around
737 /// in a much more efficient way. Defaults to the value of
738 /// ``create_graph``.
739 /// \param create_graph If ``true``, graph of the derivative will
740 /// be constructed, allowing to compute higher order derivative
741 /// products. Defaults to ``false``.
742 /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
743 /// ``at::Tensor::grad``. All other Tensors will be ignored. If not
744 /// provided, the gradient is accumulated into all the leaf Tensors
745 /// that were used to compute the current tensor.
746 /// When inputs are provided and a given input is not a leaf,
747 /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
748 /// It is an implementation detail on which the user should not rely.
749 /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
750
751 /// \fn Tensor detach() const;
752 ///
753 /// Returns a new Tensor, detached from the current graph.
754 /// The result will never require gradient.
755
756 /// \fn Tensor & detach_() const;
757 ///
758 /// Detaches the Tensor from the graph that created it, making it a leaf.
759 /// Views cannot be detached in-place.
760
761 /// \fn void retain_grad() const;
762 ///
763 /// Enables this Tensor to have their :attr:`grad` populated during
764 /// :func:`backward`. This is a no-op for leaf tensors.
765
766 /// \fn bool retains_grad() const;
767 ///
768 /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
769 /// populated during :func:`backward`, ``false`` otherwise.
770
set_requires_grad(bool requires_grad)771 const TensorBase& set_requires_grad(bool requires_grad) const {
772 impl_->set_requires_grad(requires_grad);
773 return *this;
774 }
requires_grad()775 bool requires_grad() const {
776 return impl_->requires_grad();
777 }
778
779 // The Forward AD API functions below are low level and are not to be used by end
780 // users who should use the API provided in torch/csrc/autograd.h
781
782 /// This function returns the forward gradient for this Tensor at the given level.
_fw_grad(uint64_t level)783 const Tensor& _fw_grad(uint64_t level) const {
784 return impl_->_fw_grad(level, *this);
785 }
786
787 /// This function can be used to set the value of the forward grad.
788 /// Note that the given new_grad might not be used directly if it has different
789 /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
790 /// new_grad content will be copied into a new Tensor
_set_fw_grad(const TensorBase & new_grad,uint64_t level,bool is_inplace_op)791 void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
792 impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
793 }
794
795 /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
796 /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
797 /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
798 ///
799 /// One notable difference with the legacy `.data()` function is that changes to the
800 /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
801 /// will not update the original `Variable`, due to the fact that this function
802 /// shallow-copies the `Variable`'s underlying TensorImpl.
803 at::TensorBase tensor_data() const;
804
805 /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
806 /// in Python, which create a new `Variable` that shares the same storage and
807 /// tensor metadata with the original `Variable`, but with a completely new
808 /// autograd history.
809 ///
810 /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
811 /// storage / storage_offset) of a variable created from `var.variable_data()`, those
812 /// changes will not update the original variable `var`. In `.variable_data()`, we set
813 /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
814 /// in order to prevent users from changing metadata of `var.variable_data()`
815 /// and expecting the original variable `var` to also be updated.
816 at::TensorBase variable_data() const;
817
818 // Gradient Node and Edges
819 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
820
821 /// Gets the gradient function of the `Variable`. If this is a leaf variable,
822 /// the pointer returned will be null.
823 ///
824 /// For View Variables:
825 /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
826 /// re-create the grad_fn to express the up-to-date view relationship between
827 /// this and the base Variable.
828 const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
829
830 // Hooks
831 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
832
833 template <typename T>
834 using hook_return_void_t = std::enable_if_t<std::is_void_v<typename std::invoke_result_t<T&, TensorBase>>, unsigned>;
835 template <typename T>
836 using hook_return_var_t = std::enable_if_t<std::is_same_v<typename std::invoke_result_t<T&, TensorBase>, TensorBase>, unsigned>;
837
838 /// Registers a backward hook.
839 ///
840 /// The hook will be called every time a gradient with respect to the Tensor is computed.
841 /// The hook should have one of the following signature:
842 /// ```
843 /// hook(TensorBase grad) -> TensorBase
844 /// ```
845 /// ```
846 /// hook(TensorBase grad) -> void
847 /// ```
848 /// The hook should not modify its argument, but it can optionally return a new gradient
849 /// which will be used in place of `grad`.
850 ///
851 /// This function returns the index of the hook in the list which can be used to remove hook.
852 ///
853 /// Example:
854 /// @code
855 /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
856 /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
857 /// v.backward(torch::tensor({1., 2., 3.}));
858 /// // This prints:
859 /// // ```
860 /// // 2
861 /// // 4
862 /// // 6
863 /// // [ CPUFloatType{3} ]
864 /// // ```
865 /// std::cout << v.grad() << std::endl;
866 /// v.remove_hook(h); // removes the hook
867 /// @endcode
868 template <typename T>
869 hook_return_void_t<T> register_hook(T&& hook) const;
870 template <typename T>
871 hook_return_var_t<T> register_hook(T&& hook) const;
872
873 protected:
874 unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
875
876 public:
877
878 /// Remove hook at given position
879 void remove_hook(unsigned pos) const;
880
881 // Variable methods
882 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
883
884 bool is_leaf() const;
885
886 int64_t output_nr() const;
887
888 void set_data(const TensorBase & new_data) const;
889
890 TensorBase data() const;
891
892 int64_t _version() const;
893
894 void retain_grad() const;
895
896 bool retains_grad() const;
897
898 const TensorBase& requires_grad_(bool _requires_grad=true) const;
899
900 // View Variables
901 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
902
903 /// Returns true if this `Variable` is a view of another `Variable`.
904 bool is_view() const;
905
906 /// Returns the `Variable` that this `Variable` is a view of. If this
907 /// `Variable` is not a view, throw a `std::runtime_error`.
908 const TensorBase& _base() const;
909
910 // Miscellaneous
911 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
912
913 const std::string& name() const;
914
915 protected:
916 void enforce_invariants();
917 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
918
919 private:
920 TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
921 };
922
get_device(const TensorBase & self)923 inline DeviceIndex get_device(const TensorBase& self) {
924 return self.get_device();
925 }
926
927 template <typename T>
928 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
929 auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
930 // Return the grad argument in case of a hook with void return type to have an
931 // std::function with Tensor return type
932 static_assert(std::is_same_v<decltype(hook(TensorBase())), void>,
933 "Expected hook to return void");
934 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
935 fn(grad);
936 return TensorBase();
937 });
938 }
939
940 template <typename T>
941 auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
942 return _register_hook(std::forward<T>(hook));
943 }
944
945 namespace detail {
946 // Helper creator for Tensor class which doesn't requires the users to pass
947 // in an intrusive_ptr instead it just converts the argument passed to
948 // requested intrusive_ptr type.
949 template <typename T, typename... Args>
make_tensor_base(Args &&...args)950 TensorBase make_tensor_base(Args&&... args) {
951 return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
952 }
953
954 } // namespace detail
955
legacyExtractDispatchKey(const TensorBase & t)956 inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
957 return legacyExtractDispatchKey(t.key_set());
958 }
959
960 } // namespace at
961
962 namespace c10 {
963 template <>
964 struct MaybeOwnedTraits<at::TensorBase> {
965 using owned_type = at::TensorBase;
966 using borrow_type = at::TensorBase;
967
968 static borrow_type createBorrow(const owned_type& from) {
969 // NOTE: this can be implemented without the special
970 // unsafe_borrow_t Tensor constructor as
971 //
972 // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
973 //
974 // but that hurts inlining due to the nullptr check in the
975 // Tensor(c10::intrusive_ptr<...>) constructor. We already know
976 // that from.impl_ isn't null because from is a valid Tensor, so
977 // we needn't do the check again. (using __builtin_assume can
978 // avoid this, but wouldn't be portable to MSVC.)
979 return borrow_type(borrow_type::unsafe_borrow_t{}, from);
980 }
981
982 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
983 lhs.unsafeReleaseTensorImpl();
984 // See above note: this can be implemented with public API
985 // similarly to createBorrow(), but that would hurt inlining.
986 lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
987 }
988
989 static void destroyBorrow(borrow_type& toDestroy) {
990 toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
991 }
992
993 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
994 return borrow;
995 }
996
997 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
998 return &borrow;
999 }
1000
1001 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
1002 return true;
1003 }
1004 };
1005
1006 template <>
1007 struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
1008 } // namespace c10
1009
1010 namespace at {
1011
1012 inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
1013 const std::optional<TensorBase>& opt) {
1014 return opt.has_value()
1015 ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
1016 : c10::MaybeOwned<TensorBase>::owned(std::in_place);
1017 }
1018
1019 inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
1020 if (is_contiguous(memory_format)) {
1021 return c10::MaybeOwned<TensorBase>::borrowed(*this);
1022 } else {
1023 return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
1024 }
1025 }
1026
1027 namespace symint {
1028
1029 template <typename T>
1030 using enable_if_symint = std::enable_if_t<std::is_same_v<T, c10::SymInt>>;
1031 template <typename T>
1032 using enable_if_int = std::enable_if_t<std::is_same_v<T, int64_t>>;
1033
1034 template <typename T, typename = enable_if_symint<T>>
1035 c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
1036 template <typename T, typename = enable_if_int<T>>
1037 IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
1038
1039 template <typename T, typename = enable_if_symint<T>>
1040 c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
1041 template <typename T, typename = enable_if_int<T>>
1042 int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
1043
1044 template <typename T, typename = enable_if_symint<T>>
1045 c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
1046 template <typename T, typename = enable_if_int<T>>
1047 IntArrayRef strides(const TensorBase& t) { return t.strides(); }
1048
1049 template <typename T, typename = enable_if_symint<T>>
1050 c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
1051 template <typename T, typename = enable_if_int<T>>
1052 int64_t numel(const TensorBase& t) { return t.numel(); }
1053
1054 } // namespace symint
1055
1056 } // namespace at
1057