1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/framework/device_factory.h"
20 #include "tensorflow/core/framework/memory_types.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/util/dump_graph.h"
28
29 namespace tensorflow {
30
31 struct Endpoint {
32 int node_id;
33 int output_index;
34 };
35
36 struct EndpointHash {
operator ()tensorflow::EndpointHash37 uint32 operator()(const Endpoint& x) const {
38 return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
39 x.output_index);
40 }
41 };
42
43 struct EndpointEq {
operator ()tensorflow::EndpointEq44 uint32 operator()(const Endpoint& x, const Endpoint& y) const {
45 return (x.node_id == y.node_id) && (x.output_index == y.output_index);
46 }
47 };
48
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)49 static Status ProcessMemoryTypes(
50 const DeviceType& device_type, const Graph* g,
51 const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
52 if (device_type != DEVICE_GPU &&
53 !DeviceFactory::IsPluggableDevice(device_type.type_string())) {
54 // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
55 return OkStatus();
56 }
57 // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
58 // conversion/transfer must be done.
59 //
60 // {node id, slot id} -> memory type.
61 typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
62 MemTypeMap;
63 MemTypeMap inp;
64 MemTypeMap out;
65 MemoryTypeVector inp_mvec;
66 MemoryTypeVector out_mvec;
67 for (const Node* n : g->nodes()) {
68 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
69 n->def(), &inp_mvec, &out_mvec));
70 for (size_t i = 0; i < inp_mvec.size(); ++i) {
71 VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
72 inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
73 }
74 for (size_t i = 0; i < out_mvec.size(); ++i) {
75 VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
76 out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
77 }
78 }
79 for (const Edge* e : g->edges()) {
80 if (e->IsControlEdge()) {
81 continue;
82 }
83 MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
84 DEVICE_MEMORY);
85 MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
86 DEVICE_MEMORY);
87 VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
88 << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
89 << dm;
90 TF_RETURN_IF_ERROR(fn(e, sm, dm));
91 }
92 return OkStatus();
93 }
94
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)95 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
96 return ProcessMemoryTypes(
97 device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
98 if (sm == dm) {
99 return OkStatus();
100 }
101 return errors::Internal("Memory type mismatch (", sm, " ", dm,
102 ") between :", e->src()->id(), ":",
103 e->src_output(), " and ", e->dst()->id(), ":",
104 e->dst_input(), " : from ",
105 FormatNodeForError(*e->src()), " to ",
106 FormatNodeForError(*e->dst()));
107 });
108 }
109
110 // Given an Edge whose two endpoints have different memory types and
111 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
112 // GetTensorName() returns a unique string that we can use as part of
113 // the rendezvous key. The return string is guaranteed to be unique
114 // within this process. That is sufficient because EnsureMemoryTypes
115 // is only used on a TensorFlow graph that is gonna to be executed in
116 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)117 static string GetTensorName(const Edge* edge) {
118 static std::atomic<int64_t> counter(0);
119 return strings::StrCat("memtype_", counter.fetch_add(1), "_",
120 edge->src()->name());
121 }
122
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)123 static Node* Send(Graph* g, const string& tensor_name,
124 const string& device_name, bool host, const Edge* edge) {
125 Node* ret;
126 TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
127 .Input(edge->src(), edge->src_output())
128 .Attr("tensor_name", tensor_name)
129 .Attr("send_device", device_name)
130 .Attr("send_device_incarnation", 0) // Do not care.
131 .Attr("recv_device", device_name)
132 .Attr("_hostmem_sendrecv", true)
133 .Attr("_src", edge->src()->name())
134 .Attr("_dst", edge->dst()->name())
135 .Finalize(g, &ret));
136 return ret;
137 }
138
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)139 static Node* Recv(Graph* g, const string& tensor_name,
140 const string& device_name, bool host, const Edge* edge) {
141 Node* ret;
142 TF_CHECK_OK(
143 NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
144 .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
145 .Attr("tensor_name", tensor_name)
146 .Attr("send_device", device_name)
147 .Attr("send_device_incarnation", 0)
148 .Attr("recv_device", device_name)
149 .Attr("_hostmem_sendrecv", true)
150 .Attr("_src", edge->src()->name())
151 .Attr("_dst", edge->dst()->name())
152 .Finalize(g, &ret));
153 return ret;
154 }
155
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)156 Status EnsureMemoryTypes(const DeviceType& device_type,
157 const string& device_name, Graph* g) {
158 struct Item {
159 const Edge* edge;
160 MemoryType sm;
161 MemoryType dm;
162 };
163 std::vector<Item> edges;
164 TF_RETURN_IF_ERROR(ProcessMemoryTypes(
165 device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
166 if (sm == dm) {
167 return OkStatus();
168 }
169 if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
170 ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
171 edges.push_back({e, sm, dm});
172 return OkStatus();
173 }
174 return errors::Internal("Unexpected memory type pair on an edge: ", sm,
175 " vs. ", dm);
176 }));
177
178 // edges contains edges in 'g' that memtype is not
179 // compatible. Therefore, if we found any, we need to insert
180 // HostSend/Recv and Send/HostRecv pairs. recv_nodes records all
181 // nodes we added so that we don't copy the same tensor more than
182 // once.
183 if (!edges.empty()) {
184 std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
185 for (const auto& item : edges) {
186 const Edge* e = item.edge;
187 const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
188 Node* recv = nullptr;
189 Endpoint key{e->src()->id(), e->src_output()};
190 auto iter = recv_nodes.find(key);
191 if (iter == recv_nodes.end()) {
192 const string tensor_name = GetTensorName(e);
193 Node* send =
194 Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
195 recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
196 if (!has_ref) {
197 // We only cache if there is no ref is involved.
198 recv_nodes[key] = recv;
199 }
200 g->AddControlEdge(send, recv);
201 } else {
202 recv = iter->second;
203 }
204 g->AddEdge(recv, 0, e->dst(), e->dst_input());
205 g->RemoveEdge(e);
206 }
207 }
208
209 if (VLOG_IS_ON(2)) {
210 VLOG(2) << "Dumped graph after EnsureMemoryTypes to "
211 << DumpGraphToFile("EnsureMemoryTypes", *g);
212 }
213
214 return ValidateMemoryTypes(device_type, g);
215 }
216
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)217 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
218 const Node* n, int index, MemoryType* memory_type) {
219 MemoryTypeVector inp_mvec;
220 MemoryTypeVector out_mvec;
221 TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
222 &inp_mvec, &out_mvec));
223 if (out_mvec.size() <= index) {
224 return errors::Internal("Trying to get the memory type for ", index,
225 "'th output of node ", FormatNodeForError(*n),
226 " that has only ", out_mvec.size(), " outputs");
227 }
228 *memory_type = out_mvec[index];
229 return OkStatus();
230 }
231
232 } // end namespace tensorflow
233