xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/backend/backend_data.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/backend/backend_device.h>
4 #include <torch/csrc/lazy/core/shape.h>
5 #include <cstring>
6 
7 namespace torch {
8 namespace lazy {
9 
10 class TORCH_API BackendData {
11  public:
12   struct Info {
13     /**
14      * Used by Lazy Graph Executor to tag info on BackendData objs
15      * */
16     virtual ~Info() = default;
17   };
18   /**
19    * Represents (Tensor) data stored on a backend device
20    * in its native format.
21    * */
22   using Handle = int64_t;
23 
BackendData(BackendDevice device,Shape shape)24   BackendData(BackendDevice device, Shape shape)
25       : device_(std::move(device)), shape_(std::move(shape)) {}
26 
27   virtual ~BackendData() = default;
28 
device()29   const BackendDevice& device() const {
30     return device_;
31   }
32 
shape()33   const Shape& shape() const {
34     return shape_;
35   }
36 
info()37   Info* info() const {
38     return info_.get();
39   }
40 
SetInfo(std::shared_ptr<Info> info)41   std::shared_ptr<Info> SetInfo(std::shared_ptr<Info> info) {
42     std::swap(info, info_);
43     return info;
44   }
45 
46   virtual Handle GetHandle() = 0;
47 
48   virtual void Assign(const BackendData& data) = 0;
49 
50   virtual bool HasValue() const = 0;
51 
52  private:
53   BackendDevice device_;
54   Shape shape_;
55   std::shared_ptr<Info> info_;
56 };
57 
58 using BackendDataPtr = std::shared_ptr<BackendData>;
59 
60 } // namespace lazy
61 } // namespace torch
62