xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/partitioning_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24 
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/graph/graph_partition.h"
30 
31 namespace tensorflow {
32 
33 namespace {
34 
35 // A helper to partiton a `graph` given a `device_set` and a `graph`.
36 // `partitions` maps device names to the graphdef assigned to that device.
PartitionFunctionGraph(const DeviceSet & device_set,Graph * graph,std::unordered_map<string,GraphDef> * partitions,std::function<string (const Node *)> node_to_loc,std::function<string (const Edge *)> get_tensor_name_attr)37 Status PartitionFunctionGraph(
38     const DeviceSet& device_set, Graph* graph,
39     std::unordered_map<string, GraphDef>* partitions,
40     std::function<string(const Node*)> node_to_loc,
41     std::function<string(const Edge*)> get_tensor_name_attr) {
42   PartitionOptions partition_options;
43   if (node_to_loc != nullptr) {
44     partition_options.node_to_loc = node_to_loc;
45   } else {
46     partition_options.node_to_loc = [](const Node* node) {
47       // TODO(iga): To support the distributed case, first split the graph by
48       // worker (e.g,. using the master session's `SplitByWorker` policy), and
49       // then recursively partition the per-worker shards at the remote
50       // worker(s). Currently, we simply split the graph at device boundaries.
51       return node->assigned_device_name();
52     };
53   }
54   int64_t edge_name_counter = 0;
55   partition_options.new_name = [&edge_name_counter](const string& prefix) {
56     return strings::StrCat(prefix, "/_", ++edge_name_counter);
57   };
58   partition_options.get_incarnation =
59       [&device_set](const string& name) -> int64 {
60     const Device* d = device_set.FindDeviceByName(name);
61     if (d == nullptr) {
62       return PartitionOptions::kIllegalIncarnation;
63     } else {
64       return d->attributes().incarnation();
65     }
66   };
67   partition_options.control_flow_added = false;
68   partition_options.get_tensor_name_attr = get_tensor_name_attr;
69 
70   return Partition(partition_options, graph, partitions);
71 }
72 
73 }  // namespace
74 
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs,std::function<string (const Edge *)> get_tensor_name_attr)75 Status PartitionFunctionGraph(
76     const DeviceSet& device_set, std::unique_ptr<Graph> graph,
77     std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs,
78     std::function<string(const Edge*)> get_tensor_name_attr) {
79   std::unordered_map<string, GraphDef> partitions;
80   TF_RETURN_IF_ERROR(
81       PartitionFunctionGraph(device_set, graph.get(), &partitions,
82                              /*node_to_loc=*/nullptr, get_tensor_name_attr));
83 
84   for (auto& partition : partitions) {
85     const string& device = partition.first;
86     GraphDef& graph_def = partition.second;
87     // Each partition gets a new graph.
88     std::unique_ptr<Graph> subgraph(
89         new Graph(graph->flib_def().default_registry()));
90     GraphConstructorOptions opts;
91     opts.allow_internal_ops = true;
92     opts.expect_device_spec = true;
93     TF_RETURN_IF_ERROR(
94         ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
95     subgraphs->emplace(device, std::move(subgraph));
96   }
97 
98   return OkStatus();
99 }
100 
InsertTransferOps(const DeviceSet & device_set,std::unique_ptr<Graph> graph)101 StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
102     const DeviceSet& device_set, std::unique_ptr<Graph> graph) {
103   // Skip transfer op insertion if the graph nodes are not assigned to multiple
104   // devices.
105   auto node_to_loc = [](const Node* node) {
106     return node->assigned_device_name();
107   };
108   bool has_multiple_devices = false;
109   absl::optional<std::string> location;
110   for (const Node* node : graph->op_nodes()) {
111     if (location) {
112       if (*location != node_to_loc(node)) {
113         has_multiple_devices = true;
114         break;
115       }
116     } else {
117       location = node_to_loc(node);
118     }
119   }
120   if (!has_multiple_devices) {
121     return graph;
122   }
123 
124   // Transfer ops are needed as there are multiple devices, so proceed with the
125   // partitioning.
126   auto new_graph = std::make_unique<Graph>(graph->flib_def());
127 
128   std::unordered_map<string, GraphDef> partitions;
129   TF_RETURN_IF_ERROR(PartitionFunctionGraph(device_set, graph.get(),
130                                             &partitions, node_to_loc,
131                                             /*get_tensor_name_attr=*/nullptr));
132 
133   GraphDef merged_graph_def;
134   if (!partitions.empty()) {
135     auto iter = partitions.begin();
136     merged_graph_def = std::move(iter->second);
137     while (++iter != partitions.end()) {
138       // TODO(b/220440252): MergeFrom() does memory copies when merging repeated
139       // fields. Ideally, we can merge repeated fields by 'moving' data.
140       // Consider using `proto2::util::MoveToEnd()` or so, once it is open
141       // sourced.
142       merged_graph_def.MergeFrom(iter->second);
143     }
144   }
145 
146   GraphConstructorOptions opts;
147   opts.allow_internal_ops = true;
148   opts.expect_device_spec = true;
149   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(merged_graph_def),
150                                             new_graph.get()));
151   return std::move(new_graph);
152 }
153 
UpdateArgAndRetvalMetadata(Graph * graph,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs,bool ints_on_device)154 Status UpdateArgAndRetvalMetadata(
155     Graph* graph, std::vector<FunctionArgIndex>* arg_indices,
156     std::vector<int>* ret_indices,
157     std::vector<AllocatorAttributes>* arg_alloc_attrs,
158     std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device) {
159   std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
160   std::vector<std::pair<Node*, int>> ret_nodes;
161   const AttrValue* attr_value;
162 
163   // Find the Arg and Retval nodes, along with their corresponding indices
164   // in the original function.
165   for (Node* node : graph->op_nodes()) {
166     if (node->IsArg()) {
167       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
168       int index = static_cast<int>(attr_value->i());
169       int sub_index = -1;
170       if (node->attrs().Find("sub_index", &attr_value).ok()) {
171         sub_index = static_cast<int>(attr_value->i());
172       }
173       arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
174     } else if (node->IsRetval()) {
175       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
176       int index = static_cast<int>(attr_value->i());
177       ret_nodes.emplace_back(node, index);
178     }
179   }
180 
181   // Sort the nodes by index so that the order is stable.
182   //
183   // In particular, this enables calling a single-partition function with
184   // the same signature as the original unpartitioned function.
185   auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
186                            std::pair<Node*, FunctionArgIndex> b) {
187     return std::tie(a.second.index, a.second.sub_index) <
188            std::tie(b.second.index, b.second.sub_index);
189   };
190   std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
191   auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
192     return a.second < b.second;
193   };
194   std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
195 
196   arg_indices->reserve(arg_nodes.size());
197   for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
198   ret_indices->reserve(ret_nodes.size());
199   for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
200 
201   for (int i = 0; i < arg_nodes.size(); ++i) {
202     Node* arg = arg_nodes[i].first;
203     arg->AddAttr("index", i);
204     TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
205     if (arg_alloc_attrs != nullptr) {
206       AllocatorAttributes alloc_attr;
207       DataType type = attr_value->type();
208       MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
209                                         : MTypeFromDType(type);
210       if (mtype == HOST_MEMORY) {
211         alloc_attr.set_on_host(true);
212       }
213       arg_alloc_attrs->push_back(alloc_attr);
214     }
215   }
216   for (int i = 0; i < ret_nodes.size(); ++i) {
217     Node* ret = ret_nodes[i].first;
218     ret->AddAttr("index", i);
219     TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
220     if (ret_alloc_attrs) {
221       AllocatorAttributes alloc_attr;
222       DataType type = attr_value->type();
223       MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
224                                         : MTypeFromDType(type);
225       if (mtype == HOST_MEMORY) {
226         alloc_attr.set_on_host(true);
227       }
228       ret_alloc_attrs->push_back(alloc_attr);
229     }
230   }
231 
232   return OkStatus();
233 }
234 
GetName()235 string FunctionNameGenerator::GetName() {
236   while (true) {
237     const string candidate = strings::StrCat(name_, "_", counter_++);
238     if (flib_def_->Find(candidate) == nullptr) {
239       return candidate;
240     }
241   }
242 }
243 
244 }  // namespace tensorflow
245