xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/tensor_impl.h>
2 
3 #include <c10/core/Allocator.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/impl/DeviceGuardImplInterface.h>
6 #include <c10/macros/Macros.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/lazy/core/ir_builder.h>
9 #include <torch/csrc/lazy/core/tensor_util.h>
10 
11 namespace torch {
12 namespace lazy {
13 namespace {
14 
15 // LTCGuardImpl is used by CompositeExplicitAutograd ops or eager fallbacks to
16 // make sure that some particular tensors within the life scope of the guard are
17 // on the same device. For example, in RegisterCompositeExplicitAutograd.cpp,
18 // outputs of each op are examined if they are on same device as the supplied
19 // TensorOptions. For more information, see DeviceGuard.h. For ops that have LTC
20 // native function implementations, this guard is omitted.
21 thread_local c10::Device g_device(c10::DeviceType::Lazy);
22 
23 struct LTCGuardImpl : public c10::impl::DeviceGuardImplInterface {
typetorch::lazy::__anon9001d5ec0111::LTCGuardImpl24   at::DeviceType type() const override {
25     return at::DeviceType::Lazy;
26   }
27 
exchangeDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl28   c10::Device exchangeDevice(c10::Device device) const override {
29     TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
30     auto old_device = g_device;
31     g_device = device;
32     return old_device;
33   }
34 
getDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl35   c10::Device getDevice() const override {
36     return g_device;
37   }
38 
setDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl39   void setDevice(c10::Device device) const override {
40     TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
41     g_device = device;
42   }
43 
uncheckedSetDevicetorch::lazy::__anon9001d5ec0111::LTCGuardImpl44   void uncheckedSetDevice(c10::Device device) const noexcept override {
45     TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
46     g_device = device;
47   }
48 
getStreamtorch::lazy::__anon9001d5ec0111::LTCGuardImpl49   c10::Stream getStream(c10::Device device) const noexcept override {
50     TORCH_INTERNAL_ASSERT(device.type() == c10::DeviceType::Lazy);
51     return c10::Stream(c10::Stream::DEFAULT, device);
52   }
53 
exchangeStreamtorch::lazy::__anon9001d5ec0111::LTCGuardImpl54   c10::Stream exchangeStream(c10::Stream _unused) const noexcept override {
55     return c10::Stream(c10::Stream::DEFAULT, g_device);
56   }
57 
deviceCounttorch::lazy::__anon9001d5ec0111::LTCGuardImpl58   c10::DeviceIndex deviceCount() const noexcept override {
59     // This will get called when autograd initializes its device pool
60     // regardless whether we have a backend registered aforehand.
61     if (!hasBackend()) {
62       return 0;
63     }
64 
65     return getBackend()->GetBackendDevices().size();
66   }
67 };
68 
69 C10_REGISTER_GUARD_IMPL(Lazy, LTCGuardImpl);
70 
71 } // namespace
72 
73 // TODO(whc) when do we want to clone vs share?
LTCTensorImpl(const LazyTensorPtr & tensor)74 LTCTensorImpl::LTCTensorImpl(const LazyTensorPtr& tensor)
75     : LTCTensorImpl(LazyTensor(*tensor)) {}
76 
LTCTensorImpl(const LazyTensor & tensor)77 LTCTensorImpl::LTCTensorImpl(const LazyTensor& tensor)
78     : LTCTensorImpl(LazyTensor(tensor)) {}
79 
LTCTensorImpl(LazyTensor && tensor)80 LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
81     : c10::TensorImpl(
82           c10::DispatchKeySet{
83               c10::DispatchKey::Lazy,
84               c10::DispatchKey::AutogradLazy},
85           c10::scalarTypeToTypeMeta(tensor.dtype()),
86           backendDeviceToAtenDevice(tensor.GetDevice())),
87       tensor_(c10::make_intrusive<LazyTensor>(std::move(tensor))) {
88   set_custom_sizes_strides(SizesStridesPolicy::CustomSizes);
89 }
90 
set_tensor(const LazyTensorPtr & lazy_tensor)91 void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
92   tensor_ = c10::make_intrusive<LazyTensor>(*lazy_tensor);
93   generation_ = 0;
94 }
95 
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const96 c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
97     const c10::VariableVersion& version_counter,
98     bool allow_tensor_metadata_change) const {
99   auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
100   copy_tensor_metadata(
101       /*src_impl=*/this,
102       /*dest_impl=*/impl.get(),
103       /*version_counter=*/version_counter,
104       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
105   return impl;
106 }
107 
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const108 c10::intrusive_ptr<c10::TensorImpl> LTCTensorImpl::shallow_copy_and_detach(
109     c10::VariableVersion&& version_counter,
110     bool allow_tensor_metadata_change) const {
111   auto impl = c10::make_intrusive<LTCTensorImpl>(tensor_);
112   copy_tensor_metadata(
113       /*src_impl=*/this,
114       /*dest_impl=*/impl.get(),
115       /*version_counter=*/std::move(version_counter),
116       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
117   return impl;
118 }
119 
shallow_copy_from(const c10::intrusive_ptr<TensorImpl> & impl)120 void LTCTensorImpl::shallow_copy_from(
121     const c10::intrusive_ptr<TensorImpl>& impl) {
122   LTCTensorImpl* ltc_impl = dynamic_cast<LTCTensorImpl*>(impl.get());
123   TORCH_INTERNAL_ASSERT(ltc_impl);
124   copy_tensor_metadata(
125       /*src_impl=*/ltc_impl,
126       /*dest_impl=*/this,
127       /*version_counter=*/version_counter(),
128       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
129   ltc_impl->tensor_->ShallowCopyTo(tensor_);
130   generation_ = 0;
131 }
132 
sym_strides_custom() const133 c10::SymIntArrayRef LTCTensorImpl::sym_strides_custom() const {
134   return c10::fromIntArrayRefKnownNonNegative(strides_custom());
135 }
136 
sym_sizes_custom() const137 c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
138   return c10::fromIntArrayRefKnownNonNegative(sizes_custom());
139 }
140 
sym_numel_custom() const141 c10::SymInt LTCTensorImpl::sym_numel_custom() const {
142   return numel_custom();
143 }
144 
setup_size_properties()145 void LTCTensorImpl::setup_size_properties() {
146   size_t generation = tensor_->generation();
147   if (generation != generation_) {
148     // Fill up the basic dimension data members which the base class
149     // implementation uses in its APIs.
150     auto shape = tensor_->shape();
151     // We can't call refresh_numel() given we override sizes() too.
152     numel_ = shape.Get().numel();
153     sizes_and_strides_.set_sizes(shape.Get().sizes());
154     // We can't call empty_tensor_restride(c10::MemoryFormat::Contiguous) given
155     // we override sizes() too.
156     std::vector<int64_t> updated_strides;
157     updated_strides = ComputeArrayStrides(shape.Get().sizes());
158     for (const auto i : c10::irange(updated_strides.size())) {
159       sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
160     }
161     generation_ = generation;
162   }
163 }
164 
sizes_custom() const165 at::IntArrayRef LTCTensorImpl::sizes_custom() const {
166   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
167   const_cast<LTCTensorImpl*>(this)->setup_size_properties();
168   return sizes_default();
169 }
170 
strides_custom() const171 at::IntArrayRef LTCTensorImpl::strides_custom() const {
172   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
173   const_cast<LTCTensorImpl*>(this)->setup_size_properties();
174   return strides_default();
175 }
176 
dim_custom() const177 int64_t LTCTensorImpl::dim_custom() const {
178   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
179   const_cast<LTCTensorImpl*>(this)->setup_size_properties();
180   return dim_default();
181 }
182 
numel_custom() const183 int64_t LTCTensorImpl::numel_custom() const {
184   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
185   const_cast<LTCTensorImpl*>(this)->setup_size_properties();
186   return numel_default();
187 }
188 
storage_offset_custom() const189 int64_t LTCTensorImpl::storage_offset_custom() const {
190   return 0;
191 }
192 
is_strides_like_custom(c10::MemoryFormat memory_format) const193 bool LTCTensorImpl::is_strides_like_custom(
194     c10::MemoryFormat memory_format) const {
195   TORCH_INTERNAL_ASSERT(memory_format != at::MemoryFormat::Contiguous);
196   return false;
197 }
198 
is_non_overlapping_and_dense_custom() const199 bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const {
200   // This should be true, but false as a temporary fix for a PyTorch core issue,
201   // according to https://github.com/pytorch/xla/pull/2682.
202   return false;
203 }
204 
is_contiguous_custom(c10::MemoryFormat _unused) const205 bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const {
206   // TODO(ezyang): I don't think this branch is actually necessary
207   // TODO(ezyang): I don't think this logic is right, shouldn't we pass on
208   // the memory format?
209   if (tensor_->CurrentTensorData()) {
210     return tensor_->CurrentTensorData()->is_contiguous();
211   }
212   // Only check that the storage is already contiguous.
213   TORCH_CHECK(is_contiguous_, "Non-contiguous storage for lazy tensor");
214   // TODO: I don't think logic is right, we should check the requested memory
215   // format before returning true
216   return true;
217 }
218 
219 } // namespace lazy
220 } // namespace torch
221