xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/deep_conv2d.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define USE_EIGEN_TENSOR
17 #define EIGEN_USE_THREADS
18 
19 #include "tensorflow/core/kernels/deep_conv2d.h"
20 
21 #include <stdlib.h>
22 
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/kernels/winograd_transform.h"
25 #include "tensorflow/core/util/work_sharder.h"
26 
27 namespace tensorflow {
28 
29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
31 //
32 // DeepConv2D is implemented by computing the following equation:
33 //
34 //   y = C[Ad * Bg]
35 //
36 //   C: output transform matrix
37 //   A: input data transform matrix
38 //   B: filter transform matrix
39 //   d: vectorized data tile
40 //   g: vectorized filter tile
41 //   y: vectorized output tile
42 //
43 // The transform matrices and input, filter and output tile sizes are all
44 // specified by the DeepConv2DTransform implementation selected at the
45 // start of the DeepConv2D call, based on convolution parameters.
46 
47 // Approximate cost models for direct and deep convolutions.
GetDeepConvCost(int input_tile_rows,int input_tile_cols,int out_tile_rows,int out_tile_cols,int in_depth,int out_depth,int out_rows,int out_cols)48 static int64_t GetDeepConvCost(int input_tile_rows, int input_tile_cols,
49                                int out_tile_rows, int out_tile_cols,
50                                int in_depth, int out_depth, int out_rows,
51                                int out_cols) {
52   // Input transform cost.
53   const int64_t input_tile_spatial_size = input_tile_rows * input_tile_cols;
54   const int64_t input_transform_cost =
55       input_tile_spatial_size * input_tile_spatial_size * in_depth;
56 
57   // Element-wise products (each product is a MatMul across depth).
58   const int64_t product_cost = input_tile_spatial_size * in_depth * out_depth;
59 
60   // Output transform cost.
61   const int64_t output_tile_spatial_size = out_tile_rows * out_tile_cols;
62   const int64_t output_transform_cost =
63       output_tile_spatial_size * input_tile_spatial_size * out_depth;
64 
65   // Calculate number of input tiles to process.
66   const int64_t row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
67   const int64_t col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
68   const int64_t num_tiles = row_tiles * col_tiles;
69 
70   // Return total cost.
71   return num_tiles *
72          (input_transform_cost + product_cost + output_transform_cost);
73 }
74 
GetDirectConvCost(int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)75 static int64_t GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
76                                  int out_depth, int out_rows, int out_cols) {
77   return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
78 }
79 
80 // Reads environment variable 'env_var_name'.
81 // Returns 'true' if environment variable is enabled, false otherwise.
ReadBoolFromEnvVar(const char * env_var_name,bool default_val)82 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
83   const char* tf_env_var_val = getenv(env_var_name);
84   if (tf_env_var_val != nullptr) {
85     StringPiece tf_env_var_val_str(tf_env_var_val);
86     if (tf_env_var_val_str == "0") {
87       return false;
88     }
89     return true;
90   }
91   return default_val;
92 }
93 
94 // Returns true if convolution can be computed efficiently by DeepConv2D,
95 // returns false otherwise.
96 // TODO(andydavis) Add support for other filter sizes and strides.
97 // TODO(andydavis) Add support for autotuning.
CanUseDeepConv2D(int stride_rows,int stride_cols,int filter_rows,int filter_cols,int in_depth,int out_depth,int out_rows,int out_cols)98 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
99                       int filter_cols, int in_depth, int out_depth,
100                       int out_rows, int out_cols) {
101   // Check if convolution parameters are supported.
102   // TODO(andydavis) Add support for multiple filter sizes and strides.
103   if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
104       filter_cols != 3) {
105     return false;
106   }
107 
108   // Check if deep convolution is enabled by environment variable.
109   // NOTE: IF this environment variable name changes, update conv_ops_test.py.
110   if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
111     return false;
112   }
113 
114   // Check if flop cost of deep convolution is less than direct convolution.
115   WinogradTransform<float> t;
116   const int64_t deep_conv_cost = GetDeepConvCost(
117       t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
118       t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
119   const int64_t direct_conv_cost = GetDirectConvCost(
120       filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
121 
122   VLOG(2) << "CanUseDeepConv2D"
123           << " deep_conv_cost: " << deep_conv_cost
124           << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
125           << (static_cast<float>(deep_conv_cost) /
126               static_cast<float>(direct_conv_cost))
127           << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
128   return deep_conv_cost < direct_conv_cost;
129 }
130 
131 typedef Eigen::ThreadPoolDevice CPUDevice;
132 
133 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
134 //
135 // filter_in:
136 //   [filter_rows, filter_cols, in_depth, out_depth]
137 //
138 // filter_buf:
139 //   [base_filter_rows, base_filter_cols, in_depth]
140 //
141 template <typename T>
142 struct CopyFilterDepth {
operator ()tensorflow::CopyFilterDepth143   void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
144     typedef typename Eigen::internal::packet_traits<T>::type Packet;
145     static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
146 
147     const int64_t vectorized_size = args.in_depth / kPacketSize;
148     const int64_t scalar_size = args.in_depth % kPacketSize;
149     const int64_t input_stride = args.out_depth * kPacketSize;
150 
151     // Copy vectorized portion of depth dimension.
152     for (int64_t d = 0; d < vectorized_size; ++d) {
153       auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
154                                                    args.out_depth);
155       Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
156     }
157     // Copy scalar portion of inner dimension.
158     const int64_t in_scalar_base = vectorized_size * input_stride;
159     const int64_t buf_scalar_base = vectorized_size * kPacketSize;
160     for (int64_t d = 0; d < scalar_size; ++d) {
161       filter_buf[buf_scalar_base + d] =
162           filter_in[in_scalar_base + d * args.out_depth];
163     }
164   }
165 };
166 
167 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
168 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
169 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
170 // 'filter_out' at the coordinate stride required by the transformed filter
171 // data layout.
172 //
173 // filter_in:
174 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
175 //    in_depth]
176 //
177 // filter_out:
178 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
179 //
180 // transform_matrix:
181 //   [tile_spatial_size, base_filter_spatial_size]
182 //
183 // out_buffer:
184 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
185 
186 template <typename T>
187 struct ComputeFilterRangeTransform {
188   typedef typename Eigen::internal::packet_traits<T>::type Packet;
189   static constexpr int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
190 
191   typedef Eigen::Map<
192       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
193       MatrixMap;
194   typedef Eigen::Map<
195       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
196       ConstMatrixMap;
197 
operator ()tensorflow::ComputeFilterRangeTransform198   void operator()(const Conv2DArgs& args,
199                   const DeepConv2DTransform<T>* transform,
200                   const int64_t od_start, const int64_t num_filters,
201                   const int64_t shard_rows, const int64_t shard_cols,
202                   const T* filter_in, const int64_t in_stride,
203                   const int64_t out_stride, const T* transform_matrix,
204                   T* out_buffer, T* filter_out) {
205     namespace ei = Eigen::internal;
206 
207     const int64_t in_depth = args.in_depth;
208     const int64_t base_filter_rows = transform->filter_shape().rows;
209     const int64_t base_filter_cols = transform->filter_shape().cols;
210     const int64_t base_filter_spatial_size =
211         base_filter_rows * base_filter_cols;
212     const int64_t tile_rows = transform->input_shape().rows;
213     const int64_t tile_cols = transform->input_shape().cols;
214     const int64_t tile_spatial_size = tile_rows * tile_cols;
215 
216     // Compute transform of 'num_filters' by 'transform_matrix'.
217     ConstMatrixMap A(transform_matrix, tile_spatial_size,
218                      base_filter_spatial_size);
219     ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
220     MatrixMap C(out_buffer, tile_spatial_size, in_stride);
221 
222     C.noalias() = A * B;
223 
224     // Copy 'out_buffer' to 'filter_out' at required filter output stride.
225     const int64_t scalar_size = in_depth % kPacketSize;
226     const int64_t vectorized_size = in_depth / kPacketSize;
227 
228     const int64_t shard_stride = args.in_depth;
229     const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
230 
231     for (int64_t od = 0; od < num_filters; ++od) {
232       const int64_t out_depth_buf_base = od * out_depth_stride;
233       const int64_t out_depth_base = (od_start + od) * out_depth_stride;
234 
235       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
236       for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
237         for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
238           const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
239 
240           for (int64_t i = 0; i < tile_spatial_size; ++i) {
241             const int64_t in_base =
242                 i * in_stride + out_depth_buf_base + shard_base;
243             const int64_t out_base =
244                 i * out_stride + out_depth_base + shard_base;
245             // Copy vectorized portion of 'in_depth'.
246             for (int64_t d = 0; d < vectorized_size; ++d) {
247               auto v =
248                   ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
249               ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
250             }
251             // Transform scalar portion of 'in_depth'.
252             const int64_t scalar_base = vectorized_size * kPacketSize;
253             for (int64_t d = 0; d < scalar_size; ++d) {
254               filter_out[out_base + scalar_base + d] =
255                   out_buffer[in_base + scalar_base + d];
256             }
257           }
258         }
259       }
260     }
261   }
262 };
263 
264 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
265 // For each filter in 'num_filters', copies data for all filter shards from
266 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
267 // Calls ComputeFilterRangeTransform to compute filter transform of data
268 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
269 //
270 // filter_in:
271 //   [filter_rows, filter_cols, in_depth, out_depth]
272 //
273 // filter_out:
274 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
275 //
276 // filter_buffer:
277 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
278 //    in_depth]
279 //
280 // transform_matrix:
281 //   [tile_spatial_size, base_filter_spatial_size]
282 //
283 // out_buffer:
284 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
285 //
286 
287 template <typename T>
288 struct TransformFilterRange {
operator ()tensorflow::TransformFilterRange289   void operator()(const Conv2DArgs& args,
290                   const DeepConv2DTransform<T>* transform,
291                   const int64_t od_start, const int64_t od_limit,
292                   const T* filter_in, const T* transform_matrix, T* out_buffer,
293                   T* filter_buf, T* filter_out) {
294     const int64_t num_filters = od_limit - od_start;
295     const int64_t base_filter_rows = transform->filter_shape().rows;
296     const int64_t base_filter_cols = transform->filter_shape().cols;
297     const int64_t base_filter_spatial_size =
298         base_filter_rows * base_filter_cols;
299 
300     // Compute number of filter shards.
301     const int64_t residual_row =
302         std::max(int64_t{0}, args.filter_rows - base_filter_rows);
303     const int64_t shard_rows = 1 + (residual_row + 2 - 1) / 2;
304 
305     const int64_t residual_col =
306         std::max(int64_t{0}, args.filter_cols - base_filter_cols);
307     const int64_t shard_cols = 1 + (residual_col + 2 - 1) / 2;
308 
309     // Compute strides to be used for input and output IO.
310     const int64_t shard_stride = args.in_depth;
311     const int64_t out_depth_stride = shard_rows * shard_cols * shard_stride;
312     const int64_t coord_stride = out_depth_stride * args.out_depth;
313     const int64_t filter_buf_stride =
314         num_filters * shard_rows * shard_cols * args.in_depth;
315     const int64_t tile_stride_rows = transform->output_shape().rows;
316     const int64_t tile_stride_cols = transform->output_shape().cols;
317 
318     const int64_t filter_buf_size = base_filter_spatial_size * num_filters *
319                                     shard_rows * shard_cols * args.in_depth;
320     memset(filter_buf, 0, sizeof(T) * filter_buf_size);
321 
322     // Copy filter range into 'filter_buf'.
323     for (int64_t od = 0; od < num_filters; ++od) {
324       const int64_t out_depth_base = od * out_depth_stride;
325 
326       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
327       for (int64_t s_r = 0; s_r < shard_rows; ++s_r) {
328         const int64_t row_offset = s_r == 0 ? 0 : 1;
329 
330         for (int64_t s_c = 0; s_c < shard_cols; ++s_c) {
331           const int64_t col_offset = s_c == 0 ? 0 : 1;
332           const int64_t f_r_start = s_r * tile_stride_rows;
333           const int64_t f_c_start = s_c * tile_stride_cols;
334 
335           const int64_t shard_base = shard_stride * (s_r * shard_cols + s_c);
336 
337           for (int64_t b_r = row_offset; b_r < base_filter_rows; ++b_r) {
338             const int64_t f_r = f_r_start + b_r;
339             if (f_r >= args.filter_rows) continue;
340 
341             for (int64_t b_c = col_offset; b_c < base_filter_cols; ++b_c) {
342               const int64_t f_c = f_c_start + b_c;
343               if (f_c >= args.filter_cols) continue;
344 
345               const int64_t in_index =
346                   args.out_depth *
347                       (args.in_depth * (f_r * args.filter_cols + f_c)) +
348                   (od_start + od);
349 
350               const int64_t buf_index =
351                   filter_buf_stride * (b_r * base_filter_cols + b_c) +
352                   out_depth_base + shard_base;
353 
354               CopyFilterDepth<T>()(args, filter_in + in_index,
355                                    filter_buf + buf_index);
356             }
357           }
358         }
359       }
360     }
361 
362     // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
363     // Intermediate results are stored in 'out_buffer'.
364     // Final results are stored in 'filter_out'.
365     ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
366                                      shard_rows, shard_cols, filter_buf,
367                                      filter_buf_stride, coord_stride,
368                                      transform_matrix, out_buffer, filter_out);
369   }
370 };
371 
372 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
373 //
374 // filter_in:
375 //   [filter_rows, filter_cols, in_depth, out_depth]
376 //
377 // filter_out:
378 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
379 //
380 template <typename T>
381 struct TransformFilters {
operator ()tensorflow::TransformFilters382   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
383                   const DeepConv2DTransform<T>* transform,
384                   const int64_t filter_shards_row,
385                   const int64_t filter_shards_col, const T* filter_in,
386                   T* filter_out) {
387     const int64_t in_depth = args.in_depth;
388     const int64_t out_depth = args.out_depth;
389 
390     const int64_t tile_rows = transform->input_shape().rows;
391     const int64_t tile_cols = transform->input_shape().cols;
392     const int64_t tile_spatial_size = tile_rows * tile_cols;
393 
394     const int64_t base_filter_rows = transform->filter_shape().rows;
395     const int64_t base_filter_cols = transform->filter_shape().cols;
396     const int64_t base_filter_spatial_size =
397         base_filter_rows * base_filter_cols;
398 
399     const int64_t filter_shards_total = filter_shards_row * filter_shards_col;
400 
401     // Calculate filter transform batch based on cache/filter sizes.
402 
403     // Cache budget (based on L2 cache size = 256KB).
404     // TODO(andydavis) Read cache size from system.
405     const int64_t cache_size = (256LL << 10) / sizeof(T);
406 
407     // Fixed cost.
408     const int64_t filter_transform_matrix_size =
409         tile_spatial_size * base_filter_spatial_size;
410 
411     // Per-filter costs.
412     const int64_t filter_total_size =
413         base_filter_spatial_size * in_depth * filter_shards_total;
414 
415     const int64_t filter_transform_buffer_size =
416         base_filter_spatial_size * filter_shards_total * in_depth;
417 
418     const int64_t filter_out_buf_size =
419         tile_spatial_size * filter_shards_total * in_depth;
420 
421     // Total per-filter costs.
422     const int64_t per_filter_cost =
423         filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
424 
425     // Remove fixed cost and divide by per-filter cost.
426     const int64_t num_filters_cache =
427         std::max(int64_t{1},
428                  (cache_size - filter_transform_matrix_size) / per_filter_cost);
429     const int64_t num_filters_transform =
430         std::min(out_depth, num_filters_cache);
431 
432     // Allocate buffer for filter transform matrix:
433     //   [tile_spatial_size, base_filter_spatial_size]
434     Tensor filter_transform_matrix;
435     OP_REQUIRES_OK(
436         ctx, ctx->allocate_temp(
437                  DataTypeToEnum<T>::value,
438                  TensorShape({tile_spatial_size, base_filter_spatial_size}),
439                  &filter_transform_matrix));
440     T* transform_matrix = filter_transform_matrix.template flat<T>().data();
441     transform->GetFilterTransformMatrix(
442         tile_spatial_size, base_filter_spatial_size, transform_matrix);
443 
444     auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
445                   &num_filters_transform, &in_depth, &filter_shards_row,
446                   &filter_shards_col, &tile_spatial_size, &filter_in,
447                   &transform_matrix,
448                   &filter_out](int64_t start, int64_t limit) {
449       // Allocate buffer for pre-processed filter:
450       //   [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
451       //
452       Tensor filter_transform_buffer;
453       OP_REQUIRES_OK(ctx,
454                      ctx->allocate_temp(
455                          DataTypeToEnum<T>::value,
456                          TensorShape({base_filter_rows, base_filter_cols,
457                                       num_filters_transform, filter_shards_row,
458                                       filter_shards_col, in_depth}),
459                          &filter_transform_buffer));
460       T* filter_buf = filter_transform_buffer.template flat<T>().data();
461 
462       // Allocate buffer for output filter transform matrix:
463       //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
464       Tensor filter_output_buffer;
465       OP_REQUIRES_OK(
466           ctx,
467           ctx->allocate_temp(
468               DataTypeToEnum<T>::value,
469               TensorShape({tile_spatial_size, num_filters_transform,
470                            filter_shards_row, filter_shards_col, in_depth}),
471               &filter_output_buffer));
472       T* out_buffer = filter_output_buffer.template flat<T>().data();
473 
474       const int64_t num_filters = limit - start;
475       const int64_t od_unroll = num_filters_transform;
476       const int64_t od_unroll_limit = (num_filters / od_unroll) * od_unroll;
477 
478       for (int64_t od = start; od < od_unroll_limit; od += od_unroll) {
479         TransformFilterRange<T>()(args, transform, od, od + od_unroll,
480                                   filter_in, transform_matrix, out_buffer,
481                                   filter_buf, filter_out);
482       }
483 
484       if (od_unroll_limit < limit) {
485         TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
486                                   filter_in, transform_matrix, out_buffer,
487                                   filter_buf, filter_out);
488       }
489     };
490     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
491 
492     const int64_t shard_cost = args.filter_rows * args.filter_cols * in_depth *
493                                filter_shards_total * tile_spatial_size;
494     // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
495     Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
496   }
497 };
498 
499 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
500 // gemm-kernel friendly data layout.
501 //
502 // Data layout for 'lhs_block':
503 //   [out_depth, shard_rows, shard_cols, in_depth].
504 
505 template <typename T>
506 class GemmFilterPacker {
507  public:
508   typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::RowMajor>
509       LhsMapper;
510   typedef Eigen::internal::gebp_traits<T, T> Traits;
511   Eigen::internal::gemm_pack_lhs<
512       T, int64_t, LhsMapper, Traits::mr, Traits::LhsProgress,
513       typename Traits::LhsPacket4Packing, Eigen::RowMajor>
514       pack_lhs;
515 
GemmFilterPacker(const int64_t rows,const int64_t depth,const T * lhs_input,T * lhs_block)516   GemmFilterPacker(const int64_t rows, const int64_t depth, const T* lhs_input,
517                    T* lhs_block)
518       : rows_(rows),
519         depth_(depth),
520         lhs_block_(lhs_block),
521         lhs_mapper_(lhs_input, depth_) {}
522 
Run()523   void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
524 
525  private:
526   const int64_t rows_;
527   const int64_t depth_;
528   T* lhs_block_;
529   LhsMapper lhs_mapper_;
530 };
531 
532 // Packs transformed filter stored in 'filter_transform_data' into
533 // 'packed_filters' to be used by GemmState.
534 template <typename T>
535 struct PackFilters {
operator ()tensorflow::PackFilters536   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
537                   const int64_t tile_spatial_size,
538                   const int64_t filter_shards_row,
539                   const int64_t filter_shards_col,
540                   const T* filter_transform_data,
541                   std::vector<Tensor>* packed_filters) {
542     const int64_t in_depth = args.in_depth;
543     const int64_t out_depth = args.out_depth;
544     const int64_t num_filters =
545         filter_shards_row * filter_shards_col * out_depth;
546 
547     auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
548                   &out_depth, &filter_shards_row, &filter_shards_col,
549                   &num_filters](int64_t start, int64_t limit) {
550       const int64_t filter_coord_stride = num_filters * in_depth;
551       for (int64_t i = start; i < limit; ++i) {
552         // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
553         OP_REQUIRES_OK(
554             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
555                                     TensorShape({out_depth, filter_shards_row,
556                                                  filter_shards_col, in_depth}),
557                                     &(*packed_filters)[i]));
558         T* packed_filter = (*packed_filters)[i].template flat<T>().data();
559         // Pack filters.
560         GemmFilterPacker<T> packer(
561             num_filters, in_depth,
562             filter_transform_data + i * filter_coord_stride, packed_filter);
563         packer.Run();
564       }
565     };
566     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
567     Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
568           num_filters * in_depth, shard);
569   }
570 };
571 
572 // Computes the product of filters stored in 'lhs_block' and input tiles
573 // stored in 'rhs_block', storing output in 'out_buffer'.
574 //
575 // Data layout for 'lhs_block':
576 //   [out_depth, shard_rows, shard_cols, in_depth].
577 //
578 // Data layout for 'rhs_block':
579 //   [num_tiles, in_depth]
580 //
581 // Data layout for 'out_buffer':
582 //   [num_tiles, out_depth, shard_rows, shard_cols]
583 
584 template <typename T>
585 class GemmState {
586  public:
587   typedef Eigen::internal::const_blas_data_mapper<T, int64_t, Eigen::ColMajor>
588       RhsMapper;
589   typedef Eigen::internal::blas_data_mapper<T, int64_t, Eigen::ColMajor>
590       OutputMapper;
591   typedef Eigen::internal::gebp_traits<T, T> Traits;
592 
593   Eigen::internal::gemm_pack_rhs<T, int64_t, RhsMapper, Traits::nr,
594                                  Eigen::ColMajor>
595       pack_rhs;
596   Eigen::internal::gebp_kernel<T, T, int64_t, OutputMapper, Traits::mr,
597                                Traits::nr, false, false>
598       gebp;
599 
GemmState(const int64_t rows,const int64_t cols,const int64_t depth,const int64_t out_buffer_size,const T * lhs_block,const T * rhs_input,T * rhs_block,T * out_buffer)600   GemmState(const int64_t rows, const int64_t cols, const int64_t depth,
601             const int64_t out_buffer_size, const T* lhs_block,
602             const T* rhs_input, T* rhs_block, T* out_buffer)
603       : rows_(rows),
604         cols_(cols),
605         depth_(depth),
606         out_buffer_size_(out_buffer_size),
607         lhs_block_(lhs_block),
608         rhs_block_(rhs_block),
609         out_buffer_(out_buffer),
610         rhs_mapper_(rhs_input, depth_),
611         out_mapper_(out_buffer, rows_) {}
612 
PackRhs()613   void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
614 
Compute()615   void Compute() {
616     memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
617     gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
618   }
619 
620  private:
621   const int64_t rows_;
622   const int64_t cols_;
623   const int64_t depth_;
624   const int64_t out_buffer_size_;
625   const T* lhs_block_;
626   T* rhs_block_;
627   T* out_buffer_;
628   RhsMapper rhs_mapper_;
629   OutputMapper out_mapper_;
630 };
631 
632 // Copies an input tile from 'input' into 'tile_buffer'.
633 //
634 // input:
635 //   [in_rows, in_cols, in_depth]
636 //
637 // tile_buffer:
638 //   [tile_rows, tile_cols, num_tiles, in_depth]
639 
640 template <typename T>
641 struct CopyInputTile {
operator ()tensorflow::CopyInputTile642   void operator()(const Conv2DArgs& args,
643                   const DeepConv2DTransform<T>* transform,
644                   const int64_t num_tiles, const int64_t in_r_start,
645                   const int64_t in_c_start, const T* input, T* tile_buffer) {
646     typedef typename Eigen::internal::packet_traits<T>::type Packet;
647     static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T));
648 
649     const int64_t tile_rows = transform->input_shape().rows;
650     const int64_t tile_cols = transform->input_shape().cols;
651     const int64_t coord_stride = num_tiles * args.in_depth;
652 
653     // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
654     const int64_t input_vectorized_size =
655         (args.in_depth / kPacketSize) * kPacketSize;
656     const int64_t input_scalar_size = args.in_depth % kPacketSize;
657 
658     for (int64_t r = 0; r < tile_rows; ++r) {
659       const int64_t in_r = in_r_start + r;
660       if (in_r < 0 || in_r >= args.in_rows) continue;
661 
662       for (int64_t c = 0; c < tile_cols; ++c) {
663         const int64_t in_c = in_c_start + c;
664         if (in_c < 0 || in_c >= args.in_cols) continue;
665 
666         auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
667         auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
668         // Copy vectorized portion of depth dimension.
669         for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) {
670           auto v = Eigen::internal::ploadu<Packet>(in + d);
671           Eigen::internal::pstoreu<T>(tile, v);
672           tile += kPacketSize;
673         }
674         // Copy scalar portion of inner dimension.
675         for (int64_t d = 0; d < input_scalar_size; ++d) {
676           tile[d] = in[input_vectorized_size + d];
677         }
678       }
679     }
680   }
681 };
682 
683 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
684 // final result in 'tile_transform'.
685 // Intermediate results are stored in 'tile_buffer'.
686 //
687 // input:
688 //   [in_rows, in_cols, in_depth]
689 // tile_buffer:
690 //   [tile_rows, tile_cols, num_tiles, in_depth]
691 // tile_transform_matrix:
692 //   [tile_spatial_size, tile_spatial_size]
693 // tile_transform:
694 //   [tile_rows, tile_cols, num_tiles, in_depth]
695 
696 template <typename T>
697 struct TransformInputTiles {
698   typedef Eigen::Map<
699       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
700       MatrixMap;
701   typedef Eigen::Map<
702       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
703       ConstMatrixMap;
704 
operator ()tensorflow::TransformInputTiles705   void operator()(const Conv2DArgs& args,
706                   const DeepConv2DTransform<T>* transform,
707                   const int64_t num_tiles, const int64_t in_r_start,
708                   const int64_t in_c_start, const T* input,
709                   const T* transform_matrix, T* tile_buffer,
710                   T* tile_transform) {
711     const int64_t tile_rows = transform->input_shape().rows;
712     const int64_t tile_cols = transform->input_shape().cols;
713     const int64_t tile_spatial_size = tile_rows * tile_cols;
714     const int64_t tile_stride_cols = transform->output_shape().cols;
715     const int64_t coord_stride = num_tiles * args.in_depth;
716     const int64_t num_tiles_stride = args.in_depth;
717 
718     memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
719     const int64_t in_r = in_r_start;
720     for (int64_t t = 0; t < num_tiles; ++t) {
721       const int64_t num_tiles_base = t * num_tiles_stride;
722       const int64_t in_c = in_c_start + t * tile_stride_cols;
723       CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
724                          tile_buffer + num_tiles_base);
725     }
726 
727     ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
728     ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
729     MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
730 
731     C.noalias() = A * B;
732   }
733 };
734 
735 // Transforms output tiles from buffer by 'out_transform_matrix', storing
736 // final result in 'output' (intermediate results stored in 'out_buffer').
737 //
738 // out_buffer:
739 //   [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
740 //
741 // output transform buffer:
742 //  [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
743 //
744 // output:
745 //   [out_rows, out_cols, out_depth]
746 //
747 
748 template <typename T>
749 struct TransformOutputTile {
750   typedef Eigen::Map<
751       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
752       MatrixMap;
753   typedef Eigen::Map<
754       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
755       ConstMatrixMap;
756 
operator ()tensorflow::TransformOutputTile757   void operator()(const Conv2DArgs& args,
758                   const DeepConv2DTransform<T>* transform,
759                   const int64_t num_tiles, const int64_t in_r,
760                   const int64_t in_c, const int64_t filter_shards_row,
761                   const int64_t filter_shards_col,
762                   const T* out_transform_matrix, const T* out_buffer,
763                   T* out_transform_buffer, T* output) {
764     const int64_t tile_rows = transform->input_shape().rows;
765     const int64_t tile_cols = transform->input_shape().cols;
766     const int64_t tile_spatial_size = tile_rows * tile_cols;
767 
768     const int64_t out_buf_stride =
769         num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
770 
771     const int64_t out_tile_rows = transform->output_shape().rows;
772     const int64_t out_tile_cols = transform->output_shape().cols;
773     const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
774 
775     // Compute output transform.
776     ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
777                      tile_spatial_size);
778     ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
779     MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
780 
781     C.noalias() = A * B;
782 
783     const int64_t tile_stride_rows = transform->output_shape().rows;
784     const int64_t tile_stride_cols = transform->output_shape().cols;
785 
786     const int64_t out_depth_stride = filter_shards_row * filter_shards_col;
787     const int64_t num_tiles_stride = args.out_depth * out_depth_stride;
788 
789     // Copy transformed output from 'out_transform_buffer' to proper index
790     // in 'output'. Note that some outputs at boundaries can be discarded.
791     for (int64_t t = 0; t < num_tiles; ++t) {
792       const int64_t tile_base = t * num_tiles_stride;
793 
794       for (int64_t od = 0; od < args.out_depth; ++od) {
795         const int64_t out_depth_base = od * out_depth_stride;
796 
797         // TODO(andydavis) Update filter sharding scheme in the next CL.
798         for (int64_t sr = 0; sr < filter_shards_row; ++sr) {
799           for (int64_t sc = 0; sc < filter_shards_col; ++sc) {
800             const int64_t shard_base = sr * filter_shards_col + sc;
801             const int64_t out_buf_base =
802                 tile_base + out_depth_base + shard_base;
803 
804             // Calculate output indices and outputs to drop (if needed).
805             const int64_t out_r_start =
806                 in_r + args.pad_rows - sr * tile_stride_rows;
807             // NOTE: The index 't' for 'num_tiles is used in index calculation
808             // for 'out_c_start' because we 'num_tiles' progresses along the
809             // column dimension.
810             const int64_t out_c_start = (in_c + t * tile_stride_cols) +
811                                         args.pad_cols - sc * tile_stride_cols;
812 
813             if (out_r_start < 0 || out_r_start >= args.out_rows ||
814                 out_c_start < 0 || out_c_start >= args.out_cols) {
815               continue;  // Skip un-needed outputs.
816             }
817 
818             // Increment output if not first filter shard.
819             const bool inc_output = (sr == 0 && sc == 0) ? false : true;
820 
821             for (int64_t ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
822               const int64_t out_r = out_r_start + ot_row;
823               if (out_r >= args.out_rows) continue;
824 
825               for (int64_t ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
826                 const int64_t out_c = out_c_start + ot_col;
827                 if (out_c >= args.out_cols) continue;
828 
829                 // Calculate out tile indexl
830                 const int64_t out_buf_index = ot_row * out_tile_cols + ot_col;
831                 // Read output value from buffer.
832                 const T out_val =
833                     out_transform_buffer[out_buf_base +
834                                          out_buf_index * out_buf_stride];
835                 // Calculate output index.
836                 const int64_t output_index =
837                     args.out_depth * (out_r * args.out_cols + out_c) + od;
838                 // Update output.
839                 if (inc_output) {
840                   output[output_index] += out_val;
841                 } else {
842                   output[output_index] = out_val;
843                 }
844               }
845             }
846           }
847         }
848       }
849     }
850   }
851 };
852 
853 template <typename T>
854 struct Conv2DState {
Conv2DStatetensorflow::Conv2DState855   Conv2DState(const int64_t tile_spatial_size, const int64_t filter_shards_row,
856               const int64_t filter_shards_col, const T* input,
857               const T* tile_transform_matrix, const T* output_transform_matrix,
858               T* buffer1, T* buffer2, T* packed_tile_buffer,
859               T* gemm_output_buffer)
860       : tile_spatial_size(tile_spatial_size),
861         filter_shards_row(filter_shards_row),
862         filter_shards_col(filter_shards_col),
863         input(input),
864         tile_transform_matrix(tile_transform_matrix),
865         output_transform_matrix(output_transform_matrix),
866         buffer1(buffer1),
867         buffer2(buffer2),
868         packed_tile_buffer(packed_tile_buffer),
869         gemm_output_buffer(gemm_output_buffer) {}
870 
871   const int64_t tile_spatial_size;
872   const int64_t filter_shards_row;
873   const int64_t filter_shards_col;
874   const T* input;
875   const T* tile_transform_matrix;
876   const T* output_transform_matrix;
877   T* buffer1;
878   T* buffer2;
879   T* packed_tile_buffer;
880   T* gemm_output_buffer;
881 };
882 
883 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
884 // (in_r, in_c), storing the results of the computation in 'output'.
885 // Details:
886 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
887 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
888 // *) Transforms output tiles, and stores result to 'output'.
889 
890 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
891 template <typename T>
892 struct ComputeConv2D {
operator ()tensorflow::ComputeConv2D893   void operator()(const Conv2DArgs& args,
894                   const DeepConv2DTransform<T>* transform,
895                   const Conv2DState<T>& cs, const int64_t in_r,
896                   const int64_t in_c, const int64_t num_tiles,
897                   const std::vector<Tensor>& packed_filters, const T* input,
898                   T* output) {
899     // Transform input tiles.
900     TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
901                              cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
902 
903     // Compute element-wise product (each a MatMul): input tiles X filters.
904     const int64_t in_depth = args.in_depth;
905     const int64_t out_depth = args.out_depth;
906     const int64_t num_filters =
907         cs.filter_shards_row * cs.filter_shards_col * out_depth;
908     const int64_t tile_coord_stride = num_tiles * in_depth;
909     const int64_t gemm_out_buf_size = num_tiles * num_filters;
910     const int64_t gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
911 
912     for (int64_t i = 0; i < cs.tile_spatial_size; ++i) {
913       GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
914                         packed_filters[i].template flat<T>().data(),
915                         cs.buffer2 + i * tile_coord_stride,
916                         cs.packed_tile_buffer, cs.gemm_output_buffer);
917       // Pack tile buffer.
918       gemm.PackRhs();
919       // Compute product.
920       gemm.Compute();
921       // Copy to larger output buffer without alignment requirements.
922       memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
923              gemm_out_buf_bytes);
924     }
925 
926     // Transform output.
927     TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
928                              cs.filter_shards_row, cs.filter_shards_col,
929                              cs.output_transform_matrix, cs.buffer1, cs.buffer2,
930                              output);
931   }
932 };
933 
934 namespace functor {
935 
936 // Conv2D operation specialized for deep convolutions (i.e. large
937 // in_depth * out_depth).
938 // Details:
939 // *) Transforms and packs filters from 'filter' in parallel.
940 // *) Computes Conv2D parallelized across 'batch' dimension.
941 //   *) Each thread loops over images in its batch shard, copying 'num_tiles'
942 //      input tiles into a local buffer, and computing the Conv2D output of
943 //      these tiles by all filters.
944 
945 // TODO(andydavis) Improve the performance of boundary cases where the input
946 // tile extends past the limit, and wasted outputs are computed. This overhead
947 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
948 // for smaller spatial sizes.
949 // TODO(andydavis) Improve the performance of sharded filters.
950 template <typename T>
951 struct DeepConv2D<CPUDevice, T> {
operator ()tensorflow::functor::DeepConv2D952   void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
953                   const T* filter, T* output) {
954     // TODO(andydavis) Add function to select transform based on conv params.
955     std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
956 
957     const int64_t in_depth = args.in_depth;
958     const int64_t out_depth = args.out_depth;
959 
960     const int64_t tile_rows = transform->input_shape().rows;
961     const int64_t tile_cols = transform->input_shape().cols;
962     const int64_t tile_spatial_size = tile_rows * tile_cols;
963 
964     const int64_t out_tile_rows = transform->output_shape().rows;
965     const int64_t out_tile_cols = transform->output_shape().cols;
966     const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
967 
968     const int64_t base_filter_rows = transform->filter_shape().rows;
969 
970     const int64_t filter_residual_row =
971         std::max(int64_t{0}, args.filter_rows - base_filter_rows);
972     const int64_t filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
973 
974     const int64_t filter_residual_col =
975         std::max(int64_t{0}, args.filter_cols - base_filter_rows);
976     const int64_t filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
977 
978     // Allocate buffer for transformed filters.
979     Tensor filter_transform;
980     OP_REQUIRES_OK(
981         ctx, ctx->allocate_temp(
982                  DataTypeToEnum<T>::value,
983                  TensorShape({tile_rows, tile_cols, out_depth,
984                               filter_shards_row, filter_shards_col, in_depth}),
985                  &filter_transform));
986     T* filter_transform_data = filter_transform.template flat<T>().data();
987 
988     // Transform filters.
989     TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
990                           filter_shards_col, filter, filter_transform_data);
991 
992     // Pack filters.
993     std::vector<Tensor> packed_filters(tile_spatial_size);
994     PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
995                      filter_shards_col, filter_transform_data, &packed_filters);
996 
997     // Allocate buffer for tile transform matrix.
998     Tensor tile_transform_matrix_tensor;
999     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
1000                             DataTypeToEnum<T>::value,
1001                             TensorShape({tile_spatial_size, tile_spatial_size}),
1002                             &tile_transform_matrix_tensor));
1003     T* tile_transform_matrix =
1004         tile_transform_matrix_tensor.template flat<T>().data();
1005     transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
1006                                        tile_transform_matrix);
1007 
1008     // Allocate buffer for output transform matrix.
1009     Tensor output_transform_matrix_tensor;
1010     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1011                                            TensorShape({out_tile_spatial_size,
1012                                                         tile_spatial_size}),
1013                                            &output_transform_matrix_tensor));
1014     T* output_transform_matrix =
1015         output_transform_matrix_tensor.template flat<T>().data();
1016     transform->GetOutputTransformMatrix(
1017         out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
1018 
1019     auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
1020                   out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
1021                   filter_shards_col, tile_spatial_size, &input,
1022                   &tile_transform_matrix, &output_transform_matrix,
1023                   &output](int64_t batch_start, int64_t batch_limit) {
1024       const int64_t row_tiles =
1025           (args.out_rows + out_tile_rows - 1) / out_tile_rows +
1026           filter_shards_row - 1;
1027       const int64_t col_tiles =
1028           (args.out_cols + out_tile_cols - 1) / out_tile_cols +
1029           filter_shards_col - 1;
1030 
1031       // Calculate number of tiles to process together.
1032       const int64_t filter_shard_size = filter_shards_row * filter_shards_col;
1033       const int64_t out_tile_spatial_size = out_tile_rows * out_tile_cols;
1034 
1035       // Cache budget (based on L2 cache size = 256KB).
1036       // TODO(andydavis) Read cache size from the system.
1037       const int64_t cache_size = (256LL << 10) / sizeof(T);
1038 
1039       // Fixed costs.
1040       const int64_t tile_transform_matrix_size =
1041           tile_spatial_size * tile_spatial_size;
1042       const int64_t output_transform_matrix_size =
1043           out_tile_spatial_size * tile_spatial_size;
1044       // Calculate cache reserve size.
1045       const int64_t filter_depth_size =
1046           in_depth * out_depth * filter_shard_size;
1047       const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
1048       const int64_t cache_reserve_size =
1049           small_filter ? filter_depth_size : 1024;
1050       // Calculate total fixed cost.
1051       const int64_t total_fixed_cost = tile_transform_matrix_size +
1052                                        output_transform_matrix_size +
1053                                        cache_reserve_size;
1054 
1055       // Per-tile costs.
1056       const int64_t buffer1_per_tile_size =
1057           tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
1058       const int64_t buffer2_per_tile_size =
1059           std::max(tile_spatial_size * in_depth,
1060                    out_tile_spatial_size * out_depth * filter_shard_size);
1061       const int64_t packed_tile_per_tile_size = in_depth;
1062       const int64_t gemm_out_per_tile_size = out_depth * filter_shard_size;
1063       const int64_t total_per_tile_cost =
1064           buffer1_per_tile_size + buffer2_per_tile_size +
1065           packed_tile_per_tile_size + gemm_out_per_tile_size;
1066 
1067       const int64_t num_tiles_cache = std::max(
1068           int64{4}, (cache_size - total_fixed_cost) / total_per_tile_cost);
1069       const int64_t num_tiles = std::min(num_tiles_cache, col_tiles);
1070 
1071       // Allocate temporary buffer 'buffer1', which is first used for copying
1072       // input tiles, then re-used to buffer gemm output. Calculate the
1073       // required buffer size for 'buffer1', based on max buffer size required
1074       // between copying input tiles and buffering gemm product output.
1075       //   buffer1: [max(buf1_tile_size, buf1_out_size)]
1076       const int64_t buffer1_tile_size =
1077           tile_spatial_size * num_tiles * in_depth;
1078       const int64_t buffer1_out_size =
1079           tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1080       const int64_t buffer1_size =
1081           std::max(buffer1_tile_size, buffer1_out_size);
1082       Tensor buffer1_tensor;
1083       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1084                                              TensorShape({buffer1_size}),
1085                                              &buffer1_tensor));
1086       T* buffer1 = buffer1_tensor.template flat<T>().data();
1087 
1088       // Allocate temporary buffer 'buffer2', which is first used for
1089       // transformed input tiles, then re-used for transformed output tiles.
1090       // Calculate required buffer size for 'buffer2' as max required buffer
1091       // between input and output transform buffer sizes.
1092       const int64_t buffer2_tile_transform_size =
1093           tile_spatial_size * num_tiles * in_depth;
1094       const int64_t buffer2_out_transform_size =
1095           out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
1096       const int64_t buffer2_size =
1097           std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
1098       Tensor buffer2_tensor;
1099       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1100                                              TensorShape({buffer2_size}),
1101                                              &buffer2_tensor));
1102       T* buffer2 = buffer2_tensor.template flat<T>().data();
1103 
1104       // Allocate temporary buffer to store packed tiles for one coordinate.
1105       // packed tile buffer: [num_tiles, in_depth].
1106       Tensor packed_tile_tensor;
1107       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1108                                              TensorShape({num_tiles, in_depth}),
1109                                              &packed_tile_tensor));
1110       T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
1111 
1112       // Allocate temporary buffer for gemm output.
1113       // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
1114       Tensor gemm_output_tensor;
1115       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
1116                                              TensorShape({num_tiles, out_depth,
1117                                                           filter_shards_row,
1118                                                           filter_shards_col}),
1119                                              &gemm_output_tensor));
1120       T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
1121 
1122       // Capture state needed for ComputeConv2D inner loop.
1123       Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
1124                                 filter_shards_col, input, tile_transform_matrix,
1125                                 output_transform_matrix, buffer1, buffer2,
1126                                 packed_tile_buffer, gemm_output_buffer);
1127 
1128       const int64_t row_pad = args.pad_rows;
1129       const int64_t col_pad = args.pad_cols;
1130       const int64_t unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
1131 
1132       const int64_t input_image_size = args.in_rows * args.in_cols * in_depth;
1133       const int64_t output_image_size =
1134           args.out_rows * args.out_cols * out_depth;
1135 
1136       const int64_t tile_stride_rows = transform->output_shape().rows;
1137       const int64_t tile_stride_cols = transform->output_shape().cols;
1138 
1139       for (int64_t b = batch_start; b < batch_limit; ++b) {
1140         const int64_t in_base = b * input_image_size;
1141         const int64_t out_base = b * output_image_size;
1142 
1143         for (int64_t tile_r = 0; tile_r < row_tiles; ++tile_r) {
1144           const int64_t in_r = tile_r * tile_stride_rows - row_pad;
1145 
1146           // Process unrolled tiles.
1147           for (int64_t tile_c = 0; tile_c < unroll_col_limit;
1148                tile_c += num_tiles) {
1149             const int64_t in_c = tile_c * tile_stride_cols - col_pad;
1150             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1151                                num_tiles, packed_filters, input + in_base,
1152                                output + out_base);
1153           }
1154           // Process remaining tiles.
1155           if (unroll_col_limit < col_tiles) {
1156             const int64_t rem_tiles = col_tiles - unroll_col_limit;
1157             const int64_t in_c = unroll_col_limit * tile_stride_cols - col_pad;
1158             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
1159                                rem_tiles, packed_filters, input + in_base,
1160                                output + out_base);
1161           }
1162         }
1163       }
1164     };
1165     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
1166     const int64_t shard_cost = args.out_rows * args.out_cols * args.out_depth *
1167                                tile_spatial_size * args.in_depth;
1168     Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
1169           shard_cost, shard);
1170   }
1171 };
1172 
1173 }  // namespace functor
1174 
1175 template struct functor::DeepConv2D<CPUDevice, float>;
1176 
1177 }  // namespace tensorflow
1178