1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/variant_op_registry.h"
17
18 #include <string>
19
20 #include "tensorflow/core/framework/register_types.h"
21 #include "tensorflow/core/framework/type_index.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/public/version.h"
26
27 namespace tensorflow {
28
VariantUnaryOpToString(VariantUnaryOp op)29 const char* VariantUnaryOpToString(VariantUnaryOp op) {
30 switch (op) {
31 case INVALID_VARIANT_UNARY_OP:
32 return "INVALID";
33 case ZEROS_LIKE_VARIANT_UNARY_OP:
34 return "ZEROS_LIKE";
35 case CONJ_VARIANT_UNARY_OP:
36 return "CONJ";
37 }
38 }
39
VariantBinaryOpToString(VariantBinaryOp op)40 const char* VariantBinaryOpToString(VariantBinaryOp op) {
41 switch (op) {
42 case INVALID_VARIANT_BINARY_OP:
43 return "INVALID";
44 case ADD_VARIANT_BINARY_OP:
45 return "ADD";
46 }
47 }
48
PersistentStringStorage()49 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
50 static std::unordered_set<string>* string_storage =
51 new std::unordered_set<string>();
52 return string_storage;
53 }
54
55 // Get a pointer to a global UnaryVariantOpRegistry object
UnaryVariantOpRegistryGlobal()56 UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal() {
57 static UnaryVariantOpRegistry* global_unary_variant_op_registry = nullptr;
58
59 if (global_unary_variant_op_registry == nullptr) {
60 global_unary_variant_op_registry = new UnaryVariantOpRegistry;
61 }
62 return global_unary_variant_op_registry;
63 }
64
GetDecodeFn(StringPiece type_name)65 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
66 StringPiece type_name) {
67 auto found = decode_fns.find(type_name);
68 if (found == decode_fns.end()) return nullptr;
69 return &found->second;
70 }
71
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)72 void UnaryVariantOpRegistry::RegisterDecodeFn(
73 const string& type_name, const VariantDecodeFn& decode_fn) {
74 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
75 VariantDecodeFn* existing = GetDecodeFn(type_name);
76 CHECK_EQ(existing, nullptr)
77 << "Unary VariantDecodeFn for type_name: " << type_name
78 << " already registered";
79 decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
80 GetPersistentStringPiece(type_name), decode_fn));
81 }
82
DecodeUnaryVariant(Variant * variant)83 bool DecodeUnaryVariant(Variant* variant) {
84 CHECK_NOTNULL(variant);
85 if (variant->TypeName().empty()) {
86 VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
87 if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
88 // Malformed variant.
89 return false;
90 } else {
91 // Serialization of an empty Variant.
92 variant->clear();
93 return true;
94 }
95 }
96 UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
97 UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
98 if (decode_fn == nullptr) {
99 return false;
100 }
101 const string type_name = variant->TypeName();
102 bool decoded = (*decode_fn)(variant);
103 if (!decoded) return false;
104 if (variant->TypeName() != type_name) {
105 LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
106 << type_name
107 << " but after decoding was: " << variant->TypeName()
108 << ". Treating this as a failure.";
109 return false;
110 }
111 return true;
112 }
113
114 // Add some basic registrations for use by others, e.g., for testing.
115
116 #define REGISTER_VARIANT_DECODE_TYPE(T) \
117 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
118
119 // No encode/decode registered for std::complex<> and Eigen::half
120 // objects yet.
121 REGISTER_VARIANT_DECODE_TYPE(int);
122 REGISTER_VARIANT_DECODE_TYPE(float);
123 REGISTER_VARIANT_DECODE_TYPE(bool);
124 REGISTER_VARIANT_DECODE_TYPE(double);
125
126 #undef REGISTER_VARIANT_DECODE_TYPE
127
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)128 Status VariantDeviceCopy(
129 const VariantDeviceCopyDirection direction, const Variant& from,
130 Variant* to,
131 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
132 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
133 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
134 from.TypeId());
135 if (device_copy_fn == nullptr) {
136 return errors::Internal(
137 "No unary variant device copy function found for direction: ",
138 direction, " and Variant type_index: ",
139 port::MaybeAbiDemangle(from.TypeId().name()));
140 }
141 return (*device_copy_fn)(from, to, copy_fn);
142 }
143
144 namespace {
145 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)146 Status DeviceCopyPrimitiveType(
147 const T& in, T* out,
148 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
149 // Dummy copy, we don't actually bother copying to the device and back for
150 // testing.
151 *out = in;
152 return OkStatus();
153 }
154 } // namespace
155
156 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T) \
157 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
158 T, VariantDeviceCopyDirection::HOST_TO_DEVICE, \
159 DeviceCopyPrimitiveType<T>); \
160 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
161 T, VariantDeviceCopyDirection::DEVICE_TO_HOST, \
162 DeviceCopyPrimitiveType<T>); \
163 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
164 T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, \
165 DeviceCopyPrimitiveType<T>);
166
167 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
168 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
169 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
170 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
171 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
172
173 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
174
175 namespace {
176 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)177 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
178 T* t_out) {
179 *t_out = T(0);
180 return OkStatus();
181 }
182 } // namespace
183
184 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
185 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
186 DEVICE_CPU, T, \
187 ZerosLikeVariantPrimitiveType<T>);
188
189 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
190 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
191 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
192 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
193 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
194
195 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
196
197 namespace {
198 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)199 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
200 T* out) {
201 *out = a + b;
202 return OkStatus();
203 }
204 } // namespace
205
206 #define REGISTER_VARIANT_ADD_TYPE(T) \
207 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
208 T, AddVariantPrimitiveType<T>);
209
210 // No add registered for std::complex<> or Eigen::half objects yet.
211 REGISTER_VARIANT_ADD_TYPE(int);
212 REGISTER_VARIANT_ADD_TYPE(float);
213 REGISTER_VARIANT_ADD_TYPE(double);
214 REGISTER_VARIANT_ADD_TYPE(bool);
215
216 #undef REGISTER_VARIANT_ADD_TYPE
217
218 } // namespace tensorflow
219