1 #pragma once 2 3 #include <torch/csrc/lazy/backend/backend_data.h> 4 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> 5 #include <torch/csrc/lazy/ts_backend/ts_node.h> 6 7 namespace torch { 8 namespace lazy { 9 10 class TORCH_API DeviceData : public TsNode { 11 public: ClassOpKind()12 static OpKind ClassOpKind() { 13 return ltc_device_data; 14 } 15 16 explicit DeviceData(std::shared_ptr<BackendData> data); 17 18 // A DeviceData node can be reused if the shape matches, 19 // but we will substitute the actual data_ pointer under 20 // the hood. CanBeReused(std::shared_ptr<BackendData> data)21 bool CanBeReused(std::shared_ptr<BackendData> data) const { 22 return data_->shape() == data->shape(); 23 } 24 25 std::string ToString() const override; 26 data()27 const std::shared_ptr<BackendData>& data() const { 28 return data_; 29 } 30 SetData(std::shared_ptr<BackendData> data)31 void SetData(std::shared_ptr<BackendData> data) { 32 data_ = data; 33 } 34 35 static const DeviceData* Cast(const Node* node); 36 37 // To reuse IR nodes, use this method to create DeviceData nodes 38 // instead of calling the constructor directly. 39 static NodePtr Create(std::shared_ptr<BackendData> data); 40 41 TSOpVector Lower( 42 std::shared_ptr<torch::jit::GraphFunction> function, 43 TSLoweringContext* loctx) const override; 44 45 private: 46 std::shared_ptr<BackendData> data_; 47 }; 48 49 } // namespace lazy 50 } // namespace torch 51