xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if defined(INTEL_MKL) && defined(ENABLE_MKL)
17 #include "tensorflow/cc/ops/nn_ops_internal.h"
18 #include "tensorflow/cc/ops/standard_ops.h"
19 #include "tensorflow/core/framework/tensor_testutil.h"
20 #include "tensorflow/core/grappler/devices.h"
21 #include "tensorflow/core/grappler/grappler_item.h"
22 #include "tensorflow/core/grappler/optimizers/remapper.h"
23 #include "tensorflow/core/grappler/utils/grappler_test.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/util/mkl_util.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 
31 class MklRemapperTest : public GrapplerTest {
32  public:
33   const string kAddNOp = "AddN";
34   const string kAddOp = "Add";
35   const string kAddV2Op = "AddV2";
36 
37  protected:
FuseConv2DWithBiasAndAddNOrAdd(const string & data_format,const string & activation,string add_op,bool add_with_bcast)38   void FuseConv2DWithBiasAndAddNOrAdd(const string& data_format,
39                                       const string& activation, string add_op,
40                                       bool add_with_bcast) {
41     using ::tensorflow::ops::Placeholder;
42 
43     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
44 
45     auto input_shape = (data_format == "NHWC")
46                            ? ops::Placeholder::Shape({8, 32, 32, 3})
47                            : ops::Placeholder::Shape({8, 3, 32, 32});
48     auto input_shape_addn = ops::Placeholder::Shape({});
49     if (data_format == "NHWC") {
50       if (add_with_bcast)
51         input_shape_addn = ops::Placeholder::Shape({128});
52       else
53         input_shape_addn = ops::Placeholder::Shape({8, 32, 32, 128});
54     } else {
55       if (add_with_bcast)
56         input_shape_addn = ops::Placeholder::Shape({32});
57       else
58         input_shape_addn = ops::Placeholder::Shape({8, 128, 32, 32});
59     }
60     auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128});
61     auto bias_shape = ops::Placeholder::Shape({128});
62 
63     auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
64     auto input_addn =
65         Placeholder(s.WithOpName("input_addn"), DT_FLOAT, input_shape_addn);
66     auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
67     auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
68 
69     std::vector<int> strides = {1, 1, 1, 1};
70     auto conv =
71         ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME",
72                     ops::Conv2D::Attrs().DataFormat(data_format));
73     auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias,
74                                  ops::BiasAdd::Attrs().DataFormat(data_format));
75 
76     auto addfetch = [&](::tensorflow::Input addop) {
77       auto activate = s.WithOpName("activation");
78       auto fetch = s.WithOpName("fetch");
79       if (activation == "Relu") {
80         ops::Identity(fetch, ops::Relu(activate, addop));
81       } else if (activation == "Relu6") {
82         ops::Identity(fetch, ops::Relu6(activate, addop));
83       } else if (activation == "Elu") {
84         ops::Identity(fetch, ops::Elu(activate, addop));
85       } else if (activation == "LeakyRelu") {
86         ops::Identity(fetch, ops::internal::LeakyRelu(activate, addop));
87       } else {
88         DCHECK(activation == "None");
89         ops::Identity(fetch, addop);
90       }
91     };
92 
93     if (add_op == kAddNOp) {
94       auto addn = ops::AddN(s.WithOpName(add_op),
95                             std::initializer_list<Input>{input_addn, bias_add});
96       addfetch(addn);
97     } else if (add_op == kAddV2Op) {
98       auto add = ops::AddV2(s.WithOpName(add_op), input_addn, bias_add);
99       addfetch(add);
100     } else {
101       auto add = ops::Add(s.WithOpName(add_op), input_addn, bias_add);
102       addfetch(add);
103     }
104     auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
105         TensorShape(input_shape.shape_.dim_sizes()));
106     auto input_addn_tensor = GenerateRandomTensor<DT_FLOAT>(
107         TensorShape(input_shape_addn.shape_.dim_sizes()));
108     auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
109         TensorShape(filter_shape.shape_.dim_sizes()));
110     auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
111         TensorShape(bias_shape.shape_.dim_sizes()));
112 
113     GrapplerItem item;
114     item.fetch = {"fetch"};
115     item.feed = {{"input", input_tensor},
116                  {"filter", filter_tensor},
117                  {"bias", bias_tensor},
118                  {"input_addn", input_addn_tensor}};
119     TF_CHECK_OK(s.ToGraphDef(&item.graph));
120 
121     // Place all nodes on CPU.
122     for (int i = 0; i < item.graph.node_size(); ++i) {
123       item.graph.mutable_node(i)->set_device("/device:CPU:0");
124     }
125 
126     // Set Rewriter config to AGGRESSIVE so that we can use Placeholder shape
127     // to test that Add with both inputs having same shape get fused with
128     // Conv2D. Setting this config to AGGRESSIVE is not required for the feature
129     // though.
130     Remapper optimizer(RewriterConfig::AGGRESSIVE);
131     GraphDef output;
132     TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
133 
134     bool check_fusion = !add_with_bcast;
135     int found = 0;
136     for (const NodeDef& node : output.node()) {
137       auto fetch_node_name = activation != "None" ? "activation" : add_op;
138       if (node.name() == fetch_node_name) {
139         if (check_fusion) {
140           EXPECT_EQ("_FusedConv2D", node.op());
141           EXPECT_EQ("input", node.input(0));
142           EXPECT_EQ("filter", node.input(1));
143 
144           EXPECT_EQ(2, node.attr().at("num_args").i());
145           EXPECT_EQ("bias", node.input(2));
146           EXPECT_EQ("input_addn", node.input(3));
147 
148           const auto fused_ops = node.attr().at("fused_ops").list().s();
149           if (activation != "None") {
150             EXPECT_EQ(3, fused_ops.size());
151             EXPECT_EQ("BiasAdd", fused_ops[0]);
152             EXPECT_EQ("Add", fused_ops[1]);
153             EXPECT_EQ(activation, fused_ops[2]);
154           } else {
155             EXPECT_EQ(2, fused_ops.size());
156             EXPECT_EQ("BiasAdd", fused_ops[0]);
157             EXPECT_EQ("Add", fused_ops[1]);
158           }
159         } else {
160           if (activation != "None") {
161             EXPECT_EQ(node.op(), activation);
162             ASSERT_EQ(node.input_size(), 1);
163             EXPECT_EQ(node.input(0), add_op);
164           } else {
165             EXPECT_EQ(node.op(), add_op);
166             ASSERT_EQ(node.input_size(), 2);
167           }
168         }
169         found++;
170       }
171     }
172     EXPECT_EQ(1, found);
173 
174     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
175     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
176     EXPECT_EQ(1, tensors_expected.size());
177     EXPECT_EQ(1, tensors.size());
178     // Using relative tolerance since oneDNN could produce different results
179     // when float32 numbers need to be rounded during accumulation.
180     test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
181   }
182 };
183 
184 #define CREATE_CONV2DFUSION_TEST(data_format, addop, activation, bcast)                          \
185   TEST_F(                                                                                        \
186       MklRemapperTest,                                                                           \
187       FuseConv2DWithBiasAnd##addop##_##data_format##_activation##activation##_addbcast##bcast) { \
188     FuseConv2DWithBiasAndAddNOrAdd(#data_format, #activation, #addop, bcast);                    \
189   }
190 
191 #define CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(data_format, addop, bcast) \
192   CREATE_CONV2DFUSION_TEST(data_format, addop, Relu, bcast);               \
193   CREATE_CONV2DFUSION_TEST(data_format, addop, Relu6, bcast);              \
194   CREATE_CONV2DFUSION_TEST(data_format, addop, Elu, bcast);                \
195   CREATE_CONV2DFUSION_TEST(data_format, addop, LeakyRelu, bcast);          \
196   CREATE_CONV2DFUSION_TEST(data_format, addop, None, bcast);
197 
198 #define CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(addop)            \
199   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
200   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false);
201 
202 CREATE_CONV2DFUSION_ADD_NOBCAST_TEST(AddN);
203 
204 #define CREATE_CONV2DFUSION_ADD_BCAST_TEST(addop)              \
205   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, false); \
206   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, false); \
207   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NHWC, addop, true);  \
208   CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST(NCHW, addop, true);
209 
210 CREATE_CONV2DFUSION_ADD_BCAST_TEST(Add);
211 CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2);
212 
213 #undef CREATE_CONV2DFUSION_ADD_NOBCAST_TEST
214 #undef CREATE_CONV2DFUSION_ADD_BCAST_TEST
215 #undef CREATE_CONV2DFUSION_ADD_ACTIVATION_TEST
216 #undef CREATE_CONV2DFUSION_TEST
217 
218 #define REGISTER_TEST(NAME, T, INPUT)                                         \
219   TEST_F(MklRemapperTest, NAME##_##T) {                                       \
220     using ::tensorflow::ops::Placeholder;                                     \
221                                                                               \
222     for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) {       \
223       tensorflow::Scope s = tensorflow::Scope::NewRootScope();                \
224                                                                               \
225       auto input_shape = Placeholder::Shape({8, 32, 32, 3});                  \
226       auto filter_shape = Placeholder::Shape({1, 1, 3, 1});                   \
227       auto bias_shape = Placeholder::Shape({3});                              \
228                                                                               \
229       auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); \
230       auto filter =                                                           \
231           Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);        \
232       auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);    \
233                                                                               \
234       std::vector<int> strides = {1, 1, 1, 1};                                \
235       auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),  \
236                                              input, filter, strides, "SAME"); \
237       auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias);     \
238                                                                               \
239       ops::Identity fetch = [&]() -> ops::Identity {                          \
240         auto activate = s.WithOpName("activation");                           \
241         auto fetch = s.WithOpName("fetch");                                   \
242                                                                               \
243         if (activation == "Relu") {                                           \
244           return ops::Identity(fetch, ops::Relu(activate, bias_add));         \
245         } else if (activation == "Relu6") {                                   \
246           return ops::Identity(fetch, ops::Relu6(activate, bias_add));        \
247         } else if (activation == "Elu") {                                     \
248           return ops::Identity(fetch, ops::Elu(activate, bias_add));          \
249         }                                                                     \
250                                                                               \
251         DCHECK(activation == "None");                                         \
252         return ops::Identity(fetch, bias_add);                                \
253       }();                                                                    \
254                                                                               \
255       auto input_t = GenerateRandomTensor<DT_FLOAT>({8, 32, 32, 3});          \
256       auto filter_t = GenerateRandomTensor<DT_FLOAT>({1, 1, 3, 1});           \
257       auto bias_t = GenerateRandomTensor<DT_FLOAT>({3});                      \
258                                                                               \
259       GrapplerItem item;                                                      \
260       item.fetch = {"fetch"};                                                 \
261       item.feed = {                                                           \
262           {"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};        \
263       TF_CHECK_OK(s.ToGraphDef(&item.graph));                                 \
264                                                                               \
265       for (int i = 0; i < item.graph.node_size(); ++i) {                      \
266         item.graph.mutable_node(i)->set_device("/device:CPU:0");              \
267       }                                                                       \
268                                                                               \
269       Remapper optimizer(RewriterConfig::ON);                                 \
270       GraphDef output;                                                        \
271       TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));                \
272                                                                               \
273       int found = 0;                                                          \
274       for (const NodeDef& node : output.node()) {                             \
275         if (node.name() != "bias_add" && node.name() != "activation")         \
276           continue;                                                           \
277                                                                               \
278         EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");                  \
279         ASSERT_EQ(node.input_size(), 3);                                      \
280         EXPECT_EQ(node.input(0), "input");                                    \
281         EXPECT_EQ(node.input(1), "filter");                                   \
282                                                                               \
283         EXPECT_EQ(node.attr().at("num_args").i(), 1);                         \
284         EXPECT_EQ(node.input(2), "bias");                                     \
285                                                                               \
286         const auto fused_ops = node.attr().at("fused_ops").list().s();        \
287         if (node.name() == "bias_add") {                                      \
288           ASSERT_EQ(fused_ops.size(), 1);                                     \
289           EXPECT_EQ(fused_ops[0], "BiasAdd");                                 \
290           found++;                                                            \
291         }                                                                     \
292         if (node.name() == "activation") {                                    \
293           ASSERT_EQ(fused_ops.size(), 2);                                     \
294           EXPECT_EQ(fused_ops[0], "BiasAdd");                                 \
295           EXPECT_EQ(fused_ops[1], activation);                                \
296           found++;                                                            \
297         }                                                                     \
298       }                                                                       \
299       EXPECT_EQ(found, 1);                                                    \
300                                                                               \
301       auto tensors_expected =                                                 \
302           EvaluateNodes(item.graph, item.fetch, item.feed);                   \
303       ASSERT_EQ(tensors_expected.size(), 1);                                  \
304       auto tensors = EvaluateNodes(output, item.fetch, item.feed);            \
305       ASSERT_EQ(tensors.size(), 1);                                           \
306       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);   \
307     }                                                                         \
308   }
309 REGISTER_TEST_ALL_TYPES(FuseDepthwiseConv2DWithBiasAndActivation);
310 #undef REGISTER_TEST
311 
TEST_F(MklRemapperTest,FuseBatchNormWithRelu)312 TEST_F(MklRemapperTest, FuseBatchNormWithRelu) {
313   using ::tensorflow::ops::Placeholder;
314 
315   for (bool is_training : {true, false}) {
316     for (bool has_side_input : {true, false}) {
317       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
318 
319       const int num_channels = 24;
320 
321       TensorShape channel_shape({num_channels});
322       TensorShape empty_shape({0});
323 
324       auto input =
325           Placeholder(s.WithOpName("input"), DT_FLOAT,
326                       ops::Placeholder::Shape({2, 8, 8, num_channels}));
327       auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT);
328       auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
329       auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
330       auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
331       auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
332 
333       float epsilon = 0.1f;
334       auto fbn =
335           ops::FusedBatchNormV3(s.WithOpName("fused_batch_norm"), input_cast,
336                                 scale, offset, mean, var,
337                                 ops::FusedBatchNormV3::IsTraining(is_training)
338                                     .Epsilon(epsilon)
339                                     .DataFormat("NHWC"));
340 
341       if (has_side_input) {
342         auto side_input =
343             Placeholder(s.WithOpName("side_input"), DT_FLOAT,
344                         ops::Placeholder::Shape({2, 8, 8, num_channels}));
345         auto side_input_cast =
346             ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT);
347         auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
348         auto relu = ops::Relu(s.WithOpName("relu"), add);
349       } else {
350         auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
351       }
352 
353       auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
354       auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
355       auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
356       auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
357                                                                : channel_shape);
358       auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
359                                                               : channel_shape);
360       auto side_input_t =
361           GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
362 
363       GrapplerItem item;
364       item.fetch = {"relu"};
365       if (has_side_input)
366         item.feed = {{"input", input_t},   {"scale", scale_t},
367                      {"offset", offset_t}, {"mean", mean_t},
368                      {"var", var_t},       {"side_input", side_input_t}};
369       else
370         item.feed = {{"input", input_t},
371                      {"scale", scale_t},
372                      {"offset", offset_t},
373                      {"mean", mean_t},
374                      {"var", var_t}};
375       TF_ASSERT_OK(s.ToGraphDef(&item.graph));
376 
377       // Place all nodes on CPU.
378       for (int i = 0; i < item.graph.node_size(); ++i) {
379         item.graph.mutable_node(i)->set_device("/device:CPU:0");
380       }
381 
382       Remapper optimizer(RewriterConfig::AGGRESSIVE);
383       GraphDef output;
384       TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
385 
386       int found = 0;
387       if (has_side_input) {
388         for (const NodeDef& node : output.node()) {
389           if (node.name() == "add") {
390             EXPECT_EQ(node.op(), "Add");
391             ASSERT_EQ(node.input_size(), 2);
392             EXPECT_EQ(node.input(0), "fused_batch_norm");
393             EXPECT_EQ(node.input(1), "side_input_cast");
394             found++;
395           }
396           if (node.name() == "relu") {
397             EXPECT_EQ(node.op(), "Relu");
398             ASSERT_EQ(node.input_size(), 1);
399             EXPECT_EQ(node.input(0), "add");
400             found++;
401           }
402           if (node.name() == "fused_batch_norm") {
403             EXPECT_EQ(node.op(), "FusedBatchNormV3");
404             ASSERT_EQ(node.input_size(), 5);
405             EXPECT_EQ(node.input(0), "input_cast");
406             EXPECT_EQ(node.input(1), "scale");
407             EXPECT_EQ(node.input(2), "offset");
408             EXPECT_EQ(node.input(3), "mean");
409             EXPECT_EQ(node.input(4), "var");
410             found++;
411           }
412         }
413         EXPECT_EQ(found, 3);
414       } else {
415         for (const NodeDef& node : output.node()) {
416           if (node.name() == "relu") {
417             EXPECT_EQ(node.op(), "Identity");
418             ASSERT_EQ(node.input_size(), 1);
419             EXPECT_EQ(node.input(0), "fused_batch_norm");
420             found++;
421           }
422           if (node.name() == "fused_batch_norm") {
423             EXPECT_EQ(node.op(), "_FusedBatchNormEx");
424             ASSERT_EQ(node.input_size(), 5);
425             EXPECT_EQ(node.input(0), "input_cast");
426             EXPECT_EQ(node.input(1), "scale");
427             EXPECT_EQ(node.input(2), "offset");
428             EXPECT_EQ(node.input(3), "mean");
429             EXPECT_EQ(node.input(4), "var");
430 
431             auto attr = node.attr();
432             EXPECT_EQ(attr["num_side_inputs"].i(), 0);
433             EXPECT_EQ(attr["activation_mode"].s(), "Relu");
434             found++;
435           }
436         }
437         EXPECT_EQ(found, 2);
438       }
439 
440       auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
441       ASSERT_EQ(tensors_expected.size(), 1);
442       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
443       ASSERT_EQ(tensors.size(), 1);
444       test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
445     }
446   }
447 }
448 
TEST_F(MklRemapperTest,FuseMatMulWithBiasAddAndAdd)449 TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) {
450   using ::tensorflow::ops::Placeholder;
451 
452   for (const string& add_op : {"BiasAdd", "AddV2", "Add"}) {
453     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
454 
455     auto input_shape = ops::Placeholder::Shape({4, 32});
456     auto input_shape_add = ops::Placeholder::Shape({4, 8});
457     auto filter_shape = ops::Placeholder::Shape({32, 8});
458     auto bias_shape = ops::Placeholder::Shape({8});
459 
460     auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
461     auto input_add =
462         Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add);
463     auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
464     auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
465 
466     auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter);
467     Output bias_add;
468     if (add_op == "BiasAdd")
469       bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
470     else if (add_op == "AddV2")
471       bias_add = ops::AddV2(s.WithOpName("bias_add"), matmul, bias);
472     else if (add_op == "Add")
473       bias_add = ops::Add(s.WithOpName("bias_add"), bias, matmul);
474 
475     auto fetch = s.WithOpName("fetch");
476     auto add = ops::Add(s.WithOpName("add"), bias_add, input_add);
477 
478     ops::Identity(fetch, add);
479 
480     auto input_tensor = GenerateRandomTensor<DT_FLOAT>(
481         TensorShape(input_shape.shape_.dim_sizes()));
482     auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>(
483         TensorShape(input_shape_add.shape_.dim_sizes()));
484     auto filter_tensor = GenerateRandomTensor<DT_FLOAT>(
485         TensorShape(filter_shape.shape_.dim_sizes()));
486     auto bias_tensor = GenerateRandomTensor<DT_FLOAT>(
487         TensorShape(bias_shape.shape_.dim_sizes()));
488 
489     GrapplerItem item;
490     item.fetch = {"fetch"};
491     item.feed = {{"input", input_tensor},
492                  {"filter", filter_tensor},
493                  {"bias", bias_tensor},
494                  {"input_add", input_add_tensor}};
495     TF_CHECK_OK(s.ToGraphDef(&item.graph));
496 
497     // Place all nodes on CPU.
498     for (int i = 0; i < item.graph.node_size(); ++i) {
499       item.graph.mutable_node(i)->set_device("/device:CPU:0");
500     }
501 
502     Remapper optimizer(RewriterConfig::AGGRESSIVE);
503     GraphDef output;
504     TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
505 
506     int found = 0;
507     for (const NodeDef& node : output.node()) {
508       auto fetch_node_name = "add";
509       if (node.name() == fetch_node_name) {
510         EXPECT_EQ("_FusedMatMul", node.op());
511         EXPECT_EQ("input", node.input(0));
512         EXPECT_EQ("filter", node.input(1));
513         EXPECT_EQ(2, node.attr().at("num_args").i());
514         EXPECT_EQ("bias", node.input(2));
515         EXPECT_EQ("input_add", node.input(3));
516 
517         const auto fused_ops = node.attr().at("fused_ops").list().s();
518         EXPECT_EQ(2, fused_ops.size());
519         EXPECT_EQ("BiasAdd", fused_ops[0]);
520         EXPECT_EQ("Add", fused_ops[1]);
521         found++;
522       }
523     }
524     EXPECT_EQ(1, found);
525 
526     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
527     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
528     EXPECT_EQ(1, tensors_expected.size());
529     EXPECT_EQ(1, tensors.size());
530     test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6);
531   }
532 }
533 
534 class RelpaceAddWithBiasAddTest : public GrapplerTest {
535  public:
536   const string kAddOp = "Add";
537   const string kAddV2Op = "AddV2";
538 
539  protected:
540   template <DataType DTYPE>
RelpaceAddWithBiasAddDepthwiseConv2D(const string & add_op)541   void RelpaceAddWithBiasAddDepthwiseConv2D(const string& add_op) {
542     using ::tensorflow::ops::Placeholder;
543 
544     for (const string& activation : {"None", "Relu", "Relu6", "Elu"}) {
545       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
546 
547       auto input_shape = Placeholder::Shape({8, 32, 32, 3});
548       auto filter_shape = Placeholder::Shape({1, 1, 3, 128});
549       auto bias_shape = Placeholder::Shape({128 * 3});
550 
551       auto input = Placeholder(s.WithOpName("input"), DTYPE, input_shape);
552       auto filter = Placeholder(s.WithOpName("filter"), DTYPE, filter_shape);
553       auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
554 
555       std::vector<int> strides = {1, 1, 1, 1};
556       auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"),
557                                              input, filter, strides, "SAME");
558 
559       Output bias_add;
560       if (add_op == kAddV2Op) {
561         bias_add = ops::AddV2(s.WithOpName(add_op), conv, bias);
562       } else {
563         bias_add = ops::Add(s.WithOpName(add_op), bias, conv);
564       }
565 
566       ops::Identity fetch = [&]() -> ops::Identity {
567         auto activate = s.WithOpName("activation");
568         auto fetch = s.WithOpName("fetch");
569 
570         if (activation == "Relu") {
571           return ops::Identity(fetch, ops::Relu(activate, bias_add));
572         } else if (activation == "Relu6") {
573           return ops::Identity(fetch, ops::Relu6(activate, bias_add));
574         } else if (activation == "Elu") {
575           return ops::Identity(fetch, ops::Elu(activate, bias_add));
576         }
577 
578         return ops::Identity(fetch, bias_add);
579       }();
580 
581       auto input_t = GenerateRandomTensor<DTYPE>({8, 32, 32, 3});
582       auto filter_t = GenerateRandomTensor<DTYPE>({1, 1, 3, 128});
583       auto bias_t = GenerateRandomTensor<DTYPE>({128 * 3});
584 
585       GrapplerItem item;
586       item.fetch = {"fetch"};
587       item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}};
588       TF_ASSERT_OK(s.ToGraphDef(&item.graph));
589 
590       // Place all nodes on CPU.
591       for (int i = 0; i < item.graph.node_size(); ++i) {
592         item.graph.mutable_node(i)->set_device("/device:CPU:0");
593       }
594 
595       Remapper optimizer(RewriterConfig::AGGRESSIVE);
596       GraphDef output;
597       TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
598 
599       int found = 0;
600       for (const NodeDef& node : output.node()) {
601         if (node.name() == "activation") {
602           EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
603           ASSERT_GE(node.input_size(), 3);
604           EXPECT_EQ(node.input(0), "input");
605           EXPECT_EQ(node.input(1), "filter");
606           EXPECT_EQ(node.attr().at("num_args").i(), 1);
607           EXPECT_EQ(node.input(2), "bias");
608 
609           const auto fused_ops = node.attr().at("fused_ops").list().s();
610           ASSERT_EQ(fused_ops.size(), 2);
611           EXPECT_EQ(fused_ops[0], "BiasAdd");
612           EXPECT_EQ(fused_ops[1], activation);
613 
614           found++;
615         } else if (node.name() == add_op) {
616           EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative");
617           ASSERT_GE(node.input_size(), 3);
618           EXPECT_EQ(node.input(0), "input");
619           EXPECT_EQ(node.input(1), "filter");
620           EXPECT_EQ(node.attr().at("num_args").i(), 1);
621           EXPECT_EQ(node.input(2), "bias");
622 
623           const auto fused_ops = node.attr().at("fused_ops").list().s();
624           ASSERT_EQ(fused_ops.size(), 1);
625           EXPECT_EQ(fused_ops[0], "BiasAdd");
626           found++;
627         }
628       }
629       EXPECT_EQ(found, 1);
630 
631       auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
632       ASSERT_EQ(tensors_expected.size(), 1);
633       auto tensors = EvaluateNodes(output, item.fetch, item.feed);
634       ASSERT_EQ(tensors.size(), 1);
635 
636       if (DTYPE == DT_BFLOAT16)
637         test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, 1e-2);
638       else
639         test::ExpectClose(tensors[0], tensors_expected[0], 1e-6);
640     }
641   }
642 };
643 
644 #define CREATE_REPLACEADDWITHBIASADD_TEST_1(ops, addop, dtype)              \
645   TEST_F(RelpaceAddWithBiasAddTest, RelpaceAddWithBiasAdd##ops##_##addop) { \
646     RelpaceAddWithBiasAddDepthwiseConv2D<dtype>(#addop);                    \
647   }
648 CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, AddV2, DT_FLOAT);
649 CREATE_REPLACEADDWITHBIASADD_TEST_1(DepthConv2D, Add, DT_FLOAT);
650 
651 class FusedMatMulBiasAddAndGeluTest : public GrapplerTest {
652  public:
653   template <DataType DTYPE>
RunTest()654   void RunTest() {
655     using ::tensorflow::ops::Placeholder;
656 
657     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
658 
659     auto lhs_shape = ops::Placeholder::Shape({8, 32});
660     auto rhs_shape = ops::Placeholder::Shape({32, 64});
661     auto bias_shape = ops::Placeholder::Shape({64});
662 
663     auto lhs = Placeholder(s.WithOpName("lhs"), DTYPE, lhs_shape);
664     auto rhs = Placeholder(s.WithOpName("rhs"), DTYPE, rhs_shape);
665     auto bias = Placeholder(s.WithOpName("bias"), DTYPE, bias_shape);
666 
667     auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
668     auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
669 
670     // Add Gelu approximate with smaller ops
671     auto square_root_one_half =
672         ops::Const(s.WithOpName("square_root_one_half"), {0.707106f}, {});
673     auto bias_add_times_square_root_one_half =
674         ops::Mul(s.WithOpName("bias_add_times_square_root_one_half"), bias_add,
675                  square_root_one_half);
676     auto erf =
677         ops::Erf(s.WithOpName("erf"), bias_add_times_square_root_one_half);
678     auto one = ops::Const(s.WithOpName("one"), {1.0f}, {});
679     auto erf_plus_one = ops::AddV2(s.WithOpName("one_plus_erf"), erf, one);
680     auto one_half = ops::Const(s.WithOpName("one_half"), {0.5f}, {});
681     auto erf_plus_one_times_one_half = ops::Mul(
682         s.WithOpName("erf_plus_one_times_one_half"), erf_plus_one, one_half);
683     auto gelu = ops::Mul(s.WithOpName("fusion_output"),
684                          erf_plus_one_times_one_half, bias_add);
685     auto fetch = ops::Identity(s.WithOpName("fetch"), gelu);
686 
687     auto lhs_t = GenerateTensorWithSetRandom<DTYPE>({8, 32});
688     auto rhs_t = GenerateTensorWithSetRandom<DTYPE>({32, 64});
689     auto bias_t = GenerateTensorWithSetRandom<DTYPE>({64});
690 
691     GrapplerItem item;
692     item.fetch = {"fetch"};
693     item.feed = {{"lhs", lhs_t}, {"rhs", rhs_t}, {"bias", bias_t}};
694     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
695 
696     // Place all nodes on CPU.
697     for (int i = 0; i < item.graph.node_size(); ++i) {
698       item.graph.mutable_node(i)->set_device("/device:CPU:0");
699     }
700 
701     Remapper optimizer(RewriterConfig::ON);
702     GraphDef optimized_graph;
703     TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
704     int found = 0;
705     for (const NodeDef& node : optimized_graph.node()) {
706       if (node.name() == "fusion_output") {
707         EXPECT_EQ(node.op(), "_FusedMatMul");
708         ASSERT_GE(node.input_size(), 3);
709         EXPECT_EQ(node.input(0), "lhs");
710         EXPECT_EQ(node.input(1), "rhs");
711         EXPECT_EQ(node.input(2), "bias");
712         EXPECT_EQ(node.attr().at("num_args").i(), 1);
713         const auto fused_ops = node.attr().at("fused_ops").list().s();
714         ASSERT_EQ(fused_ops.size(), 2);
715         EXPECT_EQ(fused_ops[0], "BiasAdd");
716         EXPECT_EQ(fused_ops[1], "GeluExact");
717         found++;
718       }
719     }
720     EXPECT_EQ(1, found);
721 
722     // Evaluate result without remapper fusion
723     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
724     ASSERT_EQ(tensors_expected.size(), 1);
725 
726     auto tensors_evaluated =
727         EvaluateNodes(optimized_graph, item.fetch, item.feed);
728     ASSERT_EQ(tensors_evaluated.size(), 1);
729     test::ExpectClose(tensors_evaluated[0], tensors_expected[0], 1e-6);
730   }
731 };
732 
733 // Gelu has two implementations (1) exact and (2) approximate. Exact cannot be
734 // used with bfloat16 numeric since the Erf is not supported in bfloat16 yet.
735 // Here gelu-exact is tested for float32 numeric only. Gelu-approximate test
736 // is added in tensorflow/python/grappler/remapper_test.py, since the pattern is
737 // changed by other optimizers before the remapper optimizer.
TEST_F(FusedMatMulBiasAddAndGeluTest,Float32GeluExact)738 TEST_F(FusedMatMulBiasAddAndGeluTest, Float32GeluExact) { RunTest<DT_FLOAT>(); }
739 
740 class MklFusedBatchMatMul : public MklRemapperTest {
741  public:
742   template <typename T>
VerifyFused(bool adjx,bool adjy)743   void VerifyFused(bool adjx, bool adjy) {
744     using ::tensorflow::ops::Placeholder;
745     using normal_generator = Eigen::internal::NormalRandomGenerator<T>;
746 
747     int b0 = 2;
748     int b1 = 2;
749     int m = 32;
750     int k = 16;
751     int n = 64;
752 
753     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
754 
755     auto input_shape =
756         adjx ? TensorShape({b0, b1, k, m}) : TensorShape({b0, b1, m, k});
757     auto weight_shape =
758         adjy ? TensorShape({b0, b1, n, k}) : TensorShape({b0, b1, k, n});
759     auto add_shape = TensorShape({b0, 1, m, n});
760 
761     auto input_placeholder_shape = ops::Placeholder::Shape(input_shape);
762     auto weight_placeholder_shape = ops::Placeholder::Shape(weight_shape);
763     auto add_placeholder_shape = ops::Placeholder::Shape(add_shape);
764 
765     auto input = Placeholder(s.WithOpName("input"), DataTypeToEnum<T>::v(),
766                              input_placeholder_shape);
767     auto weight = Placeholder(s.WithOpName("weight"), DataTypeToEnum<T>::v(),
768                               weight_placeholder_shape);
769     auto addend = Placeholder(s.WithOpName("addend"), DataTypeToEnum<T>::v(),
770                               add_placeholder_shape);
771 
772     auto batchmatmul =
773         ops::BatchMatMulV2(s.WithOpName("batchmatmul"), input, weight,
774                            ops::BatchMatMulV2::Attrs().AdjX(adjx).AdjY(adjy));
775     auto scale_const = ops::Const(s.WithOpName("scale_const"), {0.1f});
776     auto scale =
777         ops::Cast(s.WithOpName("scale"), scale_const, DataTypeToEnum<T>::v());
778     auto mul = ops::Multiply(s.WithOpName("mul"), batchmatmul, scale);
779     auto add = ops::AddV2(s.WithOpName("add"), mul, addend);
780     auto fetch = ops::Identity(s.WithOpName("fetch"), add);
781 
782     Tensor input_t = Tensor(DataTypeToEnum<T>::v(), input_shape);
783     Tensor weight_t = Tensor(DataTypeToEnum<T>::v(), weight_shape);
784     Tensor add_t = Tensor(DataTypeToEnum<T>::v(), add_shape);
785     input_t.flat<T>() =
786         input_t.flat<T>().template setRandom<normal_generator>();
787     weight_t.flat<T>() =
788         weight_t.flat<T>().template setRandom<normal_generator>();
789     add_t.flat<T>() = add_t.flat<T>().template setRandom<normal_generator>();
790 
791     GrapplerItem item;
792     item.fetch = {"fetch"};
793     item.feed = {{"input", input_t}, {"weight", weight_t}, {"addend", add_t}};
794     TF_CHECK_OK(s.ToGraphDef(&item.graph));
795 
796     // Place all nodes on CPU.
797     for (int i = 0; i < item.graph.node_size(); ++i) {
798       item.graph.mutable_node(i)->set_device("/device:CPU:0");
799     }
800 
801     Remapper optimizer(RewriterConfig::ON);
802     GraphDef output;
803     TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output));
804 
805     int found = 0;
806     for (const NodeDef& node : output.node()) {
807       if (node.name() == "add") {
808         EXPECT_EQ("_MklFusedBatchMatMulV2", node.op());
809         EXPECT_EQ("input", node.input(0));
810         EXPECT_EQ("weight", node.input(1));
811         EXPECT_EQ("scale", node.input(2));
812         EXPECT_EQ("addend", node.input(3));
813         const auto fused_ops = node.attr().at("fused_ops").list().s();
814         EXPECT_EQ(2, fused_ops.size());
815         EXPECT_EQ("Mul", fused_ops[0]);
816         found++;
817         EXPECT_EQ("Add", fused_ops[1]);
818         found++;
819       }
820     }
821     EXPECT_EQ(2, found);
822 
823     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
824     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
825     std::is_same<T, float>::value
826         ? test::ExpectClose(tensors_expected[0], tensors[0], 1e-6, 1e-6)
827         : test::ExpectClose(tensors_expected[0], tensors[0], 1e-2, 1e-2);
828   }
829 };
830 
TEST_F(MklFusedBatchMatMul,MulAndAdd)831 TEST_F(MklFusedBatchMatMul, MulAndAdd) {
832   for (const auto adjx : {false, true})
833     for (const auto adjy : {false, true}) {
834       this->VerifyFused<float>(adjx, adjy);
835       this->VerifyFused<bfloat16>(adjx, adjy);
836     }
837 }
838 
839 class MklRemapperSwishTest : public GrapplerTest {
840  protected:
841   template <DataType DTYPE>
RunTest()842   void RunTest() {
843     using ::tensorflow::ops::Placeholder;
844 
845     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
846     auto mul_shape = ops::Placeholder::Shape({64, 64});
847 
848     // We will test four sitations:
849     //  1. y = x * sigmoid(x)
850     //  2. y = sigmoid(x) * x
851     //  3. y = sigmoid(x) * sigmoid(sigmoid(x))
852     //  4. y = sigmoid(sigmoid(x)) * sigmoid(x)
853     auto input = Placeholder(s.WithOpName("input"), DTYPE, mul_shape);
854     auto sigmoid1 = ops::Sigmoid(s.WithOpName("sigmoid1"), input);
855     auto sigmoid2 = ops::Sigmoid(s.WithOpName("sigmoid2"), input);
856     auto sigmoid3_1 = ops::Sigmoid(s.WithOpName("sigmoid3_1"), input);
857     auto sigmoid3_2 = ops::Sigmoid(s.WithOpName("sigmoid3_2"), sigmoid3_1);
858     auto sigmoid4_1 = ops::Sigmoid(s.WithOpName("sigmoid4_1"), input);
859     auto sigmoid4_2 = ops::Sigmoid(s.WithOpName("sigmoid4_2"), sigmoid4_1);
860     auto mul1 = ops::Mul(s.WithOpName("mul1"), input, sigmoid1);
861     auto mul2 = ops::Mul(s.WithOpName("mul2"), sigmoid2, input);
862     auto mul3 = ops::Mul(s.WithOpName("mul3"), sigmoid3_1, sigmoid3_2);
863     auto mul4 = ops::Mul(s.WithOpName("mul4"), sigmoid4_2, sigmoid4_1);
864     auto fetch1 = ops::Identity(s.WithOpName("fetch1"), mul1);
865     auto fetch2 = ops::Identity(s.WithOpName("fetch2"), mul2);
866     auto fetch3 = ops::Identity(s.WithOpName("fetch3"), mul3);
867     auto fetch4 = ops::Identity(s.WithOpName("fetch4"), mul4);
868     auto mul_t = GenerateTensorWithSetRandom<DTYPE>({64, 64});
869 
870     GrapplerItem item;
871     item.fetch = {"fetch1", "fetch2", "fetch3", "fetch4"};
872     item.feed = {{"input", mul_t}};
873     TF_ASSERT_OK(s.ToGraphDef(&item.graph));
874 
875     // Place all nodes on CPU.
876     for (int i = 0; i < item.graph.node_size(); ++i) {
877       item.graph.mutable_node(i)->set_device("/device:CPU:0");
878     }
879 
880     Remapper optimizer(RewriterConfig::ON);
881     GraphDef output;
882     TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
883 
884     int found = 0;
885     for (const NodeDef& node : output.node()) {
886       if (node.name() == "mul1") {
887         EXPECT_EQ(node.op(), "_MklSwish");
888         ASSERT_EQ(node.input_size(), 1);
889         EXPECT_EQ(node.input(0), "input");
890         ++found;
891       }
892       if (node.name() == "mul2") {
893         EXPECT_EQ(node.op(), "_MklSwish");
894         ASSERT_EQ(node.input_size(), 1);
895         EXPECT_EQ(node.input(0), "input");
896         ++found;
897       }
898       // mul3 won't be replaced by swish
899       // Coz of the limitation of patternMatcher with commutative op
900       if (node.name() == "mul4") {
901         EXPECT_EQ(node.op(), "_MklSwish");
902         ASSERT_EQ(node.input_size(), 1);
903         EXPECT_EQ(node.input(0), "sigmoid4_1");
904         ++found;
905       }
906     }
907     EXPECT_EQ(found, 3);
908 
909     auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
910     ASSERT_EQ(tensors_expected.size(), 4);
911     auto tensors = EvaluateNodes(output, item.fetch, item.feed);
912     ASSERT_EQ(tensors.size(), 4);
913     float atol = 1e-6, rtol = 1e-6;
914     if (DTYPE == DT_BFLOAT16) {
915       atol = 1e-2;
916       rtol = 1e-2;
917     }
918     test::ExpectClose(tensors[0], tensors_expected[0], atol, rtol);
919     test::ExpectClose(tensors[1], tensors_expected[1], atol, rtol);
920     test::ExpectClose(tensors[2], tensors_expected[2], atol, rtol);
921     test::ExpectClose(tensors[3], tensors_expected[3], atol, rtol);
922   }
923 };
924 
TEST_F(MklRemapperSwishTest,F32)925 TEST_F(MklRemapperSwishTest, F32) { RunTest<DT_FLOAT>(); }
TEST_F(MklRemapperSwishTest,BF16)926 TEST_F(MklRemapperSwishTest, BF16) { RunTest<DT_BFLOAT16>(); }
927 
928 }  // namespace grappler
929 }  // namespace tensorflow
930 #endif  // INTEL_MKL && ENABLE_MKL
931