xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/Tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/core/Formatting.h>
3 #include <ATen/core/VariableHooksInterface.h>
4 #include <ATen/core/LegacyTypeDispatch.h>
5 #include <ATen/FunctionalTensorWrapper.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/MethodOperators.h>
9 #else
10 #include <ATen/ops/contiguous_ops.h>
11 #include <ATen/ops/fill_ops.h>
12 #include <ATen/ops/to_ops.h>
13 #include <ATen/ops/zero_ops.h>
14 #endif
15 
16 #include <iostream>
17 
18 namespace at {
19 
get_tensor_base(const Tensor & t)20 const TensorBase& get_tensor_base(const Tensor &t) {
21   return t;
22 }
23 
__dispatch_contiguous(c10::MemoryFormat memory_format) const24 TensorBase TensorBase::__dispatch_contiguous(c10::MemoryFormat memory_format) const {
25   OptionalTensorRef self(*this);
26   return at::_ops::contiguous::call(*self, memory_format);
27 }
28 
fill_(const c10::Scalar & fill_value) const29 const TensorBase& TensorBase::fill_(const c10::Scalar &fill_value) const {
30   Tensor self(*this);
31   at::_ops::fill__Scalar::call(self, fill_value);
32   return *this;
33 }
34 
zero_() const35 const TensorBase& TensorBase::zero_() const {
36   Tensor self(*this);
37   at::_ops::zero_::call(self);
38   return *this;
39 }
40 
to(at::TensorOptions options,bool non_blocking,bool copy,std::optional<at::MemoryFormat> memory_format) const41 TensorBase TensorBase::to(
42     at::TensorOptions options,
43     bool non_blocking,
44     bool copy,
45     std::optional<at::MemoryFormat> memory_format) const {
46   Tensor self(*this);
47   return at::_ops::to_dtype_layout::call(
48       self, optTypeMetaToScalarType(options.dtype_opt()),
49       options.layout_opt(), options.device_opt(),
50       options.pinned_memory_opt(), non_blocking, copy, memory_format);
51 }
52 
enforce_invariants()53 void TensorBase::enforce_invariants() {
54   if (impl_.get() == nullptr) {
55     throw std::runtime_error("TensorImpl with nullptr is not supported");
56   }
57   // Following line throws if the method is not a POD data type or is not
58   // supported by ATen
59   scalar_type();
60   if (defined()) {
61     TORCH_INTERNAL_ASSERT(
62         impl_->dtype_initialized(),
63         "Partially-initialized tensor not supported by Tensor");
64     TORCH_INTERNAL_ASSERT(
65         !impl_->is_sparse(),
66         "Sparse Tensors are supported by Tensor, but invariant checking isn't implemented.  Please file a bug.");
67     TORCH_INTERNAL_ASSERT(
68         !impl_->has_storage() || impl_->is_meta() || impl_->storage_initialized(),
69         "Partially-initialized tensor not supported by Tensor");
70   }
71 }
72 
print() const73 void TensorBase::print() const {
74   if (defined()) {
75     std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
76   } else {
77     std::cerr << "[UndefinedTensor]" << '\n';
78   }
79 }
80 
toString() const81 std::string TensorBase::toString() const {
82   std::string base_str;
83   if (scalar_type() == ScalarType::Undefined) {
84     base_str = "UndefinedType";
85   } else {
86     auto dispatchkey = options().computeDispatchKey();
87     std::string dispatchkey_str;
88     if (dispatchkey == c10::DispatchKey::PrivateUse1) {
89       dispatchkey_str = c10::get_privateuse1_backend();
90     } else if (dispatchkey == c10::DispatchKey::AutocastPrivateUse1) {
91       dispatchkey_str = "Autocast" + c10::get_privateuse1_backend();
92     } else {
93       dispatchkey_str = at::toString(dispatchkey);
94     }
95     base_str = dispatchkey_str + at::toString(scalar_type()) + "Type";
96   }
97   return base_str;
98 }
99 
variable_data() const100 TensorBase TensorBase::variable_data() const {
101   return impl::GetVariableHooks()->variable_data(*this);
102 }
103 
tensor_data() const104 TensorBase TensorBase::tensor_data() const {
105   return impl::GetVariableHooks()->tensor_data(*this);
106 }
107 
is_leaf() const108 bool TensorBase::is_leaf() const {
109   return impl::GetVariableHooks()->is_leaf(*this);
110 }
111 
output_nr() const112 int64_t TensorBase::output_nr() const {
113   return impl::GetVariableHooks()->output_nr(*this);
114 }
115 
set_data(const TensorBase & new_data) const116 void TensorBase::set_data(const TensorBase & new_data) const {
117   impl::GetVariableHooks()->set_data(*this, new_data);
118 }
119 
data() const120 TensorBase TensorBase::data() const {
121   return impl::GetVariableHooks()->data(*this);
122 }
123 
_version() const124 int64_t TensorBase::_version() const {
125   return impl::GetVariableHooks()->_version(*this);
126 }
127 
retain_grad() const128 void TensorBase::retain_grad() const {
129   impl::GetVariableHooks()->retain_grad(*this);
130 }
131 
retains_grad() const132 bool TensorBase::retains_grad() const {
133   return impl::GetVariableHooks()->retains_grad(*this);
134 }
135 
_backward(TensorList inputs,const std::optional<Tensor> & gradient,std::optional<bool> keep_graph,bool create_graph) const136 void Tensor::_backward(TensorList inputs,
137         const std::optional<Tensor>& gradient,
138         std::optional<bool> keep_graph,
139         bool create_graph) const {
140   return impl::GetVariableHooks()->_backward(*this, inputs, gradient, keep_graph, create_graph);
141 }
142 
requires_grad_(bool _requires_grad) const143 const TensorBase& TensorBase::requires_grad_(bool _requires_grad) const {
144   impl::GetVariableHooks()->requires_grad_(*this, _requires_grad);
145   return *this;
146 }
147 
148 // View Methods
149 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
150 
is_view() const151 bool TensorBase::is_view() const {
152   return impl::GetVariableHooks()->is_view(*this);
153 }
154 
_base() const155 const TensorBase& TensorBase::_base() const {
156   return impl::GetVariableHooks()->base(*this);
157 }
158 
name() const159 const std::string& TensorBase::name() const {
160   return impl::GetVariableHooks()->name(*this);
161 }
162 
grad_fn() const163 const std::shared_ptr<torch::autograd::Node>& TensorBase::grad_fn() const {
164   return impl::GetVariableHooks()->grad_fn(*this);
165 }
166 
remove_hook(unsigned pos) const167 void TensorBase::remove_hook(unsigned pos) const {
168   impl::GetVariableHooks()->remove_hook(*this, pos);
169 }
170 
_register_hook(std::function<TensorBase (const TensorBase &)> hook) const171 unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)> hook) const {
172   return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
173 }
174 
175 } // namespace at
176