xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/host_object.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_TENSORFLOW_HOST_OBJECT_H_
18 #define FCP_TENSORFLOW_HOST_OBJECT_H_
19 
20 #include <memory>
21 #include <optional>
22 #include <utility>
23 
24 #include "absl/base/thread_annotations.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/synchronization/mutex.h"
27 #include "fcp/base/random_token.h"
28 #include "fcp/base/unique_value.h"
29 
30 namespace fcp {
31 
32 /**
33  * Op-kernels are instantiated by TensorFlow, and can only be parameterized by
34  * graph 'attrs' and tensor inputs. So, op-kernels which access the 'outside
35  * world' tend to use ambient, process-global resources - for example, consider
36  * op-kernels which interpret a string tensor as a filesystem path.
37  *
38  * In some uses, we'd like to parameterize an op-kernel on some 'host'-side,
39  * non-Tensor objects (for example, a virtual filesystem) at the site of
40  * Session::Run (i.e. maintaining functional composition).
41  *
42  * This file defines a mechanism to register 'host objects' (in a
43  * HostObjectRegistry) outside of a session, pass them to Session::Run, and
44  * refer to them inside of the graph (and op-kernel implementations) using
45  * DT_STRING scalars ('tokens'). We could instead use DT_VARIANT tensors (which
46  * can wrap C++ objects directly), but DT_STRING is much more convenient to
47  * marshal (for example, Python's Session::Run wrapper accepts Python strings
48  * for placeholder bindings, but not existing Tensor objects).
49  *
50  * To register a host object:
51  *   Use HostObjectRegistry<I> for some interface type 'I'. This returns a
52  *   HostObjectRegistration object, which de-registers on destruction.
53  * To pass in a host object:
54  *   Bind the token() (from the HostObjectRegistration) to some placeholder,
55  *   when calling Session::Run.
56  * To access a host object in an op-kernel:
57  *   Use HostObjectRegistry<I>::TryLookup (the op should take a DT_STRING scalar
58  *   for the token to use).
59  */
60 
61 namespace host_object_internal {
62 
63 /**
64  * HostObjectRegistry implementation for a particular interface type.
65  *
66  * For each I, HostObjectRegistry<I> defines a HostObjectRegistryImpl with
67  * static storage duration.
68  */
69 class HostObjectRegistryImpl {
70  public:
71   std::optional<std::shared_ptr<void>> TryLookup(RandomToken token);
72   void Register(RandomToken token, std::shared_ptr<void> p);
73   void Unregister(RandomToken token);
74  private:
75   absl::Mutex mutex_;
76   absl::flat_hash_map<RandomToken, std::shared_ptr<void>> objects_
77       ABSL_GUARDED_BY(mutex_);
78 };
79 
80 }  // namespace host_object_internal
81 
82 /**
83  * Active registration of a host object, under token(). To reference this object
84  * in a TensorFlow graph, pass in token() as a DT_STRING tensor.
85  *
86  * De-registers when destructed. Note that the registered object *may* stay
87  * alive; an op-kernel can retain an std::shared_ptr ref from TryLookup.
88  */
89 class HostObjectRegistration final {
90  public:
91   HostObjectRegistration(HostObjectRegistration&&) = default;
92   HostObjectRegistration& operator=(HostObjectRegistration&&) = default;
93 
~HostObjectRegistration()94   ~HostObjectRegistration() {
95     if (token_.has_value()) {
96       registry_->Unregister(*token_);
97     }
98   }
99 
100   /**
101    * Token under which the object is registered. It can be passed into a graph
102    * (as a string tensor) and used to look up the object.
103    */
token()104   RandomToken token() const { return *token_; }
105 
106  private:
107   template<typename T>
108   friend class HostObjectRegistry;
109 
HostObjectRegistration(host_object_internal::HostObjectRegistryImpl * registry,RandomToken token)110   HostObjectRegistration(host_object_internal::HostObjectRegistryImpl* registry,
111                          RandomToken token)
112       : registry_(registry), token_(token) {}
113 
114   host_object_internal::HostObjectRegistryImpl* registry_;
115   UniqueValue<RandomToken> token_;
116 };
117 
118 /**
119  * Registry of host objects, for a particular interface type.
120  * See file remarks.
121  */
122 template<typename T>
123 class HostObjectRegistry {
124  public:
125   /**
126    * Registers the provided host object, yielding a new HostObjectRegistration
127    * with a unique token(). The object is de-registered when the
128    * HostObjectRegistration is destructed.
129    */
Register(std::shared_ptr<T> p)130   static HostObjectRegistration Register(std::shared_ptr<T> p) {
131     RandomToken token = RandomToken::Generate();
132     GetImpl()->Register(token, std::move(p));
133     return HostObjectRegistration(GetImpl(), token);
134   }
135 
136   /**
137    * Looks up a host object. Returns std::nullopt if nothing is currently
138    * registered for the provided token (and interface T).
139    */
TryLookup(RandomToken token)140   static std::optional<std::shared_ptr<T>> TryLookup(RandomToken token) {
141     std::optional<std::shared_ptr<void>> maybe_p = GetImpl()->TryLookup(token);
142     if (maybe_p.has_value()) {
143       std::shared_ptr<void> p = *std::move(maybe_p);
144       return std::static_pointer_cast<T>(std::move(p));
145     } else {
146       return std::nullopt;
147     }
148   }
149 
150  private:
151   HostObjectRegistry();
152 
GetImpl()153   static host_object_internal::HostObjectRegistryImpl* GetImpl() {
154     static auto* global_registry =
155         new host_object_internal::HostObjectRegistryImpl();
156     return global_registry;
157   }
158 };
159 
160 }  // namespace fcp
161 
162 #endif  // FCP_TENSORFLOW_HOST_OBJECT_H_
163