xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/resource_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/resource_mgr.h"
17 
18 #include <atomic>
19 
20 #include "tensorflow/core/framework/device_attributes.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/lib/strings/scanner.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/lib/strings/stringprintf.h"
28 #include "tensorflow/core/platform/demangle.h"
29 #include "tensorflow/core/platform/stacktrace.h"
30 
31 namespace tensorflow {
32 
MakeResourceHandle(const string & container,const string & name,const DeviceBase & device,const TypeIndex & type_index,const std::vector<DtypeAndPartialTensorShape> & dtypes_and_shapes,const absl::optional<ManagedStackTrace> & definition_stack_trace)33 ResourceHandle MakeResourceHandle(
34     const string& container, const string& name, const DeviceBase& device,
35     const TypeIndex& type_index,
36     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes,
37     const absl::optional<ManagedStackTrace>& definition_stack_trace) {
38   ResourceHandle result;
39   result.set_device(device.name());
40   result.set_container(container);
41   result.set_definition_stack_trace(definition_stack_trace);
42   if (name == ResourceHandle::ANONYMOUS_NAME) {
43     result.set_name(
44         strings::StrCat("_AnonymousVar", ResourceHandle::GenerateUniqueId()));
45   } else {
46     result.set_name(name);
47   }
48   result.set_hash_code(type_index.hash_code());
49   result.set_maybe_type_name(type_index.name());
50   result.set_dtypes_and_shapes(dtypes_and_shapes);
51   return result;
52 }
53 
MakeResourceHandleToOutput(OpKernelContext * context,int output_index,const string & container,const string & name,const TypeIndex & type_index)54 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
55                                   const string& container, const string& name,
56                                   const TypeIndex& type_index) {
57   Tensor* handle;
58   TF_RETURN_IF_ERROR(
59       context->allocate_output(output_index, TensorShape({}), &handle));
60   handle->scalar<ResourceHandle>()() =
61       MakeResourceHandle(container, name, *context->device(), type_index);
62   return OkStatus();
63 }
64 
65 namespace internal {
66 
ValidateDevice(OpKernelContext * ctx,const ResourceHandle & p)67 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
68   if (ctx->device()->attributes().name() != p.device()) {
69     return errors::InvalidArgument(
70         "Trying to access resource ", p.name(), " located in device ",
71         p.device(), " from device ", ctx->device()->attributes().name());
72   }
73   return OkStatus();
74 }
75 
76 }  // end namespace internal
77 
InsertDebugTypeName(uint64 hash_code,const string & type_name)78 Status ResourceMgr::InsertDebugTypeName(uint64 hash_code,
79                                         const string& type_name) {
80   auto iter = debug_type_names_.emplace(hash_code, type_name);
81   if (iter.first->second != type_name) {
82     return errors::AlreadyExists("Duplicate hash code found for type ",
83                                  type_name);
84   }
85   return OkStatus();
86 }
87 
DebugTypeName(uint64 hash_code) const88 const char* ResourceMgr::DebugTypeName(uint64 hash_code) const {
89   auto type_name_iter = debug_type_names_.find(hash_code);
90   if (type_name_iter == debug_type_names_.end()) {
91     return "<unknown>";
92   } else {
93     return type_name_iter->second.c_str();
94   }
95 }
96 
ResourceAndName()97 ResourceMgr::ResourceAndName::ResourceAndName() : name(nullptr) {}
98 
ResourceAndName(const string & name)99 ResourceMgr::ResourceAndName::ResourceAndName(const string& name)
100     : name(absl::make_unique<string>(name)) {}
101 
GetResource() const102 core::RefCountPtr<ResourceBase> ResourceMgr::ResourceAndName::GetResource()
103     const {
104   if (absl::holds_alternative<core::RefCountPtr<ResourceBase>>(resource)) {
105     ResourceBase* ptr =
106         absl::get<core::RefCountPtr<ResourceBase>>(resource).get();
107     ptr->Ref();
108     return core::RefCountPtr<ResourceBase>(ptr);
109   } else if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(resource)) {
110     return absl::get<core::WeakPtr<ResourceBase>>(resource).GetNewRef();
111   } else {
112     return nullptr;
113   }
114 }
115 
ResourceAndName(ResourceAndName && other)116 ResourceMgr::ResourceAndName::ResourceAndName(
117     ResourceAndName&& other) noexcept {
118   name = std::move(other.name);
119   resource = std::move(other.resource);
120 }
121 
~ResourceAndName()122 ResourceMgr::ResourceAndName::~ResourceAndName() {}
123 
operator =(ResourceAndName && other)124 ResourceMgr::ResourceAndName& ResourceMgr::ResourceAndName::operator=(
125     ResourceAndName&& other) noexcept {
126   name = std::move(other.name);
127   resource = std::move(other.resource);
128   return *this;
129 }
130 
ResourceMgr()131 ResourceMgr::ResourceMgr() : default_container_("localhost") {}
132 
ResourceMgr(const string & default_container)133 ResourceMgr::ResourceMgr(const string& default_container)
134     : default_container_(default_container) {}
135 
~ResourceMgr()136 ResourceMgr::~ResourceMgr() { Clear(); }
137 
Clear()138 void ResourceMgr::Clear() {
139   // We do the deallocation outside of the lock to avoid a potential deadlock
140   // in case any of the destructors access the resource manager.
141   absl::flat_hash_map<string, Container*> tmp_containers;
142   {
143     mutex_lock l(mu_);
144     tmp_containers = std::move(containers_);
145   }
146   for (const auto& p : tmp_containers) {
147     delete p.second;
148   }
149   tmp_containers.clear();
150 }
151 
DebugString() const152 string ResourceMgr::DebugString() const {
153   mutex_lock l(mu_);
154   struct Line {
155     const string* container;
156     const string type;
157     const string* resource;
158     const string detail;
159   };
160   std::vector<Line> lines;
161   for (const auto& p : containers_) {
162     const string& container = p.first;
163     for (const auto& q : *p.second) {
164       const Key& key = q.first;
165       const char* type = DebugTypeName(key.first);
166       const core::RefCountPtr<ResourceBase> resource = q.second.GetResource();
167       Line l{&container, port::Demangle(type), q.second.name.get(),
168              resource ? resource->DebugString() : "<nullptr>"};
169       lines.push_back(l);
170     }
171   }
172   std::vector<string> text;
173   text.reserve(lines.size());
174   for (const Line& line : lines) {
175     text.push_back(strings::Printf(
176         "%-20s | %-40s | %-40s | %-s", line.container->c_str(),
177         line.type.c_str(), line.resource->c_str(), line.detail.c_str()));
178   }
179   std::sort(text.begin(), text.end());
180   return absl::StrJoin(text, "\n");
181 }
182 
DoCreate(const string & container_name,TypeIndex type,const string & name,ResourceBase * resource,bool owns_resource)183 Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type,
184                              const string& name, ResourceBase* resource,
185                              bool owns_resource) {
186   Container* container = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
187     Container** ptr = &containers_[container_name];
188     if (*ptr == nullptr) {
189       *ptr = new Container;
190     }
191     return *ptr;
192   }();
193 
194   // NOTE: Separating out the construction of the map key and value so that the
195   // key can contain a StringPiece that borrows from the string in the value.
196   ResourceAndName resource_and_name(name);
197 
198   StringPiece borrowed_name(*resource_and_name.name);
199 
200   if (owns_resource) {
201     resource_and_name.resource = core::RefCountPtr<ResourceBase>(resource);
202   } else {
203     auto cleanup_fn = [this, container, type, borrowed_name]() {
204       mutex_lock l(mu_);
205       auto iter = container->find({type.hash_code(), borrowed_name});
206       if (iter != container->end()) {
207         container->erase(iter);
208       }
209     };
210     resource_and_name.resource =
211         core::WeakPtr<ResourceBase>(resource, cleanup_fn);
212   }
213 
214   Container::value_type key_and_value(Key(type.hash_code(), borrowed_name),
215                                       std::move(resource_and_name));
216 
217   auto st = container->insert(std::move(key_and_value));
218   if (st.second) {
219     TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name()));
220     return OkStatus();
221   }
222   return errors::AlreadyExists("Resource ", container_name, "/", name, "/",
223                                type.name());
224 }
225 
Lookup(const ResourceHandle & handle,ResourceBase ** resource) const226 Status ResourceMgr::Lookup(const ResourceHandle& handle,
227                            ResourceBase** resource) const {
228   tf_shared_lock l(mu_);
229   return DoLookup(handle.container(), handle.hash_code(),
230                   /*type_name=*/"ResourceBase", handle.name(), resource);
231 }
232 
DoLookup(const string & container,TypeIndex type,const string & name,ResourceBase ** resource) const233 Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
234                              const string& name,
235                              ResourceBase** resource) const {
236   return DoLookup(container, type.hash_code(), type.name(), name, resource);
237 }
238 
DoLookup(const string & container,uint64 type_hash_code,const string & type_name,const string & resource_name,ResourceBase ** resource) const239 Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code,
240                              const string& type_name,
241                              const string& resource_name,
242                              ResourceBase** resource) const {
243   const Container* b = gtl::FindPtrOrNull(containers_, container);
244   if (b == nullptr) {
245     return errors::NotFound("Container ", container,
246                             " does not exist. (Could not find resource: ",
247                             container, "/", resource_name, ")");
248   }
249   auto iter = b->find({type_hash_code, resource_name});
250   if (iter == b->end()) {
251     return errors::NotFound("Resource ", container, "/", resource_name, "/",
252                             type_name, " does not exist.");
253   }
254   ResourceBase* ptr = iter->second.GetResource().release();
255   if (ptr == nullptr) {
256     return errors::NotFound("Resource ", container, "/", resource_name, "/",
257                             type_name, " has been destroyed.");
258   }
259   *resource = ptr;
260   return OkStatus();
261 }
262 
PopResourceAndName(const string & container,uint64 type_hash_code,const string & resource_name,const string & type_name,ResourceAndName & resource_and_name)263 Status ResourceMgr::PopResourceAndName(const string& container,
264                                        uint64 type_hash_code,
265                                        const string& resource_name,
266                                        const string& type_name,
267                                        ResourceAndName& resource_and_name) {
268   mutex_lock l(mu_);
269   Container* b = gtl::FindPtrOrNull(containers_, container);
270   if (b == nullptr) {
271     return errors::NotFound("Container ", container, " does not exist.");
272   }
273   auto iter = b->find({type_hash_code, resource_name});
274   if (iter == b->end()) {
275     return errors::NotFound("Resource ", container, "/", resource_name, "/",
276                             type_name, " does not exist.");
277   }
278   std::swap(resource_and_name, iter->second);
279   b->erase(iter);
280   return OkStatus();
281 }
282 
DoDelete(const string & container,uint64 type_hash_code,const string & resource_name,const string & type_name)283 Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code,
284                              const string& resource_name,
285                              const string& type_name) {
286   ResourceAndName resource_and_name;
287   TF_RETURN_IF_ERROR(PopResourceAndName(
288       container, type_hash_code, resource_name, type_name, resource_and_name));
289 
290   if (absl::holds_alternative<core::WeakPtr<ResourceBase>>(
291           resource_and_name.resource)) {
292     return errors::Internal(
293         "Cannot delete an unowned Resource ", container, "/", resource_name,
294         "/", type_name, " from ResourceMgr. ",
295         "This indicates ref-counting ResourceHandle is exposed to weak "
296         "ResourceHandle code paths.");
297   }
298   return OkStatus();
299 }
300 
DoDelete(const string & container,TypeIndex type,const string & resource_name)301 Status ResourceMgr::DoDelete(const string& container, TypeIndex type,
302                              const string& resource_name) {
303   return DoDelete(container, type.hash_code(), resource_name, type.name());
304 }
305 
Delete(const ResourceHandle & handle)306 Status ResourceMgr::Delete(const ResourceHandle& handle) {
307   return DoDelete(handle.container(), handle.hash_code(), handle.name(),
308                   "<unknown>");
309 }
310 
Cleanup(const string & container)311 Status ResourceMgr::Cleanup(const string& container) {
312   {
313     tf_shared_lock l(mu_);
314     if (!gtl::FindOrNull(containers_, container)) {
315       // Nothing to cleanup.
316       return OkStatus();
317     }
318   }
319   Container* b = nullptr;
320   {
321     mutex_lock l(mu_);
322     auto iter = containers_.find(container);
323     if (iter == containers_.end()) {
324       // Nothing to cleanup, it's OK (concurrent cleanup).
325       return OkStatus();
326     }
327     b = iter->second;
328     containers_.erase(iter);
329   }
330   CHECK(b != nullptr);
331   delete b;
332   return OkStatus();
333 }
334 
IsValidContainerName(StringPiece s)335 static bool IsValidContainerName(StringPiece s) {
336   using ::tensorflow::strings::Scanner;
337   return Scanner(s)
338       .One(Scanner::LETTER_DIGIT_DOT)
339       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH)
340       .Eos()
341       .GetResult();
342 }
343 
Init(ResourceMgr * rmgr,const NodeDef & ndef,bool use_node_name_as_default)344 Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
345                            bool use_node_name_as_default) {
346   CHECK(rmgr);
347   rmgr_ = rmgr;
348   string attr_container;
349   TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
350   if (!attr_container.empty() && !IsValidContainerName(attr_container)) {
351     return errors::InvalidArgument("container contains invalid characters: ",
352                                    attr_container);
353   }
354   string attr_shared_name;
355   TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name));
356   if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) {
357     return errors::InvalidArgument("shared_name cannot start with '_':",
358                                    attr_shared_name);
359   }
360   if (!attr_container.empty()) {
361     container_ = attr_container;
362   } else {
363     container_ = rmgr_->default_container();
364   }
365   if (!attr_shared_name.empty()) {
366     name_ = attr_shared_name;
367   } else if (use_node_name_as_default) {
368     name_ = ndef.name();
369   } else {
370     resource_is_private_to_kernel_ = true;
371     static std::atomic<int64_t> counter(0);
372     name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name());
373   }
374   return OkStatus();
375 }
376 
DebugString() const377 string ContainerInfo::DebugString() const {
378   return strings::StrCat("[", container(), ",", name(), ",",
379                          resource_is_private_to_kernel() ? "private" : "public",
380                          "]");
381 }
382 
HandleFromInput(OpKernelContext * ctx,int input)383 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
384   return ctx->input(input).flat<ResourceHandle>()(0);
385 }
386 
HandleFromInput(OpKernelContext * ctx,StringPiece input,ResourceHandle * handle)387 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
388                        ResourceHandle* handle) {
389   const Tensor* tensor;
390   TF_RETURN_IF_ERROR(ctx->input(input, &tensor));
391   *handle = tensor->flat<ResourceHandle>()(0);
392   return OkStatus();
393 }
394 
LookupResource(OpKernelContext * ctx,const ResourceHandle & p,ResourceBase ** value)395 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
396                       ResourceBase** value) {
397   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
398   if (p.IsRefCounting()) {
399     TF_ASSIGN_OR_RETURN(*value, p.GetResource<ResourceBase>());
400     (*value)->Ref();
401     return OkStatus();
402   }
403   return ctx->resource_manager()->Lookup(p, value);
404 }
405 
DeleteResource(OpKernelContext * ctx,const ResourceHandle & p)406 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
407   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
408   if (p.IsRefCounting()) {
409     return OkStatus();
410   }
411   return ctx->resource_manager()->Delete(p);
412 }
413 
414 }  //  end namespace tensorflow
415