xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantize_and_dequantize_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_requires.h"
17 #define EIGEN_USE_THREADS
18 
19 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
20     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include "tensorflow/core/framework/op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/type_traits.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/kernels/quantize_and_dequantize_op.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 using CpuDevice = ::Eigen::ThreadPoolDevice;
37 using GpuDevice = ::Eigen::GpuDevice;
38 using ::tensorflow::errors::InvalidArgument;
39 
40 }  // namespace
41 
42 // Simulate quantization precision loss in a float tensor by:
43 // 1. Quantize the tensor to fixed point numbers, which should match the target
44 //    quantization method when it is used in inference.
45 // 2. Dequantize it back to floating point numbers for the following ops, most
46 //    likely matmul.
47 template <typename Device, typename T>
48 class QuantizeAndDequantizeV2Op : public OpKernel {
49  public:
QuantizeAndDequantizeV2Op(OpKernelConstruction * ctx)50   explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
51       : OpKernel(ctx) {
52     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
53     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
55     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
56                 InvalidArgument("num_bits is out of range: ", num_bits_,
57                                 " with signed_input_ ", signed_input_));
58     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
59 
60     string round_mode_string;
61     OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
62     OP_REQUIRES(
63         ctx,
64         (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
65         InvalidArgument("Round mode string must be "
66                         "'HALF_UP' or "
67                         "'HALF_TO_EVEN', is '" +
68                         round_mode_string + "'"));
69     if (round_mode_string == "HALF_UP") {
70       round_mode_ = ROUND_HALF_UP;
71     } else if (round_mode_string == "HALF_TO_EVEN") {
72       round_mode_ = ROUND_HALF_TO_EVEN;
73     }
74     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
75   }
76 
Compute(OpKernelContext * ctx)77   void Compute(OpKernelContext* ctx) override {
78     const Tensor& input = ctx->input(0);
79     OP_REQUIRES(ctx, axis_ >= -1,
80                 InvalidArgument("Axis must be at least -1. Found ", axis_));
81     OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
82                 InvalidArgument("Shape must be at least rank ", axis_ + 1,
83                                 " but is rank ", input.shape().dims()));
84     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
85     Tensor input_min_tensor;
86     Tensor input_max_tensor;
87     Tensor* output = nullptr;
88     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
89     if (range_given_) {
90       input_min_tensor = ctx->input(1);
91       input_max_tensor = ctx->input(2);
92       if (axis_ == -1) {
93         auto min_val = input_min_tensor.scalar<T>()();
94         auto max_val = input_max_tensor.scalar<T>()();
95         OP_REQUIRES(ctx, min_val <= max_val,
96                     InvalidArgument("Invalid range: input_min ", min_val,
97                                     " > input_max ", max_val));
98       } else {
99         OP_REQUIRES(
100             ctx, input_min_tensor.dim_size(0) == depth,
101             InvalidArgument("input_min_tensor has incorrect size, was ",
102                             input_min_tensor.dim_size(0), " expected ", depth,
103                             " to match dim ", axis_, " of the input ",
104                             input_min_tensor.shape()));
105         OP_REQUIRES(
106             ctx, input_max_tensor.dim_size(0) == depth,
107             InvalidArgument("input_max_tensor has incorrect size, was ",
108                             input_max_tensor.dim_size(0), " expected ", depth,
109                             " to match dim ", axis_, " of the input ",
110                             input_max_tensor.shape()));
111       }
112     } else {
113       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
114       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
115                                              range_shape, &input_min_tensor));
116       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
117                                              range_shape, &input_max_tensor));
118     }
119 
120     if (axis_ == -1) {
121       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
122       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
123         range_given_, &input_min_tensor, &input_max_tensor, round_mode_,
124         narrow_range_, output->flat<T>());
125     } else {
126       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
127       f(ctx->eigen_device<Device>(),
128         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
129         num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
130         round_mode_, narrow_range_,
131         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
132     }
133   }
134 
135  private:
136   int num_bits_;
137   int axis_;
138   QuantizerRoundMode round_mode_;
139   bool signed_input_;
140   bool range_given_;
141   bool narrow_range_;
142 };
143 
144 // Implementation of QuantizeAndDequantizeV4GradientOp.
145 // When back-propagating the error through a quantized layer, the following
146 // paper gives evidence that clipped-ReLU is better than non-clipped:
147 // "Deep Learning with Low Precision by Half-wave Gaussian Quantization"
148 // http://zpascal.net/cvpr2017/Cai_Deep_Learning_With_CVPR_2017_paper.pdf
149 template <typename Device, typename T>
150 class QuantizeAndDequantizeV4GradientOp : public OpKernel {
151  public:
QuantizeAndDequantizeV4GradientOp(OpKernelConstruction * ctx)152   explicit QuantizeAndDequantizeV4GradientOp(OpKernelConstruction* ctx)
153       : OpKernel::OpKernel(ctx) {
154     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
155   }
156 
Compute(OpKernelContext * ctx)157   void Compute(OpKernelContext* ctx) override {
158     const Tensor& gradient = ctx->input(0);
159     const Tensor& input = ctx->input(1);
160     Tensor* input_backprop = nullptr;
161     OP_REQUIRES_OK(ctx,
162                    ctx->allocate_output(0, input.shape(), &input_backprop));
163     OP_REQUIRES(ctx, axis_ >= -1,
164                 InvalidArgument("Axis must be at least -1. Found ", axis_));
165     OP_REQUIRES(ctx, (axis_ == -1 || axis_ < input.shape().dims()),
166                 InvalidArgument(
167                     "Axis should be -1 or 0 or a positive value less than ",
168                     input.shape().dims(), "but given axis value was ", axis_));
169 
170     OP_REQUIRES(ctx, input.IsSameSize(gradient),
171                 InvalidArgument("gradient and input must be the same size"));
172     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
173     const Tensor& input_min_tensor = ctx->input(2);
174     OP_REQUIRES(ctx,
175                 input_min_tensor.dims() == 0 || input_min_tensor.dims() == 1,
176                 InvalidArgument(
177                     "Input min tensor must have dimension 0 or 1. Received ",
178                     input_min_tensor.dims(), "."));
179     const Tensor& input_max_tensor = ctx->input(3);
180     OP_REQUIRES(ctx,
181                 input_max_tensor.dims() == 0 || input_max_tensor.dims() == 1,
182                 InvalidArgument(
183                     "Input max tensor must have dimension 0 or 1. Received ",
184                     input_max_tensor.dims(), "."));
185     if (axis_ != -1) {
186       OP_REQUIRES(ctx, input_min_tensor.dim_size(0) == depth,
187                   InvalidArgument("min has incorrect size, expected ", depth,
188                                   " was ", input_min_tensor.dim_size(0)));
189       OP_REQUIRES(ctx, input_max_tensor.dim_size(0) == depth,
190                   InvalidArgument("max has incorrect size, expected ", depth,
191                                   " was ", input_max_tensor.dim_size(0)));
192     }
193 
194     TensorShape min_max_shape(input_min_tensor.shape());
195     Tensor* input_min_backprop;
196     OP_REQUIRES_OK(ctx,
197                    ctx->allocate_output(1, min_max_shape, &input_min_backprop));
198 
199     Tensor* input_max_backprop;
200     OP_REQUIRES_OK(ctx,
201                    ctx->allocate_output(2, min_max_shape, &input_max_backprop));
202 
203     if (axis_ == -1) {
204       OP_REQUIRES(
205           ctx, TensorShapeUtils::IsScalar(input_min_tensor.shape()),
206           InvalidArgument("input_min must be a scalar if axis is unspecified"));
207       OP_REQUIRES(
208           ctx, TensorShapeUtils::IsScalar(input_max_tensor.shape()),
209           InvalidArgument("input_max must be a scalar if axis is unspecified"));
210       functor::QuantizeAndDequantizeOneScaleGradientFunctor<Device, T> f;
211       f(ctx->eigen_device<Device>(), gradient.template flat<T>(),
212         input.template flat<T>(), input_min_tensor.scalar<T>(),
213         input_max_tensor.scalar<T>(), input_backprop->template flat<T>(),
214         input_min_backprop->template scalar<T>(),
215         input_max_backprop->template scalar<T>());
216     } else {
217       functor::QuantizeAndDequantizePerChannelGradientFunctor<Device, T> f;
218       f(ctx->eigen_device<Device>(),
219         gradient.template flat_inner_outer_dims<T, 3>(axis_ - 1),
220         input.template flat_inner_outer_dims<T, 3>(axis_ - 1),
221         &input_min_tensor, &input_max_tensor,
222         input_backprop->template flat_inner_outer_dims<T, 3>(axis_ - 1),
223         input_min_backprop->template flat<T>(),
224         input_max_backprop->template flat<T>());
225     }
226   }
227 
228  private:
229   int axis_;
230 };
231 
232 // Simulate quantization precision loss in a float tensor by:
233 // 1. Quantize the tensor to fixed point numbers, which should match the target
234 //    quantization method when it is used in inference.
235 // 2. Dequantize it back to floating point numbers for the following ops, most
236 //    likely matmul.
237 // Almost identical to QuantizeAndDequantizeV2Op, except that num_bits is a
238 // tensor.
239 template <typename Device, typename T>
240 class QuantizeAndDequantizeV3Op : public OpKernel {
241  public:
QuantizeAndDequantizeV3Op(OpKernelConstruction * ctx)242   explicit QuantizeAndDequantizeV3Op(OpKernelConstruction* ctx)
243       : OpKernel(ctx) {
244     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
245     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
246     OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
247     OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
248   }
249 
Compute(OpKernelContext * ctx)250   void Compute(OpKernelContext* ctx) override {
251     const Tensor& input = ctx->input(0);
252     OP_REQUIRES(ctx, axis_ < input.dims(),
253                 InvalidArgument(
254                     "Axis requested is larger than input dimensions. Axis: ",
255                     axis_, " Input Dimensions: ", input.dims()));
256     const int depth = (axis_ == -1) ? 1 : input.dim_size(axis_);
257     Tensor* output = nullptr;
258     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
259 
260     // Get num_bits and validate.
261     const Tensor num_bits_tensor = ctx->input(3);
262     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_bits_tensor.shape()),
263                 InvalidArgument("Invalid shape. The `num_bits` tensor should "
264                                 "be a scalar. Got dimensions: ",
265                                 num_bits_tensor.dims()));
266 
267     const int num_bits_val = num_bits_tensor.scalar<int32>()();
268     OP_REQUIRES(ctx,
269                 num_bits_val > 0 && num_bits_val < (signed_input_ ? 62 : 63),
270                 InvalidArgument("num_bits is out of range: ", num_bits_val,
271                                 " with `signed_input_` ", signed_input_));
272 
273     Tensor input_min_tensor;
274     Tensor input_max_tensor;
275     if (range_given_) {
276       input_min_tensor = ctx->input(1);
277       input_max_tensor = ctx->input(2);
278       if (axis_ == -1) {
279         const auto min_val = input_min_tensor.scalar<T>()();
280         const auto max_val = input_max_tensor.scalar<T>()();
281         OP_REQUIRES(ctx, min_val <= max_val,
282                     InvalidArgument("Invalid range: input_min ", min_val,
283                                     " > input_max ", max_val));
284       } else {
285         OP_REQUIRES(
286             ctx, input_min_tensor.dim_size(0) == depth,
287             InvalidArgument("input_min_tensor has incorrect size, was ",
288                             input_min_tensor.dim_size(0), " expected ", depth,
289                             " to match dim ", axis_, " of the input ",
290                             input_min_tensor.shape()));
291         OP_REQUIRES(
292             ctx, input_max_tensor.dim_size(0) == depth,
293             InvalidArgument("input_max_tensor has incorrect size, was ",
294                             input_max_tensor.dim_size(0), " expected ", depth,
295                             " to match dim ", axis_, " of the input ",
296                             input_max_tensor.shape()));
297       }
298     } else {
299       auto range_shape = (axis_ == -1) ? TensorShape({}) : TensorShape({depth});
300       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
301                                              range_shape, &input_min_tensor));
302       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
303                                              range_shape, &input_max_tensor));
304     }
305 
306     if (axis_ == -1) {
307       functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
308       f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
309         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
310         ROUND_HALF_TO_EVEN, narrow_range_, output->flat<T>());
311     } else {
312       functor::QuantizeAndDequantizePerChannelFunctor<Device, T> f;
313       f(ctx->eigen_device<Device>(),
314         input.template flat_inner_outer_dims<T, 3>(axis_ - 1), signed_input_,
315         num_bits_val, range_given_, &input_min_tensor, &input_max_tensor,
316         ROUND_HALF_TO_EVEN, narrow_range_,
317         output->template flat_inner_outer_dims<T, 3>(axis_ - 1));
318     }
319   }
320 
321  private:
322   int axis_;
323   bool signed_input_;
324   bool range_given_;
325   bool narrow_range_;
326 };
327 
328 // DEPRECATED: Use QuantizeAndDequantizeV2Op.
329 template <typename Device, typename T>
330 class QuantizeAndDequantizeOp : public OpKernel {
331  public:
QuantizeAndDequantizeOp(OpKernelConstruction * ctx)332   explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
333     OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
334     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
335     OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
336                 InvalidArgument("num_bits is out of range: ", num_bits_,
337                                 " with signed_input_ ", signed_input_));
338     OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
339     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_min", &input_min_));
340     OP_REQUIRES_OK(ctx, ctx->GetAttr("input_max", &input_max_));
341     if (range_given_) {
342       OP_REQUIRES(ctx, input_min_ <= input_max_,
343                   InvalidArgument("Invalid range: input_min ", input_min_,
344                                   " > input_max ", input_max_));
345     }
346   }
347 
Compute(OpKernelContext * ctx)348   void Compute(OpKernelContext* ctx) override {
349     const Tensor& input = ctx->input(0);
350 
351     Tensor* output = nullptr;
352     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
353 
354     // One global scale.
355     Tensor input_min_tensor(DataTypeToEnum<T>::value, TensorShape());
356     Tensor input_max_tensor(DataTypeToEnum<T>::value, TensorShape());
357     // Initialize the tensors with the values in the Attrs.
358     input_min_tensor.template scalar<T>()() = static_cast<T>(input_min_);
359     input_max_tensor.template scalar<T>()() = static_cast<T>(input_max_);
360 
361     functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
362     functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
363             num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
364             ROUND_HALF_TO_EVEN, /*narrow_range=*/false, output->flat<T>());
365   }
366 
367  private:
368   bool signed_input_;
369   int num_bits_;
370   bool range_given_;
371   float input_min_;
372   float input_max_;
373 };
374 
375 // Specializations for CpuDevice.
376 
377 namespace functor {
378 template <typename T>
379 struct QuantizeAndDequantizeOneScaleFunctor<CpuDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleFunctor380   void operator()(const CpuDevice& d, typename TTypes<T>::ConstVec input,
381                   const bool signed_input, const int num_bits,
382                   const bool range_given, Tensor* input_min_tensor,
383                   Tensor* input_max_tensor, QuantizerRoundMode round_mode,
384                   bool narrow_range, typename TTypes<T>::Vec out) {
385     QuantizeAndDequantizeOneScaleImpl<CpuDevice, T>::Compute(
386         d, input, signed_input, num_bits, range_given, input_min_tensor,
387         input_max_tensor, round_mode, narrow_range, out);
388   }
389 };
390 
391 template <typename T>
392 struct QuantizeAndDequantizePerChannelFunctor<CpuDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelFunctor393   void operator()(const CpuDevice& d, typename TTypes<T, 3>::ConstTensor input,
394                   bool signed_input, int num_bits, bool range_given,
395                   Tensor* input_min_tensor, Tensor* input_max_tensor,
396                   QuantizerRoundMode round_mode, bool narrow_range,
397                   typename TTypes<T, 3>::Tensor out) {
398     QuantizeAndDequantizePerChannelImpl<CpuDevice, T>::Compute(
399         d, input, signed_input, num_bits, range_given, input_min_tensor,
400         input_max_tensor, round_mode, narrow_range, out);
401   }
402 };
403 
404 template <typename T>
405 struct QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizeOneScaleGradientFunctor406   void operator()(const CpuDevice& d, typename TTypes<T>::ConstFlat gradient,
407                   typename TTypes<T>::ConstFlat input,
408                   typename TTypes<T>::ConstScalar input_min_tensor,
409                   typename TTypes<T>::ConstScalar input_max_tensor,
410                   typename TTypes<T>::Flat input_backprop,
411                   typename TTypes<T>::Scalar input_min_backprop,
412                   typename TTypes<T>::Scalar input_max_backprop) {
413     QuantizeAndDequantizeOneScaleGradientImpl<CpuDevice, T>::Compute(
414         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
415         input_min_backprop, input_max_backprop);
416   }
417 };
418 
419 template <typename T>
420 struct QuantizeAndDequantizePerChannelGradientFunctor<CpuDevice, T> {
operator ()tensorflow::functor::QuantizeAndDequantizePerChannelGradientFunctor421   void operator()(const CpuDevice& d,
422                   typename TTypes<T, 3>::ConstTensor gradient,
423                   typename TTypes<T, 3>::ConstTensor input,
424                   const Tensor* input_min_tensor,
425                   const Tensor* input_max_tensor,
426                   typename TTypes<T, 3>::Tensor input_backprop,
427                   typename TTypes<T>::Flat input_min_backprop,
428                   typename TTypes<T>::Flat input_max_backprop) {
429     QuantizeAndDequantizePerChannelGradientImpl<CpuDevice, T>::Compute(
430         d, gradient, input, input_min_tensor, input_max_tensor, input_backprop,
431         input_min_backprop, input_max_backprop);
432   }
433 };
434 
435 template struct functor::QuantizeAndDequantizeOneScaleGradientFunctor<CpuDevice,
436                                                                       float>;
437 template struct functor::QuantizeAndDequantizePerChannelGradientFunctor<
438     CpuDevice, double>;
439 
440 }  // namespace functor
441 
442 #define REGISTER_CPU_KERNEL(T)                                                 \
443   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
444                               .Device(DEVICE_CPU)                              \
445                               .TypeConstraint<T>("T"),                         \
446                           QuantizeAndDequantizeV2Op<CpuDevice, T>);            \
447   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
448                               .Device(DEVICE_CPU)                              \
449                               .TypeConstraint<T>("T"),                         \
450                           QuantizeAndDequantizeV3Op<CpuDevice, T>);            \
451   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
452                               .Device(DEVICE_CPU)                              \
453                               .TypeConstraint<T>("T"),                         \
454                           QuantizeAndDequantizeV2Op<CpuDevice, T>);            \
455   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
456                               .Device(DEVICE_CPU)                              \
457                               .TypeConstraint<T>("T"),                         \
458                           QuantizeAndDequantizeV4GradientOp<CpuDevice, T>);    \
459   REGISTER_KERNEL_BUILDER(                                                     \
460       Name("QuantizeAndDequantize").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
461       QuantizeAndDequantizeOp<CpuDevice, T>);
462 TF_CALL_float(REGISTER_CPU_KERNEL);
463 TF_CALL_double(REGISTER_CPU_KERNEL);
464 #undef REGISTER_CPU_KERNEL
465 
466 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
467     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
468 #define REGISTER_GPU_KERNEL(T)                                                 \
469   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV2")                      \
470                               .Device(DEVICE_GPU)                              \
471                               .HostMemory("input_min")                         \
472                               .HostMemory("input_max")                         \
473                               .TypeConstraint<T>("T"),                         \
474                           QuantizeAndDequantizeV2Op<GpuDevice, T>);            \
475   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV3")                      \
476                               .Device(DEVICE_GPU)                              \
477                               .HostMemory("input_min")                         \
478                               .HostMemory("input_max")                         \
479                               .HostMemory("num_bits")                          \
480                               .TypeConstraint<T>("T"),                         \
481                           QuantizeAndDequantizeV3Op<GpuDevice, T>);            \
482   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4")                      \
483                               .Device(DEVICE_GPU)                              \
484                               .HostMemory("input_min")                         \
485                               .HostMemory("input_max")                         \
486                               .TypeConstraint<T>("T"),                         \
487                           QuantizeAndDequantizeV2Op<GpuDevice, T>);            \
488   REGISTER_KERNEL_BUILDER(Name("QuantizeAndDequantizeV4Grad")                  \
489                               .Device(DEVICE_GPU)                              \
490                               .HostMemory("input_min")                         \
491                               .HostMemory("input_max")                         \
492                               .TypeConstraint<T>("T"),                         \
493                           QuantizeAndDequantizeV4GradientOp<GpuDevice, T>);    \
494   REGISTER_KERNEL_BUILDER(                                                     \
495       Name("QuantizeAndDequantize").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
496       QuantizeAndDequantizeOp<GpuDevice, T>);
497 TF_CALL_float(REGISTER_GPU_KERNEL);
498 TF_CALL_double(REGISTER_GPU_KERNEL);
499 #undef REGISTER_GPU_KERNEL
500 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
501 }  // namespace tensorflow
502