xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_backend_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/backend/backend_interface.h>
4 
5 namespace torch {
6 namespace lazy {
7 
8 class TORCH_API TSData : public torch::lazy::BackendData {
9  public:
TSData(const at::Scalar & scalar,const torch::lazy::BackendDevice & device)10   TSData(const at::Scalar& scalar, const torch::lazy::BackendDevice& device)
11       : torch::lazy::BackendData(device, torch::lazy::Shape(scalar.type(), {})),
12         scalar(scalar) {}
13 
TSData(const at::Tensor & data,const torch::lazy::Shape & shape,const torch::lazy::BackendDevice & device)14   TSData(
15       const at::Tensor& data,
16       const torch::lazy::Shape& shape,
17       const torch::lazy::BackendDevice& device)
18       : torch::lazy::BackendData(device, shape), data_(data) {}
19 
TSData(const torch::lazy::Shape & shape,const torch::lazy::BackendDevice & device)20   TSData(
21       const torch::lazy::Shape& shape,
22       const torch::lazy::BackendDevice& device)
23       : torch::lazy::BackendData(device, shape) {}
24 
GetHandle()25   Handle GetHandle() override {
26     return reinterpret_cast<int64_t>(this);
27   }
28 
Assign(const torch::lazy::BackendData & data)29   void Assign(const torch::lazy::BackendData& data) override {
30     data_ = static_cast<const TSData&>(data).data_;
31   }
32 
HasValue()33   bool HasValue() const override {
34     return data_.defined();
35   }
36 
data()37   at::Tensor data() {
38     return data_;
39   }
40 
41   std::optional<at::Scalar> scalar;
42 
43  private:
44   at::Tensor data_;
45 };
46 
47 TORCH_API torch::lazy::BackendImplInterface* GetTSBackendImpl();
48 
49 TORCH_API void InitTorchScriptBackend();
50 
51 } // namespace lazy
52 } // namespace torch
53