xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #if defined(INTEL_MKL)
16 #include "tensorflow/cc/ops/const_op.h"
17 #include "tensorflow/cc/ops/image_ops.h"
18 #include "tensorflow/cc/ops/nn_ops.h"
19 #include "tensorflow/cc/ops/nn_ops_internal.h"
20 #include "tensorflow/cc/ops/standard_ops.h"
21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
22 #include "tensorflow/core/framework/fake_input.h"
23 #include "tensorflow/core/framework/node_def_builder.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/graph/mkl_graph_util.h"
27 #include "tensorflow/core/kernels/conv_ops_gpu.h"
28 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
29 #include "tensorflow/core/kernels/ops_testutil.h"
30 #include "tensorflow/core/kernels/ops_util.h"
31 #include "tensorflow/core/platform/cpu_info.h"
32 #include "tensorflow/core/platform/test.h"
33 #include "tensorflow/core/platform/test_benchmark.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/public/session.h"
36 
37 namespace tensorflow {
38 
39 // Helper class for converting MKL tensors to TF tensors and comparing to
40 // expected values
41 
42 static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
43 static const TensorShape dummy_shape({8});
44 // Set the default padding value for FusedConv test.
45 // Padding type will be `SAME` if padding value is kInvalidPaddingValue,
46 // otherwise it will be `EXPLICIT` for Mkl ops and `VALID` for Eigen op.
47 static const int kInvalidPaddingValue = -1;
48 
49 using BiasAddGraphRunner =
50     std::function<void(const Tensor& input_data, const Tensor& filter_data,
51                        const Tensor& bias_data, Tensor* out)>;
52 
53 using FusedGraphRunner = std::function<void(
54     const Tensor& input_data, const Tensor& filter_data,
55     const Tensor& bias_data, const std::vector<string>& fused_ops, Tensor* out,
56     const int padding)>;
57 
58 using FusedMatMulRunner =
59     std::function<void(const Tensor& input_data, const Tensor& filter_data,
60                        const Tensor& bias_data,
61                        const std::vector<string>& fused_ops, Tensor* out)>;
62 
63 template <typename T>
64 class CommonTestUtilities : public OpsTestBase {
65  public:
PerformConversion(DataType dtype,const Tensor & tensor,const Tensor & mkl_meta_tensor,Tensor * output)66   void PerformConversion(DataType dtype, const Tensor& tensor,
67                          const Tensor& mkl_meta_tensor, Tensor* output) {
68     // Create an MKL to TF conversion node and execute it
69     TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
70                      .Input(FakeInput(dtype))     // Input
71                      .Input(FakeInput(DT_UINT8))  // Mkl second tensor
72                      .Attr("T", dtype)
73                      .Attr("_kernel", "MklLayoutDependentOp")
74                      .Finalize(node_def()));
75     TF_EXPECT_OK(InitOp());
76     AddInputFromArray<T>(tensor.shape(), tensor.flat<T>());
77     AddInputFromArray<uint8>(mkl_meta_tensor.shape(),
78                              mkl_meta_tensor.flat<uint8>());
79     TF_ASSERT_OK(RunOpKernel());
80 
81     *output = *GetOutput(0);
82   }
83 
84   // Runs a Tensorflow graph defined by the root scope, and fetches the result
85   // of 'fetch' node into the output Tensor.
RunAndFetch(const tensorflow::Scope & root,const string & fetch,Tensor * output)86   static void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
87                           Tensor* output) {
88     tensorflow::GraphDef graph;
89     TF_ASSERT_OK(root.ToGraphDef(&graph));
90 
91     std::unique_ptr<tensorflow::Session> session(
92         tensorflow::NewSession(tensorflow::SessionOptions()));
93     TF_ASSERT_OK(session->Create(graph));
94 
95     std::vector<Tensor> unfused_tensors;
96     TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
97 
98     *output = unfused_tensors[0];
99   }
100 
ConvertAndCompare(DataType dtype,const Tensor & tensor,const Tensor & mkl_meta_tensor,const Tensor & expected)101   void ConvertAndCompare(DataType dtype, const Tensor& tensor,
102                          const Tensor& mkl_meta_tensor,
103                          const Tensor& expected) {
104     Tensor output;
105     PerformConversion(dtype, tensor, mkl_meta_tensor, &output);
106     test::ExpectTensorNear<T>(expected, output, 1e-5);
107   }
108 
ConvertAndCompareIntegral(DataType dtype,const Tensor & tensor,const Tensor & mkl_meta_tensor,const Tensor & expected)109   void ConvertAndCompareIntegral(DataType dtype, const Tensor& tensor,
110                                  const Tensor& mkl_meta_tensor,
111                                  const Tensor& expected) {
112     Tensor output;
113     PerformConversion(dtype, tensor, mkl_meta_tensor, &output);
114     test::ExpectTensorEqual<T>(expected, output);
115   }
TestBody()116   void TestBody() {}
117 
VerifyBiasAddTensorsClose(int depth,int image_width,int image_height,int image_batch_count,int filter_size,int filter_count,const BiasAddGraphRunner & run_default,const BiasAddGraphRunner & run_fused)118   static void VerifyBiasAddTensorsClose(int depth, int image_width,
119                                         int image_height, int image_batch_count,
120                                         int filter_size, int filter_count,
121                                         const BiasAddGraphRunner& run_default,
122                                         const BiasAddGraphRunner& run_fused) {
123     DataType dtype = DataTypeToEnum<T>::v();
124 
125     Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
126     image.flat<T>() = image.flat<T>().template setRandom<random_gen_>();
127 
128     Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
129     filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
130 
131     const int bias_size = filter_count;
132     Tensor bias(dtype, {bias_size});
133     bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
134 
135     Tensor conv_2d;
136     Tensor fused_conv_2d;
137 
138     run_default(image, filter, bias, &conv_2d);
139     run_fused(image, filter, bias, &fused_conv_2d);
140 
141     ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
142     ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
143 
144     test::ExpectClose(conv_2d, fused_conv_2d, 1e-5);
145   }
146 
VerifyFusedTensorsClose(int depth,int image_width,int image_height,int image_batch_count,int filter_size,int filter_count,int bias_size,const std::vector<string> & fused_ops,const FusedGraphRunner & run_default,const FusedGraphRunner & run_fused,const int padding=kInvalidPaddingValue)147   static void VerifyFusedTensorsClose(
148       int depth, int image_width, int image_height, int image_batch_count,
149       int filter_size, int filter_count, int bias_size,
150       const std::vector<string>& fused_ops, const FusedGraphRunner& run_default,
151       const FusedGraphRunner& run_fused,
152       const int padding = kInvalidPaddingValue) {
153     DataType dtype = DataTypeToEnum<T>::v();
154 
155     Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
156     image.flat<T>() = image.flat<T>().template setRandom<random_gen_>();
157 
158     Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
159     filter.flat<T>() = filter.flat<T>().template setRandom<random_gen_>();
160 
161     Tensor bias(dtype, {bias_size});
162     bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
163 
164     Tensor conv_2d;
165     Tensor fused_conv_2d;
166 
167     run_default(image, filter, bias, fused_ops, &conv_2d, padding);
168     run_fused(image, filter, bias, fused_ops, &fused_conv_2d, padding);
169 
170     ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
171     ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
172 
173     test::ExpectClose(conv_2d, fused_conv_2d, 1e-5);
174   }
175 
VerifyFusedMatrixClose(int depth,int batch,int weight_count,const std::vector<string> & fused_ops,const FusedMatMulRunner & run_default,const FusedMatMulRunner & run_fused)176   static void VerifyFusedMatrixClose(int depth, int batch, int weight_count,
177                                      const std::vector<string>& fused_ops,
178                                      const FusedMatMulRunner& run_default,
179                                      const FusedMatMulRunner& run_fused) {
180     DataType dtype = DataTypeToEnum<T>::v();
181 
182     Tensor input(dtype, {batch, depth});
183     input.flat<T>() = input.flat<T>().template setRandom<random_gen_>();
184 
185     Tensor weight(dtype, {depth, weight_count});
186     weight.flat<T>() = weight.flat<T>().template setRandom<random_gen_>();
187 
188     Tensor bias(dtype, {weight_count});
189     bias.flat<T>() = bias.flat<T>().template setRandom<random_gen_>();
190 
191     Tensor output;
192     Tensor fused_output;
193 
194     run_default(input, weight, bias, fused_ops, &output);
195     run_fused(input, weight, bias, fused_ops, &fused_output);
196 
197     ASSERT_EQ(output.dtype(), fused_output.dtype());
198     ASSERT_EQ(output.shape(), fused_output.shape());
199 
200     test::ExpectClose(output, fused_output, 1e-5);
201   }
202 
203  private:
204   using random_gen_ = Eigen::internal::NormalRandomGenerator<T>;
205 };
206 
207 // Testing MKL's fused convolution ops
208 
209 template <typename T>
210 class MklFusedConv2DOpTest : public OpsTestBase {
211  protected:
212   static constexpr int kDepth = 3;
213   static constexpr int kImageWidth = 32;
214   static constexpr int kImageHeight = 32;
215   static constexpr int kImageBatchCount = 8;
216 
RunConv2DUnfused(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,const std::vector<string> & fused_ops,Tensor * output,const int padding,int stride=1)217   void RunConv2DUnfused(const Tensor& input_data, const Tensor& filter_data,
218                         const Tensor& bias_data,
219                         const std::vector<string>& fused_ops, Tensor* output,
220                         const int padding, int stride = 1) {
221     auto root = tensorflow::Scope::NewRootScope();
222     auto input_data_op =
223         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
224 
225     if (padding != kInvalidPaddingValue) {
226       Tensor padding_data(DT_INT32, {4, 2});
227       test::FillValues<int32>(&padding_data,
228                               {0, 0, padding, padding, padding, padding, 0, 0});
229       input_data_op = ops::Pad(root.WithOpName("pad"), input_data_op,
230                                ops::Const(root.WithOpName("padding_data"),
231                                           Input::Initializer(padding_data)));
232     }
233 
234     Output next_op = ops::Conv2D(
235         root.WithOpName("conv"), input_data_op,
236         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
237         {1, stride, stride, 1},
238         padding == kInvalidPaddingValue ? "SAME" : "VALID");
239 
240     string last_op = "";
241     if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
242         fused_ops.end()) {
243       last_op = "with_bias";
244       next_op = ops::BiasAdd(
245           root.WithOpName(last_op), next_op,
246           ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
247     }
248 
249     if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
250         fused_ops.end()) {
251       last_op = "with_add";
252       next_op = ops::AddN(root.WithOpName("with_add"),
253                           std::initializer_list<Input>{next_op, input_data_op});
254     }
255 
256     if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
257         fused_ops.end()) {
258       last_op = "with_relu";
259       next_op = ops::Relu(root.WithOpName(last_op), next_op);
260     }
261 
262     if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
263         fused_ops.end()) {
264       last_op = "with_relu6";
265       next_op = ops::Relu6(root.WithOpName(last_op), next_op);
266     }
267 
268     if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
269         fused_ops.end()) {
270       last_op = "with_elu";
271       next_op = ops::Elu(root.WithOpName(last_op), next_op);
272     }
273 
274     if (std::find(fused_ops.begin(), fused_ops.end(), "LeakyRelu") !=
275         fused_ops.end()) {
276       last_op = "with_leakyrelu";
277       next_op = ops::internal::LeakyRelu(root.WithOpName(last_op), next_op);
278     }
279 
280     CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
281   }
282 
RunMklFusedConv2DOp(const Tensor & image,const Tensor & filter,const std::vector<Tensor> & args,const std::vector<string> & fused_ops,Tensor * output,const int padding,int stride=1)283   void RunMklFusedConv2DOp(const Tensor& image, const Tensor& filter,
284                            const std::vector<Tensor>& args,
285                            const std::vector<string>& fused_ops, Tensor* output,
286                            const int padding, int stride = 1) {
287     DataType dtype = DataTypeToEnum<T>::v();
288     int num_args = static_cast<int>(args.size());
289 
290     NodeDefBuilder builder =
291         NodeDefBuilder("fused_conv_op", "_MklNativeFusedConv2D")
292             .Input(FakeInput(dtype))
293             .Input(FakeInput(dtype))
294             .Input(FakeInput(num_args, dtype))
295             .Attr("T", dtype)
296             .Attr("num_args", num_args)
297             .Attr("strides", {1, stride, stride, 1})
298             .Attr("padding",
299                   padding == kInvalidPaddingValue ? "SAME" : "EXPLICIT")
300             .Attr("fused_ops", fused_ops)
301             .Attr("_kernel", "MklNameChangeOp");
302 
303     if (padding != kInvalidPaddingValue)
304       builder.Attr("explicit_paddings",
305                    {0, 0, padding, padding, padding, padding, 0, 0});
306 
307     TF_EXPECT_OK(builder.Finalize(node_def()));
308     TF_EXPECT_OK(InitOp());
309 
310     AddInputFromArray<T>(image.shape(), image.flat<T>());
311     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
312     for (const Tensor& arg : args)
313       AddInputFromArray<T>(arg.shape(), arg.flat<T>());
314     TF_ASSERT_OK(RunOpKernel());
315 
316     // Compare output to expected results
317     const Tensor& output_tensor = *GetOutput(0);
318     CommonTestUtilities<T> test_util;
319     *output = output_tensor;
320   }
321 
322   // Verifies computing unfused ops in a graph is identical to FusedConv2D.
VerifyFusedConv2D(int filter_size,int filter_count,const std::vector<string> & fused_ops,const int padding=kInvalidPaddingValue,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)323   void VerifyFusedConv2D(int filter_size, int filter_count,
324                          const std::vector<string>& fused_ops,
325                          const int padding = kInvalidPaddingValue,
326                          int depth = kDepth, int image_width = kImageWidth,
327                          int image_height = kImageHeight,
328                          int image_batch_count = kImageBatchCount) {
329     const FusedGraphRunner run_default =
330         [this](const Tensor& input_data, const Tensor& filter_data,
331                const Tensor& bias_data, const std::vector<string>& fused_ops,
332                Tensor* out, const int padding) {
333           RunConv2DUnfused(input_data, filter_data, bias_data, fused_ops, out,
334                            padding);
335         };
336 
337     const FusedGraphRunner run_fused =
338         [this](const Tensor& input_data, const Tensor& filter_data,
339                const Tensor& bias_data, const std::vector<string>& fused_ops,
340                Tensor* out, const int padding) {
341           std::vector<Tensor> fused_input = {bias_data};
342           if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
343               fused_ops.end()) {
344             fused_input.push_back(input_data);
345           }
346           RunMklFusedConv2DOp(input_data, filter_data, fused_input, fused_ops,
347                               out, padding);
348         };
349 
350     const int bias_size = filter_count;
351     CommonTestUtilities<T>::VerifyFusedTensorsClose(
352         depth, image_width, image_height, image_batch_count, filter_size,
353         filter_count, bias_size, fused_ops, run_default, run_fused, padding);
354   }
355 };
356 
357 template <typename T>
358 class MklFusedConv2DWithBiasOpTest : public MklFusedConv2DOpTest<T> {};
359 
360 TYPED_TEST_SUITE_P(MklFusedConv2DWithBiasOpTest);
361 
362 // -------------------------------------------------------------------------- //
363 // Conv2D + BiasAdd + {Activation}                                            //
364 // -------------------------------------------------------------------------- //
365 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolution)366 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolution) {
367   const int kFilterSize = 1;
368   const int kFilterCount = 12;
369   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd"});
370 }
371 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolution)372 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolution) {
373   const int kFilterSize = 3;
374   const int kFilterCount = 12;
375   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd"});
376 }
377 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndRelu)378 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndRelu) {
379   const int kFilterSize = 1;
380   const int kFilterCount = 12;
381   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu"});
382 }
383 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndRelu)384 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
385   const int kFilterSize = 3;
386   const int kFilterCount = 12;
387   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu"});
388 }
389 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndRelu6)390 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndRelu6) {
391   const int kFilterSize = 1;
392   const int kFilterCount = 12;
393   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu6"});
394 }
395 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndRelu6)396 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndRelu6) {
397   const int kFilterSize = 3;
398   const int kFilterCount = 12;
399   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu6"});
400 }
401 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndElu)402 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndElu) {
403   const int kFilterSize = 1;
404   const int kFilterCount = 12;
405   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Elu"});
406 }
407 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndElu)408 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
409   const int kFilterSize = 3;
410   const int kFilterCount = 12;
411   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Elu"});
412 }
413 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndLeakyRelu)414 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndLeakyRelu) {
415   const int kFilterSize = 1;
416   const int kFilterCount = 12;
417   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "LeakyRelu"});
418 }
419 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndLeakyRelu)420 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndLeakyRelu) {
421   const int kFilterSize = 3;
422   const int kFilterCount = 12;
423   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "LeakyRelu"});
424 }
425 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndAdd)426 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndAdd) {
427   const int kFilterSize = 1;
428   const int kFilterCount = 3;
429   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Add"});
430 }
431 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndAdd)432 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndAdd) {
433   const int kFilterSize = 3;
434   const int kFilterCount = 3;
435   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Add"});
436 }
437 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndAddRelu)438 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndAddRelu) {
439   const int kFilterSize = 1;
440   const int kFilterCount = 3;
441   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
442                           {"BiasAdd", "Add", "Relu"});
443 }
444 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndAddRelu)445 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndAddRelu) {
446   const int kFilterSize = 3;
447   const int kFilterCount = 3;
448   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
449                           {"BiasAdd", "Add", "Relu"});
450 }
451 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndAddRelu6)452 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndAddRelu6) {
453   const int kFilterSize = 1;
454   const int kFilterCount = 3;
455   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
456                           {"BiasAdd", "Add", "Relu6"});
457 }
458 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndAddRelu6)459 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndAddRelu6) {
460   const int kFilterSize = 3;
461   const int kFilterCount = 3;
462   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
463                           {"BiasAdd", "Add", "Relu6"});
464 }
465 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndAddElu)466 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndAddElu) {
467   const int kFilterSize = 1;
468   const int kFilterCount = 3;
469   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Add", "Elu"});
470 }
471 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndAddElu)472 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndAddElu) {
473   const int kFilterSize = 3;
474   const int kFilterCount = 3;
475   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Add", "Elu"});
476 }
477 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,OneByOneConvolutionAndAddLeakyRelu)478 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndAddLeakyRelu) {
479   const int kFilterSize = 1;
480   const int kFilterCount = 3;
481   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
482                           {"BiasAdd", "Add", "LeakyRelu"});
483 }
484 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,SpatialConvolutionAndAddLeakyRelu)485 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndAddLeakyRelu) {
486   const int kFilterSize = 3;
487   const int kFilterCount = 3;
488   this->VerifyFusedConv2D(kFilterSize, kFilterCount,
489                           {"BiasAdd", "Add", "LeakyRelu"});
490 }
491 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,ConvolutionAndReluWithZeroPad)492 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, ConvolutionAndReluWithZeroPad) {
493   const int kFilterSize = 3;
494   const int kFilterCount = 3;
495   const int padding = 0;
496   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu"},
497                           padding);
498 }
499 
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest,ConvolutionAndReluWithOnePad)500 TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, ConvolutionAndReluWithOnePad) {
501   const int kFilterSize = 3;
502   const int kFilterCount = 3;
503   const int padding = 1;
504   this->VerifyFusedConv2D(kFilterSize, kFilterCount, {"BiasAdd", "Relu"},
505                           padding);
506 }
507 
508 REGISTER_TYPED_TEST_SUITE_P(
509     MklFusedConv2DWithBiasOpTest, OneByOneConvolution, SpatialConvolution,
510     OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
511     OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
512     OneByOneConvolutionAndElu, SpatialConvolutionAndElu,
513     OneByOneConvolutionAndLeakyRelu, SpatialConvolutionAndLeakyRelu,
514     OneByOneConvolutionAndAdd, SpatialConvolutionAndAdd,
515     OneByOneConvolutionAndAddRelu, SpatialConvolutionAndAddRelu,
516     OneByOneConvolutionAndAddRelu6, SpatialConvolutionAndAddRelu6,
517     OneByOneConvolutionAndAddElu, SpatialConvolutionAndAddElu,
518     OneByOneConvolutionAndAddLeakyRelu, SpatialConvolutionAndAddLeakyRelu,
519     ConvolutionAndReluWithZeroPad, ConvolutionAndReluWithOnePad);
520 
521 using MklFusedBiasAddDataTypes = ::testing::Types<float>;
522 INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedConv2DWithBiasOpTest,
523                                MklFusedBiasAddDataTypes);
524 
525 // Testing MKL's fused depthwise convolution ops
526 template <typename T>
527 class MklFusedDepthwiseConv2DOpTest : public OpsTestBase {
528  protected:
529   static constexpr int kDepth = 3;
530   static constexpr int kImageWidth = 32;
531   static constexpr int kImageHeight = 32;
532   static constexpr int kImageBatchCount = 8;
533 
RunDepthwiseConv2DUnfused(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,const std::vector<string> & fused_ops,Tensor * output,int stride=1)534   void RunDepthwiseConv2DUnfused(const Tensor& input_data,
535                                  const Tensor& filter_data,
536                                  const Tensor& bias_data,
537                                  const std::vector<string>& fused_ops,
538                                  Tensor* output, int stride = 1) {
539     auto root = tensorflow::Scope::NewRootScope();
540     auto input_data_op =
541         ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
542     Output next_op = ops::DepthwiseConv2dNative(
543         root.WithOpName("depthwise_conv"), input_data_op,
544         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
545         {1, stride, stride, 1}, "SAME");
546 
547     string last_op = "";
548     if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
549         fused_ops.end()) {
550       last_op = "with_bias";
551       next_op = ops::BiasAdd(
552           root.WithOpName(last_op), next_op,
553           ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
554     }
555 
556     if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
557         fused_ops.end()) {
558       last_op = "with_relu";
559       next_op = ops::Relu(root.WithOpName(last_op), next_op);
560     }
561 
562     if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
563         fused_ops.end()) {
564       last_op = "with_relu6";
565       next_op = ops::Relu6(root.WithOpName(last_op), next_op);
566     }
567 
568     if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
569         fused_ops.end()) {
570       last_op = "with_elu";
571       next_op = ops::Elu(root.WithOpName(last_op), next_op);
572     }
573 
574     CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
575   }
576 
RunMklFusedDepthwiseConv2DOp(const Tensor & image,const Tensor & filter,const std::vector<Tensor> & args,const std::vector<string> & fused_ops,Tensor * output,int stride=1)577   void RunMklFusedDepthwiseConv2DOp(const Tensor& image, const Tensor& filter,
578                                     const std::vector<Tensor>& args,
579                                     const std::vector<string>& fused_ops,
580                                     Tensor* output, int stride = 1) {
581     DataType dtype = DataTypeToEnum<T>::v();
582     int num_args = static_cast<int>(args.size());
583 
584     TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op",
585                                 "_MklNativeFusedDepthwiseConv2dNative")
586                      .Input(FakeInput(dtype))
587                      .Input(FakeInput(dtype))
588                      .Input(FakeInput(num_args, dtype))
589                      .Attr("T", dtype)
590                      .Attr("num_args", num_args)
591                      .Attr("strides", {1, stride, stride, 1})
592                      .Attr("padding", "SAME")
593                      .Attr("fused_ops", fused_ops)
594                      .Attr("_kernel", "MklNameChangeOp")
595                      .Finalize(node_def()));
596 
597     TF_EXPECT_OK(InitOp());
598 
599     AddInputFromArray<T>(image.shape(), image.flat<T>());
600     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
601     for (const Tensor& arg : args)
602       AddInputFromArray<T>(arg.shape(), arg.flat<T>());
603     TF_ASSERT_OK(RunOpKernel());
604 
605     // Compare output to expected results
606     const Tensor& output_tensor = *GetOutput(0);
607     CommonTestUtilities<T> test_util;
608     *output = output_tensor;
609   }
610 
611   // Verifies computing unfused ops in a graph is identical to
612   // FusedDepthwiseConv2D.
VerifyFusedDepthwiseConv2D(int filter_size,int filter_count,int bias_size,const std::vector<string> & fused_ops,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)613   void VerifyFusedDepthwiseConv2D(int filter_size, int filter_count,
614                                   int bias_size,
615                                   const std::vector<string>& fused_ops,
616                                   int depth = kDepth,
617                                   int image_width = kImageWidth,
618                                   int image_height = kImageHeight,
619                                   int image_batch_count = kImageBatchCount) {
620     const FusedGraphRunner run_default =
621         [this](const Tensor& input_data, const Tensor& filter_data,
622                const Tensor& bias_data, const std::vector<string>& fused_ops,
623                Tensor* out, const int padding) {
624           RunDepthwiseConv2DUnfused(input_data, filter_data, bias_data,
625                                     fused_ops, out);
626         };
627 
628     const FusedGraphRunner run_fused =
629         [this](const Tensor& input_data, const Tensor& filter_data,
630                const Tensor& bias_data, const std::vector<string>& fused_ops,
631                Tensor* out, const int padding) {
632           std::vector<Tensor> fused_input = {bias_data};
633           RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input,
634                                        fused_ops, out);
635         };
636 
637     CommonTestUtilities<T>::VerifyFusedTensorsClose(
638         depth, image_width, image_height, image_batch_count, filter_size,
639         filter_count, bias_size, fused_ops, run_default, run_fused);
640   }
641 };
642 
643 template <typename T>
644 class MklFusedDepthwiseConv2DWithBiasOpTest
645     : public MklFusedDepthwiseConv2DOpTest<T> {};
646 
647 TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest);
648 
649 // -------------------------------------------------------------------------- //
650 // DepthwiseConv2D + BiasAdd + {Activation}                                   //
651 // -------------------------------------------------------------------------- //
652 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,OneByOneConvolution)653 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution) {
654   const int kFilterSize = 1;
655   const int kFilterCount = 1;
656   const int kBiasSize = 3;
657   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
658                                    {"BiasAdd"});
659 }
660 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,SpatialConvolution)661 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolution) {
662   const int kFilterSize = 3;
663   const int kFilterCount = 1;
664   const int kBiasSize = 3;
665   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
666                                    {"BiasAdd"});
667 }
668 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,OneByOneConvolutionAndRelu)669 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
670              OneByOneConvolutionAndRelu) {
671   const int kFilterSize = 1;
672   const int kFilterCount = 1;
673   const int kBiasSize = 3;
674   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
675                                    {"BiasAdd", "Relu"});
676 }
677 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,SpatialConvolutionAndRelu)678 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
679   const int kFilterSize = 3;
680   const int kFilterCount = 1;
681   const int kBiasSize = 3;
682   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
683                                    {"BiasAdd", "Relu"});
684 }
685 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,OneByOneConvolutionAndRelu6)686 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
687              OneByOneConvolutionAndRelu6) {
688   const int kFilterSize = 1;
689   const int kFilterCount = 1;
690   const int kBiasSize = 3;
691   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
692                                    {"BiasAdd", "Relu6"});
693 }
694 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,SpatialConvolutionAndRelu6)695 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,
696              SpatialConvolutionAndRelu6) {
697   const int kFilterSize = 3;
698   const int kFilterCount = 1;
699   const int kBiasSize = 3;
700   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
701                                    {"BiasAdd", "Relu6"});
702 }
703 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,OneByOneConvolutionAndElu)704 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolutionAndElu) {
705   const int kFilterSize = 1;
706   const int kFilterCount = 1;
707   const int kBiasSize = 3;
708   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
709                                    {"BiasAdd", "Elu"});
710 }
711 
TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest,SpatialConvolutionAndElu)712 TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) {
713   const int kFilterSize = 3;
714   const int kFilterCount = 1;
715   const int kBiasSize = 3;
716   this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize,
717                                    {"BiasAdd", "Elu"});
718 }
719 
720 REGISTER_TYPED_TEST_SUITE_P(
721     MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution,
722     SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu,
723     OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6,
724     OneByOneConvolutionAndElu, SpatialConvolutionAndElu);
725 
726 using MklFusedBiasAddDataTypes = ::testing::Types<float>;
727 INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest,
728                                MklFusedBiasAddDataTypes);
729 
730 // Testing fusion of pad and convolution
731 template <typename T>
732 class FusedPadConvOpTest : public OpsTestBase {
733  public:
Run(const string data_format)734   void Run(const string data_format) {
735     DataType dtype = DataTypeToEnum<T>::v();
736 
737     // FusedPadConv op is only supported on AVX512.
738     // So skip test if CPU instruction set is AVX2 or ealer version.
739     if ((dtype == DT_BFLOAT16) && !tensorflow::port::TestCPUFeature(
740                                       tensorflow::port::CPUFeature::AVX512F))
741       return;
742 
743     const int depth = 1;
744     const int image_width = 4;
745     const int image_height = 3;
746     const int image_batch_count = 1;
747     const int stride = 1;
748 
749     Tensor image, expected;
750     if (data_format == "NHWC") {
751       image =
752           Tensor(dtype, {image_batch_count, image_height, image_width, depth});
753     } else {
754       image =
755           Tensor(dtype, {image_batch_count, depth, image_height, image_width});
756     }
757     test::FillValues<T>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
758 
759     const int kFilterSize = 3;
760     const int kFilterCount = 1;
761     Tensor filter(dtype, {kFilterSize, kFilterSize, depth, kFilterCount});
762     test::FillValues<T>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
763 
764     const int padding_height = 4;
765     const int padding_width = 2;
766     Tensor padding(DT_INT32, {padding_height, padding_width});
767     if (data_format == "NHWC") {
768       test::FillValues<int32>(&padding, {0, 0, 3, 4, 1, 2, 0, 0});
769     } else {
770       test::FillValues<int32>(&padding, {0, 0, 0, 0, 3, 4, 1, 2});
771     }
772 
773     if (data_format == "NHWC") {
774       expected = Tensor(dtype, TensorShape({1, 8, 5, 1}));
775     } else {
776       expected = Tensor(dtype, TensorShape({1, 1, 8, 5}));
777     }
778     test::FillValues<T>(
779         &expected,
780         {0,  0,   0,   0,   0,   24, 42,  60,  33,  12,  105, 150, 183, 95,
781          32, 235, 312, 357, 178, 56, 187, 234, 261, 121, 32,  106, 126, 138,
782          59, 12,  0,   0,   0,   0,  0,   0,   0,   0,   0,   0});
783 
784     // Create a fused pad+conv2d node
785     TF_EXPECT_OK(NodeDefBuilder("fused_pad_conv_op", "_MklNativePadWithConv2D")
786                      .Input(FakeInput(dtype))     // Input
787                      .Input(FakeInput(dtype))     // Filter
788                      .Input(FakeInput(DT_INT32))  // Padding
789                      .Attr("padding", "VALID")
790                      .Attr("data_format", data_format)
791                      .Attr("T", dtype)
792                      .Attr("strides", {1, stride, stride, 1})
793                      .Attr("_kernel", "MklNameChangeOp")
794                      .Finalize(node_def()));
795     TF_EXPECT_OK(InitOp());
796 
797     // Setting up inputs and execute
798     AddInputFromArray<T>(image.shape(), image.flat<T>());
799     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
800     AddInputFromArray<int32>(padding.shape(), padding.flat<int32>());
801     TF_ASSERT_OK(RunOpKernel());
802 
803     // Compare output to expected results
804     const Tensor& first = *GetOutput(0);
805     CommonTestUtilities<T> test_util;
806     test::ExpectTensorEqual<T>(expected, first);
807   }
808 };
809 
810 TYPED_TEST_SUITE_P(FusedPadConvOpTest);
811 
TYPED_TEST_P(FusedPadConvOpTest,PaddingConvTest)812 TYPED_TEST_P(FusedPadConvOpTest, PaddingConvTest) { this->Run("NHWC"); }
813 
TYPED_TEST_P(FusedPadConvOpTest,PaddingConvTestNchw)814 TYPED_TEST_P(FusedPadConvOpTest, PaddingConvTestNchw) { this->Run("NCHW"); }
815 
816 REGISTER_TYPED_TEST_SUITE_P(FusedPadConvOpTest, PaddingConvTest,
817                             PaddingConvTestNchw);
818 
819 using FusedPadConvDataTypes = ::testing::Types<float, bfloat16>;
820 INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedPadConvOpTest, FusedPadConvDataTypes);
821 
822 class FilterCacheTest : public OpsTestBase {
823  public:
824   template <typename T>
Run(DataType dtype,Tensor & image,Tensor & filter,Tensor & expected,const bool is_filter_const)825   void Run(DataType dtype, Tensor& image, Tensor& filter, Tensor& expected,
826            const bool is_filter_const) {
827     const int stride = 1;
828 
829     TF_EXPECT_OK(NodeDefBuilder("conv2d_filter_cache", "_MklNativeConv2D")
830                      .Input(FakeInput(dtype))  // Input
831                      .Input(FakeInput(dtype))  // Filter
832                      .Attr("padding", "VALID")
833                      .Attr("data_format", "NHWC")
834                      .Attr("is_filter_const", is_filter_const)
835                      .Attr("T", dtype)
836                      .Attr("strides", {1, stride, stride, 1})
837                      .Attr("_kernel", "MklNameChangeOp")
838                      .Finalize(node_def()));
839 
840     TF_EXPECT_OK(InitOp());
841 
842     // Setting up inputs and execute
843     AddInputFromArray<T>(image.shape(), image.flat<T>());
844     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
845 
846     TF_ASSERT_OK(RunOpKernel());
847 
848     // Compare outputs to expected results
849     const Tensor& output = *GetOutput(0);
850     CommonTestUtilities<T> conv_comp;
851     test::ExpectTensorEqual<T>(expected, output);
852 
853     // TODO(intel-tf): For now, we rely on internal performance tests to
854     // determine if filter data is being cached and reused.
855     // However, we still need to add a check here to determine if this is
856     // still the case by inspecting the contents of the persistent tensor.
857     TF_ASSERT_OK(RunOpKernel());
858 
859     // Compare output to expected results
860     const Tensor& output_new = *GetOutput(0);
861     CommonTestUtilities<T> conv_comp_new;
862     test::ExpectTensorEqual<T>(expected, output_new);
863   }
864 };
865 
TEST_F(FilterCacheTest,Conv2DFilterCacheTest)866 TEST_F(FilterCacheTest, Conv2DFilterCacheTest) {
867   const int depth = 1;
868   const int image_width = 4;
869   const int image_height = 3;
870   const int image_batch_count = 1;
871   Tensor image(DT_FLOAT, {image_batch_count, image_height, image_width, depth});
872   test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
873 
874   const int kFilterSize = 3;
875   const int kFilterCount = 1;
876   Tensor filter(DT_FLOAT, {kFilterSize, kFilterSize, depth, kFilterCount});
877   test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
878 
879   Tensor expected(DT_FLOAT, TensorShape({1, 1, 2, 1}));
880   test::FillValues<float>(&expected, {312, 357});
881 
882   Run<float>(DT_FLOAT, image, filter, expected, true);
883 }
884 
885 // Testing fusion of MatMul and BiasAdd
886 template <typename T>
887 class MklFusedMatMulOpTest : public OpsTestBase {
888  private:
RunMklFusedMatMulOp(const Tensor & input,const Tensor & weight,const std::vector<Tensor> & args,const std::vector<string> & fused_ops,Tensor * output)889   void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight,
890                            const std::vector<Tensor>& args,
891                            const std::vector<string>& fused_ops,
892                            Tensor* output) {
893     DataType dtype = DataTypeToEnum<T>::v();
894     const int num_args = args.size();
895     TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
896                      .Input(FakeInput(dtype))
897                      .Input(FakeInput(dtype))
898                      .Input(FakeInput(num_args, dtype))
899                      .Attr("T", dtype)
900                      .Attr("transpose_a", false)
901                      .Attr("transpose_b", false)
902                      .Attr("num_args", num_args)
903                      .Attr("fused_ops", fused_ops)
904                      .Attr("epsilon", 0.0001)
905                      .Attr("_kernel", "MklNameChangeOp")
906                      .Finalize(node_def()));
907 
908     TF_EXPECT_OK(InitOp());
909 
910     AddInputFromArray<T>(input.shape(), input.flat<T>());
911     AddInputFromArray<T>(weight.shape(), weight.flat<T>());
912     for (const Tensor& arg : args)
913       AddInputFromArray<T>(arg.shape(), arg.flat<T>());
914 
915     TF_ASSERT_OK(RunOpKernel());
916 
917     const Tensor& output_tensor = *GetOutput(0);
918     *output = output_tensor;
919   }
920 
921  protected:
VerifyFusedMatMul(const int kBatch,const int kInputChannel,const int kOutputChannel,const std::vector<string> & fused_ops)922   void VerifyFusedMatMul(const int kBatch, const int kInputChannel,
923                          const int kOutputChannel,
924                          const std::vector<string>& fused_ops) {
925     const FusedMatMulRunner run_default =
926         [this](const Tensor& input, const Tensor& weight, const Tensor& bias,
927                const std::vector<string>& fused_ops, Tensor* output) {
928           auto root = tensorflow::Scope::NewRootScope();
929           auto input_op =
930               ops::Const(root.WithOpName("input"), Input::Initializer(input));
931           Output next_op = ops::MatMul(root.WithOpName("matmul"), input_op,
932                                        ops::Const(root.WithOpName("weight"),
933                                                   Input::Initializer(weight)));
934 
935           string last_op = "";
936           if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") !=
937               fused_ops.end()) {
938             last_op = "with_bias";
939             next_op = ops::BiasAdd(
940                 root.WithOpName(last_op), next_op,
941                 ops::Const(root.WithOpName("bias"), Input::Initializer(bias)));
942           }
943 
944           if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") !=
945               fused_ops.end()) {
946             last_op = "with_relu";
947             next_op = ops::Relu(root.WithOpName(last_op), next_op);
948           }
949 
950           if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") !=
951               fused_ops.end()) {
952             last_op = "with_relu6";
953             next_op = ops::Relu6(root.WithOpName(last_op), next_op);
954           }
955 
956           if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") !=
957               fused_ops.end()) {
958             last_op = "with_elu";
959             next_op = ops::Elu(root.WithOpName(last_op), next_op);
960           }
961 
962           if (std::find(fused_ops.begin(), fused_ops.end(), "Tanh") !=
963               fused_ops.end()) {
964             last_op = "with_tanh";
965             next_op = ops::Tanh(root.WithOpName(last_op), next_op);
966           }
967 
968           if (std::find(fused_ops.begin(), fused_ops.end(), "Sigmoid") !=
969               fused_ops.end()) {
970             last_op = "with_Sigmoid";
971             next_op = ops::Sigmoid(root.WithOpName(last_op), next_op);
972           }
973 
974           if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
975               fused_ops.end()) {
976             last_op = "with_add";
977             next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
978           }
979 
980           if (std::find(fused_ops.begin(), fused_ops.end(), "LeakyRelu") !=
981               fused_ops.end()) {
982             last_op = "with_leakyrelu";
983             next_op =
984                 ops::internal::LeakyRelu(root.WithOpName(last_op), next_op);
985           }
986 
987           CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
988         };
989 
990     const FusedMatMulRunner run_fused =
991         [this](const Tensor& input, const Tensor& weight, const Tensor& bias,
992                const std::vector<string>& fused_ops, Tensor* output) {
993           std::vector<Tensor> fused_input = {bias};
994           if (std::find(fused_ops.begin(), fused_ops.end(), "Add") !=
995               fused_ops.end()) {
996             fused_input.push_back(input);
997           }
998           RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output);
999         };
1000 
1001     CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch,
1002                                                    kOutputChannel, fused_ops,
1003                                                    run_default, run_fused);
1004   }
1005 };
1006 
1007 TYPED_TEST_SUITE_P(MklFusedMatMulOpTest);
1008 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBias)1009 TYPED_TEST_P(MklFusedMatMulOpTest, WithBias) {
1010   const int batch = 3;
1011   const int input_channel = 4;
1012   const int output_channel = 5;
1013 
1014   this->VerifyFusedMatMul(batch, input_channel, output_channel, {"BiasAdd"});
1015 }
1016 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndRelu)1017 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndRelu) {
1018   const int batch = 3;
1019   const int input_channel = 4;
1020   const int output_channel = 5;
1021 
1022   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1023                           {"BiasAdd", "Relu"});
1024 }
1025 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndRelu6)1026 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndRelu6) {
1027   const int batch = 3;
1028   const int input_channel = 4;
1029   const int output_channel = 5;
1030 
1031   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1032                           {"BiasAdd", "Relu6"});
1033 }
1034 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndElu)1035 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndElu) {
1036   const int batch = 3;
1037   const int input_channel = 4;
1038   const int output_channel = 5;
1039 
1040   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1041                           {"BiasAdd", "Elu"});
1042 }
1043 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndTanh)1044 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
1045   const int batch = 3;
1046   const int input_channel = 4;
1047   const int output_channel = 5;
1048 
1049   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1050                           {"BiasAdd", "Tanh"});
1051 }
1052 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndSigmoid)1053 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndSigmoid) {
1054   const int batch = 3;
1055   const int input_channel = 4;
1056   const int output_channel = 5;
1057 
1058   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1059                           {"BiasAdd", "Sigmoid"});
1060 }
1061 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndAdd)1062 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
1063   const int batch = 3;
1064   const int input_channel = 4;
1065   const int output_channel = 4;
1066 
1067   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1068                           {"BiasAdd", "Add"});
1069 }
1070 
TYPED_TEST_P(MklFusedMatMulOpTest,WithBiasAndLeakyRelu)1071 TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndLeakyRelu) {
1072   const int batch = 3;
1073   const int input_channel = 4;
1074   const int output_channel = 5;
1075 
1076   this->VerifyFusedMatMul(batch, input_channel, output_channel,
1077                           {"BiasAdd", "LeakyRelu"});
1078 }
1079 
1080 REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest,  //
1081                             WithBias,              //
1082                             WithBiasAndRelu,       //
1083                             WithBiasAndRelu6,      //
1084                             WithBiasAndElu,        //
1085                             WithBiasAndTanh,       //
1086                             WithBiasAndLeakyRelu,  //
1087                             WithBiasAndSigmoid,    //
1088                             WithBiasAndAdd);
1089 
1090 using MklFusedMatMulDataTypes = ::testing::Types<float>;
1091 INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
1092                                MklFusedMatMulDataTypes);
1093 
1094 // Test the correctness of MklFusedMatMul weight cache.
1095 // Weight is cached only when the input filter (weight) is constant.
1096 class MklFusedMatMulCacheTest : public OpsTestBase {
1097  public:
Run(const bool is_filter_const)1098   void Run(const bool is_filter_const) {
1099     const int num_args = 1;
1100     const std::vector<string>& fused_ops = {"BiasAdd"};
1101 
1102     TF_ASSERT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul")
1103                      .Input(FakeInput(DT_FLOAT))
1104                      .Input(FakeInput(DT_FLOAT))
1105                      .Input(FakeInput(num_args, DT_FLOAT))
1106                      .Attr("T", DT_FLOAT)
1107                      .Attr("transpose_a", false)
1108                      .Attr("transpose_b", false)
1109                      .Attr("num_args", num_args)
1110                      .Attr("is_filter_const", is_filter_const)
1111                      .Attr("fused_ops", fused_ops)
1112                      .Attr("epsilon", 0.0001)
1113                      .Attr("_kernel", "MklNameChangeOp")
1114                      .Finalize(node_def()));
1115 
1116     TF_EXPECT_OK(InitOp());
1117     // The tensor shape of (1,3) is selected to allow the oneDNN expected
1118     // weight format to be made as OI rather than IO for BS > 1
1119     // A matrix is:
1120     // |  1 |  2 |  3 |
1121     AddInputFromArray<float>(TensorShape({1, 3}), {1, 2, 3});
1122     // B matrix is:
1123     // |  7 |  8 |  9 | 10 |
1124     // | 11 | 12 | 13 | 14 |
1125     // | 15 | 16 | 17 | 18 |
1126     AddInputFromArray<float>(TensorShape({3, 4}),
1127                              {7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18});
1128     // Bias vector.
1129     AddInputFromArray<float>(TensorShape({4}), {1, 2, 3, 4});
1130 
1131     using KernelType = MklDnnMatMulOpBase<float, float>;
1132     // Before the first time kernel execution, weight should be empty
1133     EXPECT_TRUE(static_cast<KernelType*>(this->kernel_.get())
1134                     ->IsWeightCacheEmpty(this->context_.get()));
1135 
1136     TF_ASSERT_OK(RunOpKernel());
1137 
1138     // Final result after Bias addition:
1139     // | 75 | 82 | 89 | 96 |
1140     Tensor expected(DT_FLOAT, TensorShape({1, 4}));
1141     test::FillValues<float>(&expected, {75, 82, 89, 96});
1142 
1143     const Tensor& output = *GetOutput(0);
1144     CommonTestUtilities<float> test_util;
1145     test::ExpectTensorNear<float>(expected, output, 1e-5);
1146 
1147     // After the first time kernel execution, the weight will be cached
1148     // if is_filter_const is true; otherwise the weight caching is empty.
1149     EXPECT_TRUE(static_cast<KernelType*>(this->kernel_.get())
1150                     ->IsWeightCacheEmpty(this->context_.get()) !=
1151                 is_filter_const);
1152   }
1153 };
1154 
1155 // Test that a const filter can be cached.
TEST_F(MklFusedMatMulCacheTest,WeightCachedTrue)1156 TEST_F(MklFusedMatMulCacheTest, WeightCachedTrue) { Run(true); }
1157 
1158 // Test that a non-const filter can not be cached.
TEST_F(MklFusedMatMulCacheTest,WeightCachedFalse)1159 TEST_F(MklFusedMatMulCacheTest, WeightCachedFalse) { Run(false); }
1160 
1161 class BiasCacheTest : public OpsTestBase {
1162  public:
1163   template <typename T>
Run(DataType dtype,Tensor & image,Tensor & filter,Tensor & bias,Tensor & min_input,Tensor & max_input,Tensor & min_filter,Tensor & max_filter,Tensor & min_output,Tensor & max_output,Tensor & expected,const bool is_filter_const)1164   void Run(DataType dtype, Tensor& image, Tensor& filter, Tensor& bias,
1165            Tensor& min_input, Tensor& max_input, Tensor& min_filter,
1166            Tensor& max_filter, Tensor& min_output, Tensor& max_output,
1167            Tensor& expected, const bool is_filter_const) {
1168     const int stride = 1;
1169 
1170     TF_EXPECT_OK(
1171         NodeDefBuilder("quantized_conv2d_bias_cache",
1172                        "_MklQuantizedConv2DWithBiasAndReluAndRequantize")
1173             .Input(FakeInput(dtype))     // Input
1174             .Input(FakeInput(DT_QINT8))  // Filter
1175             .Input(FakeInput(DT_FLOAT))  // Bias
1176             .Input(FakeInput(DT_FLOAT))  // Min-input
1177             .Input(FakeInput(DT_FLOAT))  // Max-input
1178             .Input(FakeInput(DT_FLOAT))  // Min-filter
1179             .Input(FakeInput(DT_FLOAT))  // Max-filter
1180             .Input(FakeInput(DT_FLOAT))  // Min-output
1181             .Input(FakeInput(DT_FLOAT))  // Max-output
1182             .Attr("Tinput", DT_QUINT8)
1183             .Attr("Tfilter", DT_QINT8)
1184             .Attr("Tbias", DT_FLOAT)
1185             .Attr("out_type", DT_QUINT8)
1186             .Attr("data_format", "NHWC")
1187             .Attr("strides", {1, stride, stride, 1})
1188             .Attr("is_filter_const", is_filter_const)
1189             .Attr("is_bias_const", true)
1190             .Attr("padding", "VALID")
1191             .Attr("_kernel", "QuantizedMklOp")
1192             .Finalize(node_def()));
1193     TF_EXPECT_OK(InitOp());
1194 
1195     // Setting up inputs and execute
1196     AddInputFromArray<quint8>(image.shape(), image.flat<quint8>());
1197     AddInputFromArray<qint8>(filter.shape(), filter.flat<qint8>());
1198     AddInputFromArray<float>(bias.shape(), bias.flat<float>());
1199     AddInputFromArray<float>(min_input.shape(), min_input.flat<float>());
1200     AddInputFromArray<float>(max_input.shape(), max_input.flat<float>());
1201     AddInputFromArray<float>(min_filter.shape(), min_filter.flat<float>());
1202     AddInputFromArray<float>(max_filter.shape(), max_filter.flat<float>());
1203     AddInputFromArray<float>(min_output.shape(), min_output.flat<float>());
1204     AddInputFromArray<float>(max_output.shape(), max_output.flat<float>());
1205 
1206     TF_ASSERT_OK(RunOpKernel());
1207 
1208     // Compare outputs to expected results
1209     const Tensor& output = *GetOutput(0);
1210     test::ExpectTensorEqual<quint8>(expected, output);
1211 
1212     // TODO(intel-tf): For now, we rely on internal performance tests to
1213     // determine if filter data is being cached and reused.
1214     // However, we still need to add a check here to determine if this is
1215     // still the case by inspecting the contents of the persistent tensor.
1216     TF_ASSERT_OK(RunOpKernel());
1217 
1218     // Compare output to expected results
1219     const Tensor& output_new = *GetOutput(0);
1220     test::ExpectTensorEqual<quint8>(expected, output_new);
1221   }
1222 };
1223 
TEST_F(BiasCacheTest,Conv2DBiasCacheTest)1224 TEST_F(BiasCacheTest, Conv2DBiasCacheTest) {
1225   const int depth = 1;
1226   const int image_width = 4;
1227   const int image_height = 3;
1228   const int image_batch_count = 1;
1229 
1230   Tensor image(DT_QUINT8,
1231                {image_batch_count, image_height, image_width, depth});
1232   test::FillValues<quint8>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
1233 
1234   const int kFilterSize = 3;
1235   const int kFilterCount = 1;
1236   Tensor filter(DT_QINT8, {kFilterSize, kFilterSize, depth, kFilterCount});
1237   test::FillValues<qint8>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
1238 
1239   Tensor bias(DT_FLOAT, {kFilterCount});
1240   test::FillValues<float>(&bias, {1});
1241 
1242   Tensor min_input(DT_FLOAT, {1});
1243   test::FillValues<float>(&min_input, {1});
1244 
1245   Tensor max_input(DT_FLOAT, {1});
1246   test::FillValues<float>(&max_input, {1});
1247 
1248   Tensor min_filter(DT_FLOAT, {1});
1249   test::FillValues<float>(&min_filter, {1});
1250 
1251   Tensor max_filter(DT_FLOAT, {1});
1252   test::FillValues<float>(&max_filter, {1});
1253 
1254   Tensor min_output(DT_FLOAT, {1});
1255   test::FillValues<float>(&min_output, {1});
1256 
1257   Tensor max_output(DT_FLOAT, {1});
1258   test::FillValues<float>(&max_output, {1});
1259 
1260   Tensor expected(DT_QUINT8, TensorShape({1, 1, 2, 1}));
1261   test::FillValues<quint8>(&expected, {255, 255});
1262 
1263   Run<float>(DT_QUINT8, image, filter, bias, min_input, max_input, min_filter,
1264              max_filter, min_output, max_output, expected, true);
1265 }
1266 
1267 // Testing fusion of pad and fusedconv2d
1268 template <typename T>
1269 class MklPadWithFusedConv2DOpTest : public OpsTestBase {
1270  protected:
1271   static constexpr int kDepth = 3;
1272   static constexpr int kImageWidth = 30;
1273   static constexpr int kImageHeight = 28;
1274   static constexpr int kImageBatchCount = 8;
1275 
1276   // 0: top pad, 1: bottom pad, 2: left pad, 3: right pad
1277   int padding_list_[4];
1278 
1279   // Verifies that computing Pad+Conv2D+BiasAdd in a graph is identical to
1280   // FusedConv2D.
VerifyPadAndConv2DWithBias(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)1281   void VerifyPadAndConv2DWithBias(int filter_size, int filter_count,
1282                                   int depth = kDepth,
1283                                   int image_width = kImageWidth,
1284                                   int image_height = kImageHeight,
1285                                   int image_batch_count = kImageBatchCount) {
1286     const BiasAddGraphRunner run_default = [this](const Tensor& input_data,
1287                                                   const Tensor& filter_data,
1288                                                   const Tensor& bias_data,
1289                                                   Tensor* out) {
1290       RunMklPadWithFusedConv2DAndBias(input_data, filter_data, bias_data, out);
1291     };
1292 
1293     const BiasAddGraphRunner run_fused =
1294         [this](const Tensor& input_data, const Tensor& filter_data,
1295                const Tensor& bias_data, Tensor* out) {
1296           RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data},
1297                                      {"BiasAdd"}, out);
1298         };
1299 
1300     CommonTestUtilities<T>::VerifyBiasAddTensorsClose(
1301         depth, image_width, image_height, image_batch_count, filter_size,
1302         filter_count, run_default, run_fused);
1303   }
1304 
1305   // Verifies that computing Pad+Conv2D+BiasAdd+Relu in a graph is identical to
1306   // FusedConv2D.
VerifyPadAndConv2DWithBiasRelu(int filter_size,int filter_count,int depth=kDepth,int image_width=kImageWidth,int image_height=kImageHeight,int image_batch_count=kImageBatchCount)1307   void VerifyPadAndConv2DWithBiasRelu(
1308       int filter_size, int filter_count, int depth = kDepth,
1309       int image_width = kImageWidth, int image_height = kImageHeight,
1310       int image_batch_count = kImageBatchCount) {
1311     const BiasAddGraphRunner run_default =
1312         [this](const Tensor& input_data, const Tensor& filter_data,
1313                const Tensor& bias_data, Tensor* out) {
1314           RunMklPadWithFusedConv2DAndBiasRelu(input_data, filter_data,
1315                                               bias_data, out);
1316         };
1317 
1318     const BiasAddGraphRunner run_fused =
1319         [this](const Tensor& input_data, const Tensor& filter_data,
1320                const Tensor& bias_data, Tensor* out) {
1321           RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data},
1322                                      {"BiasAdd", "Relu"}, out);
1323         };
1324 
1325     CommonTestUtilities<T>::VerifyBiasAddTensorsClose(
1326         depth, image_width, image_height, image_batch_count, filter_size,
1327         filter_count, run_default, run_fused);
1328   }
1329 
RunMklPadWithFusedConv2DAndBias(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,Tensor * output,int stride=1)1330   void RunMklPadWithFusedConv2DAndBias(const Tensor& input_data,
1331                                        const Tensor& filter_data,
1332                                        const Tensor& bias_data, Tensor* output,
1333                                        int stride = 1) {
1334     auto root = tensorflow::Scope::NewRootScope();
1335 
1336     // FusedConv2D only supports NHWC format so we use NHWC here.
1337     auto padding = ops::Const(root.WithOpName("padding"),
1338                               {0, 0, padding_list_[0], padding_list_[1],
1339                                padding_list_[2], padding_list_[3], 0, 0},
1340                               {4, 2});
1341     auto pad = ops::Pad(
1342         root.WithOpName("pad"),
1343         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
1344         padding);
1345 
1346     auto conv = ops::Conv2D(
1347         root.WithOpName("conv"), pad,
1348         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
1349         {1, stride, stride, 1}, "VALID");
1350 
1351     auto with_bias = ops::BiasAdd(
1352         root.WithOpName("with_bias"), conv,
1353         ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
1354 
1355     CommonTestUtilities<T>::RunAndFetch(root, "with_bias", output);
1356   }
1357 
RunMklPadWithFusedConv2DAndBiasRelu(const Tensor & input_data,const Tensor & filter_data,const Tensor & bias_data,Tensor * output,int stride=1)1358   void RunMklPadWithFusedConv2DAndBiasRelu(const Tensor& input_data,
1359                                            const Tensor& filter_data,
1360                                            const Tensor& bias_data,
1361                                            Tensor* output, int stride = 1) {
1362     auto root = tensorflow::Scope::NewRootScope();
1363 
1364     // FusedConv2D only supports NHWC format so we use NHWC here.
1365     auto padding = ops::Const(root.WithOpName("padding"),
1366                               {0, 0, padding_list_[0], padding_list_[1],
1367                                padding_list_[2], padding_list_[3], 0, 0},
1368                               {4, 2});
1369     auto pad = ops::Pad(
1370         root.WithOpName("pad"),
1371         ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
1372         padding);
1373 
1374     auto conv = ops::Conv2D(
1375         root.WithOpName("conv"), pad,
1376         ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
1377         {1, stride, stride, 1}, "VALID");
1378 
1379     auto with_bias = ops::BiasAdd(
1380         root.WithOpName("with_bias"), conv,
1381         ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
1382 
1383     auto with_relu = ops::Relu(root.WithOpName("with_relu"), with_bias);
1384 
1385     CommonTestUtilities<T>::RunAndFetch(root, "with_relu", output);
1386   }
1387 
RunMklFusedConv2DWithPadOp(const Tensor & image,const Tensor & filter,const std::vector<Tensor> & args,const std::vector<string> & fused_ops,Tensor * output,int stride=1)1388   void RunMklFusedConv2DWithPadOp(const Tensor& image, const Tensor& filter,
1389                                   const std::vector<Tensor>& args,
1390                                   const std::vector<string>& fused_ops,
1391                                   Tensor* output, int stride = 1) {
1392     DataType dtype = DataTypeToEnum<T>::v();
1393     const int num_args = static_cast<int>(args.size());
1394     Tensor padding(DT_INT32, {4, 2});
1395     test::FillValues<int32>(
1396         &padding, {0, 0, padding_list_[0], padding_list_[1], padding_list_[2],
1397                    padding_list_[3], 0, 0});
1398 
1399     TF_EXPECT_OK(
1400         NodeDefBuilder("pad_fused_conv_op", "_MklNativePadWithFusedConv2D")
1401             .Input(FakeInput(dtype))
1402             .Input(FakeInput(dtype))
1403             .Input(FakeInput(num_args, dtype))
1404             .Input(FakeInput(DT_INT32))
1405             .Attr("T", dtype)
1406             .Attr("num_args", num_args)
1407             .Attr("strides", {1, stride, stride, 1})
1408             .Attr("padding", "VALID")
1409             .Attr("fused_ops", fused_ops)
1410             .Attr("_kernel", "MklNameChangeOp")
1411             .Finalize(node_def()));
1412 
1413     TF_EXPECT_OK(InitOp());
1414 
1415     AddInputFromArray<T>(image.shape(), image.flat<T>());
1416     AddInputFromArray<T>(filter.shape(), filter.flat<T>());
1417     for (const Tensor& arg : args)
1418       AddInputFromArray<T>(arg.shape(), arg.flat<T>());
1419     AddInputFromArray<int32>(padding.shape(), padding.flat<int32>());
1420     TF_ASSERT_OK(RunOpKernel());
1421 
1422     // Compare output to expected results
1423     const Tensor& output_tensor = *GetOutput(0);
1424     CommonTestUtilities<T> test_util;
1425     *output = output_tensor;
1426   }
1427 
1428  public:
SetPaddingList(int top,int bottom,int left,int right)1429   void SetPaddingList(int top, int bottom, int left, int right) {
1430     padding_list_[0] = top;
1431     padding_list_[1] = bottom;
1432     padding_list_[2] = left;
1433     padding_list_[3] = right;
1434   }
1435 };
1436 
1437 TYPED_TEST_SUITE_P(MklPadWithFusedConv2DOpTest);
1438 
TYPED_TEST_P(MklPadWithFusedConv2DOpTest,WithBiasAndRoundPad)1439 TYPED_TEST_P(MklPadWithFusedConv2DOpTest, WithBiasAndRoundPad) {
1440   const int kFilterSize = 1;
1441   const int kFilterCount = 12;
1442   this->SetPaddingList(2, 2, 1, 1);
1443   this->VerifyPadAndConv2DWithBias(kFilterSize, kFilterCount);
1444 }
1445 
TYPED_TEST_P(MklPadWithFusedConv2DOpTest,WithBiasAndPartialPad)1446 TYPED_TEST_P(MklPadWithFusedConv2DOpTest, WithBiasAndPartialPad) {
1447   const int kFilterSize = 1;
1448   const int kFilterCount = 12;
1449   this->SetPaddingList(4, 0, 2, 0);
1450   this->VerifyPadAndConv2DWithBias(kFilterSize, kFilterCount);
1451 }
1452 
TYPED_TEST_P(MklPadWithFusedConv2DOpTest,WithBiasReluAndRoundPad)1453 TYPED_TEST_P(MklPadWithFusedConv2DOpTest, WithBiasReluAndRoundPad) {
1454   const int kFilterSize = 1;
1455   const int kFilterCount = 12;
1456   this->SetPaddingList(2, 2, 1, 1);
1457   this->VerifyPadAndConv2DWithBiasRelu(kFilterSize, kFilterCount);
1458 }
1459 
TYPED_TEST_P(MklPadWithFusedConv2DOpTest,WithBiasReluAndPartialPad)1460 TYPED_TEST_P(MklPadWithFusedConv2DOpTest, WithBiasReluAndPartialPad) {
1461   const int kFilterSize = 1;
1462   const int kFilterCount = 12;
1463   this->SetPaddingList(4, 0, 2, 0);
1464   this->VerifyPadAndConv2DWithBiasRelu(kFilterSize, kFilterCount);
1465 }
1466 
1467 REGISTER_TYPED_TEST_SUITE_P(MklPadWithFusedConv2DOpTest,  //
1468                             WithBiasAndRoundPad,          //
1469                             WithBiasAndPartialPad,        //
1470                             WithBiasReluAndRoundPad,      //
1471                             WithBiasReluAndPartialPad);
1472 
1473 using MklPadWithFusedConv2DDataTypes = ::testing::Types<float>;
1474 INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklPadWithFusedConv2DOpTest,
1475                                MklPadWithFusedConv2DDataTypes);
1476 
1477 }  // namespace tensorflow
1478 #endif  // INTEL_MKL
1479