xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/variant_op_registry.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/variant_op_registry.h"
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/type_index.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/public/version.h"
26 
27 namespace tensorflow {
28 
VariantUnaryOpToString(VariantUnaryOp op)29 const char* VariantUnaryOpToString(VariantUnaryOp op) {
30   switch (op) {
31     case INVALID_VARIANT_UNARY_OP:
32       return "INVALID";
33     case ZEROS_LIKE_VARIANT_UNARY_OP:
34       return "ZEROS_LIKE";
35     case CONJ_VARIANT_UNARY_OP:
36       return "CONJ";
37   }
38 }
39 
VariantBinaryOpToString(VariantBinaryOp op)40 const char* VariantBinaryOpToString(VariantBinaryOp op) {
41   switch (op) {
42     case INVALID_VARIANT_BINARY_OP:
43       return "INVALID";
44     case ADD_VARIANT_BINARY_OP:
45       return "ADD";
46   }
47 }
48 
PersistentStringStorage()49 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
50   static std::unordered_set<string>* string_storage =
51       new std::unordered_set<string>();
52   return string_storage;
53 }
54 
55 // Get a pointer to a global UnaryVariantOpRegistry object
UnaryVariantOpRegistryGlobal()56 UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal() {
57   static UnaryVariantOpRegistry* global_unary_variant_op_registry = nullptr;
58 
59   if (global_unary_variant_op_registry == nullptr) {
60     global_unary_variant_op_registry = new UnaryVariantOpRegistry;
61   }
62   return global_unary_variant_op_registry;
63 }
64 
GetDecodeFn(StringPiece type_name)65 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
66     StringPiece type_name) {
67   auto found = decode_fns.find(type_name);
68   if (found == decode_fns.end()) return nullptr;
69   return &found->second;
70 }
71 
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)72 void UnaryVariantOpRegistry::RegisterDecodeFn(
73     const string& type_name, const VariantDecodeFn& decode_fn) {
74   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
75   VariantDecodeFn* existing = GetDecodeFn(type_name);
76   CHECK_EQ(existing, nullptr)
77       << "Unary VariantDecodeFn for type_name: " << type_name
78       << " already registered";
79   decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
80       GetPersistentStringPiece(type_name), decode_fn));
81 }
82 
DecodeUnaryVariant(Variant * variant)83 bool DecodeUnaryVariant(Variant* variant) {
84   CHECK_NOTNULL(variant);
85   if (variant->TypeName().empty()) {
86     VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
87     if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
88       // Malformed variant.
89       return false;
90     } else {
91       // Serialization of an empty Variant.
92       variant->clear();
93       return true;
94     }
95   }
96   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
97       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
98   if (decode_fn == nullptr) {
99     return false;
100   }
101   const string type_name = variant->TypeName();
102   bool decoded = (*decode_fn)(variant);
103   if (!decoded) return false;
104   if (variant->TypeName() != type_name) {
105     LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
106                << type_name
107                << " but after decoding was: " << variant->TypeName()
108                << ".  Treating this as a failure.";
109     return false;
110   }
111   return true;
112 }
113 
114 // Add some basic registrations for use by others, e.g., for testing.
115 
116 #define REGISTER_VARIANT_DECODE_TYPE(T) \
117   REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
118 
119 // No encode/decode registered for std::complex<> and Eigen::half
120 // objects yet.
121 REGISTER_VARIANT_DECODE_TYPE(int);
122 REGISTER_VARIANT_DECODE_TYPE(float);
123 REGISTER_VARIANT_DECODE_TYPE(bool);
124 REGISTER_VARIANT_DECODE_TYPE(double);
125 
126 #undef REGISTER_VARIANT_DECODE_TYPE
127 
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)128 Status VariantDeviceCopy(
129     const VariantDeviceCopyDirection direction, const Variant& from,
130     Variant* to,
131     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
132   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
133       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
134                                                         from.TypeId());
135   if (device_copy_fn == nullptr) {
136     return errors::Internal(
137         "No unary variant device copy function found for direction: ",
138         direction, " and Variant type_index: ",
139         port::MaybeAbiDemangle(from.TypeId().name()));
140   }
141   return (*device_copy_fn)(from, to, copy_fn);
142 }
143 
144 namespace {
145 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)146 Status DeviceCopyPrimitiveType(
147     const T& in, T* out,
148     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
149   // Dummy copy, we don't actually bother copying to the device and back for
150   // testing.
151   *out = in;
152   return OkStatus();
153 }
154 }  // namespace
155 
156 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T)            \
157   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
158       T, VariantDeviceCopyDirection::HOST_TO_DEVICE,    \
159       DeviceCopyPrimitiveType<T>);                      \
160   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
161       T, VariantDeviceCopyDirection::DEVICE_TO_HOST,    \
162       DeviceCopyPrimitiveType<T>);                      \
163   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
164       T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,  \
165       DeviceCopyPrimitiveType<T>);
166 
167 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
168 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
169 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
170 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
171 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
172 
173 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
174 
175 namespace {
176 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)177 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
178                                      T* t_out) {
179   *t_out = T(0);
180   return OkStatus();
181 }
182 }  // namespace
183 
184 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
185   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
186                                            DEVICE_CPU, T,               \
187                                            ZerosLikeVariantPrimitiveType<T>);
188 
189 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
190 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
191 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
192 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
193 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
194 
195 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
196 
197 namespace {
198 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)199 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
200                                T* out) {
201   *out = a + b;
202   return OkStatus();
203 }
204 }  // namespace
205 
206 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
207   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
208                                             T, AddVariantPrimitiveType<T>);
209 
210 // No add registered for std::complex<> or Eigen::half objects yet.
211 REGISTER_VARIANT_ADD_TYPE(int);
212 REGISTER_VARIANT_ADD_TYPE(float);
213 REGISTER_VARIANT_ADD_TYPE(double);
214 REGISTER_VARIANT_ADD_TYPE(bool);
215 
216 #undef REGISTER_VARIANT_ADD_TYPE
217 
218 }  // namespace tensorflow
219