1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/partitioning_utils.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <unordered_map>
23 #include <utility>
24
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/graph/graph_partition.h"
30
31 namespace tensorflow {
32
33 namespace {
34
35 // A helper to partiton a `graph` given a `device_set` and a `graph`.
36 // `partitions` maps device names to the graphdef assigned to that device.
PartitionFunctionGraph(const DeviceSet & device_set,Graph * graph,std::unordered_map<string,GraphDef> * partitions,std::function<string (const Node *)> node_to_loc,std::function<string (const Edge *)> get_tensor_name_attr)37 Status PartitionFunctionGraph(
38 const DeviceSet& device_set, Graph* graph,
39 std::unordered_map<string, GraphDef>* partitions,
40 std::function<string(const Node*)> node_to_loc,
41 std::function<string(const Edge*)> get_tensor_name_attr) {
42 PartitionOptions partition_options;
43 if (node_to_loc != nullptr) {
44 partition_options.node_to_loc = node_to_loc;
45 } else {
46 partition_options.node_to_loc = [](const Node* node) {
47 // TODO(iga): To support the distributed case, first split the graph by
48 // worker (e.g,. using the master session's `SplitByWorker` policy), and
49 // then recursively partition the per-worker shards at the remote
50 // worker(s). Currently, we simply split the graph at device boundaries.
51 return node->assigned_device_name();
52 };
53 }
54 int64_t edge_name_counter = 0;
55 partition_options.new_name = [&edge_name_counter](const string& prefix) {
56 return strings::StrCat(prefix, "/_", ++edge_name_counter);
57 };
58 partition_options.get_incarnation =
59 [&device_set](const string& name) -> int64 {
60 const Device* d = device_set.FindDeviceByName(name);
61 if (d == nullptr) {
62 return PartitionOptions::kIllegalIncarnation;
63 } else {
64 return d->attributes().incarnation();
65 }
66 };
67 partition_options.control_flow_added = false;
68 partition_options.get_tensor_name_attr = get_tensor_name_attr;
69
70 return Partition(partition_options, graph, partitions);
71 }
72
73 } // namespace
74
PartitionFunctionGraph(const DeviceSet & device_set,std::unique_ptr<Graph> graph,std::unordered_map<string,std::unique_ptr<Graph>> * subgraphs,std::function<string (const Edge *)> get_tensor_name_attr)75 Status PartitionFunctionGraph(
76 const DeviceSet& device_set, std::unique_ptr<Graph> graph,
77 std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs,
78 std::function<string(const Edge*)> get_tensor_name_attr) {
79 std::unordered_map<string, GraphDef> partitions;
80 TF_RETURN_IF_ERROR(
81 PartitionFunctionGraph(device_set, graph.get(), &partitions,
82 /*node_to_loc=*/nullptr, get_tensor_name_attr));
83
84 for (auto& partition : partitions) {
85 const string& device = partition.first;
86 GraphDef& graph_def = partition.second;
87 // Each partition gets a new graph.
88 std::unique_ptr<Graph> subgraph(
89 new Graph(graph->flib_def().default_registry()));
90 GraphConstructorOptions opts;
91 opts.allow_internal_ops = true;
92 opts.expect_device_spec = true;
93 TF_RETURN_IF_ERROR(
94 ConvertGraphDefToGraph(opts, std::move(graph_def), subgraph.get()));
95 subgraphs->emplace(device, std::move(subgraph));
96 }
97
98 return OkStatus();
99 }
100
InsertTransferOps(const DeviceSet & device_set,std::unique_ptr<Graph> graph)101 StatusOr<std::unique_ptr<Graph>> InsertTransferOps(
102 const DeviceSet& device_set, std::unique_ptr<Graph> graph) {
103 // Skip transfer op insertion if the graph nodes are not assigned to multiple
104 // devices.
105 auto node_to_loc = [](const Node* node) {
106 return node->assigned_device_name();
107 };
108 bool has_multiple_devices = false;
109 absl::optional<std::string> location;
110 for (const Node* node : graph->op_nodes()) {
111 if (location) {
112 if (*location != node_to_loc(node)) {
113 has_multiple_devices = true;
114 break;
115 }
116 } else {
117 location = node_to_loc(node);
118 }
119 }
120 if (!has_multiple_devices) {
121 return graph;
122 }
123
124 // Transfer ops are needed as there are multiple devices, so proceed with the
125 // partitioning.
126 auto new_graph = std::make_unique<Graph>(graph->flib_def());
127
128 std::unordered_map<string, GraphDef> partitions;
129 TF_RETURN_IF_ERROR(PartitionFunctionGraph(device_set, graph.get(),
130 &partitions, node_to_loc,
131 /*get_tensor_name_attr=*/nullptr));
132
133 GraphDef merged_graph_def;
134 if (!partitions.empty()) {
135 auto iter = partitions.begin();
136 merged_graph_def = std::move(iter->second);
137 while (++iter != partitions.end()) {
138 // TODO(b/220440252): MergeFrom() does memory copies when merging repeated
139 // fields. Ideally, we can merge repeated fields by 'moving' data.
140 // Consider using `proto2::util::MoveToEnd()` or so, once it is open
141 // sourced.
142 merged_graph_def.MergeFrom(iter->second);
143 }
144 }
145
146 GraphConstructorOptions opts;
147 opts.allow_internal_ops = true;
148 opts.expect_device_spec = true;
149 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(merged_graph_def),
150 new_graph.get()));
151 return std::move(new_graph);
152 }
153
UpdateArgAndRetvalMetadata(Graph * graph,std::vector<FunctionArgIndex> * arg_indices,std::vector<int> * ret_indices,std::vector<AllocatorAttributes> * arg_alloc_attrs,std::vector<AllocatorAttributes> * ret_alloc_attrs,bool ints_on_device)154 Status UpdateArgAndRetvalMetadata(
155 Graph* graph, std::vector<FunctionArgIndex>* arg_indices,
156 std::vector<int>* ret_indices,
157 std::vector<AllocatorAttributes>* arg_alloc_attrs,
158 std::vector<AllocatorAttributes>* ret_alloc_attrs, bool ints_on_device) {
159 std::vector<std::pair<Node*, FunctionArgIndex>> arg_nodes;
160 std::vector<std::pair<Node*, int>> ret_nodes;
161 const AttrValue* attr_value;
162
163 // Find the Arg and Retval nodes, along with their corresponding indices
164 // in the original function.
165 for (Node* node : graph->op_nodes()) {
166 if (node->IsArg()) {
167 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
168 int index = static_cast<int>(attr_value->i());
169 int sub_index = -1;
170 if (node->attrs().Find("sub_index", &attr_value).ok()) {
171 sub_index = static_cast<int>(attr_value->i());
172 }
173 arg_nodes.emplace_back(node, FunctionArgIndex(index, sub_index));
174 } else if (node->IsRetval()) {
175 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
176 int index = static_cast<int>(attr_value->i());
177 ret_nodes.emplace_back(node, index);
178 }
179 }
180
181 // Sort the nodes by index so that the order is stable.
182 //
183 // In particular, this enables calling a single-partition function with
184 // the same signature as the original unpartitioned function.
185 auto arg_comparator = [](std::pair<Node*, FunctionArgIndex> a,
186 std::pair<Node*, FunctionArgIndex> b) {
187 return std::tie(a.second.index, a.second.sub_index) <
188 std::tie(b.second.index, b.second.sub_index);
189 };
190 std::sort(arg_nodes.begin(), arg_nodes.end(), arg_comparator);
191 auto ret_comparator = [](std::pair<Node*, int> a, std::pair<Node*, int> b) {
192 return a.second < b.second;
193 };
194 std::sort(ret_nodes.begin(), ret_nodes.end(), ret_comparator);
195
196 arg_indices->reserve(arg_nodes.size());
197 for (const auto& pair : arg_nodes) arg_indices->push_back(pair.second);
198 ret_indices->reserve(ret_nodes.size());
199 for (const auto& pair : ret_nodes) ret_indices->push_back(pair.second);
200
201 for (int i = 0; i < arg_nodes.size(); ++i) {
202 Node* arg = arg_nodes[i].first;
203 arg->AddAttr("index", i);
204 TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
205 if (arg_alloc_attrs != nullptr) {
206 AllocatorAttributes alloc_attr;
207 DataType type = attr_value->type();
208 MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
209 : MTypeFromDType(type);
210 if (mtype == HOST_MEMORY) {
211 alloc_attr.set_on_host(true);
212 }
213 arg_alloc_attrs->push_back(alloc_attr);
214 }
215 }
216 for (int i = 0; i < ret_nodes.size(); ++i) {
217 Node* ret = ret_nodes[i].first;
218 ret->AddAttr("index", i);
219 TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
220 if (ret_alloc_attrs) {
221 AllocatorAttributes alloc_attr;
222 DataType type = attr_value->type();
223 MemoryType mtype = ints_on_device ? MTypeFromDTypeIntsOnDevice(type)
224 : MTypeFromDType(type);
225 if (mtype == HOST_MEMORY) {
226 alloc_attr.set_on_host(true);
227 }
228 ret_alloc_attrs->push_back(alloc_attr);
229 }
230 }
231
232 return OkStatus();
233 }
234
GetName()235 string FunctionNameGenerator::GetName() {
236 while (true) {
237 const string candidate = strings::StrCat(name_, "_", counter_++);
238 if (flib_def_->Find(candidate) == nullptr) {
239 return candidate;
240 }
241 }
242 }
243
244 } // namespace tensorflow
245