xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ops/device_data.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/backend/backend_data.h>
4 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
5 #include <torch/csrc/lazy/ts_backend/ts_node.h>
6 
7 namespace torch {
8 namespace lazy {
9 
10 class TORCH_API DeviceData : public TsNode {
11  public:
ClassOpKind()12   static OpKind ClassOpKind() {
13     return ltc_device_data;
14   }
15 
16   explicit DeviceData(std::shared_ptr<BackendData> data);
17 
18   // A DeviceData node can be reused if the shape matches,
19   // but we will substitute the actual data_ pointer under
20   // the hood.
CanBeReused(std::shared_ptr<BackendData> data)21   bool CanBeReused(std::shared_ptr<BackendData> data) const {
22     return data_->shape() == data->shape();
23   }
24 
25   std::string ToString() const override;
26 
data()27   const std::shared_ptr<BackendData>& data() const {
28     return data_;
29   }
30 
SetData(std::shared_ptr<BackendData> data)31   void SetData(std::shared_ptr<BackendData> data) {
32     data_ = data;
33   }
34 
35   static const DeviceData* Cast(const Node* node);
36 
37   // To reuse IR nodes, use this method to create DeviceData nodes
38   // instead of calling the constructor directly.
39   static NodePtr Create(std::shared_ptr<BackendData> data);
40 
41   TSOpVector Lower(
42       std::shared_ptr<torch::jit::GraphFunction> function,
43       TSLoweringContext* loctx) const override;
44 
45  private:
46   std::shared_ptr<BackendData> data_;
47 };
48 
49 } // namespace lazy
50 } // namespace torch
51