xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/resource_mgr.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <typeindex>
22 #include <typeinfo>
23 #include <unordered_map>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/types/variant.h"
27 #include "tensorflow/core/framework/common_shape_fns.h"
28 #include "tensorflow/core/framework/device_attributes.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_base.h"
31 #include "tensorflow/core/framework/resource_handle.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_types.h"
35 #include "tensorflow/core/framework/type_index.h"
36 #include "tensorflow/core/framework/variant_tensor_data.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/macros.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 
44 namespace tensorflow {
45 
46 // A ResourceMgr instance keeps track of named and typed resources
47 // grouped into containers.
48 //
49 // Each named resource is
50 // registered with ResourceMgr under a named "container" name. At any
51 // time, there is at most one instance of a resource given the container
52 // name, the resource type and the resource name.
53 //
54 // All resources for a given container can be dropped by one call of
55 // Cleanup().
56 //
57 // E.g.,
58 //   struct MyVar : public ResourceBase {
59 //     mutex mu;
60 //     Tensor val;
61 //   }
62 //
63 //   ResourceMgr rm;
64 //
65 //   // Create a var.
66 //   MyVar* my_var = new MyVar;
67 //   my_var->val = Tensor(DT_FLOAT, my_shape);
68 //   my_var->val.flat<float>().setZeros();   // 0 initialized.
69 //   ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
70 //
71 //   // += a variable.
72 //   MyVar* my_var = nullptr;
73 //   Status s = rm.Lookup("my_container", "my_name", &my_var);
74 //   if (s.ok()) {
75 //     my_var->val.flat<float>() += grad;
76 //   }
77 //   my_var->Unref();   // Or use ScopedUnref().
78 //   ctx->SetStatus(s);
79 
80 // Container used for per-step resources.
81 class ScopedStepContainer {
82  public:
83   // step_id: the unique ID of this step. Doesn't have to be sequential, just
84   // has to be unique.
85   // cleanup: callback to delete a container of this name.
86   // prefix: optional string prefix to disambiguate step containers.
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup)87   ScopedStepContainer(const int64_t step_id,
88                       std::function<void(const string&)> cleanup)
89       : step_id_(step_id),
90         container_(strings::StrCat("__per_step_", step_id)),
91         cleanup_(cleanup),
92         dirty_(false) {}
93 
ScopedStepContainer(const int64_t step_id,std::function<void (const string &)> cleanup,const std::string & prefix)94   ScopedStepContainer(const int64_t step_id,
95                       std::function<void(const string&)> cleanup,
96                       const std::string& prefix)
97       : step_id_(step_id),
98         container_(strings::StrCat("__", prefix, "_per_step_", step_id)),
99         cleanup_(cleanup),
100         dirty_(false) {}
101 
~ScopedStepContainer()102   ~ScopedStepContainer() { CleanUp(); }
103 
CleanUp()104   void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS {
105     // NOTE(mrry): Avoid acquiring the mutex in the case that the container is
106     // clean.
107     if (dirty_) {
108       mutex_lock ml(mu_);
109       cleanup_(container_);
110       dirty_ = false;
111     }
112   }
113 
114   // Pass through functions for resource lookup and creation. We do this to
115   // ensure that we can appropriately set the dirty_ bit in the
116   // ScopedStepContainer if the name of the container is used to create
117   // resources.
118 
119   // Pass through to MakeResourceHandle with the container name
120   template <typename T>
121   ResourceHandle MakeResourceHandle(
122       const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT;
123   // Pass through to ResourceMgr::Create with the container name
124   template <typename T>
125   Status Create(ResourceMgr* rm, const std::string& name,
126                 T* resource) TF_MUST_USE_RESULT;
127   // Pass through to ResourceMgr::Delete with the container name
128   template <typename T>
129   Status Delete(ResourceMgr* rm, const std::string& name) TF_MUST_USE_RESULT;
130   // Pass through to ResourceMgr::Lookup with the container name
131   template <typename T>
132   Status Lookup(ResourceMgr* rm, const std::string& name,
133                 T** resource) const TF_MUST_USE_RESULT;
134   // Pass through to ResourceMgr::LookupOrCreate with the container name
135   template <typename T>
136   Status LookupOrCreate(ResourceMgr* rm, const std::string& name, T** resource,
137                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
StepId()138   int64_t StepId() const { return step_id_; }
139 
140  private:
141   const int64_t step_id_;
142   const std::string container_;
143   const std::function<void(const string&)> cleanup_;
144   mutex mu_;
145   mutable std::atomic<bool> dirty_ TF_GUARDED_BY(mu_);
146 };
147 
148 class ResourceMgr {
149  public:
150   ResourceMgr();
151   explicit ResourceMgr(const std::string& default_container);
152   ~ResourceMgr();
153 
154   // Returns the default container name for *this.
default_container()155   const std::string& default_container() const { return default_container_; }
156 
157   // Creates a resource "name" in the "container".  The caller transfers
158   // the ownership of one ref on "resource" to *this, regardless of whether this
159   // operation succeeds or fails.
160   //
161   // REQUIRES: std::is_base_of<ResourceBase, T>
162   // REQUIRES: resource != nullptr.
163   template <typename T>
164   Status Create(const std::string& container, const std::string& name,
165                 T* resource) TF_MUST_USE_RESULT;
166 
167   // Creates a unowned resource "name" in the "container".  The caller does NOT
168   // transfer the ownership of any ref on "resource" to *this, regardless of
169   // whether this operation succeeds or fails.
170   //
171   // After the resource is destroyed, lookups from the manager fail.
172   // The caller must call this->Delete() on the name to free up the memory
173   // entry of the name.
174   //
175   // REQUIRES: std::is_base_of<ResourceBase, T>
176   // REQUIRES: resource != nullptr.
177   template <typename T>
178   Status CreateUnowned(const std::string& container, const std::string& name,
179                        T* resource) TF_MUST_USE_RESULT;
180 
181   // If "container" has a resource "name", returns it in "*resource" and
182   // the caller takes the ownership of one ref on "*resource".
183   //
184   // REQUIRES: std::is_base_of<ResourceBase, T>
185   // REQUIRES: resource != nullptr
186   template <typename T, bool use_dynamic_cast = false>
187   Status Lookup(const std::string& container, const std::string& name,
188                 T** resource) const TF_MUST_USE_RESULT;
189 
190   // If the resource manager has a resource matching "handle", returns it in
191   // "*resource" and the caller takes the ownership of one ref on "*resource".
192   //
193   // REQUIRES: resource != nullptr
194   Status Lookup(const ResourceHandle& handle,
195                 ResourceBase** resource) const TF_MUST_USE_RESULT;
196 
197   // Similar to Lookup, but looks up multiple resources at once, with only a
198   // single lock acquisition.  If containers_and_names[i] is uninitialized
199   // then this function does not modify resources[i].
200   template <typename T, bool use_dynamic_cast = false>
201   Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
202                         containers_and_names,
203                     std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
204                         resources) const TF_MUST_USE_RESULT;
205 
206   // If "container" has a resource "name", returns it in
207   // "*resource". Otherwise, invokes creator() to create the resource.
208   // The caller takes the ownership of one ref on "*resource".
209   //
210   // WARNING: creator() must not call any methods on ResourceMgr during its
211   // execution, because a non-reentrant lock is held during the creator() call
212   // in order to guarantee atomicity of LookupOrCreate().
213   //
214   // REQUIRES: std::is_base_of<ResourceBase, T>
215   // REQUIRES: resource != nullptr
216   template <typename T, bool use_dynamic_cast = false>
217   Status LookupOrCreate(const std::string& container, const std::string& name,
218                         T** resource,
219                         std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
220 
221   // Deletes the resource "name" from the "container".
222   //
223   // REQUIRES: std::is_base_of<ResourceBase, T>
224   template <typename T>
225   Status Delete(const std::string& container,
226                 const std::string& name) TF_MUST_USE_RESULT;
227 
228   // Deletes the resource pointed by "handle".
229   Status Delete(const ResourceHandle& handle) TF_MUST_USE_RESULT;
230 
231   // Deletes all resources from the "container" and removes the container.
232   Status Cleanup(const std::string& container) TF_MUST_USE_RESULT;
233 
234   // Deletes all resources in all containers.
235   void Clear();
236 
237   // Returns a text description for all resources.
238   std::string DebugString() const;
239 
240  private:
241   typedef std::pair<uint64, StringPiece> Key;
242   struct KeyHash {
operatorKeyHash243     std::size_t operator()(const Key& k) const {
244       return Hash64(k.second.data(), k.second.size(), k.first);
245     }
246   };
247   struct KeyEqual {
operatorKeyEqual248     bool operator()(const Key& x, const Key& y) const {
249       return (x.second == y.second) && (x.first == y.first);
250     }
251   };
252   struct ResourceAndName {
253     absl::variant<core::RefCountPtr<ResourceBase>, core::WeakPtr<ResourceBase>>
254         resource;
255     std::unique_ptr<std::string> name;
256 
257     ResourceAndName();
258     explicit ResourceAndName(const string& name);
259     ResourceAndName(ResourceAndName&& other) noexcept;
260     ~ResourceAndName();
261 
262     ResourceAndName& operator=(ResourceAndName&&) noexcept;
263 
264     // Returns a strong reference to resource, or nullptr if the resource is
265     // no longer valid.
266     core::RefCountPtr<ResourceBase> GetResource() const;
267 
268    private:
269     TF_DISALLOW_COPY_AND_ASSIGN(ResourceAndName);
270   };
271   typedef absl::flat_hash_map<Key, ResourceAndName, KeyHash, KeyEqual>
272       Container;
273 
274   const std::string default_container_;
275   mutable mutex mu_;
276   absl::flat_hash_map<string, Container*> containers_ TF_GUARDED_BY(mu_);
277 
278   template <typename T, bool use_dynamic_cast = false>
279   Status LookupInternal(const std::string& container, const std::string& name,
280                         T** resource) const
281       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
282   Status LookupInternal(const std::string& container, uint64 type_hash_code,
283                         const std::string& name, ResourceBase** resource) const
284       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
285 
286   Status DoCreate(const std::string& container, TypeIndex type,
287                   const std::string& name, ResourceBase* resource,
288                   bool owns_resource)
289       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
290 
291   Status DoLookup(const std::string& container, TypeIndex type,
292                   const std::string& name, ResourceBase** resource) const
293       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
294   Status DoLookup(const std::string& container, uint64 type_hash_code,
295                   const std::string& type_name,
296                   const std::string& resource_name,
297                   ResourceBase** resource) const
298       TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
299 
300   Status DoDelete(const std::string& container, uint64 type_hash_code,
301                   const std::string& resource_name,
302                   const std::string& type_name) TF_MUST_USE_RESULT;
303   Status DoDelete(const std::string& container, TypeIndex type,
304                   const std::string& resource_name) TF_MUST_USE_RESULT;
305 
306   // Pops the ResourceAndName entry. The entry is moved from the list to
307   // the output argument `resource_and_name`.
308   Status PopResourceAndName(
309       const std::string& container, uint64 type_hash_code,
310       const std::string& resource_name, const std::string& type_name,
311       ResourceAndName& resource_and_name) TF_MUST_USE_RESULT;
312   // Inserts the type name for 'hash_code' into the hash_code to type name map.
313   Status InsertDebugTypeName(uint64 hash_code, const std::string& type_name)
314       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
315 
316   // Returns the type name for the 'hash_code'.
317   // Returns "<unknown>" if a resource with such a type was never inserted into
318   // the container.
319   const char* DebugTypeName(uint64 hash_code) const
320       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
321 
322   // Map from type hash_code to type name.
323   std::unordered_map<uint64, string> debug_type_names_ TF_GUARDED_BY(mu_);
324 
325   TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
326 };
327 
328 // Makes a resource handle with the specified type for a given container /
329 // name.
330 ResourceHandle MakeResourceHandle(
331     const std::string& container, const std::string& name,
332     const DeviceBase& device, const TypeIndex& type_index,
333     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
334     const absl::optional<ManagedStackTrace>& definition_stack_trace = {})
335     TF_MUST_USE_RESULT;
336 
337 template <typename T>
338 ResourceHandle MakeResourceHandle(
339     OpKernelContext* ctx, const std::string& container, const std::string& name,
340     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
341     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
342   return MakeResourceHandle(container.empty()
343                                 ? ctx->resource_manager()->default_container()
344                                 : container,
345                             name, *ctx->device(), TypeIndex::Make<T>(),
346                             dtypes_and_shapes, definition_stack_trace);
347 }
348 
349 template <typename T>
350 ResourceHandle MakeResourceHandle(
351     OpKernelConstruction* ctx, const std::string& container,
352     const std::string& name,
353     const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
354     const absl::optional<ManagedStackTrace>& definition_stack_trace = {}) {
355   return MakeResourceHandle(container.empty()
356                                 ? ctx->resource_manager()->default_container()
357                                 : container,
358                             name, *ctx->device(), TypeIndex::Make<T>(),
359                             dtypes_and_shapes, definition_stack_trace);
360 }
361 
362 Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
363                                   const std::string& container,
364                                   const std::string& name,
365                                   const TypeIndex& type_index);
366 
367 // Returns a resource handle from a numbered op input.
368 const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
369 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
370                        ResourceHandle* handle);
371 
372 // Create a resource pointed by a given resource handle.
373 //
374 // If successful, the caller transfers the ownership of one ref on `resource` to
375 // `ctx->resource_mgr()`.
376 template <typename T>
377 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
378 
379 // Looks up a resource pointed by a given resource handle.
380 //
381 // If the lookup is successful, the caller takes the ownership of one ref on
382 // `*value`, and must call its `Unref()` method when it has finished using it.
383 template <typename T, bool use_dynamic_cast = false>
384 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
385 
386 // Looks up a resource pointed by a given resource handle.
387 //
388 // Prefer usage of LookupResource taking `core::RefCountPtr` to avoid
389 // requiring the caller to explicitly call `Unref()`.
390 template <typename T>
391 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
392                       core::RefCountPtr<T>* value);
393 
394 // Looks up multiple resources pointed by a sequence of resource handles.  If
395 // p[i] is uninitialized then values[i] is unmodified.
396 template <typename T>
397 Status LookupResources(OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
398                        std::vector<core::RefCountPtr<T>>* values);
399 
400 // Looks up or creates a resource.
401 //
402 // If successful, the caller takes the ownership of one ref on `*value`, and
403 // must call its `Unref()` method when it has finished using it. If the
404 // `creator` is invoked, its reference on the created resource is transferred
405 // to `ctx->resource_mgr()`.
406 //
407 // Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid
408 // requiring the caller to explicitly call `Unref()`.
409 template <typename T>
410 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
411                               T** value, std::function<Status(T**)> creator);
412 
413 // Looks up or creates a resource.
414 template <typename T>
415 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
416                               core::RefCountPtr<T>* value,
417                               std::function<Status(T**)> creator);
418 
419 // Destroys a resource pointed by a given resource handle.
420 template <typename T>
421 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
422 
423 // Same as above, but uses the hash code of the type directly.
424 // The type name information will be missing in the debug output when the
425 // resource is not present in the container.
426 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
427 
428 // Policy helper to decide which container/shared_name to use for a
429 // stateful kernel that accesses shared resource.
430 class ContainerInfo {
431  public:
432   // Analyze the node attribute of 'ndef' and decides the container and
433   // resource name the kernel should use for accessing the shared
434   // resource.
435   //
436   // 'ndef' is expected to have node attribute "container" and
437   // "shared_name". Returns non-OK if they are not provided or they are
438   // invalid.
439   //
440   // The policy is as following:
441   // * If the attribute "container" is non-empty, it is used as is.
442   //   Otherwise, uses the resource manager's default container.
443   // * If the attribute "shared_name" is non-empty, it is used as is.
444   //   Otherwise, if "use_node_name_as_default" is true, the kernel's
445   //   node name is used as the resource name. Otherwise, a string
446   //   unique to this process is used.
447   Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
448               bool use_node_name_as_default);
Init(ResourceMgr * rmgr,const NodeDef & ndef)449   Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
450     return Init(rmgr, ndef, false);
451   }
452 
453   // The policy decides that the kernel should access the resource in
454   // resource_manager(), the resource is in the container() and its
455   // name is name().  If resource_is_private_to_kernel() is true, the
456   // kernel should delete the resource when the kernel is deleted.
resource_manager()457   ResourceMgr* resource_manager() const { return rmgr_; }
container()458   const std::string& container() const { return container_; }
name()459   const std::string& name() const { return name_; }
resource_is_private_to_kernel()460   bool resource_is_private_to_kernel() const {
461     return resource_is_private_to_kernel_;
462   }
463 
464   // Returns a readable string for *this.
465   std::string DebugString() const;
466 
467  private:
468   ResourceMgr* rmgr_ = nullptr;
469   std::string container_;
470   std::string name_;
471   bool resource_is_private_to_kernel_ = false;
472 };
473 
474 // Helper for kernels to obtain 'resource' from the
475 // ctx->resource_manager().
476 //
477 // "input_name" specifies the kernel's ref input which gives a string
478 // tensor with two elements, which specifies the container and
479 // resource name.
480 //
481 // Returns OK if the resource is found and transfers one ref of
482 // *resource to the caller. Otherwise, returns an error.
483 template <typename T>
484 Status GetResourceFromContext(OpKernelContext* ctx,
485                               const std::string& input_name, T** resource);
486 
487 // Utility op kernel to check if a handle to resource type T is initialized.
488 template <typename T>
489 class IsResourceInitialized : public OpKernel {
490  public:
IsResourceInitialized(OpKernelConstruction * c)491   explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {}
492 
493   void Compute(OpKernelContext* ctx) override;
494 };
495 
496 // Registers an op which produces just a resource handle to a resource of the
497 // specified type. The type will be a part of the generated op name.
498 // TODO(apassos): figure out how to get non-cpu-allocated tensors to work
499 // through constant folding so this doesn't have to be marked as stateful.
500 #define REGISTER_RESOURCE_HANDLE_OP(Type) \
501   REGISTER_OP(#Type "HandleOp")           \
502       .Attr("container: string = ''")     \
503       .Attr("shared_name: string = ''")   \
504       .Output("resource: resource")       \
505       .SetIsStateful()                    \
506       .SetShapeFn(tensorflow::shape_inference::ScalarShape)
507 
508 // Utility op kernel to produce a handle to a resource of type T.
509 template <typename T>
510 class ResourceHandleOp : public OpKernel {
511  public:
512   explicit ResourceHandleOp(OpKernelConstruction* context);
513 
514   void Compute(OpKernelContext* ctx) override;
515 
IsExpensive()516   bool IsExpensive() override { return false; }
517 
518  private:
519   std::string container_;
520   std::string name_;
521   mutex mutex_;
522   Tensor resource_;
523   std::atomic<bool> initialized_{false};
524 };
525 
526 // Utility op kernel to produce a handle to a resource of type T.
527 template <typename T>
528 class ResourceHandlesOp : public OpKernel {
529  public:
530   explicit ResourceHandlesOp(OpKernelConstruction* context);
531 
532   void Compute(OpKernelContext* ctx) override;
533 
IsExpensive()534   bool IsExpensive() override { return false; }
535 
536  private:
537   std::vector<string> containers_;
538   std::vector<string> names_;
539   mutex mutex_;
540   std::vector<Tensor> resources_;
541   std::atomic<bool> initialized_{false};
542 };
543 
544 // Registers a kernel for an op which produces a handle to a resource of the
545 // specified type.
546 #define REGISTER_RESOURCE_HANDLE_KERNEL(Type)                        \
547   REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \
548                           ResourceHandleOp<Type>)
549 
550 // This class is used to guarantee that an anonymous resource is deleted
551 // (irrespective of whether a resource deleter op is called explicitly or
552 // the execution encounters an error before the op runs).
553 //
554 // This is achieved by wrapping an instance of this class into a variant
555 // tensor which is passed as an input to a resource deleter op. If the
556 // execution encounters an error before the op runs, the tensor will be
557 // destroyed, essentially triggering the iterator deletion.
558 // NOTE: This is not a feature-complete implementation of the DT_VARIANT
559 // specification. In particular, we cannot serialize the `ResourceMgr`
560 // object, so the `Encode()` and `Decode()` methods are not implemented.
561 class ResourceDeleter {
562  public:
ResourceDeleter()563   ResourceDeleter() : deleter_() {}
564 
ResourceDeleter(ResourceHandle handle,ResourceMgr * resource_manager)565   ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager)
566       : deleter_(std::make_shared<Helper>(handle, resource_manager)) {}
567 
ResourceDeleter(ResourceDeleter && rhs)568   ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) {
569     VLOG(3) << "ResourceDeleter move constructor called.";
570   }
571 
ResourceDeleter(const ResourceDeleter & rhs)572   ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) {
573     VLOG(3) << "ResourceDeleter copy constructor called.";
574   }
575 
576   ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete;
577 
578   ResourceDeleter& operator=(ResourceDeleter&& rhs) = default;
579 
~ResourceDeleter()580   virtual ~ResourceDeleter() {
581     VLOG(3) << "ResourceDeleter destructor called.";
582   }
583 
Encode(VariantTensorData *)584   void Encode(VariantTensorData*) const {
585     LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter "
586                   "objects.";
587   }
588 
Decode(const VariantTensorData &)589   bool Decode(const VariantTensorData&) {
590     LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter "
591                   "objects";
592     return false;  // Not supported.
593   }
594 
595  private:
596   // Helper that performs reference counting for the parent class and deletes
597   // the iterator resource when the refcount goes to zero.
598   //
599   // NOTE: The object is borrowing a pointer to the resource manager.
600   // Consequently, the tensor containing this object should not escape the
601   // function in which was created (so that it is guaranteed that the resource
602   // manager will outlive it).
603   struct Helper {
HelperHelper604     Helper(ResourceHandle handle, ResourceMgr* resource_manager)
605         : handle(handle), resource_manager(resource_manager) {}
606 
607     Helper(const Helper& rhs) = delete;
608     Helper(Helper&& rhs) = delete;
609 
~HelperHelper610     ~Helper() {
611       VLOG(3) << "Deleting Resource: " << handle.DebugString();
612       resource_manager->Delete(handle).IgnoreError();
613     }
614 
615     ResourceHandle handle;
616     ResourceMgr* resource_manager;  // not owned
617   };
618 
619   std::shared_ptr<Helper> deleter_;
620 };
621 
622 // Implementation details below.
623 
624 template <typename T>
CheckDeriveFromResourceBase()625 void CheckDeriveFromResourceBase() {
626   static_assert(std::is_base_of<ResourceBase, T>::value,
627                 "T must derive from ResourceBase");
628 }
629 
630 template <typename T>
Create(const std::string & container,const std::string & name,T * resource)631 Status ResourceMgr::Create(const std::string& container,
632                            const std::string& name, T* resource) {
633   CheckDeriveFromResourceBase<T>();
634   CHECK(resource != nullptr);
635   mutex_lock l(mu_);
636   return DoCreate(container, TypeIndex::Make<T>(), name, resource,
637                   /* owns_resource */ true);
638 }
639 
640 template <typename T>
CreateUnowned(const std::string & container,const std::string & name,T * resource)641 Status ResourceMgr::CreateUnowned(const std::string& container,
642                                   const std::string& name, T* resource) {
643   CheckDeriveFromResourceBase<T>();
644   mutex_lock l(mu_);
645   return DoCreate(container, TypeIndex::Make<T>(), name, resource,
646                   /* owns_resource */ false);
647 }
648 
649 template <typename T, bool use_dynamic_cast>
Lookup(const std::string & container,const std::string & name,T ** resource)650 Status ResourceMgr::Lookup(const std::string& container,
651                            const std::string& name, T** resource) const {
652   CheckDeriveFromResourceBase<T>();
653   tf_shared_lock l(mu_);
654   return LookupInternal<T, use_dynamic_cast>(container, name, resource);
655 }
656 
657 template <typename T, bool use_dynamic_cast>
LookupMany(absl::Span<std::pair<const string *,const string * > const> containers_and_names,std::vector<std::unique_ptr<T,core::RefCountDeleter>> * resources)658 Status ResourceMgr::LookupMany(
659     absl::Span<std::pair<const string*, const string*> const>
660         containers_and_names,
661     std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
662   CheckDeriveFromResourceBase<T>();
663   tf_shared_lock l(mu_);
664   resources->resize(containers_and_names.size());
665   for (size_t i = 0; i < containers_and_names.size(); ++i) {
666     T* resource;
667     Status s = LookupInternal<T, use_dynamic_cast>(
668         *containers_and_names[i].first, *containers_and_names[i].second,
669         &resource);
670     if (s.ok()) {
671       (*resources)[i].reset(resource);
672     }
673   }
674   return OkStatus();
675 }
676 
677 // Simple wrapper to allow conditional dynamic / static casts.
678 template <typename T, bool use_dynamic_cast>
679 struct TypeCastFunctor {
CastTypeCastFunctor680   static T* Cast(ResourceBase* r) { return static_cast<T*>(r); }
681 };
682 
683 template <typename T>
684 struct TypeCastFunctor<T, true> {
685   static T* Cast(ResourceBase* r) { return dynamic_cast<T*>(r); }
686 };
687 
688 template <typename T, bool use_dynamic_cast>
689 Status ResourceMgr::LookupInternal(const std::string& container,
690                                    const std::string& name,
691                                    T** resource) const {
692   ResourceBase* found = nullptr;
693   Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
694   if (s.ok()) {
695     // It's safe to down cast 'found' to T* since
696     // typeid(T).hash_code() is part of the map key.
697     *resource = TypeCastFunctor<T, use_dynamic_cast>::Cast(found);
698   }
699   return s;
700 }
701 
702 template <typename T, bool use_dynamic_cast>
703 Status ResourceMgr::LookupOrCreate(const std::string& container,
704                                    const std::string& name, T** resource,
705                                    std::function<Status(T**)> creator) {
706   CheckDeriveFromResourceBase<T>();
707   *resource = nullptr;
708   Status s;
709   {
710     tf_shared_lock l(mu_);
711     s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
712     if (s.ok()) return s;
713   }
714   mutex_lock l(mu_);
715   s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
716   if (s.ok()) return s;
717   TF_RETURN_IF_ERROR(creator(resource));
718   s = DoCreate(container, TypeIndex::Make<T>(), name, *resource,
719                /* owns_resource */ true);
720   if (!s.ok()) {
721     return errors::Internal("LookupOrCreate failed unexpectedly");
722   }
723   (*resource)->Ref();
724   return s;
725 }
726 
727 template <typename T>
728 Status ResourceMgr::Delete(const std::string& container,
729                            const std::string& name) {
730   CheckDeriveFromResourceBase<T>();
731   return DoDelete(container, TypeIndex::Make<T>(), name);
732 }
733 
734 template <typename T>
735 Status GetResourceFromContext(OpKernelContext* ctx,
736                               const std::string& input_name, T** resource) {
737   DataType dtype;
738   TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype));
739   if (dtype == DT_RESOURCE) {
740     const Tensor* handle;
741     TF_RETURN_IF_ERROR(ctx->input(input_name, &handle));
742     return LookupResource(ctx, handle->scalar<ResourceHandle>()(), resource);
743   }
744   std::string container;
745   std::string shared_name;
746   {
747     mutex* mu;
748     TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
749     mutex_lock l(*mu);
750     Tensor tensor;
751     TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
752     if (tensor.NumElements() != 2) {
753       return errors::InvalidArgument(
754           "Resource handle must have 2 elements, but had shape: ",
755           tensor.shape().DebugString());
756     }
757     container = tensor.flat<tstring>()(0);
758     shared_name = tensor.flat<tstring>()(1);
759   }
760   return ctx->resource_manager()->Lookup(container, shared_name, resource);
761 }
762 
763 namespace internal {
764 
765 Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
766 
767 template <typename T>
768 Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
769   TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
770   TF_RETURN_IF_ERROR(p.ValidateType<T>());
771   return OkStatus();
772 }
773 
774 }  // namespace internal
775 
776 // Creates the resource pointed at by "p". The caller transfers the ownership of
777 // one ref on "*value" to the resource manager in "ctx", regardless of whether
778 // this operation succeeds or fails.
779 template <typename T>
780 Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
781   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
782   return ctx->resource_manager()->Create(p.container(), p.name(), value);
783 }
784 
785 // Finds the resource as "*value" from the handle. If the handle is
786 // ref-counting, returns the resource owned by the handle. Otherwise, looks up
787 // the resource matching "p" from resource manager associated with ctx.
788 // Always returns a new reference to the resource in "*value". The caller shall
789 // call (*value)->Unref().
790 template <typename T, bool use_dynamic_cast>
791 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
792                       T** value) {
793   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
794   if (p.IsRefCounting()) {
795     TF_ASSIGN_OR_RETURN(*value, p.GetResource<T>());
796     // Transfers out a new reference.
797     (*value)->Ref();
798     return OkStatus();
799   }
800 
801   return ctx->resource_manager()->Lookup<T, use_dynamic_cast>(p.container(),
802                                                               p.name(), value);
803 }
804 
805 // Finds the resource as "*value" from the handle. This is a type-erased
806 // variant of LookupResource above.
807 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
808                       ResourceBase** value);
809 
810 // If the resource manager in "ctx" has a resource matching "p", returns it in
811 // "*value".
812 template <typename T>
813 Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
814                       core::RefCountPtr<T>* value) {
815   T* raw_ptr = nullptr;
816   TF_RETURN_IF_ERROR(LookupResource<T, false>(ctx, p, &raw_ptr));
817   value->reset(raw_ptr);
818 
819   return OkStatus();
820 }
821 
822 // Similar to Lookup, but looks up multiple resources at once, with only a
823 // single lock acquisition.
824 template <typename T>
825 Status LookupResources(OpKernelContext* ctx,
826                        absl::Span<ResourceHandle const* const> p,
827                        std::vector<core::RefCountPtr<T>>* values) {
828   std::vector<std::pair<const string*, const string*>> containers_and_names(
829       p.size());
830   for (size_t i = 0; i < p.size(); ++i) {
831     TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
832     containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
833   }
834   return ctx->resource_manager()->LookupMany(containers_and_names, values);
835 }
836 
837 // If the resource manager in "ctx" has a resource pointed at by "p", returns
838 // it in "*value". Otherwise, invokes creator() to create the resource.
839 // The caller takes the ownership of one ref on "*value".
840 //
841 // WARNING: creator() must not call any methods on the resource manager during
842 // its execution, because a non-reentrant lock is held during the creator() call
843 // in order to guarantee atomicity of LookupOrCreateResource().
844 template <typename T>
845 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
846                               T** value, std::function<Status(T**)> creator) {
847   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
848   return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value,
849                                                  creator);
850 }
851 
852 // If the resource manager in "ctx" has a resource pointed at by "p", returns
853 // it in "*value". Otherwise, invokes creator() to create the resource.
854 //
855 // WARNING: creator() must not call any methods on the resource manager during
856 // its execution, because a non-reentrant lock is held during the creator() call
857 // in order to guarantee atomicity of LookupOrCreateResource().
858 template <typename T>
859 Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
860                               core::RefCountPtr<T>* value,
861                               std::function<Status(T**)> creator) {
862   T* raw_ptr = nullptr;
863   TF_RETURN_IF_ERROR(LookupOrCreateResource<T>(ctx, p, &raw_ptr, creator));
864   value->reset(raw_ptr);
865 
866   return OkStatus();
867 }
868 
869 // Deletes the resource pointed by "p", using the resource manager in "ctx".
870 template <typename T>
871 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
872   TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
873   // This is a noop because ResourceMgr does not hold a reference.
874   // NOTE(feyu): if we can convert all resources handle to ref-counting, then
875   // DeleteResource can be removed.
876   if (p.IsRefCounting()) {
877     return OkStatus();
878   }
879   return ctx->resource_manager()->Delete<T>(p.container(), p.name());
880 }
881 
882 // Deletes the resource pointed by "p", using the resource manager in "ctx".
883 Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
884 
885 template <typename T>
886 void IsResourceInitialized<T>::Compute(OpKernelContext* ctx) {
887   Tensor* output;
888   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output));
889   T* object;
890   bool found;
891   if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) {
892     found = true;
893     object->Unref();
894   } else {
895     found = false;
896   }
897 
898   output->flat<bool>()(0) = found;
899 }
900 
901 template <typename T>
902 ResourceHandleOp<T>::ResourceHandleOp(OpKernelConstruction* context)
903     : OpKernel(context) {
904   OP_REQUIRES_OK(context, context->GetAttr("container", &container_));
905   OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_));
906 }
907 
908 template <typename T>
909 void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
910   if (name_ == ResourceHandle::ANONYMOUS_NAME) {
911     AllocatorAttributes attr;
912     attr.set_on_host(true);
913     Tensor handle;
914     OP_REQUIRES_OK(
915         ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
916     handle.scalar<ResourceHandle>()() = MakeResourceHandle<T>(
917         ctx, container_, name_, /*dtypes_and_shapes=*/{}, ctx->stack_trace());
918     ctx->set_output(0, handle);
919   } else {
920     if (!initialized_.load()) {
921       mutex_lock ml(mutex_);
922       // Checking again to see if another thread has initialized the resource.
923       if (!initialized_.load()) {
924         AllocatorAttributes attr;
925         attr.set_on_host(true);
926         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
927                                                &resource_, attr));
928         resource_.scalar<ResourceHandle>()() =
929             MakeResourceHandle<T>(ctx, container_, name_,
930                                   /*dtypes_and_shapes=*/{}, ctx->stack_trace());
931         initialized_.store(true);
932       }
933     }
934     ctx->set_output(0, resource_);
935   }
936 }
937 
938 template <typename T>
939 ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
940     : OpKernel(context) {
941   int n;
942   OP_REQUIRES_OK(context, context->GetAttr("N", &n));
943   OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
944   OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
945   OP_REQUIRES(
946       context, containers_.size() == n,
947       errors::InvalidArgument("Number of containers (", containers_.size(),
948                               ") must be equal to N (", n, ")"));
949   OP_REQUIRES(context, names_.size() == n,
950               errors::InvalidArgument("Number of names (", containers_.size(),
951                                       ") must be equal to N (", n, ")"));
952   resources_.resize(n);
953 }
954 
955 template <typename T>
956 void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
957   if (!initialized_.load()) {
958     mutex_lock ml(mutex_);
959     // Checking again to see if another thread has initialized the resource.
960     if (!initialized_.load()) {
961       AllocatorAttributes attr;
962       attr.set_on_host(true);
963       for (size_t i = 0; i < resources_.size(); ++i) {
964         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
965                                                &resources_[i], attr));
966         ResourceHandle h =
967             MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
968         resources_[i].template scalar<ResourceHandle>()() = h;
969       }
970       initialized_.store(true);
971     }
972   }
973   for (size_t i = 0; i < resources_.size(); ++i) {
974     ctx->set_output(i, resources_[i]);
975   }
976 }
977 
978 template <typename T>
979 ResourceHandle ScopedStepContainer::MakeResourceHandle(
980     const std::string& name, const DeviceBase& device) {
981   mutex_lock ml(mu_);
982   dirty_ = true;
983   return tensorflow::MakeResourceHandle(container_, name, device,
984                                         TypeIndex::Make<T>(), {});
985 }
986 
987 template <typename T>
988 Status ScopedStepContainer::Lookup(ResourceMgr* rm, const std::string& name,
989                                    T** resource) const {
990   return rm->Lookup<T>(container_, name, resource);
991 }
992 
993 template <typename T>
994 Status ScopedStepContainer::LookupOrCreate(ResourceMgr* rm,
995                                            const std::string& name,
996                                            T** resource,
997                                            std::function<Status(T**)> creator) {
998   mutex_lock ml(mu_);
999   dirty_ = true;
1000   return rm->LookupOrCreate<T>(container_, name, resource, creator);
1001 }
1002 
1003 template <typename T>
1004 Status ScopedStepContainer::Create(ResourceMgr* rm, const std::string& name,
1005                                    T* resource) {
1006   mutex_lock ml(mu_);
1007   dirty_ = true;
1008   return rm->Create<T>(container_, name, resource);
1009 }
1010 
1011 template <typename T>
1012 Status ScopedStepContainer::Delete(ResourceMgr* rm, const std::string& name) {
1013   return rm->Delete<T>(container_, name);
1014 }
1015 
1016 }  //  end namespace tensorflow
1017 
1018 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
1019