1 #include <torch/csrc/lazy/ts_backend/ops/device_data.h> 2 3 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> 4 #include <torch/csrc/lazy/core/ir_builder.h> 5 6 #include <sstream> 7 8 namespace torch { 9 namespace lazy { 10 DeviceData(std::shared_ptr<BackendData> data)11DeviceData::DeviceData(std::shared_ptr<BackendData> data) 12 : TsNode( 13 ClassOpKind(), 14 data->shape(), 15 /*num_outputs=*/1, 16 /*hash_seed=*/static_cast<uint32_t>(101)), 17 data_(std::move(data)) {} 18 ToString() const19std::string DeviceData::ToString() const { 20 std::stringstream ss; 21 ss << TsNode::ToString() << ", device=" << data_->device(); 22 return ss.str(); 23 } 24 Cast(const Node * node)25const DeviceData* DeviceData::Cast(const Node* node) { 26 return NodeCast<DeviceData>(node); 27 } 28 Create(std::shared_ptr<BackendData> data)29NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) { 30 NodePtr node = ReuseOrMakeNode<DeviceData>(data); 31 // ReuseOrMakeNode may return a reused node which has the same shape, 32 // however, we need to replace the old data_ with the new one. 33 // Ditching the old data_ is safe because tracing is done iteration 34 // by iteration, and after we lauch the async device execution for the 35 // previous iteration, data_ in DeviceData nodes are not needed anymore. 36 DeviceData* device_data = static_cast<DeviceData*>(node.get()); 37 device_data->SetData(data); 38 return node; 39 } 40 41 } // namespace lazy 42 } // namespace torch 43