xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/variant_op_registry.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33 
34 namespace tensorflow {
35 
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types.  To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40 
41 enum VariantUnaryOp {
42   INVALID_VARIANT_UNARY_OP = 0,
43   ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44   CONJ_VARIANT_UNARY_OP = 2,
45 };
46 
47 const char* VariantUnaryOpToString(VariantUnaryOp op);
48 
49 enum VariantBinaryOp {
50   INVALID_VARIANT_BINARY_OP = 0,
51   ADD_VARIANT_BINARY_OP = 1,
52 };
53 
54 const char* VariantBinaryOpToString(VariantBinaryOp op);
55 
56 enum VariantDeviceCopyDirection {
57   INVALID_DEVICE_COPY_DIRECTION = 0,
58   HOST_TO_DEVICE = 1,
59   DEVICE_TO_HOST = 2,
60   DEVICE_TO_DEVICE = 3,
61 };
62 
63 class UnaryVariantOpRegistry;
64 extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal();
65 
66 class UnaryVariantOpRegistry {
67  public:
68   typedef std::function<bool(Variant*)> VariantDecodeFn;
69   typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
70       VariantUnaryOpFn;
71   typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
72                                Variant*)>
73       VariantBinaryOpFn;
74 
75   // An AsyncTensorDeviceCopyFn is a function provided to
76   // the user-provided DeviceCopyFn callback as the third argument ("copier").
77   //
78   // Expected inputs:
79   //   from: A Tensor on the host (if performing cpu->gpu copy), or
80   //         device (if performing gpu->cpu or gpu->gpu copy).
81   //   to: An empty/uninitialized tensor.  It will be updated upon
82   //       successful return of the function with the correct dtype and shape.
83   //       However, the copied data will not be available until the compute
84   //       stream has been synchronized.
85   //
86   // Returns:
87   //   The status upon memory allocation / initialization of the
88   //   "to" tensor, and enqueue of the copy onto the compute stream.
89   //   Any failure of the copy itself will update the underlying
90   //   stream status and propagate through the runtime independent
91   //   of the caller.
92   typedef std::function<Status(const Tensor& from, Tensor* to)>
93       AsyncTensorDeviceCopyFn;
94 
95   // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
96   // expected to be passed to the registration macro
97   // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
98   typedef std::function<Status(const Variant& from, Variant* to,
99                                AsyncTensorDeviceCopyFn copy_fn)>
100       AsyncVariantDeviceCopyFn;
101 
102   // Add a decode function to the registry.
103   void RegisterDecodeFn(const std::string& type_name,
104                         const VariantDecodeFn& decode_fn);
105 
106   // Returns nullptr if no decode function was found for the given TypeName.
107   VariantDecodeFn* GetDecodeFn(StringPiece type_name);
108 
109   // Add a copy-to-GPU function to the registry.
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)110   void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
111                             const TypeIndex& type_index,
112                             const AsyncVariantDeviceCopyFn& device_copy_fn) {
113     AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
114     CHECK_EQ(existing, nullptr)
115         << "UnaryVariantDeviceCopy for direction: " << direction
116         << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
117         << " already registered";
118     device_copy_fns.insert(
119         std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
120                   AsyncVariantDeviceCopyFn>(
121             std::make_pair(direction, type_index), device_copy_fn));
122   }
123 
124   // Returns nullptr if no copy function was found for the given
125   // TypeName and direction.
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)126   AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
127       const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
128     auto found = device_copy_fns.find(std::make_pair(direction, type_index));
129     if (found == device_copy_fns.end()) return nullptr;
130     return &found->second;
131   }
132 
133   // Add a unary op function to the registry.
RegisterUnaryOpFn(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)134   void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
135                          const TypeIndex& type_index,
136                          const VariantUnaryOpFn& unary_op_fn) {
137     VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
138     CHECK_EQ(existing, nullptr)
139         << "Unary VariantUnaryOpFn for type_index: "
140         << port::MaybeAbiDemangle(type_index.name())
141         << " already registered for device type: " << device;
142     unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
143         {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
144   }
145 
146   // Returns nullptr if no unary op function was found for the given
147   // op, device, and TypeName.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)148   VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
149                                  const TypeIndex& type_index) {
150     auto found = unary_op_fns.find({op, device, type_index});
151     if (found == unary_op_fns.end()) return nullptr;
152     return &found->second;
153   }
154 
155   // Add a binary op function to the registry.
RegisterBinaryOpFn(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)156   void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
157                           const TypeIndex& type_index,
158                           const VariantBinaryOpFn& add_fn) {
159     VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
160     CHECK_EQ(existing, nullptr)
161         << "Unary VariantBinaryOpFn for type_index: "
162         << port::MaybeAbiDemangle(type_index.name())
163         << " already registered for device type: " << device;
164     binary_op_fns.insert(
165         std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
166             {op, GetPersistentStringPiece(device), type_index}, add_fn));
167   }
168 
169   // Returns nullptr if no binary op function was found for the given
170   // op, device and TypeName.
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)171   VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
172                                    const TypeIndex& type_index) {
173     auto found = binary_op_fns.find({op, device, type_index});
174     if (found == binary_op_fns.end()) return nullptr;
175     return &found->second;
176   }
177 
178   // Get a pointer to a global UnaryVariantOpRegistry object
Global()179   static UnaryVariantOpRegistry* Global() {
180     return UnaryVariantOpRegistryGlobal();
181   }
182 
183   // Get a pointer to a global persistent string storage object.
184   // ISO/IEC C++ working draft N4296 clarifies that insertion into an
185   // std::unordered_set does not invalidate memory locations of
186   // *values* inside the set (though it may invalidate existing
187   // iterators).  In other words, one may safely point a StringPiece to
188   // a value in the set without that StringPiece being invalidated by
189   // future insertions.
190   static std::unordered_set<string>* PersistentStringStorage();
191 
192  private:
193   struct TypeIndexHash {
operatorTypeIndexHash194     std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
195   };
196 
197   gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
198 
199   // Map std::pair<Direction, type_name> to function.
200   struct PairHash {
201     template <typename Direction>
operatorPairHash202     std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
203       // The hash of an enum is just its value as a std::size_t.
204       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
205       ret = Hash64Combine(ret, std::get<1>(x).hash_code());
206       return ret;
207     }
208   };
209 
210   gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
211                AsyncVariantDeviceCopyFn, PairHash>
212       device_copy_fns;
213 
214   // Map std::tuple<Op, device, type_name> to function.
215 
216   // this breaks by falling victim to "too perfect forwarding"
217   // see https://stackoverflow.com/questions/44475317/variadic-template-issue
218   // and references therein
219   template <typename Op>
220   struct FuncTuple {
FuncTupleFuncTuple221     FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
222         : op_type_(op), device_(dev), type_index_(type_index) {}
223     Op op_type_;
224     StringPiece device_;
225     TypeIndex type_index_;
226   };
227   // friend declaration for operator==
228   // needed for clang
229   template <typename Op>
230   friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
231   struct TupleHash {
232     template <typename Op>
operatorTupleHash233     std::size_t operator()(
234         const std::tuple<Op, StringPiece, TypeIndex>& x) const {
235       // The hash of an enum is just its value as a std::size_t.
236       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
237       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
238       ret = Hash64Combine(ret, std::get<2>(x).hash_code());
239       return ret;
240     }
241 
242     template <typename Op>
operatorTupleHash243     std::size_t operator()(const FuncTuple<Op>& x) const {
244       // The hash of an enum is just its value as a std::size_t.
245       std::size_t ret = static_cast<std::size_t>(x.op_type_);
246       ret = Hash64Combine(ret, sp_hasher_(x.device_));
247       ret = Hash64Combine(ret, x.type_index_.hash_code());
248       return ret;
249     }
250     StringPieceHasher sp_hasher_;
251   };
252   gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
253       unary_op_fns;
254   gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
255       binary_op_fns;
256 
257   // Find or insert a string into a persistent string storage
258   // container; return the StringPiece pointing to the permanent string
259   // location.
GetPersistentStringPiece(const std::string & str)260   static StringPiece GetPersistentStringPiece(const std::string& str) {
261     const auto string_storage = PersistentStringStorage();
262     auto found = string_storage->find(str);
263     if (found == string_storage->end()) {
264       auto inserted = string_storage->insert(str);
265       return StringPiece(*inserted.first);
266     } else {
267       return StringPiece(*found);
268     }
269   }
270 };
271 template <typename Op>
272 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
273                        const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
274   return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
275          (lhs.type_index_ == rhs.type_index_);
276 }
277 
278 // Decodes the Variant whose data_type has a registered decode
279 // function.  Returns an Internal error if the Variant does not have a
280 // registered decode function, or if the decoding function fails.
281 //
282 // REQUIRES:
283 //   variant is not null.
284 //
285 bool DecodeUnaryVariant(Variant* variant);
286 
287 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
288 // The variant 'from' must have a registered DeviceCopyFn for the
289 // given direction.  The returned variant 'to' will have
290 // (some subset of its) tensors stored on destination according to the
291 // registered DeviceCopyFn function for the given direction.  Returns
292 // an Internal error if the Variant does not have a registered
293 // DeviceCopyFn function for the given direction, or if initiating the
294 // copy fails.
295 //
296 // REQUIRES:
297 //   'to' is not null.
298 //
299 Status VariantDeviceCopy(
300     const VariantDeviceCopyDirection direction, const Variant& from,
301     Variant* to,
302     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
303 
304 // Sets *v_out = unary_op(v).  The variant v must have a registered
305 // UnaryOp function for the given Device.  Returns an Internal error
306 // if v does not have a registered unary_op function for this device, or if
307 // UnaryOp fails.
308 //
309 // REQUIRES:
310 //   v_out is not null.
311 //
312 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)313 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
314                       Variant* v_out) {
315   const std::string& device = DeviceName<Device>::value;
316   UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
317       UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
318   if (unary_op_fn == nullptr) {
319     return errors::Internal("No unary variant unary_op function found for op ",
320                             VariantUnaryOpToString(op),
321                             " Variant type_name: ", v.TypeName(),
322                             " for device type: ", device);
323   }
324   return (*unary_op_fn)(ctx, v, v_out);
325 }
326 
327 // Sets *out = binary_op(a, b).  The variants a and b must be the same type
328 // and have a registered binary_op function for the given Device.  Returns an
329 // Internal error if a and b are not the same type_name or if
330 // if a does not have a registered op function for this device, or if
331 // BinaryOp fails.
332 //
333 // REQUIRES:
334 //   out is not null.
335 //
336 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)337 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
338                         const Variant& a, const Variant& b, Variant* out) {
339   if (a.TypeId() != b.TypeId()) {
340     return errors::Internal(
341         "BinaryOpVariants: Variants a and b have different "
342         "type ids.  Type names: '",
343         a.TypeName(), "' vs. '", b.TypeName(), "'");
344   }
345   const std::string& device = DeviceName<Device>::value;
346   UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
347       UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
348   if (binary_op_fn == nullptr) {
349     return errors::Internal("No unary variant binary_op function found for op ",
350                             VariantBinaryOpToString(op),
351                             " Variant type_name: '", a.TypeName(),
352                             "' for device type: ", device);
353   }
354   return (*binary_op_fn)(ctx, a, b, out);
355 }
356 
357 namespace variant_op_registry_fn_registration {
358 
359 template <typename T>
360 class UnaryVariantDecodeRegistration {
361  public:
UnaryVariantDecodeRegistration(const std::string & type_name)362   UnaryVariantDecodeRegistration(const std::string& type_name) {
363     // The Variant is passed by pointer because it should be
364     // mutable: get below may Decode the variant, which
365     // is a self-mutating behavior.  The variant is not modified in
366     // any other way.
367     UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
368         type_name, [type_name](Variant* v) -> bool {
369           DCHECK_NE(v, nullptr);
370           VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
371           if (t == nullptr) {
372             return false;
373           }
374           Variant decoded = T();
375           VariantTensorData data(std::move(*t));
376           if (!decoded.Decode(std::move(data))) {
377             return false;
378           }
379           std::swap(decoded, *v);
380           return true;
381         });
382   }
383 };
384 
385 template <typename T>
386 class UnaryVariantDeviceCopyRegistration {
387  public:
388   typedef std::function<Status(const T& t, T* t_out,
389                                UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
390       LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)391   UnaryVariantDeviceCopyRegistration(
392       const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
393       const LocalVariantDeviceCopyFn& device_copy_fn) {
394     const std::string type_index_name =
395         port::MaybeAbiDemangle(type_index.name());
396     UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
397         direction, type_index,
398         [type_index_name, device_copy_fn](
399             const Variant& from, Variant* to,
400             UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
401                 device_copy_tensor_fn) -> Status {
402           DCHECK_NE(to, nullptr);
403           *to = T();
404           if (from.get<T>() == nullptr) {
405             return errors::Internal(
406                 "VariantCopyToGPUFn: Could not access object, type_index: ",
407                 type_index_name);
408           }
409           const T& t = *from.get<T>();
410           T* t_out = to->get<T>();
411           return device_copy_fn(t, t_out, device_copy_tensor_fn);
412         });
413   }
414 };
415 
416 template <typename T>
417 class UnaryVariantUnaryOpRegistration {
418   typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
419       LocalVariantUnaryOpFn;
420 
421  public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)422   UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
423                                   const TypeIndex& type_index,
424                                   const LocalVariantUnaryOpFn& unary_op_fn) {
425     const std::string type_index_name =
426         port::MaybeAbiDemangle(type_index.name());
427     UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
428         op, device, type_index,
429         [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
430                                        Variant* v_out) -> Status {
431           DCHECK_NE(v_out, nullptr);
432           *v_out = T();
433           if (v.get<T>() == nullptr) {
434             return errors::Internal(
435                 "VariantUnaryOpFn: Could not access object, type_index: ",
436                 type_index_name);
437           }
438           const T& t = *v.get<T>();
439           T* t_out = v_out->get<T>();
440           return unary_op_fn(ctx, t, t_out);
441         });
442   }
443 };
444 
445 template <typename T>
446 class UnaryVariantBinaryOpRegistration {
447   typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
448                                T* out)>
449       LocalVariantBinaryOpFn;
450 
451  public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const std::string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)452   UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
453                                    const std::string& device,
454                                    const TypeIndex& type_index,
455                                    const LocalVariantBinaryOpFn& binary_op_fn) {
456     const std::string type_index_name =
457         port::MaybeAbiDemangle(type_index.name());
458     UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
459         op, device, type_index,
460         [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
461                                         const Variant& b,
462                                         Variant* out) -> Status {
463           DCHECK_NE(out, nullptr);
464           *out = T();
465           if (a.get<T>() == nullptr) {
466             return errors::Internal(
467                 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
468                 type_index_name);
469           }
470           if (b.get<T>() == nullptr) {
471             return errors::Internal(
472                 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
473                 type_index_name);
474           }
475           const T& t_a = *a.get<T>();
476           const T& t_b = *b.get<T>();
477           T* t_out = out->get<T>();
478           return binary_op_fn(ctx, t_a, t_b, t_out);
479         });
480   }
481 };
482 
483 };  // namespace variant_op_registry_fn_registration
484 
485 // Register a unary decode variant function for the given type.
486 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
487   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
488 
489 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
490   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
491 
492 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
493   static ::tensorflow::variant_op_registry_fn_registration::           \
494       UnaryVariantDecodeRegistration<T>                                \
495           register_unary_variant_op_decoder_fn_##ctr(type_name)
496 
497 // ****** NOTE ******
498 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
499 // ****** NOTE ******
500 //
501 // Register a device copy variant function for the given copy
502 // direction and type; where direction is the enum
503 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
504 //
505 //   Status device_copy_fn(
506 //     const T& t, T* t_out,
507 //     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
508 //
509 // And device_copy_fn calls copier 0 or more times.  For details on
510 // the behavior of the copier function, see the comments at the
511 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
512 //
513 // Note, the device_copy_fn may choose to keep some tensors
514 // on host, e.g. by assigning to->tensor = from.tensor (assuming
515 // from.tensor is already on host); or by setting
516 //   to->tensor = Tensor(cpu_allocator(), ...)
517 // and manually updating its values.
518 //
519 // If this is the case, the CopyFns for HOST_TO_DEVICE,
520 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
521 // copies in a consistent manner.  For example, one must always
522 // manually copy any "always on host" tensors in all directions instead of e.g.
523 //   - performing a host-to-host copy in one direction,
524 //   - using the provided copier function in the reverse direction.
525 // Doing the latter will cause program failures.
526 //
527 // ****** NOTE ******
528 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
529 // ****** NOTE ******
530 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction,   \
531                                                              device_copy_fn) \
532   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER(          \
533       __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
534 
535 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
536     ctr, T, direction, type_index, device_copy_fn)                        \
537   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
538       ctr, T, direction, type_index, device_copy_fn)
539 
540 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
541     ctr, T, direction, type_index, device_copy_fn)                 \
542   static variant_op_registry_fn_registration::                     \
543       UnaryVariantDeviceCopyRegistration<T>                        \
544           register_unary_variant_op_device_copy_fn_##ctr(          \
545               direction, type_index, device_copy_fn)
546 
547 // Register a unary unary_op variant function with the signature:
548 //    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
549 // to Variants having TypeIndex type_index, for device string device,
550 // for UnaryVariantOp enum op.
551 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T,     \
552                                                  unary_op_function) \
553   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(             \
554       __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
555 
556 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(       \
557     ctr, op, device, T, type_index, unary_op_function)              \
558   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
559                                                 type_index, unary_op_function)
560 
561 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                       \
562     ctr, op, device, T, type_index, unary_op_function)                       \
563   static ::tensorflow::variant_op_registry_fn_registration::                 \
564       UnaryVariantUnaryOpRegistration<T>                                     \
565           register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
566                                                      unary_op_function)
567 
568 // Register a binary_op variant function with the signature:
569 //    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
570 // to Variants having TypeIndex type_index, for device string device,
571 // for BinaryVariantOp enum OP.
572 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T,      \
573                                                   binary_op_function) \
574   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(              \
575       __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
576 
577 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
578     ctr, op, device, T, type_index, binary_op_function)        \
579   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
580       ctr, op, device, T, type_index, binary_op_function)
581 
582 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                      \
583     ctr, op, device, T, type_index, binary_op_function)                      \
584   static ::tensorflow::variant_op_registry_fn_registration::                 \
585       UnaryVariantBinaryOpRegistration<T>                                    \
586           register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
587                                                      binary_op_function)
588 
589 }  // end namespace tensorflow
590 
591 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
592