xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
2 
3 #include <ATen/Functions.h>
4 #include <torch/csrc/lazy/backend/backend_device.h>
5 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
6 #include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
7 #include <torch/csrc/lazy/ts_backend/config.h>
8 #include <torch/csrc/lazy/ts_backend/ir_builder.h>
9 #include <torch/csrc/lazy/ts_backend/ts_eager_fallback.h>
10 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
11 #include <memory>
12 
13 namespace at {
14 // This function is defined in the codegenerated RegisterDispatchKey.cpp file.
15 // For the TorchScript backend, we have a special case where the registration
16 // does not happen immediately (at static initialization time), so that if an
17 // external backend is loaded, it has a chance to register itself, and
18 // TorchScript only registers itself if explicitly initialized
19 extern TORCH_API void RegisterTorchScriptLazyNativeFunctions();
20 extern TORCH_API void RegisterTorchScriptAutogradLazyNativeFunctions();
21 } // namespace at
22 
23 namespace torch {
24 namespace lazy {
25 
26 struct TSBackendDeviceType : public BackendDeviceType {
27   TSBackendDeviceType() = delete;
TSBackendDeviceTypetorch::lazy::TSBackendDeviceType28   TSBackendDeviceType(c10::DeviceType deviceType)
29       : BackendDeviceType((int8_t)deviceType) {
30     TORCH_CHECK(deviceType == at::kCPU || deviceType == at::kCUDA);
31   }
32 
toStringtorch::lazy::TSBackendDeviceType33   std::string toString() const override {
34     return c10::DeviceTypeName((c10::DeviceType)type);
35   }
36 
c10Typetorch::lazy::TSBackendDeviceType37   c10::DeviceType c10Type() const {
38     return (c10::DeviceType)type;
39   }
40 };
41 
42 class TSBackendImpl : public torch::lazy::BackendImplInterface {
43  public:
TSBackendImpl()44   TSBackendImpl() {
45     // TODO(whc) unify how all our flags are set and parsed as envs
46     static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr;
47     auto type =
48         (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU;
49     default_device_type_ = std::make_shared<TSBackendDeviceType>(type);
50   }
51 
GetIrBuilder() const52   const IrBuilder* GetIrBuilder() const override {
53     static const IrBuilder* builder = new TorchScriptIrBuilder();
54     return builder;
55   }
56 
CreateMetricReport() const57   std::string CreateMetricReport() const override {
58     return "TSBackendImpl: N/A";
59   }
60 
CreateLoweringContext(const std::string & name,torch::lazy::BackendDevice device,c10::ArrayRef<const torch::lazy::Node * > post_order,torch::lazy::Util::EmissionMap emit_status) const61   std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
62       const std::string& name,
63       torch::lazy::BackendDevice device,
64       c10::ArrayRef<const torch::lazy::Node*> post_order,
65       torch::lazy::Util::EmissionMap emit_status) const override {
66     return std::make_unique<torch::lazy::TSLoweringContext>(
67         name, device, post_order, emit_status);
68   }
69 
CreateLoweringContext(const std::string & name,torch::lazy::BackendDevice device) const70   std::unique_ptr<torch::lazy::LoweringContext> CreateLoweringContext(
71       const std::string& name,
72       torch::lazy::BackendDevice device) const override {
73     return std::make_unique<torch::lazy::TSLoweringContext>(name, device);
74   }
75 
GetCompilationDevices(const std::string & device,c10::ArrayRef<std::string> devices) const76   std::vector<std::string> GetCompilationDevices(
77       const std::string& device,
78       c10::ArrayRef<std::string> devices) const override {
79     return std::vector<std::string>(devices.begin(), devices.end());
80   }
81 
MakeTensorFromComputationData(const torch::lazy::BackendDataPtr data,std::optional<at::ScalarType> logical_scalar_type) const82   at::Tensor MakeTensorFromComputationData(
83       const torch::lazy::BackendDataPtr data,
84       std::optional<at::ScalarType> logical_scalar_type) const override {
85     const auto ts_data = std::static_pointer_cast<TSData>(data);
86     return ts_data->data();
87   }
88 
MakeComputationDataFromTensor(const at::Tensor & tensor,const torch::lazy::Shape & shape,const torch::lazy::BackendDevice & device) const89   torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
90       const at::Tensor& tensor,
91       const torch::lazy::Shape& shape,
92       const torch::lazy::BackendDevice& device) const override {
93     at::TensorOptions options = tensor.options().device(
94         default_device_type_->c10Type(), device.ordinal());
95     if (tensor.device().type() == default_device_type_->c10Type() &&
96         default_device_type_->c10Type() == at::kCUDA) {
97       return std::make_shared<TSData>(
98           tensor.to(options, /*non_blocking=*/true), shape, device);
99     } else if (tensor.device().type() == at::kCPU && tensor.numel() == 1) {
100       // calling .item() on singleton cpu tensor is fast, and using fill is a
101       // safe, async way to copy cpu to cuda for a single value
102       auto device_tensor = at::full(tensor.sizes(), tensor.item(), options);
103       return std::make_shared<TSData>(device_tensor, shape, device);
104     } else {
105       return std::make_shared<TSData>(
106           tensor.to(options, /*non_blocking=*/false), shape, device);
107     }
108   }
109 
MakeComputationDataFromScalar(const at::Scalar & scalar,const torch::lazy::BackendDevice & device) const110   torch::lazy::BackendDataPtr MakeComputationDataFromScalar(
111       const at::Scalar& scalar,
112       const torch::lazy::BackendDevice& device) const override {
113     return std::make_shared<TSData>(scalar, device);
114   }
115 
GetComputationDataFromNode(const Node * node) const116   torch::lazy::BackendDataPtr GetComputationDataFromNode(
117       const Node* node) const override {
118     auto* device_data_node = DeviceData::Cast(node);
119     if (!device_data_node) {
120       return nullptr;
121     }
122     return device_data_node->data();
123   }
124 
GetComputationBackendText(const torch::lazy::ComputationPtr computation) const125   std::string GetComputationBackendText(
126       const torch::lazy::ComputationPtr computation) const override {
127     auto ts_computation =
128         static_cast<torch::lazy::TSComputation*>(computation.get());
129     return ts_computation->graph()->toString();
130   }
131 
132   //////////////computation client interfaces///////////////////////
133 
134  public:
135   torch::lazy::BackendDataPtr CreateDataPlaceholder(
136       const torch::lazy::BackendDevice& device,
137       const torch::lazy::Shape& shape) const override;
138 
139   std::vector<torch::lazy::ComputationPtr> Compile(
140       std::vector<torch::lazy::ComputationPtr> instances) const override;
141 
142   std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
143       torch::lazy::ComputationPtr computation,
144       c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
145       const torch::lazy::BackendDevice& device) const override;
146 
GetDefaultDeviceType() const147   std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType()
148       const override {
149     return default_device_type_;
150   }
151 
152   at::DeviceType EagerFallbackDeviceType() const override;
153 
SetDefaultDeviceType(int8_t type)154   void SetDefaultDeviceType(int8_t type) override {
155     default_device_type_ = std::make_shared<TSBackendDeviceType>(
156         static_cast<c10::DeviceType>(type));
157   }
158 
GetDefaultDeviceOrdinal() const159   int64_t GetDefaultDeviceOrdinal() const override {
160     return default_device_ordinal_;
161   }
162 
SetDefaultDeviceOrdinal(int64_t ordinal)163   void SetDefaultDeviceOrdinal(int64_t ordinal) override {
164     default_device_ordinal_ = ordinal;
165   }
166 
167   std::vector<torch::lazy::BackendDevice> GetBackendDevices() const override;
168 
169   torch::lazy::BackendDevice GetBackendDevice(
170       c10::Device device) const override;
171 
SetRngSeed(size_t seed) const172   void SetRngSeed(size_t seed) const override {
173     LOG(FATAL) << "Not implemented yet.";
174   }
175 
176   // std::map<std::string, Metric> GetMetrics() const override { return {}; }
177 
178   // MemoryInfo GetMemoryInfo(const std::string& device) override {
179   //   LOG(FATAL) << "Not implemented yet.";
180   // }
181 
182   void PrepareToExit() const override;
183 
184  private:
185   std::shared_ptr<TSBackendDeviceType> default_device_type_;
186   int64_t default_device_ordinal_{0};
187 };
188 
CreateDataPlaceholder(const torch::lazy::BackendDevice & device,const torch::lazy::Shape & shape) const189 torch::lazy::BackendDataPtr TSBackendImpl::CreateDataPlaceholder(
190     const torch::lazy::BackendDevice& device,
191     const torch::lazy::Shape& shape) const {
192   return std::make_shared<TSData>(shape, device);
193 }
194 
Compile(std::vector<torch::lazy::ComputationPtr> instances) const195 std::vector<torch::lazy::ComputationPtr> TSBackendImpl::Compile(
196     std::vector<torch::lazy::ComputationPtr> instances) const {
197   for (const auto& instance : instances) {
198     auto ts_computation =
199         static_cast<torch::lazy::TSComputation*>(instance.get());
200     if (!ts_computation->in_mark_step) {
201       LOG(WARNING) << "Compile outside of mark step";
202     }
203   }
204   return instances;
205 }
206 
ExecuteComputation(torch::lazy::ComputationPtr computation,c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,const torch::lazy::BackendDevice & device) const207 std::vector<torch::lazy::BackendDataPtr> TSBackendImpl::ExecuteComputation(
208     torch::lazy::ComputationPtr computation,
209     c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
210     const torch::lazy::BackendDevice& device) const {
211   auto ts_computation =
212       std::dynamic_pointer_cast<torch::lazy::TSComputation>(computation);
213   TORCH_CHECK(ts_computation, "Computation isn't TSComputation");
214   torch::jit::GraphExecutor& graph_executor = ts_computation->graph_executor();
215   std::vector<torch::jit::IValue> stack;
216   for (const auto& argument : arguments) {
217     const auto ts_data = std::static_pointer_cast<TSData>(argument);
218     if (ts_data->scalar.has_value()) {
219       stack.emplace_back(ts_data->scalar.value());
220     } else {
221       // TODO(whc) should this check be made more general? it's written somewhat
222       // oddly
223       TORCH_CHECK(
224           static_cast<c10::DeviceType>(default_device_type_->type) !=
225               at::kCUDA ||
226           ts_data->data().device().type() == at::kCUDA);
227       stack.emplace_back(ts_data->data());
228     }
229   }
230   graph_executor.run(stack);
231   std::vector<torch::lazy::BackendDataPtr> results;
232   for (torch::jit::IValue component : stack) {
233     at::Tensor result = component.toTensor();
234     at::IntArrayRef result_sizes = result.sizes();
235     torch::lazy::Shape shape(
236         result.scalar_type(),
237         std::vector<int64_t>(result_sizes.begin(), result_sizes.end()));
238     results.push_back(std::make_shared<TSData>(result, shape, device));
239   }
240   return results;
241 }
242 
GetBackendDevices() const243 std::vector<torch::lazy::BackendDevice> TSBackendImpl::GetBackendDevices()
244     const {
245   std::vector<torch::lazy::BackendDevice> devices;
246   // TODO(whc) figure out how to query available devices from pytorch
247   devices.emplace_back(GetBackendDevice(c10::Device(c10::kCPU, 0)));
248   devices.emplace_back(GetBackendDevice(c10::Device(c10::kCUDA, 0)));
249   return devices;
250 }
251 
GetBackendDevice(c10::Device device) const252 torch::lazy::BackendDevice TSBackendImpl::GetBackendDevice(
253     c10::Device device) const {
254   // Note, we ignore the device type specified by the c10::Device since it is
255   // expected to be a virtual device (lazy::), but we need to change this when
256   // we support lazy as a mode
257   return torch::lazy::BackendDevice(GetDefaultDeviceType(), device.index());
258 }
259 
PrepareToExit() const260 void TSBackendImpl::PrepareToExit() const {}
261 
EagerFallbackDeviceType() const262 c10::DeviceType TSBackendImpl::EagerFallbackDeviceType() const {
263   // For TS backend, hardware device _is_ eager device
264   return (c10::DeviceType)GetDefaultDeviceType()->type;
265 }
266 
GetTSBackendImpl()267 torch::lazy::BackendImplInterface* GetTSBackendImpl() {
268   static TSBackendImpl* ts_backend_impl = new TSBackendImpl();
269   return ts_backend_impl;
270 }
271 
InitTorchScriptBackend()272 void InitTorchScriptBackend() {
273   at::RegisterTorchScriptLazyNativeFunctions();
274   at::RegisterTorchScriptAutogradLazyNativeFunctions();
275   register_ts_ltc_eager_fallback();
276   static std::unique_ptr<BackendRegistrar> s_registrar;
277   s_registrar = std::make_unique<BackendRegistrar>(GetTSBackendImpl());
278 
279   static LazyGraphExecutor* executor = new LazyGraphExecutor();
280   LazyGraphExecutor::Register(executor);
281 }
282 
283 } // namespace lazy
284 } // namespace torch
285