xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/config.h>
2 #include <torch/csrc/lazy/core/tensor.h>
3 
4 #include <c10/util/irange.h>
5 #include <torch/csrc/lazy/core/helpers.h>
6 #include <torch/csrc/lazy/core/ir_builder.h>
7 #include <torch/csrc/lazy/core/ir_dump_util.h>
8 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
9 #include <torch/csrc/lazy/core/metrics.h>
10 #include <torch/csrc/lazy/core/tensor_impl.h>
11 #include <torch/csrc/lazy/core/tensor_util.h>
12 
13 #include <ATen/FunctionalTensorWrapper.h>
14 
15 namespace torch {
16 namespace lazy {
17 namespace {
GetOrCreateLtcTensor(const at::Tensor & tensor,const BackendDevice & device)18 LazyTensorPtr GetOrCreateLtcTensor(
19     const at::Tensor& tensor,
20     const BackendDevice& device) {
21   if (!tensor.defined()) {
22     return torch::lazy::LazyTensorPtr();
23   }
24   auto lazy_tensor = TryGetLtcTensor(tensor);
25   return lazy_tensor ? lazy_tensor : LazyTensor::Create(tensor, device);
26 }
27 } // namespace
28 
~Data()29 LazyTensor::Data::~Data() {
30   LazyGraphExecutor::Get()->UnregisterTensor(this);
31 }
32 
Create(const at::Tensor & tensor,const BackendDevice & device)33 LazyTensorPtr LazyTensor::Create(
34     const at::Tensor& tensor,
35     const BackendDevice& device) {
36   TORCH_CHECK(tensor.device().type() != at::kLazy);
37   LazyTensorPtr lazy_tensor =
38       c10::make_intrusive<LazyTensor>(LazyTensor(tensor, device));
39   LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
40   return lazy_tensor;
41 }
42 
Create(Value ir_value,const BackendDevice & device)43 LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) {
44   LazyTensorPtr lazy_tensor =
45       c10::make_intrusive<LazyTensor>(LazyTensor(std::move(ir_value), device));
46   LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
47   return lazy_tensor;
48 }
49 
Create(BackendDataPtr handle)50 LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) {
51   LazyTensorPtr lazy_tensor =
52       c10::make_intrusive<LazyTensor>(LazyTensor(std::move(handle)));
53   LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
54   return lazy_tensor;
55 }
56 
Create(std::shared_ptr<Data> data)57 LazyTensorPtr LazyTensor::Create(std::shared_ptr<Data> data) {
58   return c10::make_intrusive<LazyTensor>(LazyTensor(std::move(data)));
59 }
60 
LazyTensor(const at::Tensor & tensor,const BackendDevice & device)61 LazyTensor::LazyTensor(const at::Tensor& tensor, const BackendDevice& device)
62     : LazyTensor(std::make_shared<Data>(tensor, device)) {}
63 
LazyTensor(BackendDataPtr handle)64 LazyTensor::LazyTensor(BackendDataPtr handle)
65     : LazyTensor(std::make_shared<Data>(handle, handle->device())) {}
66 
LazyTensor(Value ir_value,const BackendDevice & device)67 LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device)
68     : LazyTensor(std::make_shared<Data>(std::move(ir_value), device)) {
69   TryLimitGraphSize();
70 }
71 
LazyTensor(std::shared_ptr<Data> data)72 LazyTensor::LazyTensor(std::shared_ptr<Data> data) : data_(std::move(data)) {}
73 
data() const74 auto LazyTensor::data() const -> const std::shared_ptr<Data>& {
75   TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor");
76   return data_;
77 }
78 
size(int64_t dim) const79 int64_t LazyTensor::size(int64_t dim) const {
80   auto tensor_shape = shape();
81   int rank = tensor_shape.Get().dim();
82   int dim_index = GetCanonicalDimensionIndex(dim, rank);
83   return tensor_shape.Get().size(dim_index);
84 }
85 
dtype() const86 at::ScalarType LazyTensor::dtype() const {
87   return shape().Get().scalar_type();
88 }
89 
shape() const90 MaybeRef<Shape> LazyTensor::shape() const {
91   if (data()->handle != nullptr) {
92     return Shape(data()->handle->shape());
93   }
94   if (data()->ir_value) {
95     // TODO(whc) remove shape from LazyTensor API too!
96     return data()->ir_value.shape();
97   }
98   TORCH_CHECK(data()->tensor_data);
99   return Shape(
100       data()->tensor_data->scalar_type(),
101       ToI64Vector(data()->tensor_data->sizes()));
102 }
103 
GetDevice() const104 const BackendDevice& LazyTensor::GetDevice() const {
105   return data()->device;
106 }
107 
GetUniqueId() const108 int64_t LazyTensor::GetUniqueId() const {
109   return data()->unique_id;
110 }
111 
GetDataHandle()112 BackendDataPtr LazyTensor::GetDataHandle() {
113   BackendDataPtr handle = CurrentDataHandle();
114   if (handle != nullptr) {
115     TORCH_CHECK(
116         handle->HasValue(),
117         "Trying to access data while an async operation is in flight: ",
118         handle->shape().to_string());
119     return handle;
120   }
121 
122   if (data()->ir_value) {
123     ApplyPendingGraph();
124   } else {
125     TORCH_CHECK(data()->tensor_data);
126     data()->handle = TensorToDataHandle(*data()->tensor_data, GetDevice());
127   }
128 
129   return data()->handle;
130 }
131 
CurrentDataHandle() const132 BackendDataPtr LazyTensor::CurrentDataHandle() const {
133   return data()->handle;
134 }
135 
SetDataHandle(BackendDataPtr handle)136 void LazyTensor::SetDataHandle(BackendDataPtr handle) {
137   SetDataHandle(std::move(handle), /*sync=*/true);
138 }
139 
SetDataHandle(BackendDataPtr handle,bool sync)140 void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) {
141   data()->handle = std::move(handle);
142   // Assigning a device data should always clear the IR node, to allow graph
143   // trimming.
144   AssignIrValue(Value());
145   if (sync) {
146     data()->tensor_data = std::nullopt;
147   }
148 }
149 
SetIrValue(Value ir_value)150 void LazyTensor::SetIrValue(Value ir_value) {
151   data()->handle = nullptr;
152   data()->tensor_data = std::nullopt;
153   AssignIrValue(std::move(ir_value));
154   TryLimitGraphSize();
155 }
156 
SetInPlaceIrValue(Value ir_value)157 void LazyTensor::SetInPlaceIrValue(Value ir_value) {
158   auto tensor_shape = shape();
159   if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) {
160     ir_value =
161         MakeCast(ir_value, tensor_shape.Get().scalar_type(), std::nullopt);
162   }
163   SetIrValue(std::move(ir_value));
164 }
165 
AssignIrValue(Value ir_value) const166 void LazyTensor::AssignIrValue(Value ir_value) const {
167   data()->ir_value = std::move(ir_value);
168   data()->generation += 1;
169 }
170 
TryLimitGraphSize()171 void LazyTensor::TryLimitGraphSize() {
172   if (data()->ir_value &&
173       LazyGraphExecutor::Get()->IncTrimCounter() %
174               FLAGS_torch_lazy_trim_graph_check_frequency ==
175           0) {
176     size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()});
177     if (static_cast<int64_t>(graph_size) > FLAGS_torch_lazy_trim_graph_size) {
178       TORCH_LAZY_COUNTER("TrimIrGraph", 1);
179       ApplyPendingGraph();
180     }
181   }
182 }
183 
GetIrValue() const184 Value LazyTensor::GetIrValue() const {
185   Value ir_value = CurrentIrValue();
186   if (ir_value) {
187     return ir_value;
188   }
189   BackendDataPtr handle = CurrentDataHandle();
190   if (handle != nullptr) {
191     // In case of tensor node, we do not clear the data when we set the IR
192     // node. This because we want further calls to GetIrValue() to fetch the
193     // same IR node, and not create new ones (even though the lowering context
194     // will still collapse them all into a single parameter op). So the call
195     // which wants the data will still find it, w/out having to fetch it via
196     // a computation client from-server call.
197     AssignIrValue(CreateTensorNode(handle, /*read_only=*/false));
198     return data()->ir_value;
199   }
200   std::optional<at::Tensor> tensor_data = CurrentTensorData();
201   TORCH_CHECK(tensor_data);
202   AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice()));
203   return data()->ir_value;
204 }
205 
CurrentIrValue() const206 Value LazyTensor::CurrentIrValue() const {
207   return data()->ir_value;
208 }
209 
SetTensorData(at::Tensor tensor_data)210 void LazyTensor::SetTensorData(at::Tensor tensor_data) {
211   data()->tensor_data = std::move(tensor_data);
212 }
213 
CurrentTensorData() const214 std::optional<at::Tensor> LazyTensor::CurrentTensorData() const {
215   return data()->tensor_data;
216 }
217 
GetIrValueForTensor(const at::Tensor & tensor,const BackendDevice & device) const218 Value LazyTensor::GetIrValueForTensor(
219     const at::Tensor& tensor,
220     const BackendDevice& device) const {
221   BackendDataPtr data;
222   bool read_only = false;
223   if (tensor.dim() == 0 && tensor.numel() == 1) {
224     at::Scalar value = tensor.item();
225     if (IsSpecialScalar(value)) {
226       return MakeScalar(value, tensor.scalar_type());
227     }
228     data = LazyGraphExecutor::Get()->GetDeviceData(tensor.cpu(), device);
229     read_only = true;
230   } else {
231     TORCH_LAZY_TIMED("IrValueTensorToDataHandle");
232     data = TensorToDataHandle(tensor, device);
233   }
234   return CreateTensorNode(std::move(data), read_only);
235 }
236 
ToTensor(bool detached)237 at::Tensor LazyTensor::ToTensor(bool detached) {
238   at::Tensor tensor;
239   std::optional<at::Tensor> tensor_data = CurrentTensorData();
240   if (!tensor_data) {
241     LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
242     // The GetDataHandle() call will trigger an ApplyPendingGraph() if an IR
243     // Node is available on the tensor.
244     std::vector<at::Tensor> tensors =
245         DataHandlesToTensors({GetDataHandle()}, dtype());
246     tensor = std::move(tensors.front());
247     if (!detached) {
248       SetTensorData(tensor);
249     }
250   } else {
251     tensor = *tensor_data;
252     if (detached) {
253       if (data()->ir_value || data()->handle != nullptr) {
254         // If we have other authoritive sources, just drop our reference and
255         // transfer it to the caller.
256         data()->tensor_data = std::nullopt;
257       } else {
258         // Otherwise we need to make a copy to prevent the caller changing our
259         // version.
260         tensor = CopyTensor(tensor);
261       }
262     }
263   }
264   return tensor;
265 }
266 
ShallowCopyTo(LazyTensorPtr dest) const267 void LazyTensor::ShallowCopyTo(LazyTensorPtr dest) const {
268   dest->SetIrValue(GetIrValue());
269 }
270 
SetTensor(at::Tensor tensor)271 void LazyTensor::SetTensor(at::Tensor tensor) {
272   SetTensorData(tensor);
273   data()->handle = nullptr;
274   AssignIrValue(Value());
275 }
276 
UpdateFromTensor(at::Tensor tensor,bool sync)277 void LazyTensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
278   if (sync) {
279     at::Tensor typed_tensor = CopyTensor(tensor, dtype(), /*copy=*/false);
280     SetIrValue(GetIrValueForTensor(typed_tensor, GetDevice()));
281   } else {
282     SetTensorData(tensor);
283     data()->handle = nullptr;
284     AssignIrValue(Value());
285   }
286 }
287 
UpdateFromTensorOut(at::Tensor tensor)288 void LazyTensor::UpdateFromTensorOut(at::Tensor tensor) {
289   UpdateFromTensor(std::move(tensor), /*sync=*/false);
290 }
291 
UpdateFromTensorOut(const LazyTensorPtr & tensor)292 void LazyTensor::UpdateFromTensorOut(const LazyTensorPtr& tensor) {
293   SetIrValue(tensor->GetIrValue());
294 }
295 
CreateTensorNode(BackendDataPtr data,bool read_only) const296 Value LazyTensor::CreateTensorNode(BackendDataPtr data, bool read_only) const {
297   data->SetInfo(std::make_shared<LazyGraphExecutor::DeviceDataInfo>(
298       GetUniqueId(), read_only));
299   return MakeDeviceData(std::move(data));
300 }
301 
MakeOutputTensors(NodePtr node) const302 std::vector<LazyTensorPtr> LazyTensor::MakeOutputTensors(NodePtr node) const {
303   std::vector<LazyTensorPtr> tensors;
304   tensors.reserve(node->num_outputs());
305   for (const auto i : c10::irange(node->num_outputs())) {
306     tensors.push_back(Create(Value(node, i), GetDevice()));
307   }
308   return tensors;
309 }
310 
CopyTensorToDevice(const BackendDevice & device)311 LazyTensorPtr LazyTensor::CopyTensorToDevice(const BackendDevice& device) {
312   // TODO: This can be optimized.
313   return Create(ToTensor(/*detached=*/true), device);
314 }
315 
ApplyPendingGraph()316 void LazyTensor::ApplyPendingGraph() {
317   LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
318   // This method is called to ensure that the tensor data is available on
319   // device, so that a call to CurrentDataHandle() returns a valid pointer.
320   if (CurrentDataHandle() == nullptr) {
321     std::vector<LazyTensorPtr> tensors(
322         {c10::make_intrusive<LazyTensor>(LazyTensor(*this))});
323     LazyGraphExecutor::Get()->SyncTensorsGraph(
324         &tensors,
325         {},
326         /*wait=*/true,
327         /*sync_ltc_data=*/false);
328   }
329 }
330 
GetNextTensorId()331 int64_t LazyTensor::GetNextTensorId() {
332   static std::atomic<int64_t>* id_generator = new std::atomic<int64_t>(1);
333   return id_generator->fetch_add(1);
334 }
335 
GetTensorList(at::ITensorListRef tensors)336 torch::lazy::Value GetTensorList(at::ITensorListRef tensors) {
337   std::vector<Value> values;
338   for (const auto& t : tensors) {
339     auto* impl = dynamic_cast<LTCTensorImpl*>(t.unsafeGetTensorImpl());
340     TORCH_INTERNAL_ASSERT(
341         impl,
342         "GetTensorList only supports lists of valid tensors, but optional support could be added");
343     values.push_back(impl->tensor()->GetIrValue());
344   }
345 
346   return torch::lazy::Value(torch::lazy::MakeTensorList(std::move(values)));
347 }
348 
TryGetLtcTensor(const at::Tensor & tensor)349 LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) {
350   auto* impl = dynamic_cast<LTCTensorImpl*>(
351       maybe_unwrap_functional(tensor).unsafeGetTensorImpl());
352   if (impl == nullptr) {
353     // return c10::make_intrusive<LazyTensor>();
354     return LazyTensorPtr();
355   }
356   return impl->tensor();
357 }
358 
GetLtcTensor(const at::Tensor & tensor)359 LazyTensorPtr GetLtcTensor(const at::Tensor& tensor) {
360   auto lazy_tensor = TryGetLtcTensor(tensor);
361   TORCH_CHECK(
362       lazy_tensor, "Input tensor is not a lazy tensor: ", tensor.toString());
363   return lazy_tensor;
364 }
365 
GetLtcTensors(c10::ArrayRef<at::Tensor> tensors)366 std::vector<LazyTensorPtr> GetLtcTensors(c10::ArrayRef<at::Tensor> tensors) {
367   std::vector<LazyTensorPtr> ltc_tensors;
368   ltc_tensors.reserve(tensors.size());
369   for (const auto& tensor : tensors) {
370     ltc_tensors.emplace_back(TryGetLtcTensor(tensor));
371   }
372   return ltc_tensors;
373 }
374 
GetOrCreateLtcTensor(const std::optional<at::Tensor> & tensor,const BackendDevice & device)375 LazyTensorPtr GetOrCreateLtcTensor(
376     const std::optional<at::Tensor>& tensor,
377     const BackendDevice& device) {
378   return GetOrCreateLtcTensor(tensor.value_or(at::Tensor()), device);
379 }
380 
GetLtcTensorOrCreateForWrappedNumber(const at::Tensor & tensor,const BackendDevice & device)381 LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
382     const at::Tensor& tensor,
383     const BackendDevice& device) {
384   // TODO: There are places in core where a scalar is wrapped but not marked as
385   // wrapped.
386   return (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
387           (tensor.dim() == 0 && tensor.numel() == 1))
388       ? GetOrCreateLtcTensor(tensor, device)
389       : GetLtcTensor(tensor);
390 }
391 
CreateAtenFromLtcTensor(const LazyTensorPtr & ltc_tensor)392 at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor) {
393   return ltc_tensor ? at::Tensor(c10::make_intrusive<LTCTensorImpl>(ltc_tensor))
394                     : at::Tensor();
395 }
396 
CreateAtenFromLtcTensor(LazyTensor && ltc_tensor)397 at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) {
398   return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor)));
399 }
400 
to_lazy_tensor(const at::Tensor & self,const c10::TensorOptions & options,at::Device device,bool non_blocking,bool functionalize_output)401 at::Tensor to_lazy_tensor(
402     const at::Tensor& self,
403     const c10::TensorOptions& options,
404     at::Device device,
405     bool non_blocking,
406     bool functionalize_output) {
407   TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy);
408   TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy);
409 
410   auto eager_tensor =
411       self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
412   auto lazy_self = torch::lazy::GetOrCreateLtcTensor(
413       eager_tensor, torch::lazy::atenDeviceToBackendDevice(device));
414   auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self);
415   if (functionalize_output) {
416     // See Note [Lazy Tensor Functionalization]
417     return at::functionalization::impl::to_functional_tensor(out);
418   } else {
419     return out;
420   }
421 }
422 
423 } // namespace lazy
424 } // namespace torch
425