1 #include <torch/csrc/lazy/core/tensor_impl.h>
2
3 #include <c10/core/Allocator.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/impl/DeviceGuardImplInterface.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/lazy/core/ir_builder.h>
9 #include <torch/csrc/lazy/core/tensor_util.h>
10
11 namespace torch {
12 namespace lazy {
13 namespace {
14
15 // LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to
16 // make sure that some particular tensors within the life scope of the guard are
17 // on the same device. For example, in RegisterCompositeExplicitAutograd.cpp,
18 // outputs of each op are examined if they are on same device as the supplied
19 // TensorOptions. For more information, see DeviceGuard.h. For ops that have LTC
20 // native function implementations, this guard is omitted.
21 thread_local c10::Device g_device(c10::DeviceType::Lazy);
22
23 struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface {
typetorch::lazy::__anon9001d5ec0111::LTCGuardImpl24 at::DeviceType type() const override {
25 return at::DeviceType::Lazy;
26 }
27
exchangeDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl28 c10::Device exchangeDevice(c10::Device device) const override {
29 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
30 auto old_device = g_device;
31 g_device = device;
32 return old_device;
33 }
34
getDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl35 c10::Device getDevice() const override {
36 return g_device;
37 }
38
setDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl39 void setDevice(c10::Device device) const override {
40 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
41 g_device = device;
42 }
43
uncheckedSetDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl44 void uncheckedSetDevice(c10::Device device) const noexcept override {
45 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
46 g_device = device;
47 }
48
getStreamtorch::lazy::__anon9001d5ec0111::LTCGuardImpl49 c10::Stream getStream(c10::Device device) const noexcept override {
50 TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
51 return c10::Stream(c10::Stream::DEFAULT, device);
52 }
53
exchangeStreamtorch::lazy::__anon9001d5ec0111::LTCGuardImpl54 c10::Stream exchangeStream(c10::Stream _unused) const noexcept override {
55 return c10::Stream(c10::Stream::DEFAULT, g_device);
56 }
57
deviceCounttorch::lazy::__anon9001d5ec0111::LTCGuardImpl58 c10::DeviceIndex deviceCount() const noexcept override {
59 // This will get called when autograd initializes its device pool
60 // regardless whether we have a backend registered aforehand.
61 if (!hasBackend()) {
62 return 0;
63 }
64
65 return getBackend()->GetBackendDevices().size();
66 }
67 };
68
69 C10_REGISTER_GUARD_IMPL(Lazy, LTCGuardImpl);
70
71 } // namespace
72
73 // TODO(whc) when do we want to clone vs share?
LTCTensorImpl(const LazyTensorPtr & tensor)74 LTCTensorImpl::LTCTensorImpl(const LazyTensorPtr& tensor)
75 : LTCTensorImpl(LazyTensor(*tensor)) {}
76
LTCTensorImpl(const LazyTensor & tensor)77 LTCTensorImpl::LTCTensorImpl(const LazyTensor& tensor)
78 : LTCTensorImpl(LazyTensor(tensor)) {}
79
LTCTensorImpl(LazyTensor && tensor)80 LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
81 : c10::TensorImpl(
82 c10::DispatchKeySet{
83 c10::DispatchKey::Lazy,
84 c10::DispatchKey::AutogradLazy},
85 c10::scalarTypeToTypeMeta(tensor.dtype()),
86 backendDeviceToAtenDevice(tensor.GetDevice())),
87 tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) {
88 set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
89 }
90
set_tensor(const LazyTensorPtr & lazy_tensor)91 void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
92 tensor_ = c10::make_intrusive<LazyTensor>(*lazy_tensor);
93 generation_ = 0;
94 }
95
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const96 c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
97 const c10::VariableVersion& version_counter,
98 bool allow_tensor_metadata_change) const {
99 auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
100 copy_tensor_metadata(
101 /*src_impl=*/this,
102 /*dest_impl=*/impl.get(),
103 /*version_counter=*/version_counter,
104 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
105 return impl;
106 }
107
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const108 c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
109 c10::VariableVersion&& version_counter,
110 bool allow_tensor_metadata_change) const {
111 auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
112 copy_tensor_metadata(
113 /*src_impl=*/this,
114 /*dest_impl=*/impl.get(),
115 /*version_counter=*/std::move(version_counter),
116 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
117 return impl;
118 }
119
shallow_copy_from(const c10::intrusive_ptr<TensorImpl> & impl)120 void LTCTensorImpl::shallow_copy_from(
121 const c10::intrusive_ptr<TensorImpl>& impl) {
122 LTCTensorImpl* ltc_impl = dynamic_cast<LTCTensorImpl*>(impl.get());
123 TORCH_INTERNAL_ASSERT(ltc_impl);
124 copy_tensor_metadata(
125 /*src_impl=*/ltc_impl,
126 /*dest_impl=*/this,
127 /*version_counter=*/version_counter(),
128 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
129 ltc_impl->tensor_->ShallowCopyTo(tensor_);
130 generation_ = 0;
131 }
132
sym_strides_custom() const133 c10::SymIntArrayRef LTCTensorImpl::sym_strides_custom() const {
134 return c10::fromIntArrayRefKnownNonNegative(strides_custom());
135 }
136
sym_sizes_custom() const137 c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
138 return c10::fromIntArrayRefKnownNonNegative(sizes_custom());
139 }
140
sym_numel_custom() const141 c10::SymInt LTCTensorImpl::sym_numel_custom() const {
142 return numel_custom();
143 }
144
setup_size_properties()145 void LTCTensorImpl::setup_size_properties() {
146 size_t generation = tensor_->generation();
147 if (generation != generation_) {
148 // Fill up the basic dimension data members which the base class
149 // implementation uses in its APIs.
150 auto shape = tensor_->shape();
151 // We can't call refresh_numel() given we override sizes() too.
152 numel_ = shape.Get().numel();
153 sizes_and_strides_.set_sizes(shape.Get().sizes());
154 // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given
155 // we override sizes() too.
156 std::vector<int64_t> updated_strides;
157 updated_strides = ComputeArrayStrides(shape.Get().sizes());
158 for (const auto i : c10::irange(updated_strides.size())) {
159 sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
160 }
161 generation_ = generation;
162 }
163 }
164
sizes_custom() const165 at::IntArrayRef LTCTensorImpl::sizes_custom() const {
166 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
167 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
168 return sizes_default();
169 }
170
strides_custom() const171 at::IntArrayRef LTCTensorImpl::strides_custom() const {
172 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
173 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
174 return strides_default();
175 }
176
dim_custom() const177 int64_t LTCTensorImpl::dim_custom() const {
178 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
179 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
180 return dim_default();
181 }
182
numel_custom() const183 int64_t LTCTensorImpl::numel_custom() const {
184 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
185 const_cast<LTCTensorImpl*>(this)->setup_size_properties();
186 return numel_default();
187 }
188
storage_offset_custom() const189 int64_t LTCTensorImpl::storage_offset_custom() const {
190 return 0;
191 }
192
is_strides_like_custom(c10::MemoryFormat memory_format) const193 bool LTCTensorImpl::is_strides_like_custom(
194 c10::MemoryFormat memory_format) const {
195 TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous);
196 return false;
197 }
198
is_non_overlapping_and_dense_custom() const199 bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
200 // This should be true, but false as a temporary fix for a PyTorch core issue,
201 // according to https://github.com/pytorch/xla/pull/2682.
202 return false;
203 }
204
is_contiguous_custom(c10::MemoryFormat _unused) const205 bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
206 // TODO(ezyang): I don't think this branch is actually necessary
207 // TODO(ezyang): I don't think this logic is right, shouldn't we pass on
208 // the memory format?
209 if (tensor_->CurrentTensorData()) {
210 return tensor_->CurrentTensorData()->is_contiguous();
211 }
212 // Only check that the storage is already contiguous.
213 TORCH_CHECK(is_contiguous_, "Non-contiguous storage for lazy tensor");
214 // TODO: I don't think logic is right, we should check the requested memory
215 // format before returning true
216 return true;
217 }
218
219 } // namespace lazy
220 } // namespace torch
221