1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements quantized eight-bit versions of the convolution operations. 17 18 #include <algorithm> 19 #include <vector> 20 21 #include "tensorflow/core/platform/errors.h" 22 23 #define EIGEN_USE_THREADS 24 25 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK 26 #include "public/gemmlowp.h" 27 #include "tensorflow/core/framework/kernel_shape_util.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/kernels/conv_ops.h" 31 #include "tensorflow/core/kernels/meta_support.h" 32 #include "tensorflow/core/kernels/quantization_utils.h" 33 #include "tensorflow/core/kernels/reference_gemm.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/platform/errors.h" 36 #include "tensorflow/core/util/padding.h" 37 38 namespace tensorflow { 39 40 // This functor implements the convolution operation in as simple a form as 41 // possible. It won't give great performance, but it is very useful for 42 // stepping through and instrumenting for debugging, creating minimal benchmarks 43 // to prototype with, and sharing with teams that want to run this outside of 44 // our environment. 45 // With that in mind, I've avoided using anything except pretty standard C++ 46 // types. This is especially noticeable in the data access through raw array 47 // indexing. It's deliberate in this case though, since it makes the underlying 48 // memory order very explicit, which is important for both inspecting memory 49 // contents during debugging and for specifying what we expect to others. 50 // The memory layout of the data is, from biggest stride to smallest: 51 // input_data = [input_batches, input_height, input_width, input_depth] 52 // filter_data = [filter_height, filter_width, input_depth, filter_count] 53 // output_data = [input_batches, output_height, output_width, filter_count] 54 template <class T1, class T2, class T3> 55 class ReferenceConvFunctor { 56 public: operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,int input_offset,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int filter_offset,int stride,Padding padding,T3 * output_data,int output_height,int output_width,int output_shift,int output_offset,int output_mult)57 void operator()(OpKernelContext* context, const T1* input_data, 58 int input_batches, int input_height, int input_width, 59 int input_depth, int input_offset, const T2* filter_data, 60 int filter_height, int filter_width, int filter_count, 61 int filter_offset, int stride, Padding padding, 62 T3* output_data, int output_height, int output_width, 63 int output_shift, int output_offset, int output_mult) { 64 // Set up some constants we need for the output down-shifting and 65 // saturation. 66 const int32_t highest = static_cast<int32>(Eigen::NumTraits<T3>::highest()); 67 const int32_t lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest()); 68 69 // When we're converting the 32 bit accumulator to a lower bit depth, we 70 // need to add on 0.5 in fixed-point terms to make the operation round half 71 // up towards positive infinity, rather than a floor. 72 // We also need to watch out for the case when there's no down shift, 73 // because a left shift by a negative number gives undefined results. 74 const int32_t rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1)); 75 76 // The two different padding modes we support can be a bit confusing. SAME 77 // means we're trying to produce an output image that's the same size as the 78 // input. It's complicated by stride, which shrinks the output image by a 79 // a factor, but it means we end up sampling from outside the borders of the 80 // input. These out-of-bounds values are read as zeroes. VALID means only 81 // produce output values where the filters can read all their values from 82 // within the input image. It effectively removes the margins of the output 83 // image compared to the one produced by SAME. Stride complicates this 84 // definition though, because it can result in the right and bottom filter 85 // patches sampling from outside the borders if it's greater than 1. 86 // Most of the logic for sorting this all out is done before this function, 87 // when we calculate the output size, but the positioning of the origin of 88 // the filters is different between the two modes, since SAME positions the 89 // first filter off the edge of the input. 90 int filter_left_offset; 91 int filter_top_offset; 92 if (padding == VALID) { 93 filter_left_offset = 94 ((output_width - 1) * stride + filter_width - input_width + 1) / 2; 95 filter_top_offset = 96 ((output_height - 1) * stride + filter_height - input_height + 1) / 2; 97 } else { 98 filter_left_offset = 99 ((output_width - 1) * stride + filter_width - input_width) / 2; 100 filter_top_offset = 101 ((output_height - 1) * stride + filter_height - input_height) / 2; 102 } 103 104 // If we've got multiple images in our input, work through each of them. 105 for (int batch = 0; batch < input_batches; ++batch) { 106 // Walk through all the output image values, sliding the filter to 107 // different 108 // positions in the input. 109 for (int out_y = 0; out_y < output_height; ++out_y) { 110 for (int out_x = 0; out_x < output_width; ++out_x) { 111 // Each filter kernel produces one output channel. 112 for (int out_channel = 0; out_channel < filter_count; ++out_channel) { 113 // We're going to calculate a single output value, which means we 114 // need to multiply a three dimensional kernel of weights against 115 // the current location within the input image. 116 /* 117 *-------------------------------... 118 |\ ^ 119 | \in_depth 120 | \ v 121 | *-------------------------------... 122 | | ^ 123 | | in_y_origin 124 | | v \ 125 | |<in_x_origin>*---*^ 126 | | \| |filter_height 127 . | *---*v 128 . | <---> 129 . filter_width 130 . 131 */ 132 const int in_x_origin = (out_x * stride) - filter_left_offset; 133 const int in_y_origin = (out_y * stride) - filter_top_offset; 134 int32_t total = 0; 135 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 136 for (int filter_x = 0; filter_x < filter_width; ++filter_x) { 137 for (int in_channel = 0; in_channel < input_depth; 138 ++in_channel) { 139 const int in_x = in_x_origin + filter_x; 140 const int in_y = in_y_origin + filter_y; 141 int32_t input_value; 142 // If the location is outside the bounds of the input image, 143 // use zero as a default value. 144 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && 145 (in_y < input_height)) { 146 const T1 input_source_value = 147 input_data[(batch * input_height * input_width * 148 input_depth) + 149 (in_y * input_width * input_depth) + 150 (in_x * input_depth) + in_channel]; 151 // We're promoting the T1 type to a higher bit depth here as 152 // we do the subtraction. 153 input_value = 154 static_cast<int32>(input_source_value) - input_offset; 155 } else { 156 input_value = 0; 157 } 158 const T2 filter_source_value = 159 filter_data[(filter_y * filter_width * input_depth * 160 filter_count) + 161 (filter_x * input_depth * filter_count) + 162 (in_channel * filter_count) + out_channel]; 163 // Another promotion to 32 bit, as above. 164 const int32_t filter_value = 165 static_cast<int32>(filter_source_value) - filter_offset; 166 total += (input_value * filter_value); 167 } 168 } 169 } 170 // Here we're applying scale factors to compress the 32 bit 171 // accumulated total to a potentially lower bit depth. 172 const int32_t output = 173 ((((total + output_offset) * output_mult) + rounding) >> 174 output_shift); 175 // We need to saturate the results against the largest and smallest 176 // values that can be represented in this type. 177 const int32_t top_clamped_output = std::min(output, highest); 178 const int32_t clamped_output = std::max(top_clamped_output, lowest); 179 output_data[(batch * output_height * output_width * filter_count) + 180 (out_y * output_width * filter_count) + 181 (out_x * filter_count) + out_channel] = clamped_output; 182 } 183 } 184 } 185 } 186 } 187 }; 188 189 // We don't want to allocate a buffer to hold all the patches if the size is 190 // going to be extremely large, so break it into chunks if it's bigger than 191 // a limit. Each chunk will be processed serially, so we can refill the 192 // buffer for the next chunk and reuse it, keeping maximum memory size down. 193 // In this case, we've picked 1 megabyte as a reasonable limit, from 194 // experimentation. 195 const size_t kMaxChunkSize = (1 * 1024 * 1024); 196 197 // Implements convolution as a two stage process, first packing the patches of 198 // the input image into columns (im2col) and then running GEMM to produce the 199 // final result. 200 template <class T1, class T2, class T3> 201 class Im2ColConvFunctor { 202 public: operator ()(OpKernelContext * context,const T1 * input_data,int input_batches,int input_height,int input_width,int input_depth,int input_offset,const T2 * filter_data,int filter_height,int filter_width,int filter_count,int filter_offset,int stride,Padding padding,T3 * output_data,int output_height,int output_width,int output_shift,int output_offset,int output_mult)203 void operator()(OpKernelContext* context, const T1* input_data, 204 int input_batches, int input_height, int input_width, 205 int input_depth, int input_offset, const T2* filter_data, 206 int filter_height, int filter_width, int filter_count, 207 int filter_offset, int stride, Padding padding, 208 T3* output_data, int output_height, int output_width, 209 int output_shift, int output_offset, int output_mult) { 210 if (input_offset < 0) { 211 // Only log the first few occurrences of this warning. 212 static int warning_count = 0; 213 if (warning_count < 10) { 214 ++warning_count; 215 LOG(WARNING) 216 << "For kernel '" << context->op_kernel().name() << "' from input '" 217 << context->op_kernel().requested_input(0) 218 << "': Zero is not representable in the quantized range used by the" 219 << " input. This means QuantizedConv2d has to fall back to a slow" 220 << " implementation, since the border of zero values can't be" 221 << " represented easily. You should try to construct graphs that" 222 << " avoid this situation."; 223 } 224 ReferenceConvFunctor<T1, T2, T3> conv_functor; 225 conv_functor(context, input_data, input_batches, input_height, 226 input_width, input_depth, input_offset, filter_data, 227 filter_height, filter_width, filter_count, filter_offset, 228 stride, padding, output_data, output_height, output_width, 229 output_shift, output_offset, output_mult); 230 return; 231 } 232 233 OP_REQUIRES( 234 context, output_width > 0, 235 errors::InvalidArgument("output_width must be strictly positive")); 236 OP_REQUIRES( 237 context, output_height > 0, 238 errors::InvalidArgument("output_height must be strictly positive")); 239 int filter_left_offset; 240 int filter_top_offset; 241 if (padding == VALID) { 242 filter_left_offset = 243 ((output_width - 1) * stride + filter_width - input_width + 1) / 2; 244 filter_top_offset = 245 ((output_height - 1) * stride + filter_height - input_height + 1) / 2; 246 } else { 247 filter_left_offset = 248 ((output_width - 1) * stride + filter_width - input_width) / 2; 249 filter_top_offset = 250 ((output_height - 1) * stride + filter_height - input_height) / 2; 251 } 252 253 // The im2col buffer has # of patches rows, and # of filters cols. 254 // It's laid out like this, in row major order in memory: 255 // < filter value count > 256 // ^ +---------------------+ 257 // patch | | 258 // count | | 259 // v +---------------------+ 260 // Each patch row contains a filter_width x filter_height patch of the 261 // input, with the depth channel as the most contiguous in memory, followed 262 // by the width, then the height. This is the standard memory order in the 263 // image world if it helps to visualize it. 264 const int filter_value_count = filter_width * filter_height * input_depth; 265 OP_REQUIRES(context, filter_value_count > 0, 266 errors::InvalidArgument( 267 "filter patch must contain at least one element")); 268 const int64_t patches_per_chunk = 269 kMaxChunkSize / (filter_value_count * sizeof(T1)); 270 const int64_t chunk_value_count = 271 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); 272 // TODO(petewarden) - Memory allocation can be very slow on Android. Can we 273 // optimize this by keeping the scratch buffer around? 274 // Because memory allocation is very expensive on mobile platforms, try to 275 // allocate a persistent buffer that will be kept around between calls. We 276 // use TensorFlow's resource management to ensure that the memory will be 277 // released when the session is over. 278 Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; 279 std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> 280 creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { 281 #ifdef _MSC_VER 282 // MSVC complains about the capture of chunk_value_count which oddly 283 // works fine in conv_ops_using_gemm.cc for example. 284 // Define chunk_value_count inside the lambda for now. 285 const int64 chunk_value_count = 286 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); 287 #endif 288 *resource = new Im2ColBufferResource<T1, chunk_value_count>(); 289 return OkStatus(); 290 }; 291 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( 292 "Conv2d", "im2col_buffer", 293 &im2col_buffer_resource, creator)); 294 // This means that multiple ops can't be run simultaneously on different 295 // threads, because we have a single shared resource. The platforms this is 296 // aimed at have intra-op parallelism as their focus though, so it shouldn't 297 // be an issue. 298 mutex_lock lock_buffer(im2col_buffer_resource->mu); 299 core::ScopedUnref unref_buffer(im2col_buffer_resource); 300 T1* im2col_buffer = im2col_buffer_resource->data; 301 302 const int64_t patch_count = (input_batches * output_height * output_width); 303 const int64_t chunk_count = 304 (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; 305 306 for (int64_t chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { 307 const int64_t patch_index_start = chunk_index * patches_per_chunk; 308 const int64_t patch_index_end = 309 std::min(patch_index_start + patches_per_chunk, patch_count); 310 for (int64_t patch_index = patch_index_start; 311 patch_index < patch_index_end; ++patch_index) { 312 const int64_t batch = patch_index / (output_height * output_width); 313 const int64_t out_y = (patch_index / output_width) % output_height; 314 const int64_t out_x = patch_index % output_width; 315 const T1* input_batch_start = 316 input_data + (batch * input_height * input_width * input_depth); 317 const int in_y_origin = (out_y * stride) - filter_top_offset; 318 const int in_x_origin = (out_x * stride) - filter_left_offset; 319 const int patch_index_within_chunk = patch_index % patches_per_chunk; 320 T1* im2col_patch_start = 321 im2col_buffer + (patch_index_within_chunk * filter_value_count); 322 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 323 const int in_y = in_y_origin + filter_y; 324 T1* im2col_row_start = 325 im2col_patch_start + (filter_y * filter_width * input_depth); 326 // If we're off the top or the bottom of the input, fill the 327 // whole row with zeroes. 328 if ((in_y < 0) || (in_y >= input_height)) { 329 // On Android, memset and memcpy are significantly faster than the 330 // more modern std::set and std::copy equivalents. 331 memset(im2col_row_start, input_offset, 332 (filter_width * input_depth)); 333 } else { 334 // What we're doing here is trying to copy and fill the im2col 335 // buffer as efficiently as possible, using functions to set or 336 // duplicate values en masse. We know we don't have to worry about 337 // vertical edges because we dealt with that case above, so we 338 // just need to handle filters that overlap the left or right 339 // edges. Here's what that looks like: 340 // 341 // < left_zero_count > < center_copy_count > < right_zero_count > 342 // +------------------+---------------------+--------------------+ 343 // | (filter) | (image) | (filter) | 344 // +------------------+---------------------+--------------------+ 345 // in_x_origin 0 input_width in_x_end 346 // 347 // In reality it's unlikely that a filter patch will be wider 348 // than an input, but this shows all the edge cases. 349 // We use memset() to set the left and right sections to zeroes 350 // and memcpy() to copy over the input data for the center. These 351 // are preferred to std::fill and std::copy because they're much 352 // faster on Android. 353 const int in_x_end = in_x_origin + filter_width; 354 const int left_zero_count = std::max(0, 0 - in_x_origin); 355 const int right_zero_count = std::max(0, in_x_end - input_width); 356 const int center_copy_count = 357 filter_width - (left_zero_count + right_zero_count); 358 if (left_zero_count > 0) { 359 T1* im2col_left_start = im2col_row_start; 360 memset(im2col_left_start, input_offset, 361 (left_zero_count * input_depth)); 362 } 363 if (center_copy_count > 0) { 364 const T1* input_row_start = 365 input_batch_start + (in_y * input_width * input_depth) + 366 (std::max(0, in_x_origin) * input_depth); 367 T1* im2col_center_start = 368 im2col_row_start + (left_zero_count * input_depth); 369 memcpy(im2col_center_start, input_row_start, 370 (center_copy_count * input_depth)); 371 } 372 if (right_zero_count > 0) { 373 T1* im2col_right_start = 374 im2col_row_start + 375 ((left_zero_count + center_copy_count) * input_depth); 376 memset(im2col_right_start, input_offset, 377 (right_zero_count * input_depth)); 378 } 379 } 380 } 381 } 382 // Now we've assembled a set of image patches into a matrix, apply a 383 // GEMM matrix multiply of the patches as rows, times the filter 384 // weights in columns, to get partial results in the output matrix. 385 const int how_many_patches = patch_index_end - patch_index_start; 386 const bool transpose_a = false; 387 const bool transpose_b = false; 388 const bool transpose_c = false; 389 const int m = how_many_patches; 390 const int n = filter_count; 391 const int k = filter_value_count; 392 const int lda = filter_value_count; 393 const int ldb = filter_count; 394 const int ldc = filter_count; 395 T3* chunk_output_data = output_data + (patch_index_start * filter_count); 396 397 if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && 398 std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && 399 (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && 400 (transpose_c == false) && (k <= 2048)) { 401 meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer, 402 filter_data, chunk_output_data, m, n, k, 403 -input_offset, -filter_offset, lda, ldb, ldc); 404 } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && 405 std::is_same<T3, qint32>() && (output_offset == 0) && 406 (output_mult == 1) && (output_shift == 0)) { 407 // The gemmlowp optimized library only works for a particular set of 408 // data types, so check if we meet those requirements and fall back to a 409 // slower reference implementation if not. 410 const uint8* im2col_data_as_uint8 = &(im2col_buffer->value); 411 const uint8* filter_data_as_uint8 = &(filter_data->value); 412 int32* output_data_as_int32 = &(chunk_output_data->value); 413 // All of the transpose_* variables are currently compile-time consts, 414 // so we could just hard-code these values too, but that would break if 415 // anybody changed those values in the future (e.g. to match the ability 416 // of MatMul to specify them as attributes). We're using a verbose 417 // approach of deriving the order values from the transpose variables to 418 // be able to catch any changes like that. 419 static const gemmlowp::MapOrder ResultOrder = 420 !transpose_c ? gemmlowp::MapOrder::RowMajor 421 : gemmlowp::MapOrder::ColMajor; 422 static const gemmlowp::MapOrder LhsOrder = 423 !transpose_a ? gemmlowp::MapOrder::RowMajor 424 : gemmlowp::MapOrder::ColMajor; 425 static const gemmlowp::MapOrder RhsOrder = 426 !transpose_b ? gemmlowp::MapOrder::RowMajor 427 : gemmlowp::MapOrder::ColMajor; 428 gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs( 429 im2col_data_as_uint8, m, k, lda); 430 gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs( 431 filter_data_as_uint8, k, n, ldb); 432 gemmlowp::MatrixMap<std::int32_t, ResultOrder> result( 433 output_data_as_int32, m, n, ldc); 434 const std::tuple<> empty_pipeline = {}; 435 436 auto& worker_threads = 437 *(context->device()->tensorflow_cpu_worker_threads()); 438 TensorflowGemmContext context(worker_threads.num_threads, 439 worker_threads.workers); 440 gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, 441 gemmlowp::DefaultL8R8BitDepthParams>( 442 &context, lhs, rhs, &result, -input_offset, -filter_offset, 443 empty_pipeline); 444 // Since gemmlowp uses assembly to write to the output, msan won't 445 // detect the output buffer as written to, so we mark it manually. 446 TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32, 447 m * n * sizeof(int32)); 448 } else { 449 ReferenceGemm<T1, T2, T3>( 450 transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer, 451 input_offset, lda, filter_data, filter_offset, ldb, 452 chunk_output_data, output_shift, output_offset, output_mult, ldc); 453 } 454 } 455 } 456 }; 457 458 template <class T1, class T2, class T3, 459 template <class TF1, class TF2, class TF3> class ConvFunctor> 460 class QuantizedConv2DOp : public OpKernel { 461 public: QuantizedConv2DOp(OpKernelConstruction * context)462 explicit QuantizedConv2DOp(OpKernelConstruction* context) 463 : OpKernel(context) { 464 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 465 OP_REQUIRES(context, strides_.size() == 4, 466 errors::InvalidArgument("Sliding window strides field must " 467 "specify 4 dimensions")); 468 OP_REQUIRES(context, strides_[1] == strides_[2], 469 errors::InvalidArgument( 470 "Current implementation only supports equal length " 471 "strides in the row and column dimensions.")); 472 OP_REQUIRES( 473 context, (strides_[0] == 1 && strides_[3] == 1), 474 errors::InvalidArgument("Current implementation does not yet support " 475 "strides in the batch and depth dimensions.")); 476 std::vector<int32> dilations; 477 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations)); 478 OP_REQUIRES(context, dilations.size() == 4, 479 errors::InvalidArgument("Dilations field must " 480 "specify 4 dimensions")); 481 OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1, 482 errors::InvalidArgument( 483 "Current implementation only supports dilated rate as 1 " 484 "in the row and column dimensions.")); 485 OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1), 486 errors::InvalidArgument( 487 "Current implementation does not yet support " 488 "dilations in the batch and depth dimensions.")); 489 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 490 } 491 Compute(OpKernelContext * context)492 void Compute(OpKernelContext* context) override { 493 // Input tensor is of the following dimensions: 494 // [ batch, in_rows, in_cols, in_depth ] 495 const Tensor& input = context->input(0); 496 497 // Input filter is of the following dimensions: 498 // [ filter_rows, filter_cols, in_depth, out_depth] 499 const Tensor& filter = context->input(1); 500 501 // For 2D convolution, there should be 4 dimensions. 502 OP_REQUIRES(context, input.dims() == 4, 503 errors::InvalidArgument("input must be rank 4 but is rank ", 504 input.shape().dims())); 505 OP_REQUIRES(context, filter.dims() == 4, 506 errors::InvalidArgument("filter must be rank 4 but is rank ", 507 filter.shape().dims())); 508 509 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(2).shape()), 510 errors::InvalidArgument("min_input must be rank 0 but is rank ", 511 context->input(2).shape().dims())); 512 OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(3).shape()), 513 errors::InvalidArgument("max_input must be rank 0 but is rank ", 514 context->input(3).shape().dims())); 515 OP_REQUIRES( 516 context, TensorShapeUtils::IsScalar(context->input(4).shape()), 517 errors::InvalidArgument("min_filter must be rank 0 but is rank ", 518 context->input(4).shape().dims())); 519 OP_REQUIRES( 520 context, TensorShapeUtils::IsScalar(context->input(5).shape()), 521 errors::InvalidArgument("max_filter must be rank 0 but is rank ", 522 context->input(5).shape().dims())); 523 524 const float min_input = context->input(2).flat<float>()(0); 525 const float max_input = context->input(3).flat<float>()(0); 526 const float min_filter = context->input(4).flat<float>()(0); 527 const float max_filter = context->input(5).flat<float>()(0); 528 const int32_t offset_input = 529 FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input); 530 const int32_t offset_filter = 531 FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter); 532 const int32_t offset_output = 0; 533 const int32_t mult_output = 1; 534 const int32_t shift_output = 0; 535 536 // The last dimension for input is in_depth. It must be the same as the 537 // filter's in_depth. 538 const int64_t in_depth = input.dim_size(3); 539 OP_REQUIRES(context, in_depth == filter.dim_size(2), 540 errors::InvalidArgument( 541 "input and filter must have the same depth: ", in_depth, 542 " vs ", filter.dim_size(2))); 543 544 // The last dimension for filter is out_depth. 545 const int64_t out_depth = filter.dim_size(3); 546 547 // The second dimension for input is rows/height. 548 // The first dimension for filter is rows/height. 549 const int64_t input_rows = input.dim_size(1); 550 const int64_t filter_rows = filter.dim_size(0); 551 552 // The third dimension for input is columns/width. 553 // The second dimension for filter is columns/width. 554 const int64_t input_cols = input.dim_size(2); 555 const int64_t filter_cols = filter.dim_size(1); 556 557 // The first dimension for input is batch. 558 const int64_t batch = input.dim_size(0); 559 560 // For now we take the stride from the second dimension only (we 561 // assume row = col stride, and do not support striding on the 562 // batch or depth dimension). 563 const int stride = strides_[1]; 564 565 int64_t out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; 566 OP_REQUIRES_OK(context, 567 GetWindowedOutputSize(input_rows, filter_rows, stride, 568 padding_, &out_rows, &pad_rows)); 569 OP_REQUIRES_OK(context, 570 GetWindowedOutputSize(input_cols, filter_cols, stride, 571 padding_, &out_cols, &pad_cols)); 572 CHECK_GT(batch, 0); 573 CHECK_GT(out_rows, 0); 574 CHECK_GT(out_cols, 0); 575 CHECK_GT(out_depth, 0); 576 TensorShape out_shape({batch, out_rows, out_cols, out_depth}); 577 578 // Output tensor is of the following dimensions: 579 // [ in_batch, out_rows, out_cols, out_depth ] 580 Tensor* output = nullptr; 581 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 582 583 // This will call different implementations (e.g. reference or optimized) 584 // depending on the template parameter. 585 ConvFunctor<T1, T2, T3> conv_functor; 586 conv_functor(context, input.flat<T1>().data(), batch, input_rows, 587 input_cols, in_depth, offset_input, filter.flat<T2>().data(), 588 filter_rows, filter_cols, out_depth, offset_filter, stride, 589 padding_, output->flat<T3>().data(), out_rows, out_cols, 590 shift_output, offset_output, mult_output); 591 592 float min_output_value; 593 float max_output_value; 594 QuantizationRangeForMultiplication<T1, T2, T3>( 595 min_input, max_input, min_filter, max_filter, &min_output_value, 596 &max_output_value); 597 598 Tensor* output_min = nullptr; 599 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); 600 output_min->flat<float>()(0) = min_output_value; 601 602 Tensor* output_max = nullptr; 603 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); 604 output_max->flat<float>()(0) = max_output_value; 605 } 606 607 private: 608 std::vector<int32> strides_; 609 Padding padding_; 610 }; 611 612 // Right now we only support taking two eight bit inputs, and returning the 613 // results as signed 32-bit integers. 614 REGISTER_KERNEL_BUILDER( 615 Name("QuantizedConv2D") 616 .Device(DEVICE_CPU) 617 .TypeConstraint<quint8>("Tinput") 618 .TypeConstraint<quint8>("Tfilter") 619 .TypeConstraint<qint32>("out_type"), 620 QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>); 621 622 } // namespace tensorflow 623