xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cast_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/cast_op.h"
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 #include "tensorflow/core/kernels/cast_op_impl.h"
32 
33 namespace tensorflow {
34 
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 
38 #define CURRY_TYPES2(FN, arg0)   \
39   FN(arg0, bool);                \
40   FN(arg0, uint8);               \
41   FN(arg0, uint16);              \
42   FN(arg0, uint32);              \
43   FN(arg0, uint64);              \
44   FN(arg0, int8);                \
45   FN(arg0, int16);               \
46   FN(arg0, int32);               \
47   FN(arg0, int64_t);             \
48   FN(arg0, Eigen::half);         \
49   FN(arg0, float);               \
50   FN(arg0, double);              \
51   FN(arg0, std::complex<float>); \
52   FN(arg0, std::complex<double>)
53 
CastOpBase(OpKernelConstruction * ctx)54 CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
55   OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &external_src_dtype_));
56 
57   OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
58 
59   OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
60 
61   // Quantized data types use the same underlying format as their non quantized
62   // version so we use the non quantized implementation for casting.
63   if (external_dst_dtype_ == DT_QUINT8) {
64     dst_dtype_ = DT_UINT8;
65   } else if (external_dst_dtype_ == DT_QINT8) {
66     dst_dtype_ = DT_INT8;
67   } else if (external_dst_dtype_ == DT_QINT32) {
68     dst_dtype_ = DT_INT32;
69   } else if (external_dst_dtype_ == DT_QINT16) {
70     dst_dtype_ = DT_INT16;
71   } else if (external_dst_dtype_ == DT_QUINT16) {
72     dst_dtype_ = DT_UINT16;
73   } else {
74     dst_dtype_ = external_dst_dtype_;
75   }
76 
77   if (external_src_dtype_ == DT_QUINT8) {
78     src_dtype_ = DT_UINT8;
79   } else if (external_src_dtype_ == DT_QINT8) {
80     src_dtype_ = DT_INT8;
81   } else if (external_src_dtype_ == DT_QINT32) {
82     src_dtype_ = DT_INT32;
83   } else if (external_src_dtype_ == DT_QINT16) {
84     src_dtype_ = DT_INT16;
85   } else if (external_src_dtype_ == DT_QUINT16) {
86     src_dtype_ = DT_UINT16;
87   } else {
88     src_dtype_ = external_src_dtype_;
89   }
90 }
91 
Compute(OpKernelContext * ctx)92 void CastOpBase::Compute(OpKernelContext* ctx) {
93   const Tensor& inp = ctx->input(0);
94   if (work_ == nullptr) {
95     ctx->set_output(0, inp);
96   } else if (external_src_dtype_ != src_dtype_ ||
97              external_dst_dtype_ != dst_dtype_) {
98     Tensor in;
99     // If the type is a quantized type we need to do a bitcast since the
100     // src_dtype_ is different from external_src_type_.
101     OP_REQUIRES_OK(ctx, in.BitcastFrom(inp, src_dtype_, inp.shape()));
102     Tensor* out = nullptr;
103     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
104     out->set_dtype(dst_dtype_);
105     work_(ctx, in, out, use_truncation_);
106     out->set_dtype(external_dst_dtype_);
107   } else {
108     Tensor* out = nullptr;
109     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
110     work_(ctx, inp, out, use_truncation_);
111   }
112 }
113 
Unimplemented()114 Status CastOpBase::Unimplemented() {
115   return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_),
116                                " to ", DataTypeString(external_dst_dtype_),
117                                " is not supported");
118 }
119 
CpuCastOp(OpKernelConstruction * ctx)120 CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
121   OP_REQUIRES_OK(ctx, Prepare());
122 }
123 
Prepare()124 Status CpuCastOp::Prepare() {
125   if (external_src_dtype_ == external_dst_dtype_) {
126     work_ = nullptr;  // Identity
127     return OkStatus();
128   }
129   if (src_dtype_ == DT_BOOL) {
130     work_ = GetCpuCastFromBool(dst_dtype_);
131   } else if (src_dtype_ == DT_UINT8) {
132     work_ = GetCpuCastFromUint8(dst_dtype_);
133   } else if (src_dtype_ == DT_UINT16) {
134     work_ = GetCpuCastFromUint16(dst_dtype_);
135   } else if (src_dtype_ == DT_UINT32) {
136     work_ = GetCpuCastFromUint32(dst_dtype_);
137   } else if (src_dtype_ == DT_UINT64) {
138     work_ = GetCpuCastFromUint64(dst_dtype_);
139   } else if (src_dtype_ == DT_INT8) {
140     work_ = GetCpuCastFromInt8(dst_dtype_);
141   } else if (src_dtype_ == DT_INT16) {
142     work_ = GetCpuCastFromInt16(dst_dtype_);
143   } else if (src_dtype_ == DT_INT32) {
144     work_ = GetCpuCastFromInt32(dst_dtype_);
145   } else if (src_dtype_ == DT_INT64) {
146     work_ = GetCpuCastFromInt64(dst_dtype_);
147   } else if (src_dtype_ == DT_HALF) {
148     work_ = GetCpuCastFromHalf(dst_dtype_);
149   } else if (src_dtype_ == DT_FLOAT) {
150     work_ = GetCpuCastFromFloat(dst_dtype_);
151   } else if (src_dtype_ == DT_DOUBLE) {
152     work_ = GetCpuCastFromDouble(dst_dtype_);
153   } else if (src_dtype_ == DT_COMPLEX64) {
154     work_ = GetCpuCastFromComplex64(dst_dtype_);
155   } else if (src_dtype_ == DT_COMPLEX128) {
156     work_ = GetCpuCastFromComplex128(dst_dtype_);
157   } else if (src_dtype_ == DT_BFLOAT16) {
158     work_ = GetCpuCastFromBfloat(dst_dtype_);
159   }
160 
161   // TODO(sesse): If CPU casting to or from Eigen::half ever becomes a
162   // bottleneck, we could probably implement specialized support for
163   // vectorized versions (not the least based on F16C for Haswell
164   // or newer).
165 
166   return work_ == nullptr ? Unimplemented() : OkStatus();
167 }
168 
169 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
170     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
171 class GpuCastOp : public CastOpBase {
172  public:
GpuCastOp(OpKernelConstruction * ctx)173   explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
174     OP_REQUIRES_OK(ctx, Prepare());
175   }
176 
177  private:
Prepare()178   Status Prepare() {
179     if (external_src_dtype_ == external_dst_dtype_) {
180       work_ = nullptr;  // Identity
181       return OkStatus();
182     }
183     if (src_dtype_ == DT_BOOL) {
184       work_ = GetGpuCastFromBool(dst_dtype_);
185     } else if (src_dtype_ == DT_UINT8) {
186       work_ = GetGpuCastFromUint8(dst_dtype_);
187     } else if (src_dtype_ == DT_UINT16) {
188       work_ = GetGpuCastFromUint16(dst_dtype_);
189     } else if (src_dtype_ == DT_UINT32) {
190       work_ = GetGpuCastFromUint32(dst_dtype_);
191     } else if (src_dtype_ == DT_UINT64) {
192       work_ = GetGpuCastFromUint64(dst_dtype_);
193     } else if (src_dtype_ == DT_INT8) {
194       work_ = GetGpuCastFromInt8(dst_dtype_);
195     } else if (src_dtype_ == DT_INT16) {
196       work_ = GetGpuCastFromInt16(dst_dtype_);
197     } else if (src_dtype_ == DT_INT32) {
198       work_ = GetGpuCastFromInt32(dst_dtype_);
199     } else if (src_dtype_ == DT_INT64) {
200       work_ = GetGpuCastFromInt64(dst_dtype_);
201     } else if (src_dtype_ == DT_HALF) {
202       work_ = GetGpuCastFromHalf(dst_dtype_);
203     } else if (src_dtype_ == DT_FLOAT) {
204       work_ = GetGpuCastFromFloat(dst_dtype_);
205     } else if (src_dtype_ == DT_DOUBLE) {
206       work_ = GetGpuCastFromDouble(dst_dtype_);
207     } else if (src_dtype_ == DT_COMPLEX64) {
208       work_ = GetGpuCastFromComplex64(dst_dtype_);
209     } else if (src_dtype_ == DT_COMPLEX128) {
210       work_ = GetGpuCastFromComplex128(dst_dtype_);
211     } else if (src_dtype_ == DT_BFLOAT16) {
212       work_ = GetGpuCastFromBfloat(dst_dtype_);
213     }
214 
215     return work_ == nullptr ? Unimplemented() : OkStatus();
216   }
217 };
218 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
219 
220 #undef CAST_CASE
221 
222 REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
223 
224 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
225     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
226 #define REGISTER_CAST_GPU(srctype, dsttype)                    \
227   REGISTER_KERNEL_BUILDER(Name("Cast")                         \
228                               .TypeConstraint<srctype>("SrcT") \
229                               .TypeConstraint<dsttype>("DstT") \
230                               .Device(DEVICE_GPU),             \
231                           GpuCastOp)
232 
233 #if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED)
234 CURRY_TYPES2(REGISTER_CAST_GPU, bool);
235 CURRY_TYPES2(REGISTER_CAST_GPU, int8);
236 CURRY_TYPES2(REGISTER_CAST_GPU, int16);
237 CURRY_TYPES2(REGISTER_CAST_GPU, int32);
238 CURRY_TYPES2(REGISTER_CAST_GPU, int64);
239 CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
240 CURRY_TYPES2(REGISTER_CAST_GPU, uint16);
241 CURRY_TYPES2(REGISTER_CAST_GPU, uint32);
242 CURRY_TYPES2(REGISTER_CAST_GPU, uint64);
243 CURRY_TYPES2(REGISTER_CAST_GPU, Eigen::half);
244 CURRY_TYPES2(REGISTER_CAST_GPU, float);
245 CURRY_TYPES2(REGISTER_CAST_GPU, double);
246 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<float>);
247 CURRY_TYPES2(REGISTER_CAST_GPU, std::complex<double>);
248 #endif
249 
250 REGISTER_CAST_GPU(float, bfloat16);
251 REGISTER_CAST_GPU(bfloat16, float);
252 
253 #undef REGISTER_CAST_GPU
254 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
255 
256 
257 #undef CURRY_TYPES2
258 
259 // HostCast differs from Cast in that its input and output are in host memory.
260 REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
261 REGISTER_KERNEL_BUILDER(
262     Name("_HostCast").Device(DEVICE_DEFAULT).HostMemory("x").HostMemory("y"),
263     CpuCastOp);
264 }  // end namespace tensorflow
265