1 #include <torch/csrc/lazy/core/config.h>
2 #include <torch/csrc/lazy/core/tensor.h>
3
4 #include <c10/util/irange.h>
5 #include <torch/csrc/lazy/core/helpers.h>
6 #include <torch/csrc/lazy/core/ir_builder.h>
7 #include <torch/csrc/lazy/core/ir_dump_util.h>
8 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
9 #include <torch/csrc/lazy/core/metrics.h>
10 #include <torch/csrc/lazy/core/tensor_impl.h>
11 #include <torch/csrc/lazy/core/tensor_util.h>
12
13 #include <ATen/FunctionalTensorWrapper.h>
14
15 namespace torch {
16 namespace lazy {
17 namespace {
GetOrCreateLtcTensor(const at::Tensor & tensor,const BackendDevice & device)18 LazyTensorPtr GetOrCreateLtcTensor(
19 const at::Tensor& tensor,
20 const BackendDevice& device) {
21 if (!tensor.defined()) {
22 return torch::lazy::LazyTensorPtr();
23 }
24 auto lazy_tensor = TryGetLtcTensor(tensor);
25 return lazy_tensor ? lazy_tensor : LazyTensor::Create(tensor, device);
26 }
27 } // namespace
28
~Data()29 LazyTensor::Data::~Data() {
30 LazyGraphExecutor::Get()->UnregisterTensor(this);
31 }
32
Create(const at::Tensor & tensor,const BackendDevice & device)33 LazyTensorPtr LazyTensor::Create(
34 const at::Tensor& tensor,
35 const BackendDevice& device) {
36 TORCH_CHECK(tensor.device().type() != at::kLazy);
37 LazyTensorPtr lazy_tensor =
38 c10::make_intrusive<LazyTensor>(LazyTensor(tensor, device));
39 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
40 return lazy_tensor;
41 }
42
Create(Value ir_value,const BackendDevice & device)43 LazyTensorPtr LazyTensor::Create(Value ir_value, const BackendDevice& device) {
44 LazyTensorPtr lazy_tensor =
45 c10::make_intrusive<LazyTensor>(LazyTensor(std::move(ir_value), device));
46 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
47 return lazy_tensor;
48 }
49
Create(BackendDataPtr handle)50 LazyTensorPtr LazyTensor::Create(BackendDataPtr handle) {
51 LazyTensorPtr lazy_tensor =
52 c10::make_intrusive<LazyTensor>(LazyTensor(std::move(handle)));
53 LazyGraphExecutor::Get()->RegisterTensor(lazy_tensor->data());
54 return lazy_tensor;
55 }
56
Create(std::shared_ptr<Data> data)57 LazyTensorPtr LazyTensor::Create(std::shared_ptr<Data> data) {
58 return c10::make_intrusive<LazyTensor>(LazyTensor(std::move(data)));
59 }
60
LazyTensor(const at::Tensor & tensor,const BackendDevice & device)61 LazyTensor::LazyTensor(const at::Tensor& tensor, const BackendDevice& device)
62 : LazyTensor(std::make_shared<Data>(tensor, device)) {}
63
LazyTensor(BackendDataPtr handle)64 LazyTensor::LazyTensor(BackendDataPtr handle)
65 : LazyTensor(std::make_shared<Data>(handle, handle->device())) {}
66
LazyTensor(Value ir_value,const BackendDevice & device)67 LazyTensor::LazyTensor(Value ir_value, const BackendDevice& device)
68 : LazyTensor(std::make_shared<Data>(std::move(ir_value), device)) {
69 TryLimitGraphSize();
70 }
71
LazyTensor(std::shared_ptr<Data> data)72 LazyTensor::LazyTensor(std::shared_ptr<Data> data) : data_(std::move(data)) {}
73
data() const74 auto LazyTensor::data() const -> const std::shared_ptr<Data>& {
75 TORCH_CHECK(data_ != nullptr, "Trying to access a null cursor");
76 return data_;
77 }
78
size(int64_t dim) const79 int64_t LazyTensor::size(int64_t dim) const {
80 auto tensor_shape = shape();
81 int rank = tensor_shape.Get().dim();
82 int dim_index = GetCanonicalDimensionIndex(dim, rank);
83 return tensor_shape.Get().size(dim_index);
84 }
85
dtype() const86 at::ScalarType LazyTensor::dtype() const {
87 return shape().Get().scalar_type();
88 }
89
shape() const90 MaybeRef<Shape> LazyTensor::shape() const {
91 if (data()->handle != nullptr) {
92 return Shape(data()->handle->shape());
93 }
94 if (data()->ir_value) {
95 // TODO(whc) remove shape from LazyTensor API too!
96 return data()->ir_value.shape();
97 }
98 TORCH_CHECK(data()->tensor_data);
99 return Shape(
100 data()->tensor_data->scalar_type(),
101 ToI64Vector(data()->tensor_data->sizes()));
102 }
103
GetDevice() const104 const BackendDevice& LazyTensor::GetDevice() const {
105 return data()->device;
106 }
107
GetUniqueId() const108 int64_t LazyTensor::GetUniqueId() const {
109 return data()->unique_id;
110 }
111
GetDataHandle()112 BackendDataPtr LazyTensor::GetDataHandle() {
113 BackendDataPtr handle = CurrentDataHandle();
114 if (handle != nullptr) {
115 TORCH_CHECK(
116 handle->HasValue(),
117 "Trying to access data while an async operation is in flight: ",
118 handle->shape().to_string());
119 return handle;
120 }
121
122 if (data()->ir_value) {
123 ApplyPendingGraph();
124 } else {
125 TORCH_CHECK(data()->tensor_data);
126 data()->handle = TensorToDataHandle(*data()->tensor_data, GetDevice());
127 }
128
129 return data()->handle;
130 }
131
CurrentDataHandle() const132 BackendDataPtr LazyTensor::CurrentDataHandle() const {
133 return data()->handle;
134 }
135
SetDataHandle(BackendDataPtr handle)136 void LazyTensor::SetDataHandle(BackendDataPtr handle) {
137 SetDataHandle(std::move(handle), /*sync=*/true);
138 }
139
SetDataHandle(BackendDataPtr handle,bool sync)140 void LazyTensor::SetDataHandle(BackendDataPtr handle, bool sync) {
141 data()->handle = std::move(handle);
142 // Assigning a device data should always clear the IR node, to allow graph
143 // trimming.
144 AssignIrValue(Value());
145 if (sync) {
146 data()->tensor_data = std::nullopt;
147 }
148 }
149
SetIrValue(Value ir_value)150 void LazyTensor::SetIrValue(Value ir_value) {
151 data()->handle = nullptr;
152 data()->tensor_data = std::nullopt;
153 AssignIrValue(std::move(ir_value));
154 TryLimitGraphSize();
155 }
156
SetInPlaceIrValue(Value ir_value)157 void LazyTensor::SetInPlaceIrValue(Value ir_value) {
158 auto tensor_shape = shape();
159 if (tensor_shape.Get().scalar_type() != ir_value.shape().scalar_type()) {
160 ir_value =
161 MakeCast(ir_value, tensor_shape.Get().scalar_type(), std::nullopt);
162 }
163 SetIrValue(std::move(ir_value));
164 }
165
AssignIrValue(Value ir_value) const166 void LazyTensor::AssignIrValue(Value ir_value) const {
167 data()->ir_value = std::move(ir_value);
168 data()->generation += 1;
169 }
170
TryLimitGraphSize()171 void LazyTensor::TryLimitGraphSize() {
172 if (data()->ir_value &&
173 LazyGraphExecutor::Get()->IncTrimCounter() %
174 FLAGS_torch_lazy_trim_graph_check_frequency ==
175 0) {
176 size_t graph_size = Util::GetGraphSize({data()->ir_value.node.get()});
177 if (static_cast<int64_t>(graph_size) > FLAGS_torch_lazy_trim_graph_size) {
178 TORCH_LAZY_COUNTER("TrimIrGraph", 1);
179 ApplyPendingGraph();
180 }
181 }
182 }
183
GetIrValue() const184 Value LazyTensor::GetIrValue() const {
185 Value ir_value = CurrentIrValue();
186 if (ir_value) {
187 return ir_value;
188 }
189 BackendDataPtr handle = CurrentDataHandle();
190 if (handle != nullptr) {
191 // In case of tensor node, we do not clear the data when we set the IR
192 // node. This because we want further calls to GetIrValue() to fetch the
193 // same IR node, and not create new ones (even though the lowering context
194 // will still collapse them all into a single parameter op). So the call
195 // which wants the data will still find it, w/out having to fetch it via
196 // a computation client from-server call.
197 AssignIrValue(CreateTensorNode(handle, /*read_only=*/false));
198 return data()->ir_value;
199 }
200 std::optional<at::Tensor> tensor_data = CurrentTensorData();
201 TORCH_CHECK(tensor_data);
202 AssignIrValue(GetIrValueForTensor(*tensor_data, GetDevice()));
203 return data()->ir_value;
204 }
205
CurrentIrValue() const206 Value LazyTensor::CurrentIrValue() const {
207 return data()->ir_value;
208 }
209
SetTensorData(at::Tensor tensor_data)210 void LazyTensor::SetTensorData(at::Tensor tensor_data) {
211 data()->tensor_data = std::move(tensor_data);
212 }
213
CurrentTensorData() const214 std::optional<at::Tensor> LazyTensor::CurrentTensorData() const {
215 return data()->tensor_data;
216 }
217
GetIrValueForTensor(const at::Tensor & tensor,const BackendDevice & device) const218 Value LazyTensor::GetIrValueForTensor(
219 const at::Tensor& tensor,
220 const BackendDevice& device) const {
221 BackendDataPtr data;
222 bool read_only = false;
223 if (tensor.dim() == 0 && tensor.numel() == 1) {
224 at::Scalar value = tensor.item();
225 if (IsSpecialScalar(value)) {
226 return MakeScalar(value, tensor.scalar_type());
227 }
228 data = LazyGraphExecutor::Get()->GetDeviceData(tensor.cpu(), device);
229 read_only = true;
230 } else {
231 TORCH_LAZY_TIMED("IrValueTensorToDataHandle");
232 data = TensorToDataHandle(tensor, device);
233 }
234 return CreateTensorNode(std::move(data), read_only);
235 }
236
ToTensor(bool detached)237 at::Tensor LazyTensor::ToTensor(bool detached) {
238 at::Tensor tensor;
239 std::optional<at::Tensor> tensor_data = CurrentTensorData();
240 if (!tensor_data) {
241 LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
242 // The GetDataHandle() call will trigger an ApplyPendingGraph() if an IR
243 // Node is available on the tensor.
244 std::vector<at::Tensor> tensors =
245 DataHandlesToTensors({GetDataHandle()}, dtype());
246 tensor = std::move(tensors.front());
247 if (!detached) {
248 SetTensorData(tensor);
249 }
250 } else {
251 tensor = *tensor_data;
252 if (detached) {
253 if (data()->ir_value || data()->handle != nullptr) {
254 // If we have other authoritive sources, just drop our reference and
255 // transfer it to the caller.
256 data()->tensor_data = std::nullopt;
257 } else {
258 // Otherwise we need to make a copy to prevent the caller changing our
259 // version.
260 tensor = CopyTensor(tensor);
261 }
262 }
263 }
264 return tensor;
265 }
266
ShallowCopyTo(LazyTensorPtr dest) const267 void LazyTensor::ShallowCopyTo(LazyTensorPtr dest) const {
268 dest->SetIrValue(GetIrValue());
269 }
270
SetTensor(at::Tensor tensor)271 void LazyTensor::SetTensor(at::Tensor tensor) {
272 SetTensorData(tensor);
273 data()->handle = nullptr;
274 AssignIrValue(Value());
275 }
276
UpdateFromTensor(at::Tensor tensor,bool sync)277 void LazyTensor::UpdateFromTensor(at::Tensor tensor, bool sync) {
278 if (sync) {
279 at::Tensor typed_tensor = CopyTensor(tensor, dtype(), /*copy=*/false);
280 SetIrValue(GetIrValueForTensor(typed_tensor, GetDevice()));
281 } else {
282 SetTensorData(tensor);
283 data()->handle = nullptr;
284 AssignIrValue(Value());
285 }
286 }
287
UpdateFromTensorOut(at::Tensor tensor)288 void LazyTensor::UpdateFromTensorOut(at::Tensor tensor) {
289 UpdateFromTensor(std::move(tensor), /*sync=*/false);
290 }
291
UpdateFromTensorOut(const LazyTensorPtr & tensor)292 void LazyTensor::UpdateFromTensorOut(const LazyTensorPtr& tensor) {
293 SetIrValue(tensor->GetIrValue());
294 }
295
CreateTensorNode(BackendDataPtr data,bool read_only) const296 Value LazyTensor::CreateTensorNode(BackendDataPtr data, bool read_only) const {
297 data->SetInfo(std::make_shared<LazyGraphExecutor::DeviceDataInfo>(
298 GetUniqueId(), read_only));
299 return MakeDeviceData(std::move(data));
300 }
301
MakeOutputTensors(NodePtr node) const302 std::vector<LazyTensorPtr> LazyTensor::MakeOutputTensors(NodePtr node) const {
303 std::vector<LazyTensorPtr> tensors;
304 tensors.reserve(node->num_outputs());
305 for (const auto i : c10::irange(node->num_outputs())) {
306 tensors.push_back(Create(Value(node, i), GetDevice()));
307 }
308 return tensors;
309 }
310
CopyTensorToDevice(const BackendDevice & device)311 LazyTensorPtr LazyTensor::CopyTensorToDevice(const BackendDevice& device) {
312 // TODO: This can be optimized.
313 return Create(ToTensor(/*detached=*/true), device);
314 }
315
ApplyPendingGraph()316 void LazyTensor::ApplyPendingGraph() {
317 LazyGraphExecutor::Get()->DeviceBarrier(GetDevice());
318 // This method is called to ensure that the tensor data is available on
319 // device, so that a call to CurrentDataHandle() returns a valid pointer.
320 if (CurrentDataHandle() == nullptr) {
321 std::vector<LazyTensorPtr> tensors(
322 {c10::make_intrusive<LazyTensor>(LazyTensor(*this))});
323 LazyGraphExecutor::Get()->SyncTensorsGraph(
324 &tensors,
325 {},
326 /*wait=*/true,
327 /*sync_ltc_data=*/false);
328 }
329 }
330
GetNextTensorId()331 int64_t LazyTensor::GetNextTensorId() {
332 static std::atomic<int64_t>* id_generator = new std::atomic<int64_t>(1);
333 return id_generator->fetch_add(1);
334 }
335
GetTensorList(at::ITensorListRef tensors)336 torch::lazy::Value GetTensorList(at::ITensorListRef tensors) {
337 std::vector<Value> values;
338 for (const auto& t : tensors) {
339 auto* impl = dynamic_cast<LTCTensorImpl*>(t.unsafeGetTensorImpl());
340 TORCH_INTERNAL_ASSERT(
341 impl,
342 "GetTensorList only supports lists of valid tensors, but optional support could be added");
343 values.push_back(impl->tensor()->GetIrValue());
344 }
345
346 return torch::lazy::Value(torch::lazy::MakeTensorList(std::move(values)));
347 }
348
TryGetLtcTensor(const at::Tensor & tensor)349 LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor) {
350 auto* impl = dynamic_cast<LTCTensorImpl*>(
351 maybe_unwrap_functional(tensor).unsafeGetTensorImpl());
352 if (impl == nullptr) {
353 // return c10::make_intrusive<LazyTensor>();
354 return LazyTensorPtr();
355 }
356 return impl->tensor();
357 }
358
GetLtcTensor(const at::Tensor & tensor)359 LazyTensorPtr GetLtcTensor(const at::Tensor& tensor) {
360 auto lazy_tensor = TryGetLtcTensor(tensor);
361 TORCH_CHECK(
362 lazy_tensor, "Input tensor is not a lazy tensor: ", tensor.toString());
363 return lazy_tensor;
364 }
365
GetLtcTensors(c10::ArrayRef<at::Tensor> tensors)366 std::vector<LazyTensorPtr> GetLtcTensors(c10::ArrayRef<at::Tensor> tensors) {
367 std::vector<LazyTensorPtr> ltc_tensors;
368 ltc_tensors.reserve(tensors.size());
369 for (const auto& tensor : tensors) {
370 ltc_tensors.emplace_back(TryGetLtcTensor(tensor));
371 }
372 return ltc_tensors;
373 }
374
GetOrCreateLtcTensor(const std::optional<at::Tensor> & tensor,const BackendDevice & device)375 LazyTensorPtr GetOrCreateLtcTensor(
376 const std::optional<at::Tensor>& tensor,
377 const BackendDevice& device) {
378 return GetOrCreateLtcTensor(tensor.value_or(at::Tensor()), device);
379 }
380
GetLtcTensorOrCreateForWrappedNumber(const at::Tensor & tensor,const BackendDevice & device)381 LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
382 const at::Tensor& tensor,
383 const BackendDevice& device) {
384 // TODO: There are places in core where a scalar is wrapped but not marked as
385 // wrapped.
386 return (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
387 (tensor.dim() == 0 && tensor.numel() == 1))
388 ? GetOrCreateLtcTensor(tensor, device)
389 : GetLtcTensor(tensor);
390 }
391
CreateAtenFromLtcTensor(const LazyTensorPtr & ltc_tensor)392 at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor) {
393 return ltc_tensor ? at::Tensor(c10::make_intrusive<LTCTensorImpl>(ltc_tensor))
394 : at::Tensor();
395 }
396
CreateAtenFromLtcTensor(LazyTensor && ltc_tensor)397 at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor) {
398 return at::Tensor(c10::make_intrusive<LTCTensorImpl>(std::move(ltc_tensor)));
399 }
400
to_lazy_tensor(const at::Tensor & self,const c10::TensorOptions & options,at::Device device,bool non_blocking,bool functionalize_output)401 at::Tensor to_lazy_tensor(
402 const at::Tensor& self,
403 const c10::TensorOptions& options,
404 at::Device device,
405 bool non_blocking,
406 bool functionalize_output) {
407 TORCH_INTERNAL_ASSERT(self.device().type() != c10::kLazy);
408 TORCH_INTERNAL_ASSERT(device.type() == c10::kLazy);
409
410 auto eager_tensor =
411 self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
412 auto lazy_self = torch::lazy::GetOrCreateLtcTensor(
413 eager_tensor, torch::lazy::atenDeviceToBackendDevice(device));
414 auto out = torch::lazy::CreateAtenFromLtcTensor(lazy_self);
415 if (functionalize_output) {
416 // See Note [Lazy Tensor Functionalization]
417 return at::functionalization::impl::to_functional_tensor(out);
418 } else {
419 return out;
420 }
421 }
422
423 } // namespace lazy
424 } // namespace torch
425