xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/python/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/python/init.h>
2 
3 #include <ATen/FunctionalTensorWrapper.h>
4 #include <c10/core/Device.h>
5 #include <torch/csrc/jit/python/pybind.h>
6 #include <torch/csrc/lazy/backend/backend_device.h>
7 #include <torch/csrc/lazy/backend/backend_interface.h>
8 #include <torch/csrc/lazy/core/config.h>
9 #include <torch/csrc/lazy/core/debug_util.h>
10 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
11 #include <torch/csrc/lazy/core/ir_dump_util.h>
12 #include <torch/csrc/lazy/core/lazy_graph_executor.h>
13 #include <torch/csrc/lazy/core/metrics.h>
14 #include <torch/csrc/lazy/core/trie.h>
15 #include <torch/csrc/lazy/python/python_util.h>
16 #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
17 #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
18 #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
19 #endif // FBCODE_CAFFE2 || OVRSOURCE
20 #include <string>
21 #include <vector>
22 
23 namespace torch {
24 namespace lazy {
25 
26 // TODO(whc) backend 'device' related APIs are not very clear, this code could
27 // be simplified but it should probably be done together with
28 // designing/refactoring the overall approach to get/set of default eager/lazy
29 // device types
GetDeviceOrCurrent(const std::string & device_str)30 torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) {
31   if (device_str.empty()) {
32     getBackend()->GetDefaultDeviceType();
33     return torch::lazy::BackendDevice();
34   }
35   return torch::lazy::atenDeviceToBackendDevice(c10::Device(device_str));
36 }
37 
GetTensorId(const at::Tensor & tensor)38 std::ptrdiff_t GetTensorId(const at::Tensor& tensor) {
39   torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
40   return lazy_tensor->GetUniqueId();
41 }
42 
GetTensorsDump(const std::vector<at::Tensor> & tensors,const std::function<std::string (c10::ArrayRef<const torch::lazy::Node * >)> & coverter)43 std::string GetTensorsDump(
44     const std::vector<at::Tensor>& tensors,
45     const std::function<std::string(c10::ArrayRef<const torch::lazy::Node*>)>&
46         coverter) {
47   std::vector<const torch::lazy::Node*> nodes;
48   std::vector<torch::lazy::Value> values;
49   for (auto& tensor : tensors) {
50     auto inner = at::functionalization::impl::from_functional_tensor(tensor);
51     torch::lazy::LazyTensorPtr lazy_tensor =
52         torch::lazy::TryGetLtcTensor(inner);
53     values.push_back(lazy_tensor->GetIrValue());
54     nodes.push_back(values.back().node.get());
55   }
56   return coverter(nodes);
57 }
58 
GetLtcTensors(const std::vector<at::Tensor> & tensors,bool want_all)59 std::vector<torch::lazy::LazyTensorPtr> GetLtcTensors(
60     const std::vector<at::Tensor>& tensors,
61     bool want_all) {
62   std::vector<torch::lazy::LazyTensorPtr> lazy_tensors;
63   lazy_tensors.reserve(tensors.size());
64   if (want_all) {
65     for (auto& tensor : tensors) {
66       lazy_tensors.push_back(torch::lazy::TryGetLtcTensor(tensor));
67     }
68   } else {
69     for (auto& tensor : tensors) {
70       auto lazy_tensor = torch::lazy::TryGetLtcTensor(tensor);
71       if (lazy_tensor) {
72         lazy_tensors.push_back(lazy_tensor);
73       }
74     }
75   }
76   return lazy_tensors;
77 }
78 
GetTensorsBackendGraph(const std::vector<at::Tensor> & tensors)79 std::string GetTensorsBackendGraph(const std::vector<at::Tensor>& tensors) {
80   std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
81       GetLtcTensors(tensors, /*want_all=*/false);
82   return torch::lazy::LazyGraphExecutor::Get()->DumpBackendComputation(
83       lazy_tensors);
84 }
85 
SyncTensors(const std::vector<at::Tensor> & tensors,const std::vector<std::string> & devices,bool wait,bool sync_ltc_data)86 void SyncTensors(
87     const std::vector<at::Tensor>& tensors,
88     const std::vector<std::string>& devices,
89     bool wait,
90     bool sync_ltc_data) {
91   std::vector<torch::lazy::LazyTensorPtr> lazy_tensors =
92       GetLtcTensors(tensors, /*want_all=*/false);
93   torch::lazy::LazyGraphExecutor::Get()->SyncTensorsGraph(
94       &lazy_tensors, devices, wait, sync_ltc_data);
95 }
96 
initLazyBindings(PyObject * module)97 void initLazyBindings(PyObject* module) {
98   auto m = py::handle(module).cast<py::module>();
99   auto lazy = m.def_submodule("_lazy");
100   auto lazy_ts_backend = m.def_submodule("_lazy_ts_backend");
101 
102   lazy.def(
103       "_mark_step",
104       // TODO(whc) this API should probably change from vector<string> to
105       // vector<c10::device> but in a separate PR
106       [](const std::string& device_str,
107          const std::vector<std::string>& devices,
108          bool wait) {
109         pybind11::gil_scoped_release no_gil;
110         auto backend_device = GetDeviceOrCurrent(device_str);
111         torch::lazy::LazyGraphExecutor::Get()->SyncLiveTensorsGraph(
112             &backend_device, devices, wait);
113         torch::lazy::LazyGraphExecutor::Get()->MarkStep(backend_device);
114       },
115       py::arg("device") = "",
116       py::arg("devices"),
117       py::arg("wait") = true);
118   lazy.def(
119       "_wait_device_ops",
120       [](const std::vector<std::string>& devices) {
121         pybind11::gil_scoped_release no_gil;
122         // TODO: Add support of non-empty devices.
123         if (!devices.empty()) {
124           LOG(ERROR) << "Non-empty devices are not supported.";
125         }
126         torch::lazy::LazyGraphExecutor::Get()->WaitDeviceOps({});
127       },
128       py::arg("devices"));
129   lazy.def("_reset_metrics", []() {
130     torch::lazy::MetricsArena::Get()->ResetCounters();
131     torch::lazy::MetricsArena::Get()->ResetMetrics();
132   });
133   lazy.def("_counter_names", []() { return torch::lazy::GetCounterNames(); });
134   lazy.def(
135       "_metrics_report", []() { return torch::lazy::CreateMetricReport(); });
136   lazy.def("_counter_value", [](const std::string& name) -> py::object {
137     torch::lazy::CounterData* data = torch::lazy::GetCounter(name);
138     return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
139   });
140   lazy.def("_get_tensor_id", [](const at::Tensor& tensor) {
141     return GetTensorId(tensor);
142   });
143 
144   lazy.def(
145       "_get_tensors_text",
146       [](const std::vector<at::Tensor>& tensors) -> std::string {
147         auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
148           return torch::lazy::DumpUtil::ToText(nodes);
149         };
150         return GetTensorsDump(tensors, coverter);
151       });
152   lazy.def(
153       "_get_tensors_dot",
154       [](const std::vector<at::Tensor>& tensors) -> std::string {
155         auto coverter = [](c10::ArrayRef<const torch::lazy::Node*> nodes) {
156           return torch::lazy::DumpUtil::ToDot(nodes);
157         };
158         return GetTensorsDump(tensors, coverter);
159       });
160   lazy.def(
161       "_get_tensors_backend",
162       [](const std::vector<at::Tensor>& tensors) -> std::string {
163         return GetTensorsBackendGraph(tensors);
164       });
165   lazy.def("_get_graph_hash", [](const std::vector<at::Tensor>& tensors) {
166     std::vector<LazyTensorPtr> xtensors;
167     xtensors.reserve(tensors.size());
168     for (auto& tensor : tensors) {
169       xtensors.emplace_back(TryGetLtcTensor(tensor));
170     }
171     auto hash = LazyGraphExecutor::Get()->GetGraphHash(xtensors);
172     std::string bin((const char*)&hash, sizeof(hash));
173     return py::bytes(bin);
174   });
175   lazy.def(
176       "_sync_multi",
177       [](const std::vector<at::Tensor>& tensors,
178          const std::vector<std::string>& devices,
179          bool wait,
180          bool sync_ltc_data) {
181         pybind11::gil_scoped_release no_gil;
182         SyncTensors(tensors, devices, wait, sync_ltc_data);
183       },
184       py::arg("tensors"),
185       py::arg("devices"),
186       py::arg("wait") = true,
187       py::arg("sync_ltc_data") = true);
188 
189   lazy.def("_get_force_fallback", []() {
190     return torch::lazy::getLTCForceFallback();
191   });
192   lazy.def("_set_force_fallback", [](std::string newval) {
193     torch::lazy::getLTCForceFallback() = newval;
194   });
195   lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); });
196   lazy.def("_dump_ir_cache", [](std::string filename) {
197     TrieCache::Get()->DumpToDotFile(filename);
198   });
199   lazy.def("_set_reuse_ir", [](bool val) { FLAGS_torch_lazy_reuse_ir = val; });
200   lazy.def("_set_symbolic_shape_mode", [](bool val) {
201     FLAGS_ltc_enable_symbolic_shapes = val;
202   });
203   lazy.def("_get_symbolic_shape_mode", []() {
204     return FLAGS_ltc_enable_symbolic_shapes;
205   });
206   lazy.def("_get_default_device_type", []() {
207     return getBackend()->GetDefaultDeviceType()->toString();
208   });
209 
210   lazy_ts_backend.def("_init", []() {
211 #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
212     torch::lazy::InitTorchScriptBackend();
213 #else
214       TORCH_CHECK(false, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds");
215 #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
216   });
217 
218   /*
219    * Return tensor ids and tensors for DeviceData nodes.
220    * TODO(shunting) revisit this API for XLA
221    */
222   lazy_ts_backend.def(
223       "_get_tensors_ts_device_data_node",
224       [](const std::vector<at::Tensor>& tensors)
225           -> std::pair<std::vector<int64_t>, std::vector<at::IValue>> {
226 #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
227         std::vector<const Node*> roots;
228         for (auto& tensor : tensors) {
229           auto xtensor = TryGetLtcTensor(tensor);
230           roots.push_back(xtensor->GetIrValue().node.get());
231         }
232         auto post_order = Util::ComputePostOrder(roots);
233         std::vector<int64_t> tensor_ids;
234         std::vector<at::IValue> ivalues;
235 
236         std::unordered_set<BackendData::Handle> data_handles_;
237         for (auto nodeptr : post_order) {
238           if (nodeptr->op() == *torch::lazy::ltc_device_data) {
239             const auto backend_data =
240                 getBackend()->GetComputationDataFromNode(nodeptr);
241 
242             auto infoptr = backend_data->info();
243             auto deviceDataInfoPtr =
244                 (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
245             auto* tsDataPtr = (torch::lazy::TSData*)backend_data.get();
246 
247             // dedup DeviceData by handle
248             auto handle = tsDataPtr->GetHandle();
249             if (!data_handles_.insert(handle).second) {
250               continue;
251             }
252             tensor_ids.push_back(deviceDataInfoPtr->tensor_id);
253             /*
254              * If the TSData contains a tensor, then the tensor id will uniquely
255              * identify the tensor. We use that tensor id to find the tensor in
256              * other places: e.g. in the python forward method parameters.
257              *
258              * If the TSData contains a scalar, the tensor id itself is not
259              * important. We reuse the scalar value in future calls.
260              */
261             if (tsDataPtr->HasValue()) {
262               ivalues.emplace_back(tsDataPtr->data());
263             } else {
264               TORCH_CHECK(tsDataPtr->scalar.has_value());
265               ivalues.emplace_back(tsDataPtr->scalar.value());
266             }
267           }
268         }
269         return std::make_pair(tensor_ids, ivalues);
270 #else
271         TORCH_CHECK(
272             false, "TorchScript backend not yet supported in FBCODE builds");
273         return std::make_pair(
274             std::vector<int64_t>(), std::vector<at::IValue>());
275 #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
276       });
277   // TODO(shunting) revisit this part for XLA
278   lazy_ts_backend.def(
279       "_run_cached_graph",
280       [](const std::string& hash_str,
281          const std::vector<at::IValue>& graph_inputs) {
282         std::vector<at::Tensor> result;
283 #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
284         TORCH_CHECK(hash_str.size() == sizeof(hash_t));
285         hash_t hash = *(hash_t*)(hash_str.c_str());
286         auto cachedComputation =
287             LazyGraphExecutor::Get()->GetComputationCache()->Get(hash);
288         TORCH_CHECK(
289             cachedComputation,
290             "Failed to get computation by hash. Maybe the entry get kicked out of the LRU cache"); // TODO implement a fallback mechanism, or make sure those entries never get kicked out
291         auto computationPtr =
292             (torch::lazy::TSComputation*)cachedComputation->computation.get();
293 
294         std::vector<torch::jit::IValue> stack;
295         stack.reserve(graph_inputs.size());
296         for (const auto& arg : graph_inputs) {
297           stack.emplace_back(arg);
298         }
299         computationPtr->graph_executor().run(stack);
300         result.reserve(stack.size());
301         for (torch::jit::IValue elem : stack) {
302           result.push_back(elem.toTensor());
303         }
304 #else
305         TORCH_CHECK(
306             false, "TorchScript backend not yet supported in FBCODE builds");
307 #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
308         return result;
309       });
310   lazy_ts_backend.def("_get_latest_computation_graph", []() {
311 #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
312     auto computation = LazyGraphExecutor::Get()
313                            ->GetComputationCache()
314                            ->GetLatest()
315                            ->computation;
316     auto ts_computation = dynamic_cast<TSComputation*>(computation.get());
317     TORCH_CHECK(ts_computation, "Found non-TSComputation in cache");
318     return ts_computation->graph()->toString();
319 #else
320     TORCH_CHECK(
321         false, "TorchScript backend not yet supported in FBCODE builds");
322     return "";
323 #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
324   });
325 
326   // GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
327   // possibly becuase GetPythonFrames resolves to external cpython rather
328   // than embedded cpython. So far this problem has only been observed
329   // internally, so we will just block it off there.
330 
331 #if !(defined(USE_DEPLOY))
332 
333   // When libtorch_python is loaded, we register the python frame getter
334   // otherwise, debug util simply omits python frames
335   GetPythonFramesFunction() = GetPythonFrames;
336 
337 #endif // USE_DEPLOY
338 }
339 
340 } // namespace lazy
341 } // namespace torch
342