xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/backend/backend_interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <torch/csrc/lazy/backend/backend_data.h>
5 #include <torch/csrc/lazy/backend/backend_device.h>
6 #include <torch/csrc/lazy/backend/lowering_context.h>
7 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
8 #include <torch/csrc/lazy/core/shape.h>
9 #include <torch/csrc/lazy/core/tensor.h>
10 #include <atomic>
11 
12 namespace torch {
13 namespace lazy {
14 
15 struct IrBuilder;
16 
17 /**
18  * Work in progress- don't treat this as a stable interface yet!
19  */
20 class TORCH_API BackendImplInterface {
21  public:
22   virtual ~BackendImplInterface() = default;
23 
24   /**
25    * Initialization/Teardown
26    * */
27   // No-op by default. Allows custom functionality to be exposed through
28   // extension bindings.
InitializeAtenBindings()29   virtual void InitializeAtenBindings() const {}
30 
31   virtual void PrepareToExit() const = 0;
32 
33   /**
34    * Configuration
35    * */
36 
37   virtual void SetRngSeed(size_t seed) const = 0;
38 
39   /**
40    * IR Tracing
41    * */
42 
43   virtual const IrBuilder* GetIrBuilder() const = 0;
44 
45   /**
46    * Data Transfer
47    * */
48 
49   virtual BackendDataPtr MakeComputationDataFromTensor(
50       const at::Tensor& tensor,
51       const Shape& shape,
52       const BackendDevice& device) const = 0;
53   virtual BackendDataPtr MakeComputationDataFromScalar(
54       const at::Scalar& scalar,
55       const torch::lazy::BackendDevice& device) const = 0;
56   virtual BackendDataPtr CreateDataPlaceholder(
57       const BackendDevice& device,
58       const Shape& shape) const = 0;
59 
60   // Gets backend data if the node is a device data node. Otherwise returns
61   // nullptr
62   virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0;
63 
64   virtual at::Tensor MakeTensorFromComputationData(
65       const BackendDataPtr data,
66       std::optional<at::ScalarType> logical_scalar_type) const = 0;
67 
68   /**
69    * Lowering, Compilation, Execution
70    * */
71 
72   virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
73       const std::string& name,
74       BackendDevice device,
75       c10::ArrayRef<const torch::lazy::Node*> post_order,
76       Util::EmissionMap emit_status) const = 0;
77 
78   virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
79       const std::string& name,
80       BackendDevice device) const = 0;
81 
82   // TODO(whc) need to keep this?
83   virtual std::vector<std::string> GetCompilationDevices(
84       const std::string& device,
85       c10::ArrayRef<std::string> devices) const = 0;
86 
87   virtual std::vector<ComputationPtr> Compile(
88       std::vector<ComputationPtr> instances) const = 0;
89 
90   virtual std::vector<BackendDataPtr> ExecuteComputation(
91       torch::lazy::ComputationPtr computation,
92       c10::ArrayRef<BackendDataPtr> arguments,
93       const BackendDevice& device) const = 0;
94 
95   /**
96    * Device Configuration
97    * */
98 
99   // Set or get the default device type.
100   // For backends used with virtual c10::Devices, this configures what real
101   // device type the backend should use, and matters if the backend supports
102   // more than one type of real device.
103   virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const = 0;
104   virtual void SetDefaultDeviceType(int8_t type) = 0;
105 
106   // Set or get the default device ordinal.
107   // For backends that supports multi-device, this configures what the
108   // default device the backend should use.
109   virtual int64_t GetDefaultDeviceOrdinal() const = 0;
110   virtual void SetDefaultDeviceOrdinal(int64_t) = 0;
111 
112   // Specify which aten device should be used for eager fallback
113   // may change depending on current 'Default' DeviceType
114   virtual at::DeviceType EagerFallbackDeviceType() const = 0;
115 
116   // Query all available backend devices
117   virtual std::vector<BackendDevice> GetBackendDevices() const = 0;
118 
CreateMetricReport()119   virtual std::string CreateMetricReport() const {
120     return "";
121   }
122 
123   // Map a particular c10:: device to a concrete backend device
124   // Note:: c10:: devices may be virtual or concrete.  xla:: and lazy:: are
125   // virtual devices, meaning they may map to a gpu, tpu, etc. behind the
126   // scenes. In the future, non-virtual c10:: devices may also use lazy tensors
127   // through a mode, in which case these APIs should still work, but should be
128   // identity mappings.
129   virtual BackendDevice GetBackendDevice(c10::Device device) const = 0;
130 
131   // TODO(whc)
132   // Additional APIs expected for supporting distributed training, to be
133   // designed
134 
135   /**
136    * Debug/Metrics
137    * */
138 
139   //   virtual std::map<std::string, Metric> GetMetrics() const = 0;
140 
141   //   virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;
142 
143   virtual std::string GetComputationBackendText(
144       const ComputationPtr computation) const = 0;
145 };
146 
147 class TORCH_API BackendRegistrar {
148  public:
149   BackendRegistrar(const BackendImplInterface* backend_impl_interface);
150 };
151 
152 TORCH_API bool hasBackend();
153 TORCH_API const BackendImplInterface* getBackend();
154 
155 TORCH_API const IrBuilder* getIrBuilder();
156 
157 } // namespace lazy
158 } // namespace torch
159