xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/memory_types.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/memory_types.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/framework/device_factory.h"
20 #include "tensorflow/core/framework/memory_types.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/core/util/dump_graph.h"
28 
29 namespace tensorflow {
30 
31 struct Endpoint {
32   int node_id;
33   int output_index;
34 };
35 
36 struct EndpointHash {
operator ()tensorflow::EndpointHash37   uint32 operator()(const Endpoint& x) const {
38     return Hash32(reinterpret_cast<const char*>(&x.node_id), sizeof(int),
39                   x.output_index);
40   }
41 };
42 
43 struct EndpointEq {
operator ()tensorflow::EndpointEq44   uint32 operator()(const Endpoint& x, const Endpoint& y) const {
45     return (x.node_id == y.node_id) && (x.output_index == y.output_index);
46   }
47 };
48 
ProcessMemoryTypes(const DeviceType & device_type,const Graph * g,const std::function<Status (const Edge *,MemoryType,MemoryType)> & fn)49 static Status ProcessMemoryTypes(
50     const DeviceType& device_type, const Graph* g,
51     const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) {
52   if (device_type != DEVICE_GPU &&
53       !DeviceFactory::IsPluggableDevice(device_type.type_string())) {
54     // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible.
55     return OkStatus();
56   }
57   // For GPU, HOST_MEMORY and DEVICE_MEMORY is not compatible. I.e., a
58   // conversion/transfer must be done.
59   //
60   // {node id, slot id} -> memory type.
61   typedef std::unordered_map<Endpoint, MemoryType, EndpointHash, EndpointEq>
62       MemTypeMap;
63   MemTypeMap inp;
64   MemTypeMap out;
65   MemoryTypeVector inp_mvec;
66   MemoryTypeVector out_mvec;
67   for (const Node* n : g->nodes()) {
68     TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type,
69                                           n->def(), &inp_mvec, &out_mvec));
70     for (size_t i = 0; i < inp_mvec.size(); ++i) {
71       VLOG(2) << "inp mvec " << n->id() << " " << i << " " << inp_mvec[i];
72       inp[{n->id(), static_cast<int>(i)}] = inp_mvec[i];
73     }
74     for (size_t i = 0; i < out_mvec.size(); ++i) {
75       VLOG(2) << "out mvec " << n->id() << " " << i << " " << out_mvec[i];
76       out[{n->id(), static_cast<int>(i)}] = out_mvec[i];
77     }
78   }
79   for (const Edge* e : g->edges()) {
80     if (e->IsControlEdge()) {
81       continue;
82     }
83     MemoryType sm = gtl::FindWithDefault(out, {e->src()->id(), e->src_output()},
84                                          DEVICE_MEMORY);
85     MemoryType dm = gtl::FindWithDefault(inp, {e->dst()->id(), e->dst_input()},
86                                          DEVICE_MEMORY);
87     VLOG(1) << e->src()->id() << ":" << e->src_output() << " -> "
88             << e->dst()->id() << ":" << e->dst_input() << ": " << sm << " -> "
89             << dm;
90     TF_RETURN_IF_ERROR(fn(e, sm, dm));
91   }
92   return OkStatus();
93 }
94 
ValidateMemoryTypes(const DeviceType & device_type,const Graph * g)95 Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) {
96   return ProcessMemoryTypes(
97       device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) {
98         if (sm == dm) {
99           return OkStatus();
100         }
101         return errors::Internal("Memory type mismatch (", sm, " ", dm,
102                                 ") between :", e->src()->id(), ":",
103                                 e->src_output(), " and ", e->dst()->id(), ":",
104                                 e->dst_input(), " : from ",
105                                 FormatNodeForError(*e->src()), " to ",
106                                 FormatNodeForError(*e->dst()));
107       });
108 }
109 
110 // Given an Edge whose two endpoints have different memory types and
111 // are gonna to insert a pair of HostSend/Recv or Send/HostRecv nodes,
112 // GetTensorName() returns a unique string that we can use as part of
113 // the rendezvous key. The return string is guaranteed to be unique
114 // within this process. That is sufficient because EnsureMemoryTypes
115 // is only used on a TensorFlow graph that is gonna to be executed in
116 // a single tf device (hence within a single process).
GetTensorName(const Edge * edge)117 static string GetTensorName(const Edge* edge) {
118   static std::atomic<int64_t> counter(0);
119   return strings::StrCat("memtype_", counter.fetch_add(1), "_",
120                          edge->src()->name());
121 }
122 
Send(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)123 static Node* Send(Graph* g, const string& tensor_name,
124                   const string& device_name, bool host, const Edge* edge) {
125   Node* ret;
126   TF_CHECK_OK(NodeBuilder(g->NewName("n"), host ? "_HostSend" : "_Send")
127                   .Input(edge->src(), edge->src_output())
128                   .Attr("tensor_name", tensor_name)
129                   .Attr("send_device", device_name)
130                   .Attr("send_device_incarnation", 0)  // Do not care.
131                   .Attr("recv_device", device_name)
132                   .Attr("_hostmem_sendrecv", true)
133                   .Attr("_src", edge->src()->name())
134                   .Attr("_dst", edge->dst()->name())
135                   .Finalize(g, &ret));
136   return ret;
137 }
138 
Recv(Graph * g,const string & tensor_name,const string & device_name,bool host,const Edge * edge)139 static Node* Recv(Graph* g, const string& tensor_name,
140                   const string& device_name, bool host, const Edge* edge) {
141   Node* ret;
142   TF_CHECK_OK(
143       NodeBuilder(g->NewName("n"), host ? "_HostRecv" : "_Recv")
144           .Attr("tensor_type", edge->src()->output_type(edge->src_output()))
145           .Attr("tensor_name", tensor_name)
146           .Attr("send_device", device_name)
147           .Attr("send_device_incarnation", 0)
148           .Attr("recv_device", device_name)
149           .Attr("_hostmem_sendrecv", true)
150           .Attr("_src", edge->src()->name())
151           .Attr("_dst", edge->dst()->name())
152           .Finalize(g, &ret));
153   return ret;
154 }
155 
EnsureMemoryTypes(const DeviceType & device_type,const string & device_name,Graph * g)156 Status EnsureMemoryTypes(const DeviceType& device_type,
157                          const string& device_name, Graph* g) {
158   struct Item {
159     const Edge* edge;
160     MemoryType sm;
161     MemoryType dm;
162   };
163   std::vector<Item> edges;
164   TF_RETURN_IF_ERROR(ProcessMemoryTypes(
165       device_type, g, [&edges](const Edge* e, MemoryType sm, MemoryType dm) {
166         if (sm == dm) {
167           return OkStatus();
168         }
169         if (((sm == HOST_MEMORY) && (dm == DEVICE_MEMORY)) ||
170             ((sm == DEVICE_MEMORY) && (dm == HOST_MEMORY))) {
171           edges.push_back({e, sm, dm});
172           return OkStatus();
173         }
174         return errors::Internal("Unexpected memory type pair on an edge: ", sm,
175                                 " vs. ", dm);
176       }));
177 
178   // edges contains edges in 'g' that memtype is not
179   // compatible. Therefore, if we found any, we need to insert
180   // HostSend/Recv and Send/HostRecv pairs.  recv_nodes records all
181   // nodes we added so that we don't copy the same tensor more than
182   // once.
183   if (!edges.empty()) {
184     std::unordered_map<Endpoint, Node*, EndpointHash, EndpointEq> recv_nodes;
185     for (const auto& item : edges) {
186       const Edge* e = item.edge;
187       const bool has_ref = IsRefType(e->src()->output_type(e->src_output()));
188       Node* recv = nullptr;
189       Endpoint key{e->src()->id(), e->src_output()};
190       auto iter = recv_nodes.find(key);
191       if (iter == recv_nodes.end()) {
192         const string tensor_name = GetTensorName(e);
193         Node* send =
194             Send(g, tensor_name, device_name, (item.sm == HOST_MEMORY), e);
195         recv = Recv(g, tensor_name, device_name, (item.dm == HOST_MEMORY), e);
196         if (!has_ref) {
197           // We only cache if there is no ref is involved.
198           recv_nodes[key] = recv;
199         }
200         g->AddControlEdge(send, recv);
201       } else {
202         recv = iter->second;
203       }
204       g->AddEdge(recv, 0, e->dst(), e->dst_input());
205       g->RemoveEdge(e);
206     }
207   }
208 
209   if (VLOG_IS_ON(2)) {
210     VLOG(2) << "Dumped graph after EnsureMemoryTypes to "
211             << DumpGraphToFile("EnsureMemoryTypes", *g);
212   }
213 
214   return ValidateMemoryTypes(device_type, g);
215 }
216 
MemoryTypeForOutput(const DeviceType & device_type,const Graph * g,const Node * n,int index,MemoryType * memory_type)217 Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g,
218                            const Node* n, int index, MemoryType* memory_type) {
219   MemoryTypeVector inp_mvec;
220   MemoryTypeVector out_mvec;
221   TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(),
222                                         &inp_mvec, &out_mvec));
223   if (out_mvec.size() <= index) {
224     return errors::Internal("Trying to get the memory type for ", index,
225                             "'th output of node ", FormatNodeForError(*n),
226                             " that has only ", out_mvec.size(), " outputs");
227   }
228   *memory_type = out_mvec[index];
229   return OkStatus();
230 }
231 
232 }  // end namespace tensorflow
233