xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cast_op_impl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
18 
19 #define EIGEN_USE_THREADS
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/kernels/cast_op.h"
23 
24 namespace tensorflow {
25 
26 namespace functor {
27 
28 CAST_FUNCTORS(Eigen::ThreadPoolDevice);
29 
30 
31 }  // namespace functor
32 
33 #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \
34   FN(arg0, arg1, bool);                      \
35   FN(arg0, arg1, uint8);                     \
36   FN(arg0, arg1, uint16);                    \
37   FN(arg0, arg1, uint32);                    \
38   FN(arg0, arg1, uint64);                    \
39   FN(arg0, arg1, int8);                      \
40   FN(arg0, arg1, int16);                     \
41   FN(arg0, arg1, int32);                     \
42   FN(arg0, arg1, int64_t);                   \
43   FN(arg0, arg1, float);                     \
44   FN(arg0, arg1, double);                    \
45   FN(arg0, arg1, std::complex<float>);       \
46   FN(arg0, arg1, std::complex<double>)
47 
48 #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
49   CURRY_TYPES3_NO_HALF(FN, arg0, arg1)       \
50   FN(arg0, arg1, Eigen::half);
51 
52 #define CURRY_TYPES3(FN, arg0, arg1)   \
53   CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
54   FN(arg0, arg1, bfloat16);
55 
56 #define CAST_CASE(DEVICE, IN, OUT)                                        \
57   if (DataTypeToEnum<OUT>::value == dst_dtype) {                          \
58     return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out,       \
59               bool truncate) {                                            \
60       functor::CastFunctor<DEVICE, OUT, IN> func;                         \
61       func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
62            truncate);                                                     \
63     };                                                                    \
64   }
65 
66 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
67 
68 // The subset of types which are currently not supported yet with the MLIR
69 // generated kernels.
70 #define CURRY_SUBSET_TYPES3(FN, arg0, arg1) \
71   FN(arg0, arg1, std::complex<float>);      \
72   FN(arg0, arg1, std::complex<double>)
73 
74 #endif
75 
76 // The functions below are implemented in the cast_op_impl_*.cc files.
77 CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
78 
79 CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
80 
81 CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
82 
83 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
84 
85 CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
86 
87 CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
88 
89 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
90 
91 CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
92 
93 CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
94 
95 CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
96 
97 CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
98 
99 CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
100 
101 CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
102 
103 CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
104 
105 CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
106 
107 CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
108 
109 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
110     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
111 // Same, for GPU.
112 CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
113 
114 CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
115 
116 CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
117 
118 CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
119 
120 CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
121 
122 CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
123 
124 CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
125 
126 CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
127 
128 CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
129 
130 CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
131 
132 CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
133 
134 CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
135 
136 CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
137 
138 CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
139 
140 CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
141 
142 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
143 
144 
145 }  // namespace tensorflow
146 
147 #endif  // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
148