1 #include <ATen/core/Tensor.h>
2 #include <ATen/core/Formatting.h>
3 #include <ATen/core/VariableHooksInterface.h>
4 #include <ATen/core/LegacyTypeDispatch.h>
5 #include <ATen/FunctionalTensorWrapper.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/MethodOperators.h>
9 #else
10 #include <ATen/ops/contiguous_ops.h>
11 #include <ATen/ops/fill_ops.h>
12 #include <ATen/ops/to_ops.h>
13 #include <ATen/ops/zero_ops.h>
14 #endif
15
16 #include <iostream>
17
18 namespace at {
19
get_tensor_base(const Tensor & t)20 const TensorBase& get_tensor_base(const Tensor &t) {
21 return t;
22 }
23
__dispatch_contiguous(c10::MemoryFormat memory_format) const24 TensorBase TensorBase::__dispatch_contiguous(c10::MemoryFormat memory_format) const {
25 OptionalTensorRef self(*this);
26 return at::_ops::contiguous::call(*self, memory_format);
27 }
28
fill_(const c10::Scalar & fill_value) const29 const TensorBase& TensorBase::fill_(const c10::Scalar &fill_value) const {
30 Tensor self(*this);
31 at::_ops::fill__Scalar::call(self, fill_value);
32 return *this;
33 }
34
zero_() const35 const TensorBase& TensorBase::zero_() const {
36 Tensor self(*this);
37 at::_ops::zero_::call(self);
38 return *this;
39 }
40
to(at::TensorOptions options,bool non_blocking,bool copy,std::optional<at::MemoryFormat> memory_format) const41 TensorBase TensorBase::to(
42 at::TensorOptions options,
43 bool non_blocking,
44 bool copy,
45 std::optional<at::MemoryFormat> memory_format) const {
46 Tensor self(*this);
47 return at::_ops::to_dtype_layout::call(
48 self, optTypeMetaToScalarType(options.dtype_opt()),
49 options.layout_opt(), options.device_opt(),
50 options.pinned_memory_opt(), non_blocking, copy, memory_format);
51 }
52
enforce_invariants()53 void TensorBase::enforce_invariants() {
54 if (impl_.get() == nullptr) {
55 throw std::runtime_error("TensorImpl with nullptr is not supported");
56 }
57 // Following line throws if the method is not a POD data type or is not
58 // supported by ATen
59 scalar_type();
60 if (defined()) {
61 TORCH_INTERNAL_ASSERT(
62 impl_->dtype_initialized(),
63 "Partially-initialized tensor not supported by Tensor");
64 TORCH_INTERNAL_ASSERT(
65 !impl_->is_sparse(),
66 "Sparse Tensors are supported by Tensor, but invariant checking isn't implemented. Please file a bug.");
67 TORCH_INTERNAL_ASSERT(
68 !impl_->has_storage() || impl_->is_meta() || impl_->storage_initialized(),
69 "Partially-initialized tensor not supported by Tensor");
70 }
71 }
72
print() const73 void TensorBase::print() const {
74 if (defined()) {
75 std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
76 } else {
77 std::cerr << "[UndefinedTensor]" << '\n';
78 }
79 }
80
toString() const81 std::string TensorBase::toString() const {
82 std::string base_str;
83 if (scalar_type() == ScalarType::Undefined) {
84 base_str = "UndefinedType";
85 } else {
86 auto dispatchkey = options().computeDispatchKey();
87 std::string dispatchkey_str;
88 if (dispatchkey == c10::DispatchKey::PrivateUse1) {
89 dispatchkey_str = c10::get_privateuse1_backend();
90 } else if (dispatchkey == c10::DispatchKey::AutocastPrivateUse1) {
91 dispatchkey_str = "Autocast" + c10::get_privateuse1_backend();
92 } else {
93 dispatchkey_str = at::toString(dispatchkey);
94 }
95 base_str = dispatchkey_str + at::toString(scalar_type()) + "Type";
96 }
97 return base_str;
98 }
99
variable_data() const100 TensorBase TensorBase::variable_data() const {
101 return impl::GetVariableHooks()->variable_data(*this);
102 }
103
tensor_data() const104 TensorBase TensorBase::tensor_data() const {
105 return impl::GetVariableHooks()->tensor_data(*this);
106 }
107
is_leaf() const108 bool TensorBase::is_leaf() const {
109 return impl::GetVariableHooks()->is_leaf(*this);
110 }
111
output_nr() const112 int64_t TensorBase::output_nr() const {
113 return impl::GetVariableHooks()->output_nr(*this);
114 }
115
set_data(const TensorBase & new_data) const116 void TensorBase::set_data(const TensorBase & new_data) const {
117 impl::GetVariableHooks()->set_data(*this, new_data);
118 }
119
data() const120 TensorBase TensorBase::data() const {
121 return impl::GetVariableHooks()->data(*this);
122 }
123
_version() const124 int64_t TensorBase::_version() const {
125 return impl::GetVariableHooks()->_version(*this);
126 }
127
retain_grad() const128 void TensorBase::retain_grad() const {
129 impl::GetVariableHooks()->retain_grad(*this);
130 }
131
retains_grad() const132 bool TensorBase::retains_grad() const {
133 return impl::GetVariableHooks()->retains_grad(*this);
134 }
135
_backward(TensorList inputs,const std::optional<Tensor> & gradient,std::optional<bool> keep_graph,bool create_graph) const136 void Tensor::_backward(TensorList inputs,
137 const std::optional<Tensor>& gradient,
138 std::optional<bool> keep_graph,
139 bool create_graph) const {
140 return impl::GetVariableHooks()->_backward(*this, inputs, gradient, keep_graph, create_graph);
141 }
142
requires_grad_(bool _requires_grad) const143 const TensorBase& TensorBase::requires_grad_(bool _requires_grad) const {
144 impl::GetVariableHooks()->requires_grad_(*this, _requires_grad);
145 return *this;
146 }
147
148 // View Methods
149 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
150
is_view() const151 bool TensorBase::is_view() const {
152 return impl::GetVariableHooks()->is_view(*this);
153 }
154
_base() const155 const TensorBase& TensorBase::_base() const {
156 return impl::GetVariableHooks()->base(*this);
157 }
158
name() const159 const std::string& TensorBase::name() const {
160 return impl::GetVariableHooks()->name(*this);
161 }
162
grad_fn() const163 const std::shared_ptr<torch::autograd::Node>& TensorBase::grad_fn() const {
164 return impl::GetVariableHooks()->grad_fn(*this);
165 }
166
remove_hook(unsigned pos) const167 void TensorBase::remove_hook(unsigned pos) const {
168 impl::GetVariableHooks()->remove_hook(*this, pos);
169 }
170
_register_hook(std::function<TensorBase (const TensorBase &)> hook) const171 unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)> hook) const {
172 return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
173 }
174
175 } // namespace at
176