1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_TENSORFLOW_HOST_OBJECT_H_ 18 #define FCP_TENSORFLOW_HOST_OBJECT_H_ 19 20 #include <memory> 21 #include <optional> 22 #include <utility> 23 24 #include "absl/base/thread_annotations.h" 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/synchronization/mutex.h" 27 #include "fcp/base/random_token.h" 28 #include "fcp/base/unique_value.h" 29 30 namespace fcp { 31 32 /** 33 * Op-kernels are instantiated by TensorFlow, and can only be parameterized by 34 * graph 'attrs' and tensor inputs. So, op-kernels which access the 'outside 35 * world' tend to use ambient, process-global resources - for example, consider 36 * op-kernels which interpret a string tensor as a filesystem path. 37 * 38 * In some uses, we'd like to parameterize an op-kernel on some 'host'-side, 39 * non-Tensor objects (for example, a virtual filesystem) at the site of 40 * Session::Run (i.e. maintaining functional composition). 41 * 42 * This file defines a mechanism to register 'host objects' (in a 43 * HostObjectRegistry) outside of a session, pass them to Session::Run, and 44 * refer to them inside of the graph (and op-kernel implementations) using 45 * DT_STRING scalars ('tokens'). We could instead use DT_VARIANT tensors (which 46 * can wrap C++ objects directly), but DT_STRING is much more convenient to 47 * marshal (for example, Python's Session::Run wrapper accepts Python strings 48 * for placeholder bindings, but not existing Tensor objects). 49 * 50 * To register a host object: 51 * Use HostObjectRegistry<I> for some interface type 'I'. This returns a 52 * HostObjectRegistration object, which de-registers on destruction. 53 * To pass in a host object: 54 * Bind the token() (from the HostObjectRegistration) to some placeholder, 55 * when calling Session::Run. 56 * To access a host object in an op-kernel: 57 * Use HostObjectRegistry<I>::TryLookup (the op should take a DT_STRING scalar 58 * for the token to use). 59 */ 60 61 namespace host_object_internal { 62 63 /** 64 * HostObjectRegistry implementation for a particular interface type. 65 * 66 * For each I, HostObjectRegistry<I> defines a HostObjectRegistryImpl with 67 * static storage duration. 68 */ 69 class HostObjectRegistryImpl { 70 public: 71 std::optional<std::shared_ptr<void>> TryLookup(RandomToken token); 72 void Register(RandomToken token, std::shared_ptr<void> p); 73 void Unregister(RandomToken token); 74 private: 75 absl::Mutex mutex_; 76 absl::flat_hash_map<RandomToken, std::shared_ptr<void>> objects_ 77 ABSL_GUARDED_BY(mutex_); 78 }; 79 80 } // namespace host_object_internal 81 82 /** 83 * Active registration of a host object, under token(). To reference this object 84 * in a TensorFlow graph, pass in token() as a DT_STRING tensor. 85 * 86 * De-registers when destructed. Note that the registered object *may* stay 87 * alive; an op-kernel can retain an std::shared_ptr ref from TryLookup. 88 */ 89 class HostObjectRegistration final { 90 public: 91 HostObjectRegistration(HostObjectRegistration&&) = default; 92 HostObjectRegistration& operator=(HostObjectRegistration&&) = default; 93 ~HostObjectRegistration()94 ~HostObjectRegistration() { 95 if (token_.has_value()) { 96 registry_->Unregister(*token_); 97 } 98 } 99 100 /** 101 * Token under which the object is registered. It can be passed into a graph 102 * (as a string tensor) and used to look up the object. 103 */ token()104 RandomToken token() const { return *token_; } 105 106 private: 107 template<typename T> 108 friend class HostObjectRegistry; 109 HostObjectRegistration(host_object_internal::HostObjectRegistryImpl * registry,RandomToken token)110 HostObjectRegistration(host_object_internal::HostObjectRegistryImpl* registry, 111 RandomToken token) 112 : registry_(registry), token_(token) {} 113 114 host_object_internal::HostObjectRegistryImpl* registry_; 115 UniqueValue<RandomToken> token_; 116 }; 117 118 /** 119 * Registry of host objects, for a particular interface type. 120 * See file remarks. 121 */ 122 template<typename T> 123 class HostObjectRegistry { 124 public: 125 /** 126 * Registers the provided host object, yielding a new HostObjectRegistration 127 * with a unique token(). The object is de-registered when the 128 * HostObjectRegistration is destructed. 129 */ Register(std::shared_ptr<T> p)130 static HostObjectRegistration Register(std::shared_ptr<T> p) { 131 RandomToken token = RandomToken::Generate(); 132 GetImpl()->Register(token, std::move(p)); 133 return HostObjectRegistration(GetImpl(), token); 134 } 135 136 /** 137 * Looks up a host object. Returns std::nullopt if nothing is currently 138 * registered for the provided token (and interface T). 139 */ TryLookup(RandomToken token)140 static std::optional<std::shared_ptr<T>> TryLookup(RandomToken token) { 141 std::optional<std::shared_ptr<void>> maybe_p = GetImpl()->TryLookup(token); 142 if (maybe_p.has_value()) { 143 std::shared_ptr<void> p = *std::move(maybe_p); 144 return std::static_pointer_cast<T>(std::move(p)); 145 } else { 146 return std::nullopt; 147 } 148 } 149 150 private: 151 HostObjectRegistry(); 152 GetImpl()153 static host_object_internal::HostObjectRegistryImpl* GetImpl() { 154 static auto* global_registry = 155 new host_object_internal::HostObjectRegistryImpl(); 156 return global_registry; 157 } 158 }; 159 160 } // namespace fcp 161 162 #endif // FCP_TENSORFLOW_HOST_OBJECT_H_ 163