1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33
34 namespace tensorflow {
35
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types. To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40
41 enum VariantUnaryOp {
42 INVALID_VARIANT_UNARY_OP = 0,
43 ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44 CONJ_VARIANT_UNARY_OP = 2,
45 };
46
47 const char* VariantUnaryOpToString(VariantUnaryOp op);
48
49 enum VariantBinaryOp {
50 INVALID_VARIANT_BINARY_OP = 0,
51 ADD_VARIANT_BINARY_OP = 1,
52 };
53
54 const char* VariantBinaryOpToString(VariantBinaryOp op);
55
56 enum VariantDeviceCopyDirection {
57 INVALID_DEVICE_COPY_DIRECTION = 0,
58 HOST_TO_DEVICE = 1,
59 DEVICE_TO_HOST = 2,
60 DEVICE_TO_DEVICE = 3,
61 };
62
63 class UnaryVariantOpRegistry;
64 extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal();
65
66 class UnaryVariantOpRegistry {
67 public:
68 typedef std::function<bool(Variant*)> VariantDecodeFn;
69 typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
70 VariantUnaryOpFn;
71 typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
72 Variant*)>
73 VariantBinaryOpFn;
74
75 // An AsyncTensorDeviceCopyFn is a function provided to
76 // the user-provided DeviceCopyFn callback as the third argument ("copier").
77 //
78 // Expected inputs:
79 // from: A Tensor on the host (if performing cpu->gpu copy), or
80 // device (if performing gpu->cpu or gpu->gpu copy).
81 // to: An empty/uninitialized tensor. It will be updated upon
82 // successful return of the function with the correct dtype and shape.
83 // However, the copied data will not be available until the compute
84 // stream has been synchronized.
85 //
86 // Returns:
87 // The status upon memory allocation / initialization of the
88 // "to" tensor, and enqueue of the copy onto the compute stream.
89 // Any failure of the copy itself will update the underlying
90 // stream status and propagate through the runtime independent
91 // of the caller.
92 typedef std::function<Status(const Tensor& from, Tensor* to)>
93 AsyncTensorDeviceCopyFn;
94
95 // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
96 // expected to be passed to the registration macro
97 // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
98 typedef std::function<Status(const Variant& from, Variant* to,
99 AsyncTensorDeviceCopyFn copy_fn)>
100 AsyncVariantDeviceCopyFn;
101
102 // Add a decode function to the registry.
103 void RegisterDecodeFn(const std::string& type_name,
104 const VariantDecodeFn& decode_fn);
105
106 // Returns nullptr if no decode function was found for the given TypeName.
107 VariantDecodeFn* GetDecodeFn(StringPiece type_name);
108
109 // Add a copy-to-GPU function to the registry.
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)110 void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
111 const TypeIndex& type_index,
112 const AsyncVariantDeviceCopyFn& device_copy_fn) {
113 AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
114 CHECK_EQ(existing, nullptr)
115 << "UnaryVariantDeviceCopy for direction: " << direction
116 << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
117 << " already registered";
118 device_copy_fns.insert(
119 std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
120 AsyncVariantDeviceCopyFn>(
121 std::make_pair(direction, type_index), device_copy_fn));
122 }
123
124 // Returns nullptr if no copy function was found for the given
125 // TypeName and direction.
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)126 AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
127 const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
128 auto found = device_copy_fns.find(std::make_pair(direction, type_index));
129 if (found == device_copy_fns.end()) return nullptr;
130 return &found->second;
131 }
132
133 // Add a unary op function to the registry.
RegisterUnaryOpFn(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)134 void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
135 const TypeIndex& type_index,
136 const VariantUnaryOpFn& unary_op_fn) {
137 VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
138 CHECK_EQ(existing, nullptr)
139 << "Unary VariantUnaryOpFn for type_index: "
140 << port::MaybeAbiDemangle(type_index.name())
141 << " already registered for device type: " << device;
142 unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
143 {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
144 }
145
146 // Returns nullptr if no unary op function was found for the given
147 // op, device, and TypeName.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)148 VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
149 const TypeIndex& type_index) {
150 auto found = unary_op_fns.find({op, device, type_index});
151 if (found == unary_op_fns.end()) return nullptr;
152 return &found->second;
153 }
154
155 // Add a binary op function to the registry.
RegisterBinaryOpFn(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)156 void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
157 const TypeIndex& type_index,
158 const VariantBinaryOpFn& add_fn) {
159 VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
160 CHECK_EQ(existing, nullptr)
161 << "Unary VariantBinaryOpFn for type_index: "
162 << port::MaybeAbiDemangle(type_index.name())
163 << " already registered for device type: " << device;
164 binary_op_fns.insert(
165 std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
166 {op, GetPersistentStringPiece(device), type_index}, add_fn));
167 }
168
169 // Returns nullptr if no binary op function was found for the given
170 // op, device and TypeName.
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)171 VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
172 const TypeIndex& type_index) {
173 auto found = binary_op_fns.find({op, device, type_index});
174 if (found == binary_op_fns.end()) return nullptr;
175 return &found->second;
176 }
177
178 // Get a pointer to a global UnaryVariantOpRegistry object
Global()179 static UnaryVariantOpRegistry* Global() {
180 return UnaryVariantOpRegistryGlobal();
181 }
182
183 // Get a pointer to a global persistent string storage object.
184 // ISO/IEC C++ working draft N4296 clarifies that insertion into an
185 // std::unordered_set does not invalidate memory locations of
186 // *values* inside the set (though it may invalidate existing
187 // iterators). In other words, one may safely point a StringPiece to
188 // a value in the set without that StringPiece being invalidated by
189 // future insertions.
190 static std::unordered_set<string>* PersistentStringStorage();
191
192 private:
193 struct TypeIndexHash {
operatorTypeIndexHash194 std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
195 };
196
197 gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
198
199 // Map std::pair<Direction, type_name> to function.
200 struct PairHash {
201 template <typename Direction>
operatorPairHash202 std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
203 // The hash of an enum is just its value as a std::size_t.
204 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
205 ret = Hash64Combine(ret, std::get<1>(x).hash_code());
206 return ret;
207 }
208 };
209
210 gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
211 AsyncVariantDeviceCopyFn, PairHash>
212 device_copy_fns;
213
214 // Map std::tuple<Op, device, type_name> to function.
215
216 // this breaks by falling victim to "too perfect forwarding"
217 // see https://stackoverflow.com/questions/44475317/variadic-template-issue
218 // and references therein
219 template <typename Op>
220 struct FuncTuple {
FuncTupleFuncTuple221 FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
222 : op_type_(op), device_(dev), type_index_(type_index) {}
223 Op op_type_;
224 StringPiece device_;
225 TypeIndex type_index_;
226 };
227 // friend declaration for operator==
228 // needed for clang
229 template <typename Op>
230 friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
231 struct TupleHash {
232 template <typename Op>
operatorTupleHash233 std::size_t operator()(
234 const std::tuple<Op, StringPiece, TypeIndex>& x) const {
235 // The hash of an enum is just its value as a std::size_t.
236 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
237 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
238 ret = Hash64Combine(ret, std::get<2>(x).hash_code());
239 return ret;
240 }
241
242 template <typename Op>
operatorTupleHash243 std::size_t operator()(const FuncTuple<Op>& x) const {
244 // The hash of an enum is just its value as a std::size_t.
245 std::size_t ret = static_cast<std::size_t>(x.op_type_);
246 ret = Hash64Combine(ret, sp_hasher_(x.device_));
247 ret = Hash64Combine(ret, x.type_index_.hash_code());
248 return ret;
249 }
250 StringPieceHasher sp_hasher_;
251 };
252 gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
253 unary_op_fns;
254 gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
255 binary_op_fns;
256
257 // Find or insert a string into a persistent string storage
258 // container; return the StringPiece pointing to the permanent string
259 // location.
GetPersistentStringPiece(const std::string & str)260 static StringPiece GetPersistentStringPiece(const std::string& str) {
261 const auto string_storage = PersistentStringStorage();
262 auto found = string_storage->find(str);
263 if (found == string_storage->end()) {
264 auto inserted = string_storage->insert(str);
265 return StringPiece(*inserted.first);
266 } else {
267 return StringPiece(*found);
268 }
269 }
270 };
271 template <typename Op>
272 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
273 const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
274 return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
275 (lhs.type_index_ == rhs.type_index_);
276 }
277
278 // Decodes the Variant whose data_type has a registered decode
279 // function. Returns an Internal error if the Variant does not have a
280 // registered decode function, or if the decoding function fails.
281 //
282 // REQUIRES:
283 // variant is not null.
284 //
285 bool DecodeUnaryVariant(Variant* variant);
286
287 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
288 // The variant 'from' must have a registered DeviceCopyFn for the
289 // given direction. The returned variant 'to' will have
290 // (some subset of its) tensors stored on destination according to the
291 // registered DeviceCopyFn function for the given direction. Returns
292 // an Internal error if the Variant does not have a registered
293 // DeviceCopyFn function for the given direction, or if initiating the
294 // copy fails.
295 //
296 // REQUIRES:
297 // 'to' is not null.
298 //
299 Status VariantDeviceCopy(
300 const VariantDeviceCopyDirection direction, const Variant& from,
301 Variant* to,
302 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
303
304 // Sets *v_out = unary_op(v). The variant v must have a registered
305 // UnaryOp function for the given Device. Returns an Internal error
306 // if v does not have a registered unary_op function for this device, or if
307 // UnaryOp fails.
308 //
309 // REQUIRES:
310 // v_out is not null.
311 //
312 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)313 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
314 Variant* v_out) {
315 const std::string& device = DeviceName<Device>::value;
316 UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
317 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
318 if (unary_op_fn == nullptr) {
319 return errors::Internal("No unary variant unary_op function found for op ",
320 VariantUnaryOpToString(op),
321 " Variant type_name: ", v.TypeName(),
322 " for device type: ", device);
323 }
324 return (*unary_op_fn)(ctx, v, v_out);
325 }
326
327 // Sets *out = binary_op(a, b). The variants a and b must be the same type
328 // and have a registered binary_op function for the given Device. Returns an
329 // Internal error if a and b are not the same type_name or if
330 // if a does not have a registered op function for this device, or if
331 // BinaryOp fails.
332 //
333 // REQUIRES:
334 // out is not null.
335 //
336 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)337 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
338 const Variant& a, const Variant& b, Variant* out) {
339 if (a.TypeId() != b.TypeId()) {
340 return errors::Internal(
341 "BinaryOpVariants: Variants a and b have different "
342 "type ids. Type names: '",
343 a.TypeName(), "' vs. '", b.TypeName(), "'");
344 }
345 const std::string& device = DeviceName<Device>::value;
346 UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
347 UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
348 if (binary_op_fn == nullptr) {
349 return errors::Internal("No unary variant binary_op function found for op ",
350 VariantBinaryOpToString(op),
351 " Variant type_name: '", a.TypeName(),
352 "' for device type: ", device);
353 }
354 return (*binary_op_fn)(ctx, a, b, out);
355 }
356
357 namespace variant_op_registry_fn_registration {
358
359 template <typename T>
360 class UnaryVariantDecodeRegistration {
361 public:
UnaryVariantDecodeRegistration(const std::string & type_name)362 UnaryVariantDecodeRegistration(const std::string& type_name) {
363 // The Variant is passed by pointer because it should be
364 // mutable: get below may Decode the variant, which
365 // is a self-mutating behavior. The variant is not modified in
366 // any other way.
367 UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
368 type_name, [type_name](Variant* v) -> bool {
369 DCHECK_NE(v, nullptr);
370 VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
371 if (t == nullptr) {
372 return false;
373 }
374 Variant decoded = T();
375 VariantTensorData data(std::move(*t));
376 if (!decoded.Decode(std::move(data))) {
377 return false;
378 }
379 std::swap(decoded, *v);
380 return true;
381 });
382 }
383 };
384
385 template <typename T>
386 class UnaryVariantDeviceCopyRegistration {
387 public:
388 typedef std::function<Status(const T& t, T* t_out,
389 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
390 LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)391 UnaryVariantDeviceCopyRegistration(
392 const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
393 const LocalVariantDeviceCopyFn& device_copy_fn) {
394 const std::string type_index_name =
395 port::MaybeAbiDemangle(type_index.name());
396 UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
397 direction, type_index,
398 [type_index_name, device_copy_fn](
399 const Variant& from, Variant* to,
400 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
401 device_copy_tensor_fn) -> Status {
402 DCHECK_NE(to, nullptr);
403 *to = T();
404 if (from.get<T>() == nullptr) {
405 return errors::Internal(
406 "VariantCopyToGPUFn: Could not access object, type_index: ",
407 type_index_name);
408 }
409 const T& t = *from.get<T>();
410 T* t_out = to->get<T>();
411 return device_copy_fn(t, t_out, device_copy_tensor_fn);
412 });
413 }
414 };
415
416 template <typename T>
417 class UnaryVariantUnaryOpRegistration {
418 typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
419 LocalVariantUnaryOpFn;
420
421 public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)422 UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
423 const TypeIndex& type_index,
424 const LocalVariantUnaryOpFn& unary_op_fn) {
425 const std::string type_index_name =
426 port::MaybeAbiDemangle(type_index.name());
427 UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
428 op, device, type_index,
429 [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
430 Variant* v_out) -> Status {
431 DCHECK_NE(v_out, nullptr);
432 *v_out = T();
433 if (v.get<T>() == nullptr) {
434 return errors::Internal(
435 "VariantUnaryOpFn: Could not access object, type_index: ",
436 type_index_name);
437 }
438 const T& t = *v.get<T>();
439 T* t_out = v_out->get<T>();
440 return unary_op_fn(ctx, t, t_out);
441 });
442 }
443 };
444
445 template <typename T>
446 class UnaryVariantBinaryOpRegistration {
447 typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
448 T* out)>
449 LocalVariantBinaryOpFn;
450
451 public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)452 UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
453 const std::string& device,
454 const TypeIndex& type_index,
455 const LocalVariantBinaryOpFn& binary_op_fn) {
456 const std::string type_index_name =
457 port::MaybeAbiDemangle(type_index.name());
458 UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
459 op, device, type_index,
460 [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
461 const Variant& b,
462 Variant* out) -> Status {
463 DCHECK_NE(out, nullptr);
464 *out = T();
465 if (a.get<T>() == nullptr) {
466 return errors::Internal(
467 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
468 type_index_name);
469 }
470 if (b.get<T>() == nullptr) {
471 return errors::Internal(
472 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
473 type_index_name);
474 }
475 const T& t_a = *a.get<T>();
476 const T& t_b = *b.get<T>();
477 T* t_out = out->get<T>();
478 return binary_op_fn(ctx, t_a, t_b, t_out);
479 });
480 }
481 };
482
483 }; // namespace variant_op_registry_fn_registration
484
485 // Register a unary decode variant function for the given type.
486 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
487 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
488
489 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
490 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
491
492 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
493 static ::tensorflow::variant_op_registry_fn_registration:: \
494 UnaryVariantDecodeRegistration<T> \
495 register_unary_variant_op_decoder_fn_##ctr(type_name)
496
497 // ****** NOTE ******
498 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
499 // ****** NOTE ******
500 //
501 // Register a device copy variant function for the given copy
502 // direction and type; where direction is the enum
503 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
504 //
505 // Status device_copy_fn(
506 // const T& t, T* t_out,
507 // const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
508 //
509 // And device_copy_fn calls copier 0 or more times. For details on
510 // the behavior of the copier function, see the comments at the
511 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
512 //
513 // Note, the device_copy_fn may choose to keep some tensors
514 // on host, e.g. by assigning to->tensor = from.tensor (assuming
515 // from.tensor is already on host); or by setting
516 // to->tensor = Tensor(cpu_allocator(), ...)
517 // and manually updating its values.
518 //
519 // If this is the case, the CopyFns for HOST_TO_DEVICE,
520 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
521 // copies in a consistent manner. For example, one must always
522 // manually copy any "always on host" tensors in all directions instead of e.g.
523 // - performing a host-to-host copy in one direction,
524 // - using the provided copier function in the reverse direction.
525 // Doing the latter will cause program failures.
526 //
527 // ****** NOTE ******
528 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
529 // ****** NOTE ******
530 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
531 device_copy_fn) \
532 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
533 __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
534
535 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
536 ctr, T, direction, type_index, device_copy_fn) \
537 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
538 ctr, T, direction, type_index, device_copy_fn)
539
540 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
541 ctr, T, direction, type_index, device_copy_fn) \
542 static variant_op_registry_fn_registration:: \
543 UnaryVariantDeviceCopyRegistration<T> \
544 register_unary_variant_op_device_copy_fn_##ctr( \
545 direction, type_index, device_copy_fn)
546
547 // Register a unary unary_op variant function with the signature:
548 // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
549 // to Variants having TypeIndex type_index, for device string device,
550 // for UnaryVariantOp enum op.
551 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
552 unary_op_function) \
553 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
554 __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
555
556 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
557 ctr, op, device, T, type_index, unary_op_function) \
558 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
559 type_index, unary_op_function)
560
561 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
562 ctr, op, device, T, type_index, unary_op_function) \
563 static ::tensorflow::variant_op_registry_fn_registration:: \
564 UnaryVariantUnaryOpRegistration<T> \
565 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
566 unary_op_function)
567
568 // Register a binary_op variant function with the signature:
569 // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
570 // to Variants having TypeIndex type_index, for device string device,
571 // for BinaryVariantOp enum OP.
572 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
573 binary_op_function) \
574 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
575 __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
576
577 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
578 ctr, op, device, T, type_index, binary_op_function) \
579 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
580 ctr, op, device, T, type_index, binary_op_function)
581
582 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
583 ctr, op, device, T, type_index, binary_op_function) \
584 static ::tensorflow::variant_op_registry_fn_registration:: \
585 UnaryVariantBinaryOpRegistration<T> \
586 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
587 binary_op_function)
588
589 } // end namespace tensorflow
590
591 #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
592