xref: /aosp_15_r20/external/tensorflow/tensorflow/python/grappler/cluster_wrapper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cfloat>
18 #include <cstdint>
19 #include <memory>
20 #include <set>
21 #include <stdexcept>
22 #include <string>
23 #include <tuple>
24 #include <unordered_map>
25 #include <vector>
26 
27 #include "pybind11/pybind11.h"
28 #include "pybind11/stl.h"
29 #include "tensorflow/core/framework/kernel_def.pb.h"
30 #include "tensorflow/core/framework/memory_types.h"
31 #include "tensorflow/core/framework/op_def.pb.h"
32 #include "tensorflow/core/framework/step_stats.pb.h"
33 #include "tensorflow/core/grappler/clusters/cluster.h"
34 #include "tensorflow/core/grappler/clusters/single_machine.h"
35 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
36 #include "tensorflow/core/grappler/costs/cost_estimator.h"
37 #include "tensorflow/core/grappler/costs/graph_memory.h"
38 #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
39 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
40 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
41 #include "tensorflow/core/grappler/costs/utils.h"
42 #include "tensorflow/core/grappler/devices.h"
43 #include "tensorflow/core/grappler/grappler_item.h"
44 #include "tensorflow/core/grappler/utils.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/protobuf/config.pb.h"
47 #include "tensorflow/core/protobuf/device_properties.pb.h"
48 #include "tensorflow/python/lib/core/pybind11_status.h"
49 
50 namespace py = pybind11;
51 
_GetOpPerformanceDataAndRunTime(const tensorflow::grappler::GrapplerItem & item,tensorflow::grappler::CostEstimator * cost_measure,tensorflow::OpPerformanceList * op_performance_data,tensorflow::grappler::Costs * costs)52 tensorflow::Status _GetOpPerformanceDataAndRunTime(
53     const tensorflow::grappler::GrapplerItem& item,
54     tensorflow::grappler::CostEstimator* cost_measure,
55     tensorflow::OpPerformanceList* op_performance_data,
56     tensorflow::grappler::Costs* costs) {
57   tensorflow::Status status = cost_measure->Initialize(item);
58   if (!status.ok()) return status;
59 
60   tensorflow::RunMetadata run_metadata;
61   MaybeRaiseRegisteredFromStatus(
62       cost_measure->PredictCosts(item.graph, &run_metadata, costs));
63 
64   if (op_performance_data) {
65     *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
66         run_metadata.cost_graph(), item.graph);
67   }
68   return ::tensorflow::OkStatus();
69 }
70 
71 PYBIND11_MAKE_OPAQUE(tensorflow::grappler::Cluster);
72 
PYBIND11_MODULE(_pywrap_tf_cluster,m)73 PYBIND11_MODULE(_pywrap_tf_cluster, m) {
74   py::class_<tensorflow::grappler::Cluster> grappler_cluster(
75       m, "tensorflow::grappler::Cluster");
76 
77   m.def("TF_NewCluster",
78         [](bool allow_soft_placement,
79            bool disable_detailed_stats) -> tensorflow::grappler::Cluster* {
80           // TODO(petebu): Make these named arguments with default values
81           // instead.
82           int num_cpu_cores =
83               tensorflow::grappler::GetNumAvailableLogicalCPUCores();
84           int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
85           int timeout_s = 60 * 10;
86           std::unique_ptr<tensorflow::grappler::Cluster> cluster =
87               std::make_unique<tensorflow::grappler::SingleMachine>(
88                   timeout_s, num_cpu_cores, num_gpus);
89           cluster->DisableDetailedStats(disable_detailed_stats);
90           cluster->AllowSoftPlacement(allow_soft_placement);
91           cluster->SetNumWarmupSteps(10);
92           MaybeRaiseRegisteredFromStatus(cluster->Provision());
93           return cluster.release();
94         });
95 
96   m.def("TF_NewVirtualCluster",
97         [](const std::vector<py::bytes>& serialized_named_devices)
98             -> tensorflow::grappler::Cluster* {
99           std::vector<tensorflow::NamedDevice> named_devices;
100           for (const auto& s : serialized_named_devices) {
101             tensorflow::NamedDevice named_device;
102             if (!named_device.ParseFromString(std::string(s))) {
103               throw std::invalid_argument(
104                   "The NamedDevice could not be parsed as a valid protocol "
105                   "buffer");
106             }
107             named_devices.push_back(named_device);
108           }
109 
110           std::unordered_map<std::string, tensorflow::DeviceProperties> devices;
111           for (const auto& named_device : named_devices) {
112             devices[named_device.name()] = named_device.properties();
113           }
114           std::unique_ptr<tensorflow::grappler::Cluster> cluster =
115               std::make_unique<tensorflow::grappler::VirtualCluster>(devices);
116           {
117             // TODO(petebu): Do we need to hold the GIL here?
118             py::gil_scoped_acquire acquire;
119             MaybeRaiseRegisteredFromStatus(cluster->Provision());
120           }
121           return cluster.release();
122         });
123 
124   m.def("TF_ShutdownCluster", [](tensorflow::grappler::Cluster* cluster) {
125     // TODO(petebu): Do we need to hold the GIL here?
126     py::gil_scoped_acquire acquire;
127     cluster->Shutdown();
128   });
129 
130   m.def("TF_ListDevices",
131         [](tensorflow::grappler::Cluster* cluster) -> std::vector<py::bytes> {
132           const std::unordered_map<std::string, tensorflow::DeviceProperties>&
133               devices = cluster->GetDevices();
134           std::vector<py::bytes> named_devices;
135           for (auto& dev : devices) {
136             tensorflow::NamedDevice d;
137             d.set_name(dev.first);
138             *d.mutable_properties() = dev.second;
139             named_devices.push_back(d.SerializeAsString());
140           }
141           return named_devices;
142         });
143 
144   m.def("TF_ListAvailableOps", []() -> std::vector<std::string> {
145     tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
146     std::vector<tensorflow::OpDef> ops;
147     registry->GetRegisteredOps(&ops);
148     std::vector<std::string> op_names;
149     op_names.reserve(ops.size());
150     for (const tensorflow::OpDef& op : ops) {
151       op_names.push_back(op.name());
152     }
153     std::sort(op_names.begin(), op_names.end());
154     return op_names;
155   });
156 
157   m.def(
158       "TF_GetSupportedDevices",
159       [](tensorflow::grappler::Cluster* cluster,
160          tensorflow::grappler::GrapplerItem* item)
161           -> std::unordered_map<std::string, std::vector<std::string>> {
162         if (cluster == nullptr || item == nullptr) {
163           MaybeRaiseRegisteredFromStatus(tensorflow::Status(
164               tensorflow::errors::Internal("You need both a cluster and an "
165                                            "item to get supported devices.")));
166         }
167         const std::unordered_map<std::string, tensorflow::DeviceProperties>&
168             devices = cluster->GetDevices();
169         std::unordered_map<std::string, std::vector<std::string>> device_types;
170         for (const auto& dev : devices) {
171           device_types[dev.second.type()].push_back(dev.first);
172         }
173 
174         std::unordered_map<std::string, std::set<std::string>>
175             supported_device_types;
176         std::unordered_map<std::string, std::set<std::string>>
177             device_restrictions;
178 
179         for (const auto& node : item->graph.node()) {
180           for (const auto& dev : device_types) {
181             const std::string& type = dev.first;
182             if (cluster->type() != "single_machine") {
183               // The actual kernel may not be linked in this binary.
184               supported_device_types[node.name()].insert(type);
185             } else {
186               // Check the kernel capabilities
187               const tensorflow::DeviceType dev_type(type);
188               tensorflow::Status s =
189                   tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
190               if (s.ok()) {
191                 supported_device_types[node.name()].insert(type);
192 
193                 // Check which inputs are restricted to reside on the host.
194                 // TODO: extends this to support outputs as well
195                 tensorflow::MemoryTypeVector inp_mtypes;
196                 tensorflow::MemoryTypeVector out_mtypes;
197                 tensorflow::Status s = tensorflow::MemoryTypesForNode(
198                     tensorflow::OpRegistry::Global(), dev_type, node,
199                     &inp_mtypes, &out_mtypes);
200                 if (s.ok()) {
201                   for (size_t i = 0; i < inp_mtypes.size(); ++i) {
202                     if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
203                       device_restrictions[tensorflow::grappler::NodeName(
204                                               node.input(i))]
205                           .insert("CPU");
206                       break;
207                     }
208                   }
209                 }
210               }
211             }
212           }
213         }
214 
215         std::unordered_map<std::string, std::vector<std::string>> result;
216         for (const auto& supported_dev : supported_device_types) {
217           const std::string& node = supported_dev.first;
218           std::set<std::string> feasible;
219           const auto it = device_restrictions.find(node);
220           if (it != device_restrictions.end()) {
221             const std::set<std::string>& candidates = supported_dev.second;
222             const std::set<std::string>& valid = it->second;
223             std::set_intersection(candidates.begin(), candidates.end(),
224                                   valid.begin(), valid.end(),
225                                   std::inserter(feasible, feasible.begin()));
226           } else {
227             feasible = supported_dev.second;
228           }
229 
230           std::vector<std::string> device_names;
231           for (const std::string& type : feasible) {
232             auto it = device_types.find(type);
233             DCHECK(it != device_types.end());
234             for (const std::string& name : it->second) {
235               device_names.push_back(name);
236             }
237           }
238           result[node] = device_names;
239         }
240         return result;
241       });
242 
243   m.def("TF_EstimatePerformance", [](const py::bytes& serialized_device) {
244     tensorflow::NamedDevice device;
245     if (!device.ParseFromString(std::string(serialized_device))) {
246       throw std::invalid_argument(
247           "The NamedDevice could not be parsed as a valid protocol buffer");
248     }
249     tensorflow::grappler::OpLevelCostEstimator estimator;
250     tensorflow::grappler::DeviceInfo info =
251         estimator.GetDeviceInfo(device.properties());
252     return info.gigaops;
253   });
254 
255   m.def("TF_MeasureCosts",
256         [](tensorflow::grappler::GrapplerItem* item,
257            tensorflow::grappler::Cluster* cluster, bool generate_timeline)
258             -> std::tuple<std::vector<py::bytes>, double, py::bytes> {
259           const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
260           tensorflow::grappler::MeasuringCostEstimator cost_measure(
261               cluster, num_measurements, 0);
262 
263           tensorflow::OpPerformanceList op_performance_data;
264           tensorflow::grappler::Costs costs;
265           tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
266               *item, &cost_measure, &op_performance_data, &costs);
267           double run_time = FLT_MAX;
268           if (s.ok()) {
269             run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
270           }
271           tensorflow::StepStats step_stats;
272           if (generate_timeline) {
273             tensorflow::RunMetadata metadata;
274             MaybeRaiseRegisteredFromStatus(
275                 cluster->Run(item->graph, item->feed, item->fetch, &metadata));
276             step_stats = metadata.step_stats();
277           }
278 
279           std::vector<py::bytes> op_perf_objs;
280           op_perf_objs.resize(op_performance_data.op_performance_size());
281           for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
282             op_perf_objs[i] =
283                 op_performance_data.op_performance(i).SerializeAsString();
284           }
285 
286           py::bytes step_stats_str = step_stats.SerializeAsString();
287           return std::make_tuple(op_perf_objs, run_time, step_stats_str);
288         });
289 
290   using DurationType = tensorflow::grappler::Costs::Duration::rep;
291   using MemoryUsage =
292       std::tuple<std::string, int, size_t, DurationType, DurationType>;
293 
294   m.def(
295       "TF_DeterminePeakMemoryUsage",
296       [](tensorflow::grappler::GrapplerItem* item,
297          tensorflow::grappler::Cluster* cluster)
298           -> std::unordered_map<std::string,
299                                 std::tuple<int64_t, std::vector<MemoryUsage>>> {
300         if (item == nullptr || cluster == nullptr) {
301           MaybeRaiseRegisteredFromStatus(
302               tensorflow::Status(tensorflow::errors::Internal(
303                   "You need both a cluster and an item to determine peak "
304                   "memory usage.")));
305         }
306         tensorflow::grappler::GraphMemory memory(*item);
307 
308         if (cluster->DetailedStatsEnabled()) {
309           MaybeRaiseRegisteredFromStatus(memory.InferDynamically(cluster));
310         } else {
311           MaybeRaiseRegisteredFromStatus(
312               memory.InferStatically(cluster->GetDevices()));
313         }
314 
315         std::unordered_map<std::string,
316                            std::tuple<int64_t, std::vector<MemoryUsage>>>
317             result;
318         for (const auto& device : cluster->GetDevices()) {
319           const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
320               memory.GetPeakMemoryUsage(device.first);
321           std::vector<MemoryUsage> per_device;
322           for (size_t i = 0; i < usage.live_tensors.size(); ++i) {
323             const auto& live_tensor = usage.live_tensors[i];
324             per_device.push_back(std::make_tuple(
325                 live_tensor.node, live_tensor.output_id,
326                 live_tensor.memory_used, live_tensor.allocation_time.count(),
327                 live_tensor.deallocation_time.count()));
328           }
329           result[device.first] = std::make_tuple(usage.used_memory, per_device);
330         }
331         return result;
332       });
333 }
334