xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/resource_op_kernel.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/platform/thread_annotations.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 
31 // ResourceOpKernel<T> is a virtual base class for resource op implementing
32 // interface type T. The inherited op looks up the resource name (determined by
33 // ContainerInfo), and creates a new resource if necessary.
34 //
35 // Requirements:
36 //  - Op must be marked as stateful.
37 //  - Op must have `container` and `shared_name` attributes. Empty `container`
38 //  means using the default container. Empty `shared_name` means private
39 //  resource.
40 //  - Subclass must override CreateResource().
41 //  - Subclass is encouraged to override VerifyResource().
42 template <typename T>
43 class ResourceOpKernel : public OpKernel {
44  public:
ResourceOpKernel(OpKernelConstruction * context)45   explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
46     has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
47     if (!has_resource_type_) {
48       // The resource variant of the op may be placed on non-CPU devices, but
49       // this allocation is always on the host. Fortunately we don't need it in
50       // the resource case.
51       OP_REQUIRES_OK(context, context->allocate_temp(
52                                   DT_STRING, TensorShape({2}), &tensor_));
53     }
54   }
55 
56   // The resource is deleted from the resource manager only when it is private
57   // to kernel. Ideally the resource should be deleted when it is no longer held
58   // by anyone, but it would break backward compatibility.
~ResourceOpKernel()59   ~ResourceOpKernel() override {
60     if (resource_ != nullptr) {
61       resource_->Unref();
62       if (cinfo_.resource_is_private_to_kernel()) {
63         if (!cinfo_.resource_manager()
64                  ->template Delete<T>(cinfo_.container(), cinfo_.name())
65                  .ok()) {
66           // Do nothing; the resource can have been deleted by session resets.
67         }
68       }
69     }
70   }
71 
Compute(OpKernelContext * context)72   void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) {
73     mutex_lock l(mu_);
74     if (resource_ == nullptr) {
75       ResourceMgr* mgr = context->resource_manager();
76       OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
77 
78       T* resource;
79       OP_REQUIRES_OK(context,
80                      mgr->LookupOrCreate<T>(
81                          cinfo_.container(), cinfo_.name(), &resource,
82                          [this](T** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
83                            Status s = CreateResource(ret);
84                            if (!s.ok() && *ret != nullptr) {
85                              CHECK((*ret)->Unref());
86                            }
87                            return s;
88                          }));
89 
90       Status s = VerifyResource(resource);
91       if (TF_PREDICT_FALSE(!s.ok())) {
92         resource->Unref();
93         context->SetStatus(s);
94         return;
95       }
96 
97       if (!has_resource_type_) {
98         auto h = tensor_.template flat<tstring>();
99         h(0) = cinfo_.container();
100         h(1) = cinfo_.name();
101       }
102       resource_ = resource;
103     }
104     if (has_resource_type_) {
105       OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
106                                   context, 0, cinfo_.container(), cinfo_.name(),
107                                   TypeIndex::Make<T>()));
108     } else {
109       context->set_output_ref(0, &mu_, &tensor_);
110     }
111   }
112 
113  protected:
114   // Variables accessible from subclasses.
115   mutex mu_;
116   ContainerInfo cinfo_ TF_GUARDED_BY(mu_);
117   T* resource_ TF_GUARDED_BY(mu_) = nullptr;
118 
119  private:
120   // Must return a T descendant allocated with new that ResourceOpKernel will
121   // take ownership of.
122   virtual Status CreateResource(T** resource)
123       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
124 
125   // During the first Compute(), resource is either created or looked up using
126   // shared_name. In the latter case, the resource found should be verified if
127   // it is compatible with this op's configuration. The verification may fail in
128   // cases such as two graphs asking queues of the same shared name to have
129   // inconsistent capacities.
VerifyResource(T * resource)130   virtual Status VerifyResource(T* resource) { return OkStatus(); }
131 
132   Tensor tensor_ TF_GUARDED_BY(mu_);
133 
134   // Is the output of the operator of type DT_RESOURCE?
135   bool has_resource_type_;
136 };
137 }  // namespace tensorflow
138 
139 #endif  // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
140