xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cast_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
17 #define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/bfloat16.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/tensor_types.h"
23 #include "tensorflow/core/framework/types.h"
24 #include "tensorflow/core/platform/byte_order.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 // Note that the GPU cast functor templates need to be instantiated unlike the
28 // CPU ones, and hence their specializations are different than that for CPUs.
29 #ifdef SPECIALIZE_FOR_GPUS
30 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
31   template <typename Device>                                        \
32   struct CastFunctor<Device, OUT_TYPE, IN_OUT> {                    \
33     void operator()(const Device& d,                                \
34                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
35                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
36                     bool truncate = false) {                        \
37       if (truncate) {                                               \
38         out_tensor.device(d) =                                      \
39             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
40                 .template cast<OUT_TYPE>();                         \
41       } else {                                                      \
42         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
43       }                                                             \
44     }                                                               \
45   };                                                                \
46   template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
47 #else
48 #define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT)                   \
49   template <>                                                       \
50   struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> {                    \
51     void operator()(const DEVICE& d,                                \
52                     typename TTypes<OUT_TYPE>::Flat out_tensor,     \
53                     typename TTypes<IN_OUT>::ConstFlat in_tensor,   \
54                     bool truncate = false) {                        \
55       if (truncate) {                                               \
56         out_tensor.device(d) =                                      \
57             in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>())  \
58                 .template cast<OUT_TYPE>();                         \
59       } else {                                                      \
60         out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
61       }                                                             \
62     }                                                               \
63   };
64 #endif
65 
66 #define CAST_FUNCTORS(devname)                                        \
67   SPECIALIZE_CAST(devname, float, double)                             \
68   SPECIALIZE_CAST(devname, float, std::complex<double>)               \
69   SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
70   SPECIALIZE_CAST(devname, std::complex<float>, double)               \
71   SPECIALIZE_CAST(devname, Eigen::half, double)                       \
72   SPECIALIZE_CAST(devname, Eigen::half, float)                        \
73   SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>)         \
74   SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>)          \
75   SPECIALIZE_CAST(devname, bfloat16, float)                           \
76   template <typename OUT_TYPE, typename IN_OUT>                       \
77   struct CastFunctor<devname, OUT_TYPE, IN_OUT> {                     \
78     void operator()(const devname& d,                                 \
79                     typename TTypes<OUT_TYPE>::Flat out_tensor,       \
80                     typename TTypes<IN_OUT>::ConstFlat in_tensor,     \
81                     bool truncate = false) {                          \
82       out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>();     \
83     }                                                                 \
84   };
85 
86 #if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
87 // If MLIR kernels are enabled, we don't need the specialized cast from float to
88 // double or from Eigen::half to double. We still need the specialized cast from
89 // Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We
90 // still need the specialized cast from float to double because it is used in
91 // resize_bilinear_op.cc.
92 #define CAST_FUNCTORS_SUBSET(devname)                                 \
93   SPECIALIZE_CAST(devname, float, double)                             \
94   SPECIALIZE_CAST(devname, Eigen::half, float)                        \
95   SPECIALIZE_CAST(devname, bfloat16, float)                           \
96   template <typename OUT_TYPE, typename IN_OUT>                       \
97   struct CastFunctor<devname, OUT_TYPE, IN_OUT> {                     \
98     void operator()(const devname& d,                                 \
99                     typename TTypes<OUT_TYPE>::Flat out_tensor,       \
100                     typename TTypes<IN_OUT>::ConstFlat in_tensor,     \
101                     bool truncate = false) {                          \
102       out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>();     \
103     }                                                                 \
104   };
105 #endif
106 
107 namespace tensorflow {
108 
109 typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
110                            bool trunc)>
111     CastFunctorType;
112 
113 // Common base class of Cast kernels
114 class CastOpBase : public OpKernel {
115  public:
116   explicit CastOpBase(OpKernelConstruction* ctx);
117 
118   void Compute(OpKernelContext* ctx) override;
119 
120  protected:
121   DataType src_dtype_;
122   DataType dst_dtype_;
123   DataType external_src_dtype_;
124   DataType external_dst_dtype_;
125   bool use_truncation_;
126   CastFunctorType work_ = nullptr;
127   Status Unimplemented();
128 
129   TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
130 };
131 
132 // CPU implementation of Cast
133 class CpuCastOp : public CastOpBase {
134  public:
135   explicit CpuCastOp(OpKernelConstruction* ctx);
136 
137  private:
138   Status Prepare();
139 };
140 
141 namespace functor {
142 
143 template <typename I>
MantissaWidth()144 constexpr int MantissaWidth() {
145   return std::numeric_limits<I>::digits;
146 }
147 
148 template <>
149 constexpr int MantissaWidth<Eigen::half>() {
150   // Remember, there's 1 hidden bit
151   return 10 + 1;
152 }
153 
154 template <>
155 constexpr int MantissaWidth<bfloat16>() {
156   // Remember, there's 1 hidden bit
157   return 7 + 1;
158 }
159 
160 template <typename Device, typename Tout, typename Tin>
Cast(const Device & d,typename TTypes<Tout>::Flat o,typename TTypes<Tin>::ConstFlat i)161 void Cast(const Device& d, typename TTypes<Tout>::Flat o,
162           typename TTypes<Tin>::ConstFlat i) {
163   o.device(d) = i.template cast<Tout>();
164 }
165 
166 template <typename Device, typename Tout, typename Tin>
167 struct CastFunctor {
168   void operator()(const Device& d, typename TTypes<Tout>::Flat o,
169                   typename TTypes<Tin>::ConstFlat i, bool truncate = false);
170 };
171 
172 // Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
173 // Specialize for others if needed in future.
174 template <typename I>
175 typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)176     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
177   // Only zero the bits for non-NaNs.
178   // For NaNs, let the non-truncation version handle it.
179   if (!std::isnan(t)) {
180     uint64_t* p = reinterpret_cast<uint64_t*>(&t);
181     *p &= (0xFFFFFFFFFFFFFFFF << n);
182   }
183 }
184 
185 template <typename I>
186 typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
LSBZeroSetterHelper(I & t,int n)187     EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
188   // Only zero the bits for non-NaNs.
189   // For NaNs, let the non-truncation version handle it.
190   if (!std::isnan(t)) {
191     uint32_t* p = reinterpret_cast<uint32_t*>(&t);
192     *p &= (0xFFFFFFFF << n);
193   }
194 }
195 
196 // Set n least significant bits to 0
197 template <typename I, typename O>
198 struct LSBZeroSetter {
operatorLSBZeroSetter199   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
200     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
201     static_assert(
202         bits > 0,
203         "The output type must have fewer mantissa bits than the input type\n");
204     I t = a;
205     LSBZeroSetterHelper(t, bits);
206     return t;
207   }
208 };
209 
210 template <typename I, typename O>
211 struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
212   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
213       const std::complex<I>& a) const {
214     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
215     static_assert(
216         bits > 0,
217         "The output type must have fewer mantissa bits than the input type\n");
218     I re = std::real(a);
219     I img = std::imag(a);
220     LSBZeroSetterHelper(re, bits);
221     LSBZeroSetterHelper(img, bits);
222     std::complex<I> toReturn(re, img);
223     return toReturn;
224   }
225 };
226 
227 template <typename I, typename O>
228 struct LSBZeroSetter<std::complex<I>, O> {
229   // Sets the 16 LSBits of the float to 0
230   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
231       const std::complex<I>& a) const {
232     constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
233     static_assert(
234         bits > 0,
235         "The output type must have fewer mantissa bits than the input type\n");
236     I re = std::real(a);
237     I img = std::imag(a);
238     LSBZeroSetterHelper(re, bits);
239     LSBZeroSetterHelper(img, bits);
240     std::complex<I> toReturn(re, img);
241     return toReturn;
242   }
243 };
244 
245 }  // end namespace functor
246 }  // end namespace tensorflow
247 
248 namespace Eigen {
249 namespace internal {
250 
251 // Eigen can't convert to/from complex numbers, because it is limited to cases
252 // that can be static_casted. But numpy is able to cast to/from complex, which
253 // we want to replicate. So we add specializations for complex here.
254 template <typename From, typename To>
255 struct scalar_cast_op<std::complex<From>, To> {
256   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To
257   operator()(const std::complex<From>& a) const {
258     // Replicate numpy behavior of returning just the real part
259     return static_cast<To>(a.real());
260   }
261 };
262 
263 template <typename From, typename To>
264 struct scalar_cast_op<From, std::complex<To>> {
265   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
266       const From& a) const {
267     // Replicate numpy behavior of setting the imaginary part to 0
268     return std::complex<To>(static_cast<To>(a), To(0));
269   }
270 };
271 
272 template <typename From, typename To>
273 struct scalar_cast_op<std::complex<From>, std::complex<To>> {
274   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex<To> operator()(
275       const std::complex<From>& a) const {
276     return std::complex<To>(static_cast<To>(a.real()),
277                             static_cast<To>(a.imag()));
278   }
279 };
280 
281 template <typename From, typename To>
282 struct functor_traits_complex_impl {
283   enum { Cost = NumTraits<To>::AddCost, PacketAccess = false };
284 };
285 
286 template <typename From, typename To>
287 struct functor_traits<scalar_cast_op<std::complex<From>, To>>
288     : functor_traits_complex_impl<std::complex<From>, To> {};
289 template <typename From, typename To>
290 struct functor_traits<scalar_cast_op<From, std::complex<To>>>
291     : functor_traits_complex_impl<From, std::complex<To>> {};
292 // Needed to avoid ambiguous partial specialization
293 template <typename From, typename To>
294 struct functor_traits<scalar_cast_op<std::complex<From>, std::complex<To>>>
295     : functor_traits_complex_impl<std::complex<From>, std::complex<To>> {};
296 
297 }  // namespace internal
298 }  // namespace Eigen
299 
300 #endif  // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
301