xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_conv_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements quantized eight-bit versions of the convolution operations.
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/core/platform/errors.h"
22 
23 #define EIGEN_USE_THREADS
24 
25 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
26 #include "public/gemmlowp.h"
27 #include "tensorflow/core/framework/kernel_shape_util.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/kernels/conv_ops.h"
31 #include "tensorflow/core/kernels/meta_support.h"
32 #include "tensorflow/core/kernels/quantization_utils.h"
33 #include "tensorflow/core/kernels/reference_gemm.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/util/padding.h"
37 
38 namespace tensorflow {
39 
40 // This functor implements the convolution operation in as simple a form as
41 // possible. It won't give great performance, but it is very useful for
42 // stepping through and instrumenting for debugging, creating minimal benchmarks
43 // to prototype with, and sharing with teams that want to run this outside of
44 // our environment.
45 // With that in mind, I've avoided using anything except pretty standard C++
46 // types. This is especially noticeable in the data access through raw array
47 // indexing. It's deliberate in this case though, since it makes the underlying
48 // memory order very explicit, which is important for both inspecting memory
49 // contents during debugging and for specifying what we expect to others.
50 // The memory layout of the data is, from biggest stride to smallest:
51 // input_data = [input_batches, input_height, input_width, input_depth]
52 // filter_data = [filter_height, filter_width, input_depth, filter_count]
53 // output_data = [input_batches, output_height, output_width, filter_count]
54 template <class T1, class T2, class T3>
55 class ReferenceConvFunctor {
56  public:
operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,int input_offset,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int filter_offset,int stride,Padding padding,T3 * output_data,int output_height,int output_width,int output_shift,int output_offset,int output_mult)57   void operator()(OpKernelContext* context, const T1* input_data,
58                   int input_batches, int input_height, int input_width,
59                   int input_depth, int input_offset, const T2* filter_data,
60                   int filter_height, int filter_width, int filter_count,
61                   int filter_offset, int stride, Padding padding,
62                   T3* output_data, int output_height, int output_width,
63                   int output_shift, int output_offset, int output_mult) {
64     // Set up some constants we need for the output down-shifting and
65     // saturation.
66     const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
67     const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
68 
69     // When we're converting the 32 bit accumulator to a lower bit depth, we
70     // need to add on 0.5 in fixed-point terms to make the operation round half
71     // up towards positive infinity, rather than a floor.
72     // We also need to watch out for the case when there's no down shift,
73     // because a left shift by a negative number gives undefined results.
74     const int32_t rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1));
75 
76     // The two different padding modes we support can be a bit confusing. SAME
77     // means we're trying to produce an output image that's the same size as the
78     // input. It's complicated by stride, which shrinks the output image by a
79     // a factor, but it means we end up sampling from outside the borders of the
80     // input. These out-of-bounds values are read as zeroes. VALID means only
81     // produce output values where the filters can read all their values from
82     // within the input image. It effectively removes the margins of the output
83     // image compared to the one produced by SAME. Stride complicates this
84     // definition though, because it can result in the right and bottom filter
85     // patches sampling from outside the borders if it's greater than 1.
86     // Most of the logic for sorting this all out is done before this function,
87     // when we calculate the output size, but the positioning of the origin of
88     // the filters is different between the two modes, since SAME positions the
89     // first filter off the edge of the input.
90     int filter_left_offset;
91     int filter_top_offset;
92     if (padding == VALID) {
93       filter_left_offset =
94           ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
95       filter_top_offset =
96           ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
97     } else {
98       filter_left_offset =
99           ((output_width - 1) * stride + filter_width - input_width) / 2;
100       filter_top_offset =
101           ((output_height - 1) * stride + filter_height - input_height) / 2;
102     }
103 
104     // If we've got multiple images in our input, work through each of them.
105     for (int batch = 0; batch < input_batches; ++batch) {
106       // Walk through all the output image values, sliding the filter to
107       // different
108       // positions in the input.
109       for (int out_y = 0; out_y < output_height; ++out_y) {
110         for (int out_x = 0; out_x < output_width; ++out_x) {
111           // Each filter kernel produces one output channel.
112           for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
113             // We're going to calculate a single output value, which means we
114             // need to multiply a three dimensional kernel of weights against
115             // the current location within the input image.
116             /*
117               *-------------------------------...
118               |\ ^
119               | \in_depth
120               |  \ v
121               |   *-------------------------------...
122               |   |            ^
123               |   |       in_y_origin
124               |   |            v   \
125               |   |<in_x_origin>*---*^
126               |   |            \|   |filter_height
127               .   |             *---*v
128               .   |             <--->
129                   .         filter_width
130                   .
131             */
132             const int in_x_origin = (out_x * stride) - filter_left_offset;
133             const int in_y_origin = (out_y * stride) - filter_top_offset;
134             int32_t total = 0;
135             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
136               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
137                 for (int in_channel = 0; in_channel < input_depth;
138                      ++in_channel) {
139                   const int in_x = in_x_origin + filter_x;
140                   const int in_y = in_y_origin + filter_y;
141                   int32_t input_value;
142                   // If the location is outside the bounds of the input image,
143                   // use zero as a default value.
144                   if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
145                       (in_y < input_height)) {
146                     const T1 input_source_value =
147                         input_data[(batch * input_height * input_width *
148                                     input_depth) +
149                                    (in_y * input_width * input_depth) +
150                                    (in_x * input_depth) + in_channel];
151                     // We're promoting the T1 type to a higher bit depth here as
152                     // we do the subtraction.
153                     input_value =
154                         static_cast<int32>(input_source_value) - input_offset;
155                   } else {
156                     input_value = 0;
157                   }
158                   const T2 filter_source_value =
159                       filter_data[(filter_y * filter_width * input_depth *
160                                    filter_count) +
161                                   (filter_x * input_depth * filter_count) +
162                                   (in_channel * filter_count) + out_channel];
163                   // Another promotion to 32 bit, as above.
164                   const int32_t filter_value =
165                       static_cast<int32>(filter_source_value) - filter_offset;
166                   total += (input_value * filter_value);
167                 }
168               }
169             }
170             // Here we're applying scale factors to compress the 32 bit
171             // accumulated total to a potentially lower bit depth.
172             const int32_t output =
173                 ((((total + output_offset) * output_mult) + rounding) >>
174                  output_shift);
175             // We need to saturate the results against the largest and smallest
176             // values that can be represented in this type.
177             const int32_t top_clamped_output = std::min(output, highest);
178             const int32_t clamped_output = std::max(top_clamped_output, lowest);
179             output_data[(batch * output_height * output_width * filter_count) +
180                         (out_y * output_width * filter_count) +
181                         (out_x * filter_count) + out_channel] = clamped_output;
182           }
183         }
184       }
185     }
186   }
187 };
188 
189 // We don't want to allocate a buffer to hold all the patches if the size is
190 // going to be extremely large, so break it into chunks if it's bigger than
191 // a limit. Each chunk will be processed serially, so we can refill the
192 // buffer for the next chunk and reuse it, keeping maximum memory size down.
193 // In this case, we've picked 1 megabyte as a reasonable limit, from
194 // experimentation.
195 const size_t kMaxChunkSize = (1 * 1024 * 1024);
196 
197 // Implements convolution as a two stage process, first packing the patches of
198 // the input image into columns (im2col) and then running GEMM to produce the
199 // final result.
200 template <class T1, class T2, class T3>
201 class Im2ColConvFunctor {
202  public:
operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,int input_offset,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int filter_offset,int stride,Padding padding,T3 * output_data,int output_height,int output_width,int output_shift,int output_offset,int output_mult)203   void operator()(OpKernelContext* context, const T1* input_data,
204                   int input_batches, int input_height, int input_width,
205                   int input_depth, int input_offset, const T2* filter_data,
206                   int filter_height, int filter_width, int filter_count,
207                   int filter_offset, int stride, Padding padding,
208                   T3* output_data, int output_height, int output_width,
209                   int output_shift, int output_offset, int output_mult) {
210     if (input_offset < 0) {
211       // Only log the first few occurrences of this warning.
212       static int warning_count = 0;
213       if (warning_count < 10) {
214         ++warning_count;
215         LOG(WARNING)
216             << "For kernel '" << context->op_kernel().name() << "' from input '"
217             << context->op_kernel().requested_input(0)
218             << "': Zero is not representable in the quantized range used by the"
219             << " input. This means QuantizedConv2d has to fall back to a slow"
220             << " implementation, since the border of zero values can't be"
221             << " represented easily. You should try to construct graphs that"
222             << " avoid this situation.";
223       }
224       ReferenceConvFunctor<T1, T2, T3> conv_functor;
225       conv_functor(context, input_data, input_batches, input_height,
226                    input_width, input_depth, input_offset, filter_data,
227                    filter_height, filter_width, filter_count, filter_offset,
228                    stride, padding, output_data, output_height, output_width,
229                    output_shift, output_offset, output_mult);
230       return;
231     }
232 
233     OP_REQUIRES(
234         context, output_width > 0,
235         errors::InvalidArgument("output_width must be strictly positive"));
236     OP_REQUIRES(
237         context, output_height > 0,
238         errors::InvalidArgument("output_height must be strictly positive"));
239     int filter_left_offset;
240     int filter_top_offset;
241     if (padding == VALID) {
242       filter_left_offset =
243           ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
244       filter_top_offset =
245           ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
246     } else {
247       filter_left_offset =
248           ((output_width - 1) * stride + filter_width - input_width) / 2;
249       filter_top_offset =
250           ((output_height - 1) * stride + filter_height - input_height) / 2;
251     }
252 
253     // The im2col buffer has # of patches rows, and # of filters cols.
254     // It's laid out like this, in row major order in memory:
255     //        < filter value count >
256     //   ^   +---------------------+
257     // patch |                     |
258     // count |                     |
259     //   v   +---------------------+
260     // Each patch row contains a filter_width x filter_height patch of the
261     // input, with the depth channel as the most contiguous in memory, followed
262     // by the width, then the height. This is the standard memory order in the
263     // image world if it helps to visualize it.
264     const int filter_value_count = filter_width * filter_height * input_depth;
265     OP_REQUIRES(context, filter_value_count > 0,
266                 errors::InvalidArgument(
267                     "filter patch must contain at least one element"));
268     const int64_t patches_per_chunk =
269         kMaxChunkSize / (filter_value_count * sizeof(T1));
270     const int64_t chunk_value_count =
271         (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
272     // TODO(petewarden) - Memory allocation can be very slow on Android. Can we
273     // optimize this by keeping the scratch buffer around?
274     // Because memory allocation is very expensive on mobile platforms, try to
275     // allocate a persistent buffer that will be kept around between calls. We
276     // use TensorFlow's resource management to ensure that the memory will be
277     // released when the session is over.
278     Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
279     std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
280         creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
281 #ifdef _MSC_VER
282           // MSVC complains about the capture of chunk_value_count which oddly
283           // works fine in conv_ops_using_gemm.cc for example.
284           // Define chunk_value_count inside the lambda for now.
285           const int64 chunk_value_count =
286               (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
287 #endif
288           *resource = new Im2ColBufferResource<T1, chunk_value_count>();
289           return OkStatus();
290         };
291     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
292                                 "Conv2d", "im2col_buffer",
293                                 &im2col_buffer_resource, creator));
294     // This means that multiple ops can't be run simultaneously on different
295     // threads, because we have a single shared resource. The platforms this is
296     // aimed at have intra-op parallelism as their focus though, so it shouldn't
297     // be an issue.
298     mutex_lock lock_buffer(im2col_buffer_resource->mu);
299     core::ScopedUnref unref_buffer(im2col_buffer_resource);
300     T1* im2col_buffer = im2col_buffer_resource->data;
301 
302     const int64_t patch_count = (input_batches * output_height * output_width);
303     const int64_t chunk_count =
304         (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
305 
306     for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
307       const int64_t patch_index_start = chunk_index * patches_per_chunk;
308       const int64_t patch_index_end =
309           std::min(patch_index_start + patches_per_chunk, patch_count);
310       for (int64_t patch_index = patch_index_start;
311            patch_index < patch_index_end; ++patch_index) {
312         const int64_t batch = patch_index / (output_height * output_width);
313         const int64_t out_y = (patch_index / output_width) % output_height;
314         const int64_t out_x = patch_index % output_width;
315         const T1* input_batch_start =
316             input_data + (batch * input_height * input_width * input_depth);
317         const int in_y_origin = (out_y * stride) - filter_top_offset;
318         const int in_x_origin = (out_x * stride) - filter_left_offset;
319         const int patch_index_within_chunk = patch_index % patches_per_chunk;
320         T1* im2col_patch_start =
321             im2col_buffer + (patch_index_within_chunk * filter_value_count);
322         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
323           const int in_y = in_y_origin + filter_y;
324           T1* im2col_row_start =
325               im2col_patch_start + (filter_y * filter_width * input_depth);
326           // If we're off the top or the bottom of the input, fill the
327           // whole row with zeroes.
328           if ((in_y < 0) || (in_y >= input_height)) {
329             // On Android, memset and memcpy are significantly faster than the
330             // more modern std::set and std::copy equivalents.
331             memset(im2col_row_start, input_offset,
332                    (filter_width * input_depth));
333           } else {
334             // What we're doing here is trying to copy and fill the im2col
335             // buffer as efficiently as possible, using functions to set or
336             // duplicate values en masse. We know we don't have to worry about
337             // vertical edges because we dealt with that case above, so we
338             // just need to handle filters that overlap the left or right
339             // edges. Here's what that looks like:
340             //
341             // < left_zero_count > < center_copy_count > < right_zero_count >
342             // +------------------+---------------------+--------------------+
343             // |     (filter)     |       (image)       |      (filter)      |
344             // +------------------+---------------------+--------------------+
345             // in_x_origin        0                 input_width       in_x_end
346             //
347             // In reality it's unlikely that a filter patch will be wider
348             // than an input, but this shows all the edge cases.
349             // We use memset() to set the left and right sections to zeroes
350             // and memcpy() to copy over the input data for the center. These
351             // are preferred to std::fill and std::copy because they're much
352             // faster on Android.
353             const int in_x_end = in_x_origin + filter_width;
354             const int left_zero_count = std::max(0, 0 - in_x_origin);
355             const int right_zero_count = std::max(0, in_x_end - input_width);
356             const int center_copy_count =
357                 filter_width - (left_zero_count + right_zero_count);
358             if (left_zero_count > 0) {
359               T1* im2col_left_start = im2col_row_start;
360               memset(im2col_left_start, input_offset,
361                      (left_zero_count * input_depth));
362             }
363             if (center_copy_count > 0) {
364               const T1* input_row_start =
365                   input_batch_start + (in_y * input_width * input_depth) +
366                   (std::max(0, in_x_origin) * input_depth);
367               T1* im2col_center_start =
368                   im2col_row_start + (left_zero_count * input_depth);
369               memcpy(im2col_center_start, input_row_start,
370                      (center_copy_count * input_depth));
371             }
372             if (right_zero_count > 0) {
373               T1* im2col_right_start =
374                   im2col_row_start +
375                   ((left_zero_count + center_copy_count) * input_depth);
376               memset(im2col_right_start, input_offset,
377                      (right_zero_count * input_depth));
378             }
379           }
380         }
381       }
382       // Now we've assembled a set of image patches into a matrix, apply a
383       // GEMM matrix multiply of the patches as rows, times the filter
384       // weights in columns, to get partial results in the output matrix.
385       const int how_many_patches = patch_index_end - patch_index_start;
386       const bool transpose_a = false;
387       const bool transpose_b = false;
388       const bool transpose_c = false;
389       const int m = how_many_patches;
390       const int n = filter_count;
391       const int k = filter_value_count;
392       const int lda = filter_value_count;
393       const int ldb = filter_count;
394       const int ldc = filter_count;
395       T3* chunk_output_data = output_data + (patch_index_start * filter_count);
396 
397       if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
398           std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
399           (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
400           (transpose_c == false) && (k <= 2048)) {
401         meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer,
402                             filter_data, chunk_output_data, m, n, k,
403                             -input_offset, -filter_offset, lda, ldb, ldc);
404       } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
405                  std::is_same<T3, qint32>() && (output_offset == 0) &&
406                  (output_mult == 1) && (output_shift == 0)) {
407         // The gemmlowp optimized library only works for a particular set of
408         // data types, so check if we meet those requirements and fall back to a
409         // slower reference implementation if not.
410         const uint8* im2col_data_as_uint8 = &(im2col_buffer->value);
411         const uint8* filter_data_as_uint8 = &(filter_data->value);
412         int32* output_data_as_int32 = &(chunk_output_data->value);
413         // All of the transpose_* variables are currently compile-time consts,
414         // so we could just hard-code these values too, but that would break if
415         // anybody changed those values in the future (e.g. to match the ability
416         // of MatMul to specify them as attributes). We're using a verbose
417         // approach of deriving the order values from the transpose variables to
418         // be able to catch any changes like that.
419         static const gemmlowp::MapOrder ResultOrder =
420             !transpose_c ? gemmlowp::MapOrder::RowMajor
421                          : gemmlowp::MapOrder::ColMajor;
422         static const gemmlowp::MapOrder LhsOrder =
423             !transpose_a ? gemmlowp::MapOrder::RowMajor
424                          : gemmlowp::MapOrder::ColMajor;
425         static const gemmlowp::MapOrder RhsOrder =
426             !transpose_b ? gemmlowp::MapOrder::RowMajor
427                          : gemmlowp::MapOrder::ColMajor;
428         gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
429             im2col_data_as_uint8, m, k, lda);
430         gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
431             filter_data_as_uint8, k, n, ldb);
432         gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
433             output_data_as_int32, m, n, ldc);
434         const std::tuple<> empty_pipeline = {};
435 
436         auto& worker_threads =
437             *(context->device()->tensorflow_cpu_worker_threads());
438         TensorflowGemmContext context(worker_threads.num_threads,
439                                       worker_threads.workers);
440         gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
441                                          gemmlowp::DefaultL8R8BitDepthParams>(
442             &context, lhs, rhs, &result, -input_offset, -filter_offset,
443             empty_pipeline);
444         // Since gemmlowp uses assembly to write to the output, msan won't
445         // detect the output buffer as written to, so we mark it manually.
446         TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32,
447                                           m * n * sizeof(int32));
448       } else {
449         ReferenceGemm<T1, T2, T3>(
450             transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer,
451             input_offset, lda, filter_data, filter_offset, ldb,
452             chunk_output_data, output_shift, output_offset, output_mult, ldc);
453       }
454     }
455   }
456 };
457 
458 template <class T1, class T2, class T3,
459           template <class TF1, class TF2, class TF3> class ConvFunctor>
460 class QuantizedConv2DOp : public OpKernel {
461  public:
QuantizedConv2DOp(OpKernelConstruction * context)462   explicit QuantizedConv2DOp(OpKernelConstruction* context)
463       : OpKernel(context) {
464     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
465     OP_REQUIRES(context, strides_.size() == 4,
466                 errors::InvalidArgument("Sliding window strides field must "
467                                         "specify 4 dimensions"));
468     OP_REQUIRES(context, strides_[1] == strides_[2],
469                 errors::InvalidArgument(
470                     "Current implementation only supports equal length "
471                     "strides in the row and column dimensions."));
472     OP_REQUIRES(
473         context, (strides_[0] == 1 && strides_[3] == 1),
474         errors::InvalidArgument("Current implementation does not yet support "
475                                 "strides in the batch and depth dimensions."));
476     std::vector<int32> dilations;
477     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
478     OP_REQUIRES(context, dilations.size() == 4,
479                 errors::InvalidArgument("Dilations field must "
480                                         "specify 4 dimensions"));
481     OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1,
482                 errors::InvalidArgument(
483                     "Current implementation only supports dilated rate as 1 "
484                     "in the row and column dimensions."));
485     OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1),
486                 errors::InvalidArgument(
487                     "Current implementation does not yet support "
488                     "dilations in the batch and depth dimensions."));
489     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
490   }
491 
Compute(OpKernelContext * context)492   void Compute(OpKernelContext* context) override {
493     // Input tensor is of the following dimensions:
494     // [ batch, in_rows, in_cols, in_depth ]
495     const Tensor& input = context->input(0);
496 
497     // Input filter is of the following dimensions:
498     // [ filter_rows, filter_cols, in_depth, out_depth]
499     const Tensor& filter = context->input(1);
500 
501     // For 2D convolution, there should be 4 dimensions.
502     OP_REQUIRES(context, input.dims() == 4,
503                 errors::InvalidArgument("input must be rank 4 but is rank ",
504                                         input.shape().dims()));
505     OP_REQUIRES(context, filter.dims() == 4,
506                 errors::InvalidArgument("filter must be rank 4 but is rank ",
507                                         filter.shape().dims()));
508 
509     OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()),
510                 errors::InvalidArgument("min_input must be rank 0 but is rank ",
511                                         context->input(2).shape().dims()));
512     OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(3).shape()),
513                 errors::InvalidArgument("max_input must be rank 0 but is rank ",
514                                         context->input(3).shape().dims()));
515     OP_REQUIRES(
516         context, TensorShapeUtils::IsScalar(context->input(4).shape()),
517         errors::InvalidArgument("min_filter must be rank 0 but is rank ",
518                                 context->input(4).shape().dims()));
519     OP_REQUIRES(
520         context, TensorShapeUtils::IsScalar(context->input(5).shape()),
521         errors::InvalidArgument("max_filter must be rank 0 but is rank ",
522                                 context->input(5).shape().dims()));
523 
524     const float min_input = context->input(2).flat<float>()(0);
525     const float max_input = context->input(3).flat<float>()(0);
526     const float min_filter = context->input(4).flat<float>()(0);
527     const float max_filter = context->input(5).flat<float>()(0);
528     const int32_t offset_input =
529         FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input);
530     const int32_t offset_filter =
531         FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter);
532     const int32_t offset_output = 0;
533     const int32_t mult_output = 1;
534     const int32_t shift_output = 0;
535 
536     // The last dimension for input is in_depth. It must be the same as the
537     // filter's in_depth.
538     const int64_t in_depth = input.dim_size(3);
539     OP_REQUIRES(context, in_depth == filter.dim_size(2),
540                 errors::InvalidArgument(
541                     "input and filter must have the same depth: ", in_depth,
542                     " vs ", filter.dim_size(2)));
543 
544     // The last dimension for filter is out_depth.
545     const int64_t out_depth = filter.dim_size(3);
546 
547     // The second dimension for input is rows/height.
548     // The first dimension for filter is rows/height.
549     const int64_t input_rows = input.dim_size(1);
550     const int64_t filter_rows = filter.dim_size(0);
551 
552     // The third dimension for input is columns/width.
553     // The second dimension for filter is columns/width.
554     const int64_t input_cols = input.dim_size(2);
555     const int64_t filter_cols = filter.dim_size(1);
556 
557     // The first dimension for input is batch.
558     const int64_t batch = input.dim_size(0);
559 
560     // For now we take the stride from the second dimension only (we
561     // assume row = col stride, and do not support striding on the
562     // batch or depth dimension).
563     const int stride = strides_[1];
564 
565     int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
566     OP_REQUIRES_OK(context,
567                    GetWindowedOutputSize(input_rows, filter_rows, stride,
568                                          padding_, &out_rows, &pad_rows));
569     OP_REQUIRES_OK(context,
570                    GetWindowedOutputSize(input_cols, filter_cols, stride,
571                                          padding_, &out_cols, &pad_cols));
572     CHECK_GT(batch, 0);
573     CHECK_GT(out_rows, 0);
574     CHECK_GT(out_cols, 0);
575     CHECK_GT(out_depth, 0);
576     TensorShape out_shape({batch, out_rows, out_cols, out_depth});
577 
578     // Output tensor is of the following dimensions:
579     // [ in_batch, out_rows, out_cols, out_depth ]
580     Tensor* output = nullptr;
581     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
582 
583     // This will call different implementations (e.g. reference or optimized)
584     // depending on the template parameter.
585     ConvFunctor<T1, T2, T3> conv_functor;
586     conv_functor(context, input.flat<T1>().data(), batch, input_rows,
587                  input_cols, in_depth, offset_input, filter.flat<T2>().data(),
588                  filter_rows, filter_cols, out_depth, offset_filter, stride,
589                  padding_, output->flat<T3>().data(), out_rows, out_cols,
590                  shift_output, offset_output, mult_output);
591 
592     float min_output_value;
593     float max_output_value;
594     QuantizationRangeForMultiplication<T1, T2, T3>(
595         min_input, max_input, min_filter, max_filter, &min_output_value,
596         &max_output_value);
597 
598     Tensor* output_min = nullptr;
599     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
600     output_min->flat<float>()(0) = min_output_value;
601 
602     Tensor* output_max = nullptr;
603     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
604     output_max->flat<float>()(0) = max_output_value;
605   }
606 
607  private:
608   std::vector<int32> strides_;
609   Padding padding_;
610 };
611 
612 // Right now we only support taking two eight bit inputs, and returning the
613 // results as signed 32-bit integers.
614 REGISTER_KERNEL_BUILDER(
615     Name("QuantizedConv2D")
616         .Device(DEVICE_CPU)
617         .TypeConstraint<quint8>("Tinput")
618         .TypeConstraint<quint8>("Tfilter")
619         .TypeConstraint<qint32>("out_type"),
620     QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>);
621 
622 }  // namespace tensorflow
623