1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 18 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/kernels/cast_op.h" 23 24 namespace tensorflow { 25 26 namespace functor { 27 28 CAST_FUNCTORS(Eigen::ThreadPoolDevice); 29 30 31 } // namespace functor 32 33 #define CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 34 FN(arg0, arg1, bool); \ 35 FN(arg0, arg1, uint8); \ 36 FN(arg0, arg1, uint16); \ 37 FN(arg0, arg1, uint32); \ 38 FN(arg0, arg1, uint64); \ 39 FN(arg0, arg1, int8); \ 40 FN(arg0, arg1, int16); \ 41 FN(arg0, arg1, int32); \ 42 FN(arg0, arg1, int64_t); \ 43 FN(arg0, arg1, float); \ 44 FN(arg0, arg1, double); \ 45 FN(arg0, arg1, std::complex<float>); \ 46 FN(arg0, arg1, std::complex<double>) 47 48 #define CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 49 CURRY_TYPES3_NO_HALF(FN, arg0, arg1) \ 50 FN(arg0, arg1, Eigen::half); 51 52 #define CURRY_TYPES3(FN, arg0, arg1) \ 53 CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \ 54 FN(arg0, arg1, bfloat16); 55 56 #define CAST_CASE(DEVICE, IN, OUT) \ 57 if (DataTypeToEnum<OUT>::value == dst_dtype) { \ 58 return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ 59 bool truncate) { \ 60 functor::CastFunctor<DEVICE, OUT, IN> func; \ 61 func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \ 62 truncate); \ 63 }; \ 64 } 65 66 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) 67 68 // The subset of types which are currently not supported yet with the MLIR 69 // generated kernels. 70 #define CURRY_SUBSET_TYPES3(FN, arg0, arg1) \ 71 FN(arg0, arg1, std::complex<float>); \ 72 FN(arg0, arg1, std::complex<double>) 73 74 #endif 75 76 // The functions below are implemented in the cast_op_impl_*.cc files. 77 CastFunctorType GetCpuCastFromBool(DataType dst_dtype); 78 79 CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); 80 81 CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); 82 83 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 84 85 CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); 86 87 CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); 88 89 CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); 90 91 CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); 92 93 CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); 94 95 CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); 96 97 CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); 98 99 CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); 100 101 CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); 102 103 CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); 104 105 CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); 106 107 CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); 108 109 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ 110 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) 111 // Same, for GPU. 112 CastFunctorType GetGpuCastFromBool(DataType dst_dtype); 113 114 CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); 115 116 CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); 117 118 CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); 119 120 CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); 121 122 CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); 123 124 CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); 125 126 CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); 127 128 CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); 129 130 CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); 131 132 CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); 133 134 CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); 135 136 CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); 137 138 CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); 139 140 CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); 141 142 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 143 144 145 } // namespace tensorflow 146 147 #endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ 148