xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #ifdef INTEL_MKL
19 
20 #include <algorithm>
21 #include <vector>
22 
23 #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h"
24 #include "tensorflow/core/util/use_cudnn.h"
25 #include "tensorflow/core/util/work_sharder.h"
26 #ifdef DNNL_AARCH64_USE_ACL
27 #include "tensorflow/core/platform/mutex.h"
28 #endif
29 
30 using dnnl::convolution_backward_weights;
31 using dnnl::memory;
32 using dnnl::prop_kind;
33 using dnnl::stream;
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 
39 using ConvBwdFilterDesc = dnnl::convolution_backward_weights::desc;
40 using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc;
41 
42 struct MklConvBwdFilterParams {
43   memory::dims src_dims;
44   memory::dims diff_filter_dims;
45   memory::dims diff_bias_dims;
46   memory::dims diff_dst_dims;
47   memory::dims strides;
48   MklTensorFormat tf_fmt;
49   bool native_format;
50   memory::dims dilations;
51   memory::dims padding_left;
52   memory::dims padding_right;
53 
MklConvBwdFilterParamstensorflow::MklConvBwdFilterParams54   MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
55                          memory::dims diff_bias_dims,
56                          memory::dims diff_dst_dims, memory::dims strides,
57                          MklTensorFormat tf_fmt, bool native_format,
58                          memory::dims dilations, memory::dims padding_left,
59                          memory::dims padding_right)
60       : src_dims(src_dims),
61         diff_filter_dims(diff_filter_dims),
62         diff_bias_dims(diff_bias_dims),
63         diff_dst_dims(diff_dst_dims),
64         strides(strides),
65         tf_fmt(tf_fmt),
66         native_format(native_format),
67         dilations(dilations),
68         padding_left(padding_left),
69         padding_right(padding_right) {}
70 };
71 
72 template <typename T>
73 class MklConvBwdFilterPrimitive : public MklPrimitive {
74  public:
MklConvBwdFilterPrimitive(const MklConvBwdFilterParams & convBwdFilterDims)75   explicit MklConvBwdFilterPrimitive(
76       const MklConvBwdFilterParams& convBwdFilterDims)
77       : MklPrimitive(engine(engine::kind::cpu, 0)) {
78     // Create convolution backward filter primitive.
79     if (context_.conv_bwd_filter == nullptr) {
80       Setup(convBwdFilterDims);
81     }
82   }
83 
~MklConvBwdFilterPrimitive()84   ~MklConvBwdFilterPrimitive() {}
85 
86   // Convolution backward weights execution with bias
87   //   src_data:         input data buffer for src
88   //   diff_filter_data: output data buffer for diff_filter
89   //   diff_bias_data:   output data buffer for diff_bias
90   //   diff_dst_data:    input data buffer for diff_dst
Execute(const T * src_data,const T * diff_filter_data,const T * diff_bias_data,const T * diff_dst_data,std::shared_ptr<stream> bwd_filter_stream)91   void Execute(const T* src_data, const T* diff_filter_data,
92                const T* diff_bias_data, const T* diff_dst_data,
93                std::shared_ptr<stream> bwd_filter_stream) {
94 #ifdef DNNL_AARCH64_USE_ACL
95     mutex_lock lock(primitive_execution_mu_);
96 #endif
97 #ifndef ENABLE_ONEDNN_OPENMP
98     // TODO(intel-tf): Create a common function and avoid the duplicate code
99     context_.src_mem->set_data_handle(
100         static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
101     context_.diff_filter_mem->set_data_handle(
102         static_cast<void*>(const_cast<T*>(diff_filter_data)),
103         *bwd_filter_stream);
104     if (diff_bias_data != nullptr) {
105       context_.diff_bias_mem->set_data_handle(
106           static_cast<void*>(const_cast<T*>(diff_bias_data)),
107           *bwd_filter_stream);
108     }
109     context_.diff_dst_mem->set_data_handle(
110         static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_filter_stream);
111 #else
112     context_.src_mem->set_data_handle(
113         static_cast<void*>(const_cast<T*>(src_data)));
114     context_.diff_filter_mem->set_data_handle(
115         static_cast<void*>(const_cast<T*>(diff_filter_data)));
116     if (diff_bias_data != nullptr) {
117       context_.diff_bias_mem->set_data_handle(
118           static_cast<void*>(const_cast<T*>(diff_bias_data)));
119     }
120     context_.diff_dst_mem->set_data_handle(
121         static_cast<void*>(const_cast<T*>(diff_dst_data)));
122 #endif  // !ENABLE_ONEDNN_OPENMP
123     execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
124                        context_.bwd_filter_primitives_args);
125 
126     context_.src_mem->set_data_handle(DummyData);
127     context_.diff_filter_mem->set_data_handle(DummyData);
128     if (diff_bias_data != nullptr) {
129       context_.diff_bias_mem->set_data_handle(DummyData);
130     }
131     context_.diff_dst_mem->set_data_handle(DummyData);
132   }
133 
134   // Convolution backward weights without bias.
135   //   src_data:         input data buffer of src
136   //   diff_filter_data: output data buffer of diff_filter
137   //   diff_dst_data:    input data buffer of diff_dst
Execute(const T * src_data,const T * diff_filter_data,const T * diff_dst_data,std::shared_ptr<stream> bwd_filter_stream)138   void Execute(const T* src_data, const T* diff_filter_data,
139                const T* diff_dst_data,
140                std::shared_ptr<stream> bwd_filter_stream) {
141     Execute(src_data, diff_filter_data, nullptr, diff_dst_data,
142             bwd_filter_stream);
143   }
144 
GetPrimitiveDesc() const145   std::shared_ptr<ConvBwdFilterPd> GetPrimitiveDesc() const {
146     return context_.bwd_filter_pd;
147   }
148 
149  private:
150   // Primitive reuse context for Conv2D backward filter op.
151   struct ConvBwdFilterContext {
152     // oneDNN memory for inputs and outputs.
153     std::shared_ptr<dnnl::memory> src_mem;
154     std::shared_ptr<dnnl::memory> diff_filter_mem;
155     std::shared_ptr<dnnl::memory> diff_bias_mem;
156     std::shared_ptr<dnnl::memory> diff_dst_mem;
157 
158     // Primitive descriptor and descriptor for convolution backward filter.
159     std::shared_ptr<ConvBwdFilterPd> bwd_filter_pd;
160     std::shared_ptr<ConvBwdFilterDesc> bwd_filter_desc;
161 
162     // Primitive descriptor and descriptor for convolution forward.
163     std::shared_ptr<ConvFwdPd> fwd_pd;
164     std::shared_ptr<ConvFwdDesc> fwd_desc;
165 
166     // Convolution backward filter primitive.
167     std::shared_ptr<dnnl::primitive> conv_bwd_filter;
168 
169     // Memory descriptors: forward & backward share the same memory descriptors
170     std::shared_ptr<dnnl::memory::desc> src_md;
171     std::shared_ptr<dnnl::memory::desc> diff_filter_md;
172     std::shared_ptr<dnnl::memory::desc> diff_bias_md;
173     std::shared_ptr<dnnl::memory::desc> diff_dst_md;
174 
175     // oneDNN pipeline for executing primitives.
176     std::shared_ptr<dnnl::stream> bwd_filter_stream;
177     std::vector<dnnl::primitive> bwd_filter_primitives;
178     std::vector<MemoryArgsMap> bwd_filter_primitives_args;
179 
ConvBwdFilterContexttensorflow::MklConvBwdFilterPrimitive::ConvBwdFilterContext180     ConvBwdFilterContext()
181         : src_mem(nullptr),
182           diff_filter_mem(nullptr),
183           diff_bias_mem(nullptr),
184           diff_dst_mem(nullptr),
185           bwd_filter_desc(nullptr),
186           fwd_pd(nullptr),
187           fwd_desc(nullptr),
188           src_md(nullptr),
189           diff_filter_md(nullptr),
190           diff_bias_md(nullptr),
191           diff_dst_md(nullptr) {}
192   };
193 
Setup(const MklConvBwdFilterParams & convBwdFilterDims)194   void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
195     memory::format_tag user_data_fmt;
196     if (convBwdFilterDims.native_format) {
197       user_data_fmt =
198           MklTensorFormatToMklDnnDataFormat(convBwdFilterDims.tf_fmt);
199     } else {
200       // Create memory descriptors for convolution backward filter without any
201       // specific format so that oneDNN can pick an appropriate one depending
202       // on the input parameters.
203       user_data_fmt = memory::format_tag::any;
204     }
205     context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims},
206                                            MklDnnType<T>(), user_data_fmt));
207 
208     context_.diff_dst_md.reset(new memory::desc(
209         {convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(), user_data_fmt));
210 
211     context_.diff_filter_md.reset(
212         new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(),
213                          memory::format_tag::any));
214 
215     if (!convBwdFilterDims.diff_bias_dims.empty())
216       context_.diff_bias_md.reset(
217           new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
218                            memory::format_tag::x));
219 
220     // Create descriptor and primitive descriptor for convolution forward.
221     context_.fwd_desc.reset(new ConvFwdDesc(
222         prop_kind::forward, dnnl::algorithm::convolution_direct,
223         *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md,
224         convBwdFilterDims.strides, convBwdFilterDims.dilations,
225         convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
226     context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
227 
228     // Create descriptor and primitive descriptor for convolution bwd filter.
229     if (!convBwdFilterDims.diff_bias_dims.empty()) {
230       context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
231           dnnl::algorithm::convolution_direct, *context_.src_md,
232           *context_.diff_filter_md, *context_.diff_bias_md,
233           *context_.diff_dst_md, convBwdFilterDims.strides,
234           convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
235           convBwdFilterDims.padding_right));
236     } else {
237       context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
238           dnnl::algorithm::convolution_direct, *context_.src_md,
239           *context_.diff_filter_md, *context_.diff_dst_md,
240           convBwdFilterDims.strides, convBwdFilterDims.dilations,
241           convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));
242     }
243     context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
244         *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
245 
246     auto bwd_filter_pd = context_.bwd_filter_pd.get();
247 
248     // Create memory using dummy data.
249     context_.src_mem.reset(
250         new memory(bwd_filter_pd->src_desc(), cpu_engine_, DummyData));
251     context_.diff_filter_mem.reset(
252         new memory(bwd_filter_pd->diff_weights_desc(), cpu_engine_, DummyData));
253     context_.diff_dst_mem.reset(
254         new memory(bwd_filter_pd->diff_dst_desc(), cpu_engine_, DummyData));
255 
256     // Create convolution backward filter primitive and add it to the net.
257     if (!convBwdFilterDims.diff_bias_dims.empty()) {
258       context_.diff_bias_mem.reset(
259           new memory({{convBwdFilterDims.diff_bias_dims},
260                       MklDnnType<T>(),
261                       memory::format_tag::x},
262                      cpu_engine_, DummyData));
263       context_.conv_bwd_filter.reset(
264           new convolution_backward_weights(*context_.bwd_filter_pd));
265       context_.bwd_filter_primitives_args.push_back(
266           {{DNNL_ARG_SRC, *context_.src_mem},
267            {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
268            {DNNL_ARG_DIFF_BIAS, *context_.diff_bias_mem},
269            {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}});
270     } else {
271       context_.conv_bwd_filter.reset(
272           new convolution_backward_weights(*context_.bwd_filter_pd));
273       context_.bwd_filter_primitives_args.push_back(
274           {{DNNL_ARG_SRC, *context_.src_mem},
275            {DNNL_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
276            {DNNL_ARG_DIFF_DST, *context_.diff_dst_mem}});
277     }
278     context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
279   }
280 
281   struct ConvBwdFilterContext context_;
282 
283 #ifdef DNNL_AARCH64_USE_ACL
284   mutex primitive_execution_mu_;
285 #endif
286 };
287 
288 template <typename T>
289 class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
290  public:
Get(const MklConvBwdFilterParams & convBwdFilterDims,bool do_not_cache)291   static MklConvBwdFilterPrimitive<T>* Get(
292       const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
293     MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
294 
295     if (do_not_cache) { /* Create new primitive always */
296       conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
297     } else {
298       // Look into the pool for reusable primitive.
299       conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
300           MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
301               convBwdFilterDims));
302 
303       if (conv_bwd_filter == nullptr) {
304         conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
305         MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
306             convBwdFilterDims, conv_bwd_filter);
307       }
308     }
309 
310     return conv_bwd_filter;
311   }
312 
313  private:
MklConvBwdFilterPrimitiveFactory()314   MklConvBwdFilterPrimitiveFactory() {}
~MklConvBwdFilterPrimitiveFactory()315   ~MklConvBwdFilterPrimitiveFactory() {}
316 
GetInstance()317   static MklConvBwdFilterPrimitiveFactory& GetInstance() {
318     static MklConvBwdFilterPrimitiveFactory instance_;
319     return instance_;
320   }
321 
CreateKey(const MklConvBwdFilterParams & convBwdFilterDims)322   static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
323     string prefix = "conv_bwd_filter";
324     FactoryKeyCreator key_creator;
325     key_creator.AddAsKey(prefix);
326     key_creator.AddAsKey(convBwdFilterDims.src_dims);
327     key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims);
328     key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims);
329     key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims);
330     key_creator.AddAsKey(convBwdFilterDims.strides);
331     key_creator.AddAsKey(convBwdFilterDims.dilations);
332     key_creator.AddAsKey(convBwdFilterDims.padding_left);
333     key_creator.AddAsKey(convBwdFilterDims.padding_right);
334     if (convBwdFilterDims.native_format) {
335       key_creator.AddAsKey(convBwdFilterDims.tf_fmt);
336     }
337     return key_creator.GetKey();
338   }
339 
GetConvBwdFilter(const MklConvBwdFilterParams & convBwdFilterDims)340   MklPrimitive* GetConvBwdFilter(
341       const MklConvBwdFilterParams& convBwdFilterDims) {
342     string key = CreateKey(convBwdFilterDims);
343     return this->GetOp(key);
344   }
345 
SetConvBwdFilter(const MklConvBwdFilterParams & convBwdFilterDims,MklPrimitive * op)346   void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
347                         MklPrimitive* op) {
348     string key = CreateKey(convBwdFilterDims);
349     this->SetOp(key, op);
350   }
351 };
352 
353 template <typename Device, class T, bool bias_enabled, bool is_depthwise,
354           bool native_format>
355 class MklConvCustomBackpropFilterOp
356     : public MklConvBackpropCommonOp<Device, T, is_depthwise> {
357  public:
MklConvCustomBackpropFilterOp(OpKernelConstruction * context)358   explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
359       : MklConvBackpropCommonOp<Device, T, is_depthwise>(context) {}
360 
~MklConvCustomBackpropFilterOp()361   ~MklConvCustomBackpropFilterOp() {}
362 
Compute(OpKernelContext * context)363   void Compute(OpKernelContext* context) {
364     try {
365       // Input tensors.
366       const Tensor& src_tensor = MklGetInput(context, kInputIdx);
367       const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
368       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx);
369 
370       MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
371       GetMklShape(context, kInputIdx, &src_mkl_shape, native_format);
372       GetMklShape(context, kFilterIdx, &filter_mkl_shape, native_format);
373       GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, native_format);
374       // Allow operator-specific sanity checking of shapes.
375       ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
376 
377       // Allow operator-specific generation of shapes.
378       // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
379       // tensor containing shape of filter. So filter.shape() is not
380       // a correct way to get filter shape. These operator-specific calls
381       // allow this class to handle this case.
382       TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
383       const string& op_type = this->type_string();
384       if ((op_type.find("3D") != std::string::npos) &&
385           (op_type.find("V2") != std::string::npos)) {
386         OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
387                     errors::InvalidArgument(
388                         "filter_sizes shape must be rank 1 but is rank ",
389                         filter_tensor.shape().dims()));
390       }
391       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
392       TensorShape diff_dst_tf_shape =
393           GetTfShape(context, kDiffDstIdx, native_format);
394 
395       // Corner cases: output with 0 elements and 0 batch size.
396       Tensor* diff_filter_tensor = nullptr;
397       if (src_tf_shape.num_elements() == 0 ||
398           filter_tf_shape.num_elements() == 0 ||
399           diff_dst_tf_shape.num_elements() == 0) {
400         MklDnnShape diff_filter_mkl_shape;
401         diff_filter_mkl_shape.SetMklTensor(false);
402         TensorShape diff_filter_tf_shape =
403             GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
404         const int kOutputIdx = 0;
405         AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
406                                   diff_filter_tf_shape, diff_filter_mkl_shape,
407                                   native_format);
408         DCHECK(diff_filter_tensor != nullptr);
409 
410         // If output tensor has more than 0 elements, we need to 0 them out.
411         auto diff_filter_data = diff_filter_tensor->flat<T>().data();
412         for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
413           diff_filter_data[i] = static_cast<T>(0);
414         }
415         return;
416       }
417 
418       // By default, all dims are in MKL order except those that are suffixed
419       // with `tf_order`
420       memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
421       memory::dims padding_left, padding_right, dilations, strides;
422       memory::dims fwd_dst_dims, fwd_dst_dims_tf_order;
423 
424       // Get forward convolution parameters.
425       bool is_grouped_convolution = false;
426       MklDnnConvUtil conv_util(context, this->strides_, this->padding_,
427                                this->data_format_, this->dilations_);
428       conv_util.GetConvFwdSizesInMklOrder(
429           src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
430           &strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims,
431           &padding_left, &padding_right, &is_grouped_convolution, false,
432           is_depthwise);
433       if (!context->status().ok()) return;
434 
435       bool is_conv2d = (this->strides_.size() == 4);
436 
437       auto tf_fmt = is_conv2d
438                         ? TFDataFormatToMklDnnDataFormat(this->data_format_)
439                         : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
440       auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
441       OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
442                   errors::InvalidArgument("Invalid data format"));
443 
444       auto fwd_src_md =
445           src_mkl_shape.IsMklTensor()
446               ? src_mkl_shape.GetMklLayout()
447               : memory::desc(fwd_src_dims, MklDnnType<T>(), mkl_fmt_tag);
448 
449       conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
450       if (!context->status().ok()) return;
451 
452       auto diff_dst_md =
453           diff_dst_mkl_shape.IsMklTensor()
454               ? diff_dst_mkl_shape.GetMklLayout()
455               : memory::desc(diff_dst_dims, MklDnnType<T>(), mkl_fmt_tag);
456 
457       memory::dims diff_bias_dims = {};
458       int64 depth = 0;
459       if (bias_enabled) {
460         TensorShape obp_tf_shape = GetTfShape(context, 2, native_format);
461         depth = (this->data_format_ == FORMAT_NCHW)
462                     ? obp_tf_shape.dim_size(1)
463                     : obp_tf_shape.dim_size(is_conv2d ? 3 : 4);
464         diff_bias_dims = {static_cast<int>(depth)};
465       }
466 
467       // The default dilation factor for each dimension is 1 in TF and
468       // 0 in oneDNN.
469       for (int i = 0; i < dilations.size(); ++i) --dilations[i];
470       MklConvBwdFilterParams convBwdFilterDims(
471           fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
472           tf_fmt, native_format, dilations, padding_left, padding_right);
473 
474       // oneDNN allocates large buffers when a conv gradient filter primitive
475       // is created. So we don't cache conv backward primitives when the env
476       // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
477       bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
478 
479       MklConvBwdFilterPrimitive<T>* conv_bwd_filter =
480           MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims,
481                                                    do_not_cache);
482 
483       // Allocate output tensors: diff_filter and diff_bias (w bias).
484       auto diff_filter_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
485 
486       MklDnnShape diff_filter_mkl_shape;
487       diff_filter_mkl_shape.SetMklTensor(false);
488 
489       if (is_conv2d) {
490         if (!is_depthwise && !is_grouped_convolution) {
491           // Conv2D: output_dims_mkl_order is in OIHW format.
492           TensorShape diff_filter_tf_shape(
493               {diff_filter_dims[MklDnnDims::Dim_H],
494                diff_filter_dims[MklDnnDims::Dim_W],
495                diff_filter_dims[MklDnnDims::Dim_I],
496                diff_filter_dims[MklDnnDims::Dim_O]});
497           AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
498                                     diff_filter_tf_shape, diff_filter_mkl_shape,
499                                     native_format);
500         } else if (is_depthwise) {
501           // Depthwise Conv2d: diff_filter_dims is GOIHW format.
502           //                  | TensorFlow       | oneDNN
503           // ----------------------------------------------------------------
504           // filter_out_depth | depth_multiplier | depth_multiplier *
505           //                  |                  | group_count
506           // ----------------------------------------------------------------
507           // filter_in_depth  | in_depth         | in_depth / group_count
508           // For depthwise convolution, we have group_count == in_depth.
509           // So here G = original I, and I = 1.
510           // And the GOIHW is oneDNN format, here we try to extract the TF
511           // format, TF format is HWIO, as G = original I, so here is HWGO.
512           TensorShape diff_filter_tf_shape(
513               {diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
514                diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
515                diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
516                diff_filter_dims
517                    [MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
518           AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
519                                     diff_filter_tf_shape, diff_filter_mkl_shape,
520                                     native_format);
521         } else {
522           // For group convolution, we have group_count == in_depth /
523           // filter_in_depth. So here G = in_depth / filter_in_depth, and
524           // O = original O / group_count.
525           // And the GOIHW is oneDNN format, here we try to extract the TF
526           // format, TF format is HWIO, here O is O * G.
527           TensorShape diff_filter_tf_shape(
528               {diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
529                diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
530                diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_I],
531                diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O] *
532                    diff_filter_dims
533                        [MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G]});
534           AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
535                                     diff_filter_tf_shape, diff_filter_mkl_shape,
536                                     native_format);
537         }
538       } else {
539         // Conv3D: output_dims_mkl_order is in OIDHW format.
540         TensorShape diff_filter_tf_shape(
541             {diff_filter_dims[MklDnnDims3D::Dim3d_D],
542              diff_filter_dims[MklDnnDims3D::Dim3d_H],
543              diff_filter_dims[MklDnnDims3D::Dim3d_W],
544              diff_filter_dims[MklDnnDims3D::Dim3d_I],
545              diff_filter_dims[MklDnnDims3D::Dim3d_O]});
546         AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
547                                   diff_filter_tf_shape, diff_filter_mkl_shape,
548                                   native_format);
549       }
550 
551       Tensor* diff_bias_tensor = nullptr;
552       if (bias_enabled) {
553         TensorShape diff_bias_shape({depth});
554         AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
555       }
556 
557       // Check if src and diff_dst need to be reordered.
558       T* src_data = nullptr;
559       MklDnnData<T> src(&cpu_engine_);
560       auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
561       if (fwd_src_md != bwd_filter_pd->src_desc()) {
562         src.SetUsrMem(fwd_src_md, &src_tensor);
563         src.CheckReorderToOpMem(bwd_filter_pd->src_desc(), cpu_engine_,
564                                 context);
565         src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
566       } else {
567         src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
568       }
569 
570       T* diff_dst_data = nullptr;
571       MklDnnData<T> diff_dst(&cpu_engine_);
572       if (diff_dst_md != bwd_filter_pd->diff_dst_desc()) {
573         diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
574         diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_desc(),
575                                      cpu_engine_, context);
576         diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
577       } else {
578         diff_dst_data =
579             static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
580       }
581 
582       DCHECK(!diff_filter_mkl_shape.IsMklTensor());
583       // Output layout is Tensorflow's filter layout
584       //   Conv2D: HWIO;  Conv3D: DHWIO; Depthwise Conv: HWIGO; Group Conv:
585       //   HWIGO
586       auto diff_filter_format =
587           (is_depthwise || is_grouped_convolution)
588               ? memory::format_tag::hwigo
589               : ((this->strides_.size() == 4) ? memory::format_tag::hwio
590                                               : memory::format_tag::dhwio);
591       auto diff_filter_md =
592           memory::desc(diff_filter_dims, MklDnnType<T>(), diff_filter_format);
593 
594       // Convert diff_filter (output) back to TF layout if needed
595       // (i.e. reorder op memory back to user memory)
596       MklDnnData<T> diff_filter(&cpu_engine_);
597       bool diff_filter_reorder_required = false;
598       T* diff_filter_data = nullptr;
599       if (diff_filter_md != bwd_filter_pd->diff_weights_desc()) {
600         // Allocate diff_filter tensor as Tensorflow layout.
601         diff_filter.SetUsrMem(diff_filter_dims, diff_filter_format,
602                               diff_filter_tensor);
603         diff_filter_reorder_required = true;
604         diff_filter.PrepareReorderToUserMemIfReq(
605             bwd_filter_pd->diff_weights_desc());
606         diff_filter_data =
607             static_cast<T*>(diff_filter.GetOpMem().get_data_handle());
608       } else {
609         diff_filter_data = static_cast<T*>(
610             const_cast<T*>(diff_filter_tensor->flat<T>().data()));
611       }
612 
613       // Execute convolution backward filter.
614       std::shared_ptr<stream> bwd_cpu_stream;
615       MklDnnThreadPool eigen_tp(context);
616       bwd_cpu_stream.reset(
617           CreateStream(&eigen_tp, conv_bwd_filter->GetEngine()));
618       if (bias_enabled) {
619         T* diff_bias_data =
620             static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
621         conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
622                                  diff_dst_data, bwd_cpu_stream);
623       } else {
624         conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data,
625                                  bwd_cpu_stream);
626       }
627 
628       // Reorder diff_filter back to Tensorflow layout if necessary.
629       if (diff_filter_reorder_required) {
630         diff_filter.InsertReorderToUserMem(context);
631       }
632 
633       // Delete primitive since it is not cached.
634       if (do_not_cache) delete conv_bwd_filter;
635     } catch (dnnl::error& e) {
636       string error_msg = "Status: " + std::to_string(e.status) +
637                          ", message: " + string(e.message) + ", in file " +
638                          string(__FILE__) + ":" + std::to_string(__LINE__);
639       OP_REQUIRES_OK(
640           context,
641           errors::Aborted("Operation received an exception:", error_msg));
642     }
643   }
644 
645  private:
646   const int kInputIdx = 0, kFilterIdx = 1, kDiffDstIdx = 2;
647   const int kDilationH = 0, kDilationW = 1;
648 
649   engine cpu_engine_ = engine(engine::kind::cpu, 0);
650 
651   // Assert that input shapes are valid.
ValidateMklShapes(const MklDnnShape & input_mkl_shape,const MklDnnShape & filter_mkl_shape,const MklDnnShape & obp_mkl_shape)652   void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
653                          const MklDnnShape& filter_mkl_shape,
654                          const MklDnnShape& obp_mkl_shape) {
655     CHECK(!filter_mkl_shape.IsMklTensor())
656         << "ConvBackpropFilter: filter should not be in MKL Layout";
657   }
658 
659   // Get TensorFlow shape of input tensor.
MakeInputTfShape(OpKernelContext * context,const Tensor & input_tensor)660   TensorShape MakeInputTfShape(OpKernelContext* context,
661                                const Tensor& input_tensor) {
662     size_t input_idx = 0;
663     return GetTfShape(context, input_idx, native_format);
664   }
665 
666   // Get TensorFlow shape of filter tensor.
MakeFilterTfShape(OpKernelContext * context,const Tensor & filter_tensor)667   TensorShape MakeFilterTfShape(OpKernelContext* context,
668                                 const Tensor& filter_tensor) {
669     TensorShape filter_tf_shape;
670     CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
671     CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(),
672                                          &filter_tf_shape)
673                  .ok(),
674              true);
675     return filter_tf_shape;
676   }
677 
678   // Get Tensorflow shape of output tensor (diff_filter),
679   // which is same as shape of filter.
GetOutputTfShape(const TensorShape & input_shape,const TensorShape & filter_shape,const TensorShape & outbprop_shape)680   TensorShape GetOutputTfShape(const TensorShape& input_shape,
681                                const TensorShape& filter_shape,
682                                const TensorShape& outbprop_shape) {
683     return filter_shape;
684   }
685 
686   // Get the shape of output (diff_filter) in oneDNN order.
687   // Computes shape of output from input shape (fwd_input_dims)
688   // and filter shape (fwd_filter_dims).
GetOutputDims(const memory::dims & fwd_input_dims,const memory::dims & fwd_filter_dims)689   const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
690                                     const memory::dims& fwd_filter_dims) {
691     return fwd_filter_dims;
692   }
693 
AllocateOutputTensor(OpKernelContext * context,const memory::dims & output_dims_mkl_order,Tensor ** output_tensor)694   void AllocateOutputTensor(OpKernelContext* context,
695                             const memory::dims& output_dims_mkl_order,
696                             Tensor** output_tensor) {
697     DCHECK(output_tensor != nullptr);
698 
699     // For BackpropFilter, we convert the output tensor back in Tensorflow
700     // layout. Because typically, BackpropFilter is the last operator in the
701     // graph that emit filter gradient that is provided to ApplyGradient
702     // method to update the filter. But it may be possible to eliminate this
703     // by forwarding filter in MKL layout if we support ApplyGradient method
704     // for MKL layout propagation.
705     MklDnnShape output_mkl_shape;
706     output_mkl_shape.SetMklTensor(false);
707     // output_dims_mkl_order is in OIHW format.
708     // Allocate shape of TF tensor in HWIO format.
709     TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H],
710                                  output_dims_mkl_order[MklDnnDims::Dim_W],
711                                  output_dims_mkl_order[MklDnnDims::Dim_I],
712                                  output_dims_mkl_order[MklDnnDims::Dim_O]});
713     AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
714                               output_mkl_shape);
715   }
716 
AllocateBiasGradTensor(OpKernelContext * context,const TensorShape & bias_grad_shape,Tensor ** bias_grad_tensor)717   void AllocateBiasGradTensor(OpKernelContext* context,
718                               const TensorShape& bias_grad_shape,
719                               Tensor** bias_grad_tensor) {
720     DCHECK(bias_grad_tensor);
721 
722     MklDnnShape bias_grad_mkl_shape;
723     bias_grad_mkl_shape.SetMklTensor(false);
724     AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
725                               bias_grad_mkl_shape, native_format);
726   }
727 };
728 
729 #define REGISTER_MKL_FILTER_KERNELS(T)                                   \
730   REGISTER_KERNEL_BUILDER(                                               \
731       Name("_MklConv2DBackpropFilter")                                   \
732           .Device(DEVICE_CPU)                                            \
733           .TypeConstraint<T>("T")                                        \
734           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
735       MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \
736   REGISTER_KERNEL_BUILDER(                                               \
737       Name("_MklConv2DBackpropFilterWithBias")                           \
738           .Device(DEVICE_CPU)                                            \
739           .TypeConstraint<T>("T")                                        \
740           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
741       MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, false>);  \
742   REGISTER_KERNEL_BUILDER(                                               \
743       Name("_MklDepthwiseConv2dNativeBackpropFilter")                    \
744           .Device(DEVICE_CPU)                                            \
745           .TypeConstraint<T>("T")                                        \
746           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
747       MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, false>);  \
748   REGISTER_KERNEL_BUILDER(                                               \
749       Name("__MklDummyConv2DBackpropFilterWithBias")                     \
750           .Device(DEVICE_CPU)                                            \
751           .TypeConstraint<T>("T")                                        \
752           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
753       MklDummyOp<CPUDevice, T>);                                         \
754   REGISTER_KERNEL_BUILDER(                                               \
755       Name("_MklConv3DBackpropFilterV2")                                 \
756           .Device(DEVICE_CPU)                                            \
757           .TypeConstraint<T>("T")                                        \
758           .Label(mkl_op_registry::kMklLayoutDependentOpLabel),           \
759       MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, false>); \
760   REGISTER_KERNEL_BUILDER(                                               \
761       Name("_MklNativeConv2DBackpropFilter")                             \
762           .Device(DEVICE_CPU)                                            \
763           .TypeConstraint<T>("T")                                        \
764           .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
765       MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>);  \
766   REGISTER_KERNEL_BUILDER(                                               \
767       Name("_MklNativeDepthwiseConv2dNativeBackpropFilter")              \
768           .Device(DEVICE_CPU)                                            \
769           .TypeConstraint<T>("T")                                        \
770           .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
771       MklConvCustomBackpropFilterOp<CPUDevice, T, false, true, true>);   \
772   REGISTER_KERNEL_BUILDER(                                               \
773       Name("_MklNativeConv3DBackpropFilterV2")                           \
774           .Device(DEVICE_CPU)                                            \
775           .TypeConstraint<T>("T")                                        \
776           .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
777       MklConvCustomBackpropFilterOp<CPUDevice, T, false, false, true>);  \
778   REGISTER_KERNEL_BUILDER(                                               \
779       Name("_MklNativeConv2DBackpropFilterWithBias")                     \
780           .Device(DEVICE_CPU)                                            \
781           .TypeConstraint<T>("T")                                        \
782           .Label(mkl_op_registry::kMklNameChangeOpLabel),                \
783       MklConvCustomBackpropFilterOp<CPUDevice, T, true, false, true>);
784 
785 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
786 TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
787 
788 #undef REGISTER_MKL_FILTER_KERNELS
789 
790 }  // namespace tensorflow
791 
792 #endif  // INTEL_MKL
793