xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ops/device_data.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/ts_backend/ops/device_data.h>
2 
3 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
4 #include <torch/csrc/lazy/core/ir_builder.h>
5 
6 #include <sstream>
7 
8 namespace torch {
9 namespace lazy {
10 
DeviceData(std::shared_ptr<BackendData> data)11 DeviceData::DeviceData(std::shared_ptr<BackendData> data)
12     : TsNode(
13           ClassOpKind(),
14           data->shape(),
15           /*num_outputs=*/1,
16           /*hash_seed=*/static_cast<uint32_t>(101)),
17       data_(std::move(data)) {}
18 
ToString() const19 std::string DeviceData::ToString() const {
20   std::stringstream ss;
21   ss << TsNode::ToString() << ", device=" << data_->device();
22   return ss.str();
23 }
24 
Cast(const Node * node)25 const DeviceData* DeviceData::Cast(const Node* node) {
26   return NodeCast<DeviceData>(node);
27 }
28 
Create(std::shared_ptr<BackendData> data)29 NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
30   NodePtr node = ReuseOrMakeNode<DeviceData>(data);
31   // ReuseOrMakeNode may return a reused node which has the same shape,
32   // however, we need to replace the old data_ with the new one.
33   // Ditching the old data_ is safe because tracing is done iteration
34   // by iteration, and after we lauch the async device execution for the
35   // previous iteration, data_ in DeviceData nodes are not needed anymore.
36   DeviceData* device_data = static_cast<DeviceData*>(node.get());
37   device_data->SetData(data);
38   return node;
39 }
40 
41 } // namespace lazy
42 } // namespace torch
43