1 #pragma once 2 3 #include <torch/csrc/lazy/backend/backend_interface.h> 4 5 namespace torch { 6 namespace lazy { 7 8 class TORCH_API TSData : public torch::lazy::BackendData { 9 public: TSData(const at::Scalar & scalar,const torch::lazy::BackendDevice & device)10 TSData(const at::Scalar& scalar, const torch::lazy::BackendDevice& device) 11 : torch::lazy::BackendData(device, torch::lazy::Shape(scalar.type(), {})), 12 scalar(scalar) {} 13 TSData(const at::Tensor & data,const torch::lazy::Shape & shape,const torch::lazy::BackendDevice & device)14 TSData( 15 const at::Tensor& data, 16 const torch::lazy::Shape& shape, 17 const torch::lazy::BackendDevice& device) 18 : torch::lazy::BackendData(device, shape), data_(data) {} 19 TSData(const torch::lazy::Shape & shape,const torch::lazy::BackendDevice & device)20 TSData( 21 const torch::lazy::Shape& shape, 22 const torch::lazy::BackendDevice& device) 23 : torch::lazy::BackendData(device, shape) {} 24 GetHandle()25 Handle GetHandle() override { 26 return reinterpret_cast<int64_t>(this); 27 } 28 Assign(const torch::lazy::BackendData & data)29 void Assign(const torch::lazy::BackendData& data) override { 30 data_ = static_cast<const TSData&>(data).data_; 31 } 32 HasValue()33 bool HasValue() const override { 34 return data_.defined(); 35 } 36 data()37 at::Tensor data() { 38 return data_; 39 } 40 41 std::optional<at::Scalar> scalar; 42 43 private: 44 at::Tensor data_; 45 }; 46 47 TORCH_API torch::lazy::BackendImplInterface* GetTSBackendImpl(); 48 49 TORCH_API void InitTorchScriptBackend(); 50 51 } // namespace lazy 52 } // namespace torch 53