xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_pooling_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/numeric_op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/kernels/ops_util.h"
27 #include "tensorflow/core/kernels/pooling_ops_common.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/util/padding.h"
32 #include "tensorflow/core/util/tensor_format.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 
38 template <typename Device, typename T>
39 class QuantizedAvgPoolingOp : public OpKernel {
40  public:
QuantizedAvgPoolingOp(OpKernelConstruction * context)41   explicit QuantizedAvgPoolingOp(OpKernelConstruction* context)
42       : OpKernel(context) {
43     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
44     OP_REQUIRES(context, ksize_.size() == 4,
45                 errors::InvalidArgument("Sliding window ksize field must "
46                                         "specify 4 dimensions"));
47     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
48     OP_REQUIRES(context, stride_.size() == 4,
49                 errors::InvalidArgument("Sliding window strides field must "
50                                         "specify 4 dimensions"));
51     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
52     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
53                 errors::Unimplemented(
54                     "Pooling is not yet supported on the batch dimension."));
55   }
56 
Compute(OpKernelContext * context)57   void Compute(OpKernelContext* context) override {
58     const Tensor& tensor_in = context->input(0);
59     PoolParameters params{context,
60                           ksize_,
61                           stride_,
62                           padding_,
63                           /*explicit_paddings=*/{},
64                           FORMAT_NHWC,
65                           tensor_in.shape()};
66     if (!context->status().ok()) {
67       return;
68     }
69 
70     const Tensor& min_input_tensor = context->input(1);
71     const Tensor& max_input_tensor = context->input(2);
72     OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
73                 errors::InvalidArgument(
74                     "min_input shape must be rank 0 but is rank ",
75                     min_input_tensor.dims(),
76                     ", received shape: ", min_input_tensor.shape()));
77     OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
78                 errors::InvalidArgument(
79                     "max_input shape must be rank 0 but is rank ",
80                     max_input_tensor.dims(),
81                     ", received shape: ", max_input_tensor.shape()));
82     const float min_input = context->input(1).scalar<float>()();
83     const float max_input = context->input(2).scalar<float>()();
84 
85     OP_REQUIRES(context, params.depth_window == 1,
86                 errors::Unimplemented("Non-spatial pooling is not "
87                                       "yet supported. Volunteers? :)"));
88 
89     OP_REQUIRES(context, tensor_in.dims() == 4,
90                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
91 
92     Tensor* output = nullptr;
93     OP_REQUIRES_OK(context, context->allocate_output(
94                                 0, params.forward_output_shape(), &output));
95     const int32_t highest = static_cast<int32>(Eigen::NumTraits<T>::highest());
96     const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T>::lowest());
97 
98     // TODO(vrv): Switch this to the Eigen::Tensor version of
99     // SpatialAvgPooling once that version is running quickly.
100     Tensor int32_output(DT_INT32, params.forward_output_shape());
101     // Cast input to int32 tensor and call SpatialAvgPool.
102     Tensor int32_input(DT_INT32, tensor_in.shape());
103     int32_input.flat<int32>() = tensor_in.flat<T>().template cast<int32>();
104     SpatialAvgPool<Device, int32>(context, &int32_output, int32_input, params,
105                                   padding_);
106 
107     // Clamp the int32 output back into quantized space.
108     output->flat<T>() = int32_output.flat<int32>()
109                             .cwiseMax(lowest)
110                             .cwiseMin(highest)
111                             .template cast<T>();
112 
113     Tensor* output_min = nullptr;
114     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
115     output_min->flat<float>()(0) = min_input;
116     Tensor* output_max = nullptr;
117     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
118     output_max->flat<float>()(0) = max_input;
119   }
120 
121  private:
122   std::vector<int32> ksize_;
123   std::vector<int32> stride_;
124   Padding padding_;
125 };
126 
127 template <typename Device, typename T>
128 class QuantizedMaxPoolingOp : public MaxPoolingOp<Device, T> {
129  public:
QuantizedMaxPoolingOp(OpKernelConstruction * context)130   explicit QuantizedMaxPoolingOp(OpKernelConstruction* context)
131       : MaxPoolingOp<Device, T>(context) {}
132 
Compute(OpKernelContext * context)133   void Compute(OpKernelContext* context) override {
134     const Tensor& min_input_tensor = context->input(1);
135     const Tensor& max_input_tensor = context->input(2);
136     OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_input_tensor.shape()),
137                 errors::InvalidArgument(
138                     "min_input shape must be rank 0 but is rank ",
139                     min_input_tensor.dims(),
140                     ", received shape: ", min_input_tensor.shape()));
141     OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_input_tensor.shape()),
142                 errors::InvalidArgument(
143                     "max_input shape must be rank 0 but is rank ",
144                     max_input_tensor.dims(),
145                     ", received shape: ", max_input_tensor.shape()));
146     const float min_input = context->input(1).scalar<float>()();
147     const float max_input = context->input(2).scalar<float>()();
148     MaxPoolingOp<Device, T>::Compute(context);
149     Tensor* output_min = nullptr;
150     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
151     output_min->flat<float>()(0) = min_input;
152     Tensor* output_max = nullptr;
153     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
154     output_max->flat<float>()(0) = max_input;
155   }
156 };
157 
158 REGISTER_KERNEL_BUILDER(
159     Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
160     QuantizedAvgPoolingOp<CPUDevice, quint8>);
161 
162 REGISTER_KERNEL_BUILDER(
163     Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
164     QuantizedMaxPoolingOp<CPUDevice, quint8>);
165 
166 #ifdef INTEL_MKL
167 REGISTER_KERNEL_BUILDER(
168     Name("QuantizedAvgPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
169     QuantizedAvgPoolingOp<CPUDevice, qint8>);
170 
171 REGISTER_KERNEL_BUILDER(
172     Name("QuantizedMaxPool").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
173     QuantizedMaxPoolingOp<CPUDevice, qint8>);
174 #endif
175 
176 }  // namespace tensorflow
177